created 2 functions to gen text and train
This commit is contained in:
parent
416fc0dd44
commit
2fcda33cf8
1 changed files with 26 additions and 0 deletions
26
textgen.py
26
textgen.py
|
@ -0,0 +1,26 @@
|
||||||
|
from aitextgen.TokenDataset import TokenDataset
|
||||||
|
from aitextgen.tokenizers import train_tokenizer
|
||||||
|
from aitextgen.utils import GPT2ConfigCPU
|
||||||
|
from aitextgen import aitextgen
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open('config.json', 'r') as file:
|
||||||
|
json_object = json.load(file)
|
||||||
|
|
||||||
|
file_name = json_object['file']
|
||||||
|
|
||||||
|
def generate_message():
|
||||||
|
ai = aitextgen(model_folder="trained_model",
|
||||||
|
tokenizer_file="aitextgen.tokenizer.json")
|
||||||
|
ai.generate()
|
||||||
|
|
||||||
|
def train_ai():
|
||||||
|
train_tokenizer(file_name)
|
||||||
|
tokenizer_file = "aitextgen.tokenizer.json"
|
||||||
|
config = GPT2ConfigCPU()
|
||||||
|
ai = aitextgen(tokenizer_file=tokenizer_file, config=config)
|
||||||
|
data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64)
|
||||||
|
ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000)
|
||||||
|
print("AI has been trained!")
|
||||||
|
|
||||||
|
print(generate_message())
|
Loading…
Add table
Reference in a new issue