package nex import ( "fmt" "math/rand" "net" "runtime" "time" ) // Server represents a PRUDP server type Server struct { socket *net.UDPConn clients map[string]*Client genericEventHandles map[string][]func(PacketInterface) prudpV0EventHandles map[string][]func(*PacketV0) prudpV1EventHandles map[string][]func(*PacketV1) accessKey string prudpVersion int nexVersion int prudpProtocolMinorVersion int supportedFunctions int fragmentSize int16 resendTimeout float32 pingTimeout int kerberosPassword string kerberosKeySize int kerberosKeyDerivation int kerberosTicketVersion int connectionIDCounter *Counter } // Listen starts a NEX server on a given address func (server *Server) Listen(address string) { protocol := "udp" udpAddress, err := net.ResolveUDPAddr(protocol, address) if err != nil { panic(err) } socket, err := net.ListenUDP(protocol, udpAddress) if err != nil { panic(err) } server.SetSocket(socket) quit := make(chan struct{}) for i := 0; i < runtime.NumCPU(); i++ { go server.listenDatagram(quit) } logger.Success(fmt.Sprintf("PRUDP server listening on address - %s", udpAddress.String())) server.Emit("Listening", nil) <-quit } func (server *Server) listenDatagram(quit chan struct{}) { err := error(nil) for err == nil { err = server.handleSocketMessage() } quit <- struct{}{} panic(err) } func (server *Server) handleSocketMessage() error { var buffer [64000]byte socket := server.Socket() length, addr, err := socket.ReadFromUDP(buffer[0:]) if err != nil { return err } discriminator := addr.String() if _, ok := server.clients[discriminator]; !ok { newClient := NewClient(addr, server) server.clients[discriminator] = newClient } client := server.clients[discriminator] data := buffer[0:length] var packet PacketInterface if server.PrudpVersion() == 0 { packet, err = NewPacketV0(client, data) } else { packet, err = NewPacketV1(client, data) } if err != nil { return nil } client.IncreasePingTimeoutTime(server.PingTimeout()) if packet.HasFlag(FlagAck) || packet.HasFlag(FlagMultiAck) { return nil } if packet.HasFlag(FlagNeedsAck) { if packet.Type() != ConnectPacket || (packet.Type() == ConnectPacket && len(packet.Payload()) <= 0) { go server.AcknowledgePacket(packet, nil) } } switch packet.Type() { case SynPacket: client.Reset() client.SetConnected(true) client.StartTimeoutTimer() server.Emit("Syn", packet) case ConnectPacket: packet.Sender().SetClientConnectionSignature(packet.ConnectionSignature()) server.Emit("Connect", packet) case DataPacket: server.Emit("Data", packet) case DisconnectPacket: server.Emit("Disconnect", packet) server.Kick(client) case PingPacket: //server.SendPing(client) server.Emit("Ping", packet) } server.Emit("Packet", packet) return nil } // On sets the data event handler func (server *Server) On(event string, handler interface{}) { // Check if the handler type matches one of the allowed types, and store the handler in it's allowed property // Need to cast the handler to the correct function type before storing switch handler := handler.(type) { case func(PacketInterface): server.genericEventHandles[event] = append(server.genericEventHandles[event], handler) case func(*PacketV0): server.prudpV0EventHandles[event] = append(server.prudpV0EventHandles[event], handler) case func(*PacketV1): server.prudpV1EventHandles[event] = append(server.prudpV1EventHandles[event], handler) } } // Emit runs the given event handle func (server *Server) Emit(event string, packet interface{}) { eventName := server.genericEventHandles[event] for i := 0; i < len(eventName); i++ { handler := eventName[i] packet := packet.(PacketInterface) go handler(packet) } // Check if the packet type matches one of the allowed types and run the given handler switch packet.(type) { case *PacketV0: eventName := server.prudpV0EventHandles[event] for i := 0; i < len(eventName); i++ { handler := eventName[i] go handler(packet.(*PacketV0)) } case *PacketV1: eventName := server.prudpV1EventHandles[event] for i := 0; i < len(eventName); i++ { handler := eventName[i] go handler(packet.(*PacketV1)) } } } // ClientConnected checks if a given client is stored on the server func (server *Server) ClientConnected(client *Client) bool { discriminator := client.Address().String() _, connected := server.clients[discriminator] return connected } // Kick removes a client from the server func (server *Server) Kick(client *Client) { // Server events expect a packet to be passed, even though this isn't really a packet event var packet PacketInterface if server.PrudpVersion() == 0 { packet, _ = NewPacketV0(client, nil) } else { packet, _ = NewPacketV1(client, nil) } server.Emit("Kick", packet) client.SetConnected(false) discriminator := client.Address().String() delete(server.clients, discriminator) } // SendPing sends a ping packet to the given client func (server *Server) SendPing(client *Client) { var pingPacket PacketInterface if server.PrudpVersion() == 0 { pingPacket, _ = NewPacketV0(client, nil) } else { pingPacket, _ = NewPacketV1(client, nil) } pingPacket.SetSource(0xA1) pingPacket.SetDestination(0xAF) pingPacket.SetType(PingPacket) pingPacket.AddFlag(FlagNeedsAck) pingPacket.AddFlag(FlagReliable) server.Send(pingPacket) } // AcknowledgePacket acknowledges that the given packet was recieved func (server *Server) AcknowledgePacket(packet PacketInterface, payload []byte) { sender := packet.Sender() var ackPacket PacketInterface if server.PrudpVersion() == 0 { ackPacket, _ = NewPacketV0(sender, nil) } else { ackPacket, _ = NewPacketV1(sender, nil) } ackPacket.SetSource(packet.Destination()) ackPacket.SetDestination(packet.Source()) ackPacket.SetType(packet.Type()) ackPacket.SetSequenceID(packet.SequenceID()) ackPacket.SetFragmentID(packet.FragmentID()) ackPacket.AddFlag(FlagAck) ackPacket.AddFlag(FlagHasSize) if payload != nil { ackPacket.SetPayload(payload) } if server.PrudpVersion() == 1 { packet := packet.(*PacketV1) ackPacket := ackPacket.(*PacketV1) ackPacket.SetVersion(1) ackPacket.SetSubstreamID(0) ackPacket.AddFlag(FlagHasSize) if packet.Type() == SynPacket || packet.Type() == ConnectPacket { ackPacket.SetPRUDPProtocolMinorVersion(packet.sender.PRUDPProtocolMinorVersion()) ackPacket.SetSupportedFunctions(packet.sender.SupportedFunctions()) ackPacket.SetMaximumSubstreamID(0) } if packet.Type() == SynPacket { serverConnectionSignature := make([]byte, 16) rand.Read(serverConnectionSignature) ackPacket.Sender().SetServerConnectionSignature(serverConnectionSignature) ackPacket.SetConnectionSignature(serverConnectionSignature) } if packet.Type() == ConnectPacket { ackPacket.SetConnectionSignature(make([]byte, 16)) ackPacket.SetInitialSequenceID(10000) } if packet.Type() == DataPacket { // Aggregate acknowledgement ackPacket.ClearFlag(FlagAck) ackPacket.AddFlag(FlagMultiAck) payloadStream := NewStreamOut(server) // New version if server.PRUDPProtocolMinorVersion() >= 2 { ackPacket.SetSequenceID(0) ackPacket.SetSubstreamID(1) // I'm lazy so just ack one packet payloadStream.WriteUInt8(0) // substream ID payloadStream.WriteUInt8(0) // length of additional sequence ids payloadStream.WriteUInt16LE(packet.SequenceID()) // Sequence id } ackPacket.SetPayload(payloadStream.Bytes()) } } data := ackPacket.Bytes() server.SendRaw(sender.Address(), data) } // Socket returns the underlying server UDP socket func (server *Server) Socket() *net.UDPConn { return server.socket } // SetSocket sets the underlying UDP socket func (server *Server) SetSocket(socket *net.UDPConn) { server.socket = socket } // PrudpVersion returns the server PRUDP version func (server *Server) PrudpVersion() int { return server.prudpVersion } // SetPrudpVersion sets the server PRUDP version func (server *Server) SetPrudpVersion(prudpVersion int) { server.prudpVersion = prudpVersion } // NexVersion returns the server NEX version func (server *Server) NexVersion() int { return server.nexVersion } // SetNexVersion sets the server NEX version func (server *Server) SetNexVersion(nexVersion int) { server.nexVersion = nexVersion } // PRUDPProtocolMinorVersion returns the server PRUDP minor version func (server *Server) PRUDPProtocolMinorVersion() int { return server.prudpProtocolMinorVersion } // SetPRUDPProtocolMinorVersion sets the server PRUDP minor func (server *Server) SetPRUDPProtocolMinorVersion(prudpProtocolMinorVersion int) { server.prudpProtocolMinorVersion = prudpProtocolMinorVersion } // SupportedFunctions returns the supported PRUDP functions by the server func (server *Server) SupportedFunctions() int { return server.supportedFunctions } // SetSupportedFunctions sets the supported PRUDP functions by the server func (server *Server) SetSupportedFunctions(supportedFunctions int) { server.supportedFunctions = supportedFunctions } // AccessKey returns the server access key func (server *Server) AccessKey() string { return server.accessKey } // SetAccessKey sets the server access key func (server *Server) SetAccessKey(accessKey string) { server.accessKey = accessKey } // KerberosPassword returns the server kerberos password func (server *Server) KerberosPassword() string { return server.kerberosPassword } // SetKerberosPassword sets the server kerberos password func (server *Server) SetKerberosPassword(kerberosPassword string) { server.kerberosPassword = kerberosPassword } // KerberosKeySize returns the server kerberos key size func (server *Server) KerberosKeySize() int { return server.kerberosKeySize } // SetKerberosKeySize sets the server kerberos key size func (server *Server) SetKerberosKeySize(kerberosKeySize int) { server.kerberosKeySize = kerberosKeySize } // KerberosTicketVersion returns the server kerberos ticket contents version func (server *Server) KerberosTicketVersion() int { return server.kerberosTicketVersion } // SetKerberosTicketVersion sets the server kerberos ticket contents version func (server *Server) SetKerberosTicketVersion(ticketVersion int) { server.kerberosTicketVersion = ticketVersion } // PingTimeout returns the server ping timeout time in seconds func (server *Server) PingTimeout() int { return server.pingTimeout } // SetPingTimeout sets the server ping timeout time in seconds func (server *Server) SetPingTimeout(pingTimeout int) { server.pingTimeout = pingTimeout } // SetFragmentSize sets the packet fragment size func (server *Server) SetFragmentSize(fragmentSize int16) { server.fragmentSize = fragmentSize } // ConnectionIDCounter gets the server connection ID counter func (server *Server) ConnectionIDCounter() *Counter { return server.connectionIDCounter } // FindClientFromPID finds a client by their PID func (server *Server) FindClientFromPID(pid uint32) *Client { for _, client := range server.clients { if client.pid == pid { return client } } return nil } // FindClientFromConnectionID finds a client by their Connection ID func (server *Server) FindClientFromConnectionID(rvcid uint32) *Client { for _, client := range server.clients { if client.connectionID == rvcid { return client } } return nil } // Send writes data to client func (server *Server) Send(packet PacketInterface) { data := packet.Payload() fragments := int(int16(len(data)) / server.fragmentSize) var fragmentID uint8 = 1 for i := 0; i <= fragments; i++ { time.Sleep(time.Second / 2) if int16(len(data)) < server.fragmentSize { packet.SetPayload(data) server.SendFragment(packet, 0) } else { packet.SetPayload(data[:server.fragmentSize]) server.SendFragment(packet, fragmentID) data = data[server.fragmentSize:] fragmentID++ } } } // SendFragment sends a packet fragment to the client func (server *Server) SendFragment(packet PacketInterface, fragmentID uint8) { data := packet.Payload() client := packet.Sender() packet.SetFragmentID(fragmentID) packet.SetPayload(data) packet.SetSequenceID(uint16(client.SequenceIDCounterOut().Increment())) encodedPacket := packet.Bytes() server.SendRaw(client.Address(), encodedPacket) } // SendRaw writes raw packet data to the client socket func (server *Server) SendRaw(conn *net.UDPAddr, data []byte) { server.Socket().WriteToUDP(data, conn) } // NewServer returns a new NEX server func NewServer() *Server { server := &Server{ genericEventHandles: make(map[string][]func(PacketInterface)), prudpV0EventHandles: make(map[string][]func(*PacketV0)), prudpV1EventHandles: make(map[string][]func(*PacketV1)), clients: make(map[string]*Client), prudpVersion: 1, fragmentSize: 1300, resendTimeout: 1.5, pingTimeout: 5, kerberosKeySize: 32, kerberosKeyDerivation: 0, connectionIDCounter: NewCounter(10), } return server }