Netplay: Rewrite/simplify serialization code

This commit is contained in:
Sour 2019-10-24 20:52:20 -04:00
parent 1c7a1060a6
commit e4eb6f997e
12 changed files with 84 additions and 200 deletions

View file

@ -8,13 +8,12 @@
class ForceDisconnectMessage : public NetMessage
{
private:
char* _disconnectMessage = nullptr;
uint32_t _messageLength = 0;
string _disconnectMessage;
protected:
virtual void ProtectedStreamState()
void Serialize(Serializer &s) override
{
StreamArray((void**)&_disconnectMessage, _messageLength);
s.Stream(_disconnectMessage);
}
public:
@ -22,7 +21,7 @@ public:
ForceDisconnectMessage(string message) : NetMessage(MessageType::ForceDisconnect)
{
CopyString(&_disconnectMessage, _messageLength, message);
_disconnectMessage = message;
}
string GetMessage()

View file

@ -167,8 +167,8 @@ void GameClientConnection::PushControllerState(uint8_t port, ControlDeviceState
void GameClientConnection::DisableControllers()
{
//Used to prevent deadlocks when client is trying to fill its buffer while the host changes the current game/settings/etc. (i.e situations where we need to call Console::Pause())
ClearInputData();
_enableControllers = false;
ClearInputData();
for(int i = 0; i < BaseControlDevice::PortCount; i++) {
_waitForInput[i].Signal();
}
@ -192,6 +192,10 @@ bool GameClientConnection::SetInput(BaseControlDevice *device)
}
LockHandler lock = _writeLock.AcquireSafe();
if(_shutdown || !_enableControllers || _inputSize[port] == 0) {
return true;
}
ControlDeviceState state = _inputData[port].front();
_inputData[port].pop_front();
_inputSize[port]--;

View file

@ -7,19 +7,15 @@
class GameInformationMessage : public NetMessage
{
private:
char* _romFilename = nullptr;
uint32_t _romFilenameLength = 0;
char _sha1Hash[40];
string _romFilename;
string _sha1Hash;
uint8_t _controllerPort = 0;
bool _paused = false;
protected:
virtual void ProtectedStreamState()
void Serialize(Serializer &s) override
{
StreamArray((void**)&_romFilename, _romFilenameLength);
StreamArray((void**)&_sha1Hash, 40);
Stream<uint8_t>(_controllerPort);
Stream<bool>(_paused);
s.Stream(_romFilename, _sha1Hash, _controllerPort, _paused);
}
public:
@ -27,8 +23,8 @@ public:
GameInformationMessage(string filepath, string sha1Hash, uint8_t port, bool paused) : NetMessage(MessageType::GameInformation)
{
CopyString(&_romFilename, _romFilenameLength, FolderUtilities::GetFilename(filepath, true));
memcpy(_sha1Hash, sha1Hash.c_str(), 40);
_romFilename = FolderUtilities::GetFilename(filepath, true);
_sha1Hash = sha1Hash;
_controllerPort = port;
_paused = paused;
}
@ -45,7 +41,7 @@ public:
string GetSha1Hash()
{
return string(_sha1Hash, _sha1Hash+40);
return _sha1Hash;
}
bool IsPaused()

View file

@ -9,20 +9,14 @@ private:
static constexpr int CurrentVersion = 100; //Use 100+ to distinguish from Mesen
uint32_t _emuVersion = 0;
uint32_t _protocolVersion = CurrentVersion;
char* _playerName = nullptr;
uint32_t _playerNameLength = 0;
char* _hashedPassword = nullptr;
uint32_t _hashedPasswordLength = 0;
string _playerName;
string _hashedPassword;
bool _spectator = false;
protected:
virtual void ProtectedStreamState()
void Serialize(Serializer &s) override
{
Stream<uint32_t>(_emuVersion);
Stream<uint32_t>(_protocolVersion);
StreamArray((void**)&_playerName, _playerNameLength);
StreamArray((void**)&_hashedPassword, _hashedPasswordLength);
Stream<bool>(_spectator);
s.Stream(_emuVersion, _protocolVersion, _playerName, _hashedPassword, _spectator);
}
public:
@ -32,14 +26,14 @@ public:
{
_emuVersion = emuVersion;
_protocolVersion = HandShakeMessage::CurrentVersion;
CopyString(&_playerName, _playerNameLength, playerName);
CopyString(&_hashedPassword, _hashedPasswordLength, hashedPassword);
_playerName = playerName;
_hashedPassword = hashedPassword;
_spectator = spectator;
}
string GetPlayerName()
{
return string(_playerName);
return _playerName;
}
bool IsValid(uint32_t emuVersion)

View file

@ -9,9 +9,9 @@ private:
ControlDeviceState _inputState;
protected:
virtual void ProtectedStreamState()
void Serialize(Serializer &s) override
{
StreamArray(_inputState.State);
s.StreamVector(_inputState.State);
}
public:

View file

@ -10,10 +10,10 @@ private:
ControlDeviceState _inputState;
protected:
virtual void ProtectedStreamState()
void Serialize(Serializer &s) override
{
Stream<uint8_t>(_portNumber);
StreamArray(_inputState.State);
s.Stream(_portNumber);
s.StreamVector(_inputState.State);
}
public:

View file

@ -1,115 +1,36 @@
#pragma once
#include "stdafx.h"
#include "MessageType.h"
#include "SaveStateManager.h"
#include "../Utilities/Socket.h"
#include "../Utilities/Serializer.h"
class NetMessage
{
protected:
MessageType _type;
bool _sending;
vector<uint8_t> _buffer;
uint32_t _position = 0;
vector<uint8_t*> _arraysToRelease;
vector<char*> _stringsToRelease;
template<typename T>
void Stream(T &value)
{
if(_sending) {
uint8_t* bytes = (uint8_t*)&value;
int typeSize = sizeof(T);
for(int i = 0; i < typeSize; i++) {
_buffer.push_back(bytes[i]);
}
} else {
memcpy(&value, _buffer.data()+_position, sizeof(T));
_position += sizeof(T);
}
}
void StreamArray(void* value, uint32_t length)
{
void* pointer = value;
uint32_t len = length;
StreamArray(&pointer, len);
}
void StreamArray(void** value, uint32_t &length)
{
if(_sending) {
Stream<uint32_t>(length);
uint8_t* bytes = (uint8_t*)(*value);
for(uint32_t i = 0, len = length; i < len; i++) {
_buffer.push_back(bytes[i]);
_position++;
}
} else {
Stream<uint32_t>(length);
if(*value == nullptr) {
*value = (void*)new uint8_t[length];
_arraysToRelease.push_back((uint8_t*)*value);
}
uint8_t* bytes = (uint8_t*)(*value);
for(uint32_t i = 0, len = length; i < len; i++) {
bytes[i] = _buffer[_position];
_position++;
}
}
}
void StreamArray(vector<uint8_t> &data)
{
uint32_t length = (uint32_t)data.size();
Stream<uint32_t>(length);
if(_sending) {
uint8_t* bytes = (uint8_t*)data.data();
for(uint32_t i = 0, len = length; i < len; i++) {
_buffer.push_back(bytes[i]);
_position++;
}
} else {
data.resize(length, 0);
uint8_t* bytes = (uint8_t*)data.data();
for(uint32_t i = 0, len = length; i < len; i++) {
bytes[i] = _buffer[_position];
_position++;
}
}
}
void StreamState()
{
Stream<MessageType>(_type);
ProtectedStreamState();
}
stringstream _receivedData;
NetMessage(MessageType type)
{
_type = type;
_sending = true;
}
NetMessage(void* buffer, uint32_t length)
{
_buffer.assign((uint8_t*)buffer, (uint8_t*)buffer + length);
_sending = false;
_type = (MessageType)((uint8_t*)buffer)[0];
_receivedData.write((char*)buffer + 1, length - 1);
}
public:
virtual ~NetMessage()
{
for(uint8_t *arrayPtr: _arraysToRelease) {
delete[] arrayPtr;
}
for(char *stringPtr: _stringsToRelease) {
delete[] stringPtr;
}
}
void Initialize()
{
StreamState();
Serializer s(_receivedData, SaveStateManager::FileFormatVersion);
Serialize(s);
}
MessageType GetType()
@ -119,21 +40,18 @@ public:
void Send(Socket &socket)
{
StreamState();
uint32_t messageLength = (uint32_t)_buffer.size();
Serializer s(SaveStateManager::FileFormatVersion);
Serialize(s);
_buffer.insert(_buffer.begin(), (char*)&messageLength, (char*)&messageLength + sizeof(messageLength));
socket.Send((char*)_buffer.data(), (int)_buffer.size(), 0);
}
stringstream out;
s.Save(out);
void CopyString(char** dest, uint32_t &length, string src)
{
length = (uint32_t)(src.length() + 1);
*dest = new char[length];
memcpy(*dest, src.c_str(), length);
_stringsToRelease.push_back(*dest);
string data = out.str();
uint32_t messageLength = (uint32_t)data.size() + 1;
data = string((char*)&messageLength, 4) + (char)_type + data;
socket.Send((char*)data.c_str(), (int)data.size(), 0);
}
protected:
virtual void ProtectedStreamState() = 0;
virtual void Serialize(Serializer &s) = 0;
};

View file

@ -8,42 +8,21 @@ private:
vector<PlayerInfo> _playerList;
protected:
virtual void ProtectedStreamState()
void Serialize(Serializer &s) override
{
constexpr uint32_t PlayerNameMaxLength = 50;
uint32_t nameLength = PlayerNameMaxLength + 1;
char playerName[PlayerNameMaxLength + 1];
uint8_t playerPort = 0;
bool isHost = false;
if(_sending) {
if(s.IsSaving()) {
uint32_t playerCount = (uint32_t)_playerList.size();
Stream<uint32_t>(playerCount);
s.Stream(playerCount);
for(uint32_t i = 0; i < playerCount; i++) {
memset(playerName, 0, nameLength);
memcpy(playerName, _playerList[i].Name.c_str(), std::min((uint32_t)_playerList[i].Name.size(), PlayerNameMaxLength));
playerPort = _playerList[i].ControllerPort;
StreamArray(playerName, nameLength);
Stream<uint8_t>(playerPort);
Stream<bool>(isHost);
s.Stream(_playerList[i].Name, _playerList[i].ControllerPort, _playerList[i].IsHost);
}
} else {
uint32_t playerCount;
Stream<uint32_t>(playerCount);
s.Stream(playerCount);
for(uint32_t i = 0; i < playerCount; i++) {
memset(playerName, 0, nameLength);
StreamArray(playerName, nameLength);
Stream<uint8_t>(playerPort);
Stream<bool>(isHost);
PlayerInfo playerInfo;
playerInfo.Name = playerName;
playerInfo.ControllerPort = playerPort;
playerInfo.IsHost = isHost;
s.Stream(playerInfo.Name, playerInfo.ControllerPort, playerInfo.IsHost);
_playerList.push_back(playerInfo);
}
}

View file

@ -10,9 +10,7 @@ class SaveStateMessage : public NetMessage
{
private:
vector<CheatCode> _activeCheats;
uint8_t* _stateData = nullptr;
uint32_t _dataSize = 0;
vector<uint8_t> _stateData;
ControllerType _controllerTypes[5];
ConsoleRegion _region;
@ -20,28 +18,13 @@ private:
uint32_t _ppuExtraScanlinesBeforeNmi;
uint32_t _gsuClockSpeed;
CheatCode* _cheats = nullptr;
uint32_t _cheatArraySize = 0;
protected:
virtual void ProtectedStreamState()
void Serialize(Serializer &s) override
{
StreamArray((void**)&_stateData, _dataSize);
Stream(_region);
Stream(_ppuExtraScanlinesAfterNmi);
Stream(_ppuExtraScanlinesBeforeNmi);
Stream(_gsuClockSpeed);
StreamArray(_controllerTypes, sizeof(ControllerType) * 5);
if(_sending) {
_cheats = _activeCheats.size() > 0 ? &_activeCheats[0] : nullptr;
_cheatArraySize = (uint32_t)_activeCheats.size() * sizeof(CheatCode);
StreamArray((void**)&_cheats, _cheatArraySize);
delete[] _stateData;
} else {
StreamArray((void**)&_cheats, _cheatArraySize);
}
s.StreamVector(_stateData);
s.Stream(_region, _ppuExtraScanlinesAfterNmi, _ppuExtraScanlinesBeforeNmi, _gsuClockSpeed);
s.StreamArray(_controllerTypes, 5);
s.StreamVector(_activeCheats);
}
public:
@ -68,22 +51,18 @@ public:
console->Unlock();
_dataSize = (uint32_t)state.tellp();
_stateData = new uint8_t[_dataSize];
state.read((char*)_stateData, _dataSize);
uint32_t dataSize = (uint32_t)state.tellp();
_stateData.resize(dataSize);
state.read((char*)_stateData.data(), dataSize);
}
void LoadState(shared_ptr<Console> console)
{
std::stringstream ss;
ss.write((char*)_stateData, _dataSize);
ss.write((char*)_stateData.data(), _stateData.size());
console->Deserialize(ss, SaveStateManager::FileFormatVersion);
vector<CheatCode> cheats;
for(uint32_t i = 0; i < _cheatArraySize / sizeof(CheatCode); i++) {
cheats.push_back(((CheatCode*)_cheats)[i]);
}
console->GetCheatManager()->SetCheats(cheats);
console->GetCheatManager()->SetCheats(_activeCheats);
EmulationConfig emuCfg = console->GetSettings()->GetEmulationConfig();
emuCfg.Region = _region;

View file

@ -8,9 +8,9 @@ private:
uint8_t _portNumber;
protected:
virtual void ProtectedStreamState()
void Serialize(Serializer &s) override
{
Stream<uint8_t>(_portNumber);
s.Stream(_portNumber);
}
public:

View file

@ -5,24 +5,23 @@
class ServerInformationMessage : public NetMessage
{
private:
char* _hashSalt = nullptr;
uint32_t _hashSaltLength = 0;
string _hashSalt;
protected:
virtual void ProtectedStreamState()
void Serialize(Serializer &s) override
{
StreamArray((void**)&_hashSalt, _hashSaltLength);
s.Stream(_hashSalt);
}
public:
ServerInformationMessage(void* buffer, uint32_t length) : NetMessage(buffer, length) {}
ServerInformationMessage(string hashSalt) : NetMessage(MessageType::ServerInformation)
{
CopyString(&_hashSalt, _hashSaltLength, hashSalt);
_hashSalt = hashSalt;
}
string GetHashSalt()
{
return string(_hashSalt);
return _hashSalt;
}
};

View file

@ -50,6 +50,7 @@ private:
template<typename T> void InternalStream(VectorInfo<T> &info);
template<typename T> void InternalStream(ValueInfo<T> &info);
template<typename T> void InternalStream(T &value);
template<> void InternalStream(string &str);
void RecursiveStream();
template<typename T, typename... T2> void RecursiveStream(T &value, T2&... args);
@ -151,6 +152,21 @@ void Serializer::InternalStream(ValueInfo<T> &info)
StreamElement<T>(*info.Value, info.DefaultValue);
}
template<>
void Serializer::InternalStream(string &str)
{
if(_saving) {
vector<uint8_t> stringData;
stringData.resize(str.size());
memcpy(stringData.data(), str.data(), str.size());
StreamVector(stringData);
} else {
vector<uint8_t> stringData;
StreamVector(stringData);
str = string(stringData.begin(), stringData.end());
}
}
template<typename T>
void Serializer::InternalStream(T &value)
{