diff --git a/prudp_endpoint.go b/prudp_endpoint.go index f070353..87c31bb 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -23,6 +23,7 @@ type PRUDPEndPoint struct { StreamID uint8 DefaultStreamSettings *StreamSettings Connections *MutexMap[string, *PRUDPConnection] + packetHandlers map[uint16]func(packet PRUDPPacketInterface) packetEventHandlers map[string][]func(packet PacketInterface) connectionEndedEventHandlers []func(connection *PRUDPConnection) errorEventHandlers []func(err *Error) @@ -39,6 +40,11 @@ func (pep *PRUDPEndPoint) RegisterServiceProtocol(protocol ServiceProtocol) { pep.OnData(protocol.HandlePacket) } +// RegisterCustomPacketHandler registers a custom handler for a given packet type. Used to override existing handlers or create new ones for custom packet types. +func (pep *PRUDPEndPoint) RegisterCustomPacketHandler(packetType uint16, handler func(packet PRUDPPacketInterface)) { + pep.packetHandlers[packetType] = handler +} + // OnData adds an event handler which is fired when a new DATA packet is received func (pep *PRUDPEndPoint) OnData(handler func(packet PacketInterface)) { pep.on("data", handler) @@ -119,17 +125,10 @@ func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *Soc return } - switch packet.Type() { - case constants.SynPacket: - pep.handleSyn(packet) - case constants.ConnectPacket: - pep.handleConnect(packet) - case constants.DataPacket: - pep.handleData(packet) - case constants.DisconnectPacket: - pep.handleDisconnect(packet) - case constants.PingPacket: - pep.handlePing(packet) + if packetHandler, ok := pep.packetHandlers[packet.Type()]; ok { + packetHandler(packet) + } else { + logger.Warningf("Unhandled packet type %d", packet.Type()) } } @@ -721,14 +720,23 @@ func (pep *PRUDPEndPoint) EnableVerboseRMC(enable bool) { // NewPRUDPEndPoint returns a new PRUDPEndPoint for a server on the provided stream ID func NewPRUDPEndPoint(streamID uint8) *PRUDPEndPoint { - return &PRUDPEndPoint{ + pep := &PRUDPEndPoint{ StreamID: streamID, DefaultStreamSettings: NewStreamSettings(), Connections: NewMutexMap[string, *PRUDPConnection](), + packetHandlers: make(map[uint16]func(packet PRUDPPacketInterface)), packetEventHandlers: make(map[string][]func(PacketInterface)), connectionEndedEventHandlers: make([]func(connection *PRUDPConnection), 0), errorEventHandlers: make([]func(err *Error), 0), ConnectionIDCounter: NewCounter[uint32](0), IsSecureEndPoint: false, } + + pep.packetHandlers[constants.SynPacket] = pep.handleSyn + pep.packetHandlers[constants.ConnectPacket] = pep.handleConnect + pep.packetHandlers[constants.DataPacket] = pep.handleData + pep.packetHandlers[constants.DisconnectPacket] = pep.handleDisconnect + pep.packetHandlers[constants.PingPacket] = pep.handlePing + + return pep } diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 447135f..47eb761 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -100,10 +100,6 @@ func (p *PRUDPPacketV0) decode() error { p.packetType = typeAndFlags & 0xF } - if p.packetType > constants.PingPacket { - return errors.New("Invalid PRUDPv0 packet type") - } - p.sessionID, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv0 session ID. %s", err.Error()) diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index 9a34308..ec6f7e4 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -176,10 +176,6 @@ func (p *PRUDPPacketV1) decodeHeader() error { p.flags = typeAndFlags >> 4 p.packetType = typeAndFlags & 0xF - if p.packetType > constants.PingPacket { - return errors.New("Invalid PRUDPv1 packet type") - } - p.sessionID, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv1 session ID. %s", err.Error())