49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
from aitextgen.TokenDataset import TokenDataset
|
|
from aitextgen.tokenizers import train_tokenizer
|
|
from aitextgen.utils import GPT2ConfigCPU
|
|
from aitextgen import aitextgen
|
|
import json
|
|
import sys
|
|
|
|
with open('config.json', 'r') as file:
|
|
json_object = json.load(file)
|
|
|
|
file_name = json_object['file']
|
|
|
|
# ? generate message using trained model
|
|
|
|
|
|
def generate_message(prompt=None):
|
|
# ai = aitextgen(prompt=prompt)
|
|
ai = aitextgen(model_folder="trained_model",
|
|
tokenizer_file="aitextgen.tokenizer.json", prompt=prompt)
|
|
return ai.generate()
|
|
|
|
# ? train model using text file
|
|
|
|
|
|
def train_ai():
|
|
train_tokenizer(file_name)
|
|
tokenizer_file = "aitextgen.tokenizer.json"
|
|
config = GPT2ConfigCPU()
|
|
ai = aitextgen(tokenizer_file=tokenizer_file, config=config)
|
|
data = TokenDataset(
|
|
file_name, tokenizer_file=tokenizer_file, block_size=64)
|
|
ai.train(data, batch_size=8, num_steps=50000,
|
|
generate_every=5000, save_every=5000)
|
|
|
|
|
|
match sys.argv[1]:
|
|
case "generate":
|
|
# ? send message to parent JS process
|
|
try:
|
|
prompt = ' '.join(map(str, sys.argv[2:]))
|
|
generate_message(prompt)
|
|
except IndexError:
|
|
generate_message()
|
|
finally:
|
|
sys.stdout.flush()
|
|
case "train":
|
|
train_ai()
|
|
case _:
|
|
raise NameError("Argument not provided.")
|