from aitextgen.TokenDataset import TokenDataset
from aitextgen.tokenizers import train_tokenizer
from aitextgen.utils import GPT2ConfigCPU
from aitextgen import aitextgen
import json
import sys

with open('config.json', 'r') as file:
    json_object = json.load(file)

file_name = json_object['file']

def generate_message(prompt=None):
    ai = aitextgen(model_folder="trained_model",
                   tokenizer_file="aitextgen.tokenizer.json", prompt=prompt)
    return ai.generate()

match sys.argv[1]: 
    case "prompt":
        try:
            msg = generate_message(' '.join(map(str, sys.argv[1:])))
            print(msg)
        except IndexError:
            print(generate_message())
        finally:
            sys.stdout.flush()

    case _:
        msg = generate_message()
        print(msg)