check arg passed
This commit is contained in:
parent
8c6ad6d5a3
commit
1b4b769b7a
1 changed files with 21 additions and 8 deletions
29
textgen.py
29
textgen.py
|
@ -2,7 +2,8 @@ from aitextgen.TokenDataset import TokenDataset
|
|||
from aitextgen.tokenizers import train_tokenizer
|
||||
from aitextgen.utils import GPT2ConfigCPU
|
||||
from aitextgen import aitextgen
|
||||
import json, sys
|
||||
import json
|
||||
import sys
|
||||
|
||||
with open('config.json', 'r') as file:
|
||||
json_object = json.load(file)
|
||||
|
@ -10,21 +11,33 @@ with open('config.json', 'r') as file:
|
|||
file_name = json_object['file']
|
||||
|
||||
# ? generate message using trained model
|
||||
|
||||
|
||||
def generate_message():
|
||||
ai = aitextgen(model_folder="trained_model",
|
||||
tokenizer_file="aitextgen.tokenizer.json")
|
||||
tokenizer_file="aitextgen.tokenizer.json")
|
||||
ai.generate()
|
||||
|
||||
# ? train model using text file
|
||||
|
||||
|
||||
def train_ai():
|
||||
train_tokenizer(file_name)
|
||||
tokenizer_file = "aitextgen.tokenizer.json"
|
||||
config = GPT2ConfigCPU()
|
||||
ai = aitextgen(tokenizer_file=tokenizer_file, config=config)
|
||||
data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64)
|
||||
ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000)
|
||||
print("AI has been trained!")
|
||||
data = TokenDataset(
|
||||
file_name, tokenizer_file=tokenizer_file, block_size=64)
|
||||
ai.train(data, batch_size=8, num_steps=50000,
|
||||
generate_every=5000, save_every=5000)
|
||||
|
||||
# ? send message to parent JS process
|
||||
print(generate_message())
|
||||
sys.stdout.flush()
|
||||
|
||||
match sys.argv[1]:
|
||||
case "generate":
|
||||
# ? send message to parent JS process
|
||||
print(generate_message())
|
||||
sys.stdout.flush()
|
||||
case "train":
|
||||
train_ai()
|
||||
case _:
|
||||
raise NameError("Argument not provided.")
|
||||
|
|
Loading…
Add table
Reference in a new issue