nex-go/prudp_packet_v1.go
2024-04-07 23:40:51 +01:00

401 lines
11 KiB
Go

package nex
import (
"bytes"
"crypto/hmac"
"crypto/md5"
"encoding/binary"
"errors"
"fmt"
"net"
"github.com/PretendoNetwork/nex-go/v2/constants"
)
// PRUDPPacketV1 represents a PRUDPv1 packet
type PRUDPPacketV1 struct {
PRUDPPacket
optionsLength uint8
payloadLength uint16
minorVersion uint32
supportedFunctions uint32
maximumSubstreamID uint8
initialUnreliableSequenceID uint16
}
// Copy copies the packet into a new PRUDPPacketV1
//
// Retains the same PRUDPConnection pointer
func (p *PRUDPPacketV1) Copy() PRUDPPacketInterface {
copied, _ := NewPRUDPPacketV1(p.server, p.sender, nil)
copied.server = p.server
copied.sourceVirtualPort = p.sourceVirtualPort
copied.destinationVirtualPort = p.destinationVirtualPort
copied.packetType = p.packetType
copied.flags = p.flags
copied.sessionID = p.sessionID
copied.substreamID = p.substreamID
if p.signature != nil {
copied.signature = append([]byte(nil), p.signature...)
}
copied.sequenceID = p.sequenceID
if p.connectionSignature != nil {
copied.connectionSignature = append([]byte(nil), p.connectionSignature...)
}
copied.fragmentID = p.fragmentID
if p.payload != nil {
copied.payload = append([]byte(nil), p.payload...)
}
if p.message != nil {
copied.message = p.message.Copy()
}
copied.optionsLength = p.optionsLength
copied.payloadLength = p.payloadLength
copied.minorVersion = p.minorVersion
copied.supportedFunctions = p.supportedFunctions
copied.maximumSubstreamID = p.maximumSubstreamID
copied.initialUnreliableSequenceID = p.initialUnreliableSequenceID
return copied
}
// Version returns the packets PRUDP version
func (p *PRUDPPacketV1) Version() int {
return 1
}
// decode parses the packets data
func (p *PRUDPPacketV1) decode() error {
if p.readStream.Remaining() < 2 {
return errors.New("Failed to read PRUDPv1 magic. Not have enough data")
}
magic := p.readStream.ReadBytesNext(2)
if !bytes.Equal(magic, []byte{0xEA, 0xD0}) {
return fmt.Errorf("Invalid PRUDPv1 magic. Expected 0xEAD0, got 0x%x", magic)
}
err := p.decodeHeader()
if err != nil {
return fmt.Errorf("Failed to decode PRUDPv1 header. %s", err.Error())
}
p.signature = p.readStream.ReadBytesNext(16)
err = p.decodeOptions()
if err != nil {
return fmt.Errorf("Failed to decode PRUDPv1 options. %s", err.Error())
}
p.payload = p.readStream.ReadBytesNext(int64(p.payloadLength))
return nil
}
// Bytes encodes a PRUDPv1 packet into a byte slice
func (p *PRUDPPacketV1) Bytes() []byte {
options := p.encodeOptions()
p.optionsLength = uint8(len(options))
header := p.encodeHeader()
stream := NewByteStreamOut(p.server.LibraryVersions, p.server.ByteStreamSettings)
stream.Grow(2)
stream.WriteBytesNext([]byte{0xEA, 0xD0})
stream.Grow(12)
stream.WriteBytesNext(header)
stream.Grow(16)
stream.WriteBytesNext(p.signature)
stream.Grow(int64(p.optionsLength))
stream.WriteBytesNext(options)
stream.Grow(int64(len(p.payload)))
stream.WriteBytesNext(p.payload)
return stream.Bytes()
}
func (p *PRUDPPacketV1) decodeHeader() error {
if p.readStream.Remaining() < 12 {
return errors.New("Failed to read PRUDPv1 magic. Not have enough data")
}
version, err := p.readStream.ReadPrimitiveUInt8()
if err != nil {
return fmt.Errorf("Failed to decode PRUDPv1 version. %s", err.Error())
}
if version != 1 {
return fmt.Errorf("Invalid PRUDPv1 version. Expected 1, got %d", version)
}
p.optionsLength, err = p.readStream.ReadPrimitiveUInt8()
if err != nil {
return fmt.Errorf("Failed to decode PRUDPv1 options length. %s", err.Error())
}
p.payloadLength, err = p.readStream.ReadPrimitiveUInt16LE()
if err != nil {
return fmt.Errorf("Failed to decode PRUDPv1 payload length. %s", err.Error())
}
source, err := p.readStream.ReadPrimitiveUInt8()
if err != nil {
return fmt.Errorf("Failed to read PRUDPv1 source. %s", err.Error())
}
p.sourceVirtualPort = VirtualPort(source)
destination, err := p.readStream.ReadPrimitiveUInt8()
if err != nil {
return fmt.Errorf("Failed to read PRUDPv1 destination. %s", err.Error())
}
p.destinationVirtualPort = VirtualPort(destination)
// TODO - Does QRV also encode it this way in PRUDPv1?
typeAndFlags, err := p.readStream.ReadPrimitiveUInt16LE()
if err != nil {
return fmt.Errorf("Failed to read PRUDPv1 type and flags. %s", err.Error())
}
p.flags = typeAndFlags >> 4
p.packetType = typeAndFlags & 0xF
p.sessionID, err = p.readStream.ReadPrimitiveUInt8()
if err != nil {
return fmt.Errorf("Failed to read PRUDPv1 session ID. %s", err.Error())
}
p.substreamID, err = p.readStream.ReadPrimitiveUInt8()
if err != nil {
return fmt.Errorf("Failed to read PRUDPv1 substream ID. %s", err.Error())
}
p.sequenceID, err = p.readStream.ReadPrimitiveUInt16LE()
if err != nil {
return fmt.Errorf("Failed to read PRUDPv1 sequence ID. %s", err.Error())
}
return nil
}
func (p *PRUDPPacketV1) encodeHeader() []byte {
stream := NewByteStreamOut(p.server.LibraryVersions, p.server.ByteStreamSettings)
stream.WritePrimitiveUInt8(1) // * Version
stream.WritePrimitiveUInt8(p.optionsLength)
stream.WritePrimitiveUInt16LE(uint16(len(p.payload)))
stream.WritePrimitiveUInt8(uint8(p.sourceVirtualPort))
stream.WritePrimitiveUInt8(uint8(p.destinationVirtualPort))
stream.WritePrimitiveUInt16LE(p.packetType | (p.flags << 4)) // TODO - Does QRV also encode it this way in PRUDPv1?
stream.WritePrimitiveUInt8(p.sessionID)
stream.WritePrimitiveUInt8(p.substreamID)
stream.WritePrimitiveUInt16LE(p.sequenceID)
return stream.Bytes()
}
func (p *PRUDPPacketV1) decodeOptions() error {
data := p.readStream.ReadBytesNext(int64(p.optionsLength))
optionsStream := NewByteStreamIn(data, p.server.LibraryVersions, p.server.ByteStreamSettings)
for optionsStream.Remaining() > 0 {
optionID, err := optionsStream.ReadPrimitiveUInt8()
if err != nil {
return err
}
_, err = optionsStream.ReadPrimitiveUInt8() // * Options size. We already know the size based on the ID, though
if err != nil {
return err
}
if p.packetType == constants.SynPacket || p.packetType == constants.ConnectPacket {
if optionID == 0 {
p.supportedFunctions, err = optionsStream.ReadPrimitiveUInt32LE()
p.minorVersion = p.supportedFunctions & 0xFF
p.supportedFunctions = p.supportedFunctions >> 8
}
if optionID == 1 {
p.connectionSignature = optionsStream.ReadBytesNext(16)
}
if optionID == 4 {
p.maximumSubstreamID, err = optionsStream.ReadPrimitiveUInt8()
}
}
if p.packetType == constants.ConnectPacket {
if optionID == 3 {
p.initialUnreliableSequenceID, err = optionsStream.ReadPrimitiveUInt16LE()
}
}
if p.packetType == constants.DataPacket {
if optionID == 2 {
p.fragmentID, err = optionsStream.ReadPrimitiveUInt8()
}
}
// * Only one option is processed at a time, so we can
// * just check for errors here rather than after EVERY
// * read
if err != nil {
return err
}
}
return nil
}
func (p *PRUDPPacketV1) encodeOptions() []byte {
optionsStream := NewByteStreamOut(p.server.LibraryVersions, p.server.ByteStreamSettings)
if p.packetType == constants.SynPacket || p.packetType == constants.ConnectPacket {
optionsStream.WritePrimitiveUInt8(0)
optionsStream.WritePrimitiveUInt8(4)
optionsStream.WritePrimitiveUInt32LE(p.minorVersion | (p.supportedFunctions << 8))
optionsStream.WritePrimitiveUInt8(1)
optionsStream.WritePrimitiveUInt8(16)
optionsStream.Grow(16)
optionsStream.WriteBytesNext(p.connectionSignature)
// * Encoded here for NintendoClients compatibility.
// * The order of these options should not matter,
// * however when NintendoClients calculates the
// * signature it does NOT use the original options
// * section, and instead re-encodes the data in a
// * specific order. Due to how this section is
// * parsed, though, order REALLY doesn't matter.
// * NintendoClients expects option 3 before 4, though
if p.packetType == constants.ConnectPacket {
optionsStream.WritePrimitiveUInt8(3)
optionsStream.WritePrimitiveUInt8(2)
optionsStream.WritePrimitiveUInt16LE(p.initialUnreliableSequenceID)
}
optionsStream.WritePrimitiveUInt8(4)
optionsStream.WritePrimitiveUInt8(1)
optionsStream.WritePrimitiveUInt8(p.maximumSubstreamID)
}
if p.packetType == constants.DataPacket {
optionsStream.WritePrimitiveUInt8(2)
optionsStream.WritePrimitiveUInt8(1)
optionsStream.WritePrimitiveUInt8(p.fragmentID)
}
return optionsStream.Bytes()
}
func (p *PRUDPPacketV1) calculateConnectionSignature(addr net.Addr) ([]byte, error) {
return p.server.PRUDPV1Settings.ConnectionSignatureCalculator(p, addr)
}
func (p *PRUDPPacketV1) calculateSignature(sessionKey, connectionSignature []byte) []byte {
return p.server.PRUDPV1Settings.SignatureCalculator(p, sessionKey, connectionSignature)
}
// NewPRUDPPacketV1 creates and returns a new PacketV1 using the provided Client and stream
func NewPRUDPPacketV1(server *PRUDPServer, connection *PRUDPConnection, readStream *ByteStreamIn) (*PRUDPPacketV1, error) {
packet := &PRUDPPacketV1{
PRUDPPacket: PRUDPPacket{
sender: connection,
readStream: readStream,
version: 1,
},
}
packet.server = server
if readStream != nil {
err := packet.decode()
if err != nil {
return nil, fmt.Errorf("Failed to decode PRUDPv1 packet. %s", err.Error())
}
}
return packet, nil
}
// NewPRUDPPacketsV1 reads all possible PRUDPv1 packets from the stream
func NewPRUDPPacketsV1(server *PRUDPServer, connection *PRUDPConnection, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) {
packets := make([]PRUDPPacketInterface, 0)
for readStream.Remaining() > 0 {
packet, err := NewPRUDPPacketV1(server, connection, readStream)
if err != nil {
return packets, err
}
packets = append(packets, packet)
}
return packets, nil
}
func defaultPRUDPv1ConnectionSignature(packet *PRUDPPacketV1, addr net.Addr) ([]byte, error) {
var ip net.IP
var port int
switch v := addr.(type) {
case *net.UDPAddr:
ip = v.IP.To4()
port = v.Port
default:
return nil, fmt.Errorf("Unsupported network type: %T", addr)
}
portBytes := make([]byte, 2)
binary.BigEndian.PutUint16(portBytes, uint16(port))
data := append(ip, portBytes...)
hash := hmac.New(md5.New, packet.server.PRUDPv1ConnectionSignatureKey)
hash.Write(data)
return hash.Sum(nil), nil
}
func defaultPRUDPv1CalculateSignature(packet *PRUDPPacketV1, sessionKey, connectionSignature []byte) []byte {
accessKeyBytes := []byte(packet.server.AccessKey)
options := packet.encodeOptions()
header := packet.encodeHeader()
accessKeySum := sum[byte, uint32](accessKeyBytes)
accessKeySumBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(accessKeySumBytes, accessKeySum)
key := md5.Sum(accessKeyBytes)
mac := hmac.New(md5.New, key[:])
if packet.packetType == constants.ConnectPacket && packet.server.PRUDPV1Settings.LegacyConnectionSignature {
connectionSignature = make([]byte, 0)
}
mac.Write(header[4:])
mac.Write(sessionKey)
mac.Write(accessKeySumBytes)
mac.Write(connectionSignature)
mac.Write(options)
mac.Write(packet.payload)
return mac.Sum(nil)
}