diff --git a/Core/SaveState.cpp b/Core/SaveState.cpp index 659066612e..16708ed54c 100644 --- a/Core/SaveState.cpp +++ b/Core/SaveState.cpp @@ -103,9 +103,14 @@ namespace SaveState } struct StateBuffer { + void Clear() { + zstd_compressed.clear(); + decompressed_size = 0; + compressed_size = 0; + } std::vector zstd_compressed; - size_t decompressed_size; - size_t compressed_size; + size_t decompressed_size = 0; + size_t compressed_size = 0; }; // This ring buffer of states is for rewind save states, which are kept in RAM. @@ -196,6 +201,8 @@ namespace SaveState }); } + const bool USE_XOR = false; + void Compress(StateBuffer *result, std::vector &state, const std::vector &base) { std::lock_guard guard(lock_); @@ -205,26 +212,36 @@ namespace SaveState double start_time = time_now_d(); std::vector compressed; - compressed.reserve(512 * 1024); - for (size_t i = 0; i < state.size(); i += BLOCK_SIZE) - { - int blockSize = std::min(BLOCK_SIZE, (int)(state.size() - i)); - if (i + blockSize > base.size() || memcmp(&state[i], &base[i], blockSize) != 0) - { - compressed.push_back(1); - compressed.insert(compressed.end(), state.begin() + i, state.begin() + i + blockSize); + if (USE_XOR) { + compressed.resize(state.size()); + for (size_t i = 0; i < state.size(); i++) { + if (i >= base.size()) { + compressed[i] = state[i]; + } else { + compressed[i] = base[i] ^ state[i]; + } + } + } else { + compressed.reserve(512 * 1024); + for (size_t i = 0; i < state.size(); i += BLOCK_SIZE) + { + int blockSize = std::min(BLOCK_SIZE, (int)(state.size() - i)); + if (i + blockSize > base.size() || memcmp(&state[i], &base[i], blockSize) != 0) + { + compressed.push_back(1); + compressed.insert(compressed.end(), state.begin() + i, state.begin() + i + blockSize); + } else { + compressed.push_back(0); + } } - else - compressed.push_back(0); } double taken_s = time_now_d() - start_time; DEBUG_LOG(SAVESTATE, "Rewind: Compressed save from %d bytes to %d in %0.2f ms.", (int)state.size(), (int)compressed.size(), taken_s * 1000.0); - // Temporarily allocate a buffer to do decompression in. + // Temporarily allocate a buffer to do compression in. size_t compressCapacity = ZSTD_compressBound(compressed.size()); u8 *compress_buf = (u8 *)malloc(compressCapacity); - result->compressed_size = ZSTD_compress(compress_buf, compressCapacity, compressed.data(), compressed.size(), 0); if (result->compressed_size) { result->zstd_compressed = std::vector(result->compressed_size, 0); @@ -247,25 +264,35 @@ namespace SaveState std::vector compressed = std::vector(buffer.decompressed_size, 0); ZSTD_decompress(&compressed[0], compressed.size(), buffer.zstd_compressed.data(), buffer.zstd_compressed.size()); - for (size_t i = 0; i < compressed.size(); ) - { - if (compressed[i] == 0) - { - ++i; - int blockSize = std::min(BLOCK_SIZE, (int)(base.size() - result.size())); - result.insert(result.end(), basePos, basePos + blockSize); - basePos += blockSize; + if (USE_XOR) { + result.resize(compressed.size()); + for (size_t i = 0; i < compressed.size(); i++) { + if (i < base.size()) { + result[i] = compressed[i] ^ base[i]; + } else { + result[i] = compressed[i]; + } } - else - { - ++i; - int blockSize = std::min(BLOCK_SIZE, (int)(compressed.size() - i)); - result.insert(result.end(), compressed.begin() + i, compressed.begin() + i + blockSize); - i += blockSize; - // This check is to avoid advancing basePos out of range, which MSVC catches. - // When this happens, we're at the end of decoding anyway. - if (base.end() - basePos >= blockSize) { + } else { + for (size_t i = 0; i < compressed.size(); ) { + if (compressed[i] == 0) { + ++i; + int blockSize = std::min(BLOCK_SIZE, (int)(base.size() - result.size())); + _dbg_assert_(blockSize >= 0); + result.insert(result.end(), basePos, basePos + blockSize); basePos += blockSize; + } else { + ++i; + int blockSize = std::min(BLOCK_SIZE, (int)(compressed.size() - i)); + if (blockSize > 0) { + result.insert(result.end(), compressed.begin() + i, compressed.begin() + i + blockSize); + i += blockSize; + // This check is to avoid advancing basePos out of range, which MSVC catches. + // When this happens, we're at the end of decoding anyway. + if (base.end() - basePos >= blockSize) { + basePos += blockSize; + } + } } } } @@ -286,7 +313,7 @@ namespace SaveState baseMapping_.clear(); baseMapping_.resize(size_); for (auto &s : states_) { - s.clear(); + s.Clear(); } buffer_.clear(); base_ = -1;