RetroArch/deps/game_ai_lib/RetroModel.cpp
Mathieu Poliquin 66e23fca79
New feature: Override player input with machine learning models (#17407)
* Add dummy game ai subsystem

* First working prototype of a machine learning model that can override player input

* Update README.md

* Update README.md

* Fix loading path on Windows

* Change ai override to player 2

* Added quick menu show game ai option

* Implemented Quick Menu entry for Game AI options

* Redirect debug logs to retroarch log system + properly support player override

* Added support to use framebuffer as input to the AI

* Added pixel format parameter to API

* Fix game name

* code clean-up of game_ai.cpp

* Update README.md - Windows Build

* Update README.md

* Update README.md

* Update README.md

* Update config.params.sh

turn off GAME_AI feature by default

* Fix compile error in menu_displaylist.c

* Add missing #define in menu_cbs_title.c

* Added new game_ai entry in griffin_cpp

* Remove GAME_AI entry in  msg_hash_us.c

* Fix compile error in menu_displaylist.h

* Removed GAME AI references from README.md

* Fixes coding style + add GameAI lib API header

* Convert comment to legacy + remove unused code

* Additional coding style fixes to game_ai.cpp

* Fix identation issues in game_ai.cpp

* Removed some debug code in game_ai.cpp

* Add game_ai_lib in deps

* Replace assert with retro_assert

* Update Makefile.common

* Converting game_ai from cpp to c. First step.

* Convert game_ai from CPP to C. STEP 2: add C function calls

* Convert game_ai from CPP to C. Final Step

* Added shutdown function for game ai lib

* Update game_ai_lib README

* Fix crash when loading/unloading multiple games
2025-01-21 13:05:43 +01:00

94 lines
No EOL
2.9 KiB
C++

#include "RetroModel.h"
//=======================================================
// RetroModelPytorch::LoadModel
//=======================================================
void RetroModelPytorch::LoadModel(std::string path)
{
try {
this->module = torch::jit::load(path);
std::cerr << "LOADED MODEL:!" << path << std::endl;
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
}
}
//=======================================================
// RetroModelPytorch::Forward
//=======================================================
void RetroModelPytorch::Forward(std::vector<float> & output, const std::vector<float> & input)
{
std::vector<torch::jit::IValue> inputs;
at::Tensor tmp = torch::zeros({1, input.size()});
for(int i=0; i < input.size(); i++)
{
tmp[0][i] = input[i];
}
inputs.push_back(tmp);
at::Tensor result = module.forward(inputs).toTuple()->elements()[0].toTensor();
for(int i=0; i < output.size(); i++)
{
output[i] = result[0][i].item<float>();
}
}
//=======================================================
// RetroModelPytorch::Forward
//=======================================================
void RetroModelPytorch::Forward(std::vector<float> & output, RetroModelFrameData & input)
{
std::vector<torch::jit::IValue> inputs;
cv::Mat image(cv::Size(input.width, input.height), CV_8UC2, input.data);
cv::Mat rgb;
cv::Mat gray;
cv::Mat result;
// add new frame on the stack
cv::Mat * newFrame = input.PushNewFrameOnStack();
// Downsample to 84x84 and turn to greyscale
cv::cvtColor(image, gray, cv::COLOR_BGR5652GRAY);
cv::resize(gray, result, cv::Size(84,84), cv::INTER_AREA);
result.copyTo(*newFrame);
/*cv::namedWindow("Display Image", cv::WINDOW_NORMAL);
cv::imshow("Display Image", result);
cv::waitKey(0);*/
at::Tensor tmp = torch::ones({1, 4, 84, 84});
for(auto i : {0,1,2,3})
{
if(input.stack[i]->data)
tmp[0][3-i] = torch::from_blob(input.stack[i]->data, { result.rows, result.cols }, at::kByte);
}
/*test[0][3] = torch::from_blob(input.stack[0]->data, { result.rows, result.cols }, at::kByte);
if(input.stack[1]->data)
test[0][2] = torch::from_blob(input.stack[1]->data, { result.rows, result.cols }, at::kByte);
if(input.stack[2]->data)
test[0][1] = torch::from_blob(input.stack[2]->data, { result.rows, result.cols }, at::kByte);
if(input.stack[3]->data)
test[0][0] = torch::from_blob(input.stack[3]->data, { result.rows, result.cols }, at::kByte);*/
inputs.push_back(tmp);
// Execute the model and turn its output into a tensor.
torch::jit::IValue ret = module.forward(inputs);
at::Tensor actions = ret.toTuple()->elements()[0].toTensor();
for(int i=0; i < output.size(); i++)
{
output[i] = actions[0][i].item<float>();
}
}