RetroArch/deps/game_ai_lib/RetroModel.h
Mathieu Poliquin 66e23fca79
New feature: Override player input with machine learning models (#17407)
* Add dummy game ai subsystem

* First working prototype of a machine learning model that can override player input

* Update README.md

* Update README.md

* Fix loading path on Windows

* Change ai override to player 2

* Added quick menu show game ai option

* Implemented Quick Menu entry for Game AI options

* Redirect debug logs to retroarch log system + properly support player override

* Added support to use framebuffer as input to the AI

* Added pixel format parameter to API

* Fix game name

* code clean-up of game_ai.cpp

* Update README.md - Windows Build

* Update README.md

* Update README.md

* Update README.md

* Update config.params.sh

turn off GAME_AI feature by default

* Fix compile error in menu_displaylist.c

* Add missing #define in menu_cbs_title.c

* Added new game_ai entry in griffin_cpp

* Remove GAME_AI entry in  msg_hash_us.c

* Fix compile error in menu_displaylist.h

* Removed GAME AI references from README.md

* Fixes coding style + add GameAI lib API header

* Convert comment to legacy + remove unused code

* Additional coding style fixes to game_ai.cpp

* Fix identation issues in game_ai.cpp

* Removed some debug code in game_ai.cpp

* Add game_ai_lib in deps

* Replace assert with retro_assert

* Update Makefile.common

* Converting game_ai from cpp to c. First step.

* Convert game_ai from CPP to C. STEP 2: add C function calls

* Convert game_ai from CPP to C. Final Step

* Added shutdown function for game ai lib

* Update game_ai_lib README

* Fix crash when loading/unloading multiple games
2025-01-21 13:05:43 +01:00

69 lines
1.6 KiB
C++

#pragma once
#include <torch/script.h>
#include <opencv2/opencv.hpp>
#include <bitset>
#include <string>
#include <filesystem>
#include <vector>
#include <queue>
class RetroModelFrameData
{
public:
RetroModelFrameData(): data(nullptr)
{
stack[0] = new cv::Mat;
stack[1] = new cv::Mat;
stack[2] = new cv::Mat;
stack[3] = new cv::Mat;
}
~RetroModelFrameData()
{
if(stack[0]) delete stack[0];
if(stack[1]) delete stack[1];
if(stack[2]) delete stack[2];
if(stack[3]) delete stack[3];
}
cv::Mat * PushNewFrameOnStack()
{
//push everything down
cv::Mat * tmp = stack[3];
stack[3] = stack[2];
stack[2] = stack[1];
stack[1] = stack[0];
stack[0] = tmp;
return stack[0];
}
void *data;
unsigned int width;
unsigned int height;
unsigned int pitch;
unsigned int format;
cv::Mat * stack[4];
};
class RetroModel {
public:
virtual void Forward(std::vector<float> & output, const std::vector<float> & input)=0;
virtual void Forward(std::vector<float> & output, RetroModelFrameData & input)=0;
};
class RetroModelPytorch : public RetroModel {
public:
virtual void LoadModel(std::string);
virtual void Forward(std::vector<float> & output, const std::vector<float> & input);
virtual void Forward(std::vector<float> & output, RetroModelFrameData & input);
private:
torch::jit::script::Module module;
};