from aitextgen.TokenDataset import TokenDataset
from aitextgen.tokenizers import train_tokenizer
from aitextgen.utils import GPT2ConfigCPU
from aitextgen import aitextgen
import json
import sys

with open('config.json', 'r') as file:
    json_object = json.load(file)

file_name = json_object['file']

# ? generate message using trained model


def generate_message(prompt):
    ai = aitextgen(model_folder="trained_model",
                   tokenizer_file="aitextgen.tokenizer.json", prompt=prompt)
    ai.generate()

# ? train model using text file


def train_ai():
    train_tokenizer(file_name)
    tokenizer_file = "aitextgen.tokenizer.json"
    config = GPT2ConfigCPU()
    ai = aitextgen(tokenizer_file=tokenizer_file, config=config)
    data = TokenDataset(
        file_name, tokenizer_file=tokenizer_file, block_size=64)
    ai.train(data, batch_size=8, num_steps=50000,
             generate_every=5000, save_every=5000)


match sys.argv[1]:
    case "generate":
        # ? send message to parent JS process
        print(generate_message(sys.argv[2]))
        sys.stdout.flush()
    case "train":
        train_ai()
    case _:
        raise NameError("Argument not provided.")