refactor make the code actually somewhat bearable
This commit is contained in:
parent
32a6ba2c35
commit
da638441b7
4 changed files with 76 additions and 91 deletions
30
generate.py
Normal file
30
generate.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
from aitextgen.TokenDataset import TokenDataset
|
||||||
|
from aitextgen.tokenizers import train_tokenizer
|
||||||
|
from aitextgen.utils import GPT2ConfigCPU
|
||||||
|
from aitextgen import aitextgen
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
with open('config.json', 'r') as file:
|
||||||
|
json_object = json.load(file)
|
||||||
|
|
||||||
|
file_name = json_object['file']
|
||||||
|
|
||||||
|
def generate_message(prompt=None):
|
||||||
|
ai = aitextgen(model_folder="trained_model",
|
||||||
|
tokenizer_file="aitextgen.tokenizer.json", prompt=prompt)
|
||||||
|
return ai.generate()
|
||||||
|
|
||||||
|
match sys.argv[1]:
|
||||||
|
case "prompt":
|
||||||
|
try:
|
||||||
|
msg = generate_message(' '.join(map(str, sys.argv[1:])))
|
||||||
|
print(msg)
|
||||||
|
except IndexError:
|
||||||
|
print(generate_message())
|
||||||
|
finally:
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
case _:
|
||||||
|
msg = generate_message()
|
||||||
|
print(msg)
|
66
index.js
66
index.js
|
@ -2,75 +2,59 @@ import config from './config.json' assert {type: "json"};
|
||||||
import { MatrixClient, SimpleFsStorageProvider, AutojoinRoomsMixin } from "matrix-bot-sdk";
|
import { MatrixClient, SimpleFsStorageProvider, AutojoinRoomsMixin } from "matrix-bot-sdk";
|
||||||
import fs from "fs";
|
import fs from "fs";
|
||||||
import { PythonShell } from 'python-shell';
|
import { PythonShell } from 'python-shell';
|
||||||
import { type } from 'os';
|
// import { type } from 'os';
|
||||||
|
|
||||||
const storage = new SimpleFsStorageProvider("storage.json");
|
const storage = new SimpleFsStorageProvider("storage.json");
|
||||||
const client = new MatrixClient(config.homeserver, config.token, storage);
|
const client = new MatrixClient(config.homeserver, config.token, storage);
|
||||||
const pyFile = "textgen.py";
|
const pyFile = "generate.py";
|
||||||
|
|
||||||
AutojoinRoomsMixin.setupOnClient(client);
|
AutojoinRoomsMixin.setupOnClient(client);
|
||||||
client.start().then(() => console.log(`Client has started!\n`));
|
client.start().then(() => console.log(`Client has started!\n`));
|
||||||
|
|
||||||
const messageCounters = new Map(); // room ID, message count
|
const messageCounters = new Map(); // room ID, message count
|
||||||
let trainingCounter = 0;
|
|
||||||
|
|
||||||
client.on("room.message", (roomId, event) => {
|
client.on("room.message", (roomId, event) => {
|
||||||
if (!event["content"] || event["sender"] === config.user) return;
|
if (!event["content"] || event["sender"] === config.user) return; // ? ignore if message sent by bot itself or is empty
|
||||||
|
|
||||||
++trainingCounter;
|
let messageArray = event["content"]["body"].split(" ");
|
||||||
|
|
||||||
|
if (!(messageArray[0] === config.prefix)) {
|
||||||
messageCounters.set(roomId, (messageCounters.get(roomId) ?? 0) + 1);
|
messageCounters.set(roomId, (messageCounters.get(roomId) ?? 0) + 1);
|
||||||
let userMessage = event["content"]["body"].split(" ");
|
console.log(`COUNTER:\t${messageCounters.get(roomId)}\t${roomId}\t${event["content"]["body"]}`);
|
||||||
|
fs.appendFile(config.file, event["content"]["body"] + "\n", function (err) {
|
||||||
console.log(`COUNTER:\t${messageCounters.get(roomId)}\t${roomId}\t${userMessage.join(" ")}`);
|
|
||||||
|
|
||||||
|
|
||||||
if (userMessage[0].startsWith(config.prefix)) {
|
|
||||||
userMessage[0] = userMessage[0].replace(config.prefix, '').toLowerCase();
|
|
||||||
} else {
|
|
||||||
fs.appendFile(config.file, userMessage.join(" ") + "\n", function (err) {
|
|
||||||
if (err) throw err;
|
if (err) throw err;
|
||||||
});
|
});
|
||||||
};
|
return;
|
||||||
|
} // ? if message does not start with prefix log it for training
|
||||||
|
|
||||||
|
|
||||||
|
messageArray.shift() // ? remove bot's prefix from array
|
||||||
|
|
||||||
// ? send message if:
|
// ? send message if:
|
||||||
// ? - enough messages have been sent
|
// ? - enough messages have been sent
|
||||||
// ? - commanded
|
// ? - commanded
|
||||||
if (!(messageCounters.get(roomId) % config.frequency) || userMessage[0] === "speak") {
|
if (!(messageCounters.get(roomId) % config.frequency) || messageArray[0].toLowerCase() === "speak") {
|
||||||
console.log("Generating message...");
|
console.log("Generating message...");
|
||||||
|
|
||||||
userMessage.shift()
|
const options = { args: ["", ""] };
|
||||||
userMessage = userMessage.join(" ")
|
|
||||||
fs.appendFile(config.file, userMessage + "\n", function (err) {
|
|
||||||
if (err) throw err;
|
|
||||||
});
|
|
||||||
const options = { args: ['generate', userMessage] };
|
|
||||||
PythonShell.run(pyFile, options, (err, message) => {
|
PythonShell.run(pyFile, options, (err, message) => {
|
||||||
if (err) throw err;
|
if (err) throw err;
|
||||||
client.sendText(roomId, message.toString());
|
client.sendText(roomId, message.toString());
|
||||||
console.log("Message sent!");
|
console.log("Message sent!");
|
||||||
}); // ? send generated message to room
|
}); // ? send generated message to room
|
||||||
};
|
|
||||||
|
|
||||||
// ? retrain if:
|
} else if (messageArray[0] === "prompt") {
|
||||||
// ? - enough message have been sent
|
console.log("prompted to generate...")
|
||||||
// ? - commanded
|
|
||||||
if (trainingCounter >= config.retrain || userMessage[0] === "train") {
|
|
||||||
console.log("Retraining the AI...");
|
|
||||||
client.sendText(roomId, "Retraining the AI...");
|
|
||||||
|
|
||||||
trainingCounter = 0;
|
|
||||||
const options = { args: ['train'] };
|
|
||||||
|
|
||||||
|
const options = { args: ['prompt', messageArray.join(" ")] };
|
||||||
PythonShell.run(pyFile, options, (err, message) => {
|
PythonShell.run(pyFile, options, (err, message) => {
|
||||||
if (err) throw err;
|
if (err) throw err;
|
||||||
console.log(message.toString());
|
client.sendText(roomId, message.toString());
|
||||||
});
|
console.log("Message sent!");
|
||||||
console.log("Training finished!");
|
}); // ? send prompted message to room
|
||||||
client.sendText(roomId, "Training finished!");
|
|
||||||
};
|
} else {
|
||||||
|
console.log("Invalid command")
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
function lineCount(text) {
|
|
||||||
return fs.readFileSync(text).toString().split("\n").length - 1;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
49
textgen.py
49
textgen.py
|
@ -1,49 +0,0 @@
|
||||||
from aitextgen.TokenDataset import TokenDataset
|
|
||||||
from aitextgen.tokenizers import train_tokenizer
|
|
||||||
from aitextgen.utils import GPT2ConfigCPU
|
|
||||||
from aitextgen import aitextgen
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
|
|
||||||
with open('config.json', 'r') as file:
|
|
||||||
json_object = json.load(file)
|
|
||||||
|
|
||||||
file_name = json_object['file']
|
|
||||||
|
|
||||||
# ? generate message using trained model
|
|
||||||
|
|
||||||
|
|
||||||
def generate_message(prompt=None):
|
|
||||||
# ai = aitextgen(prompt=prompt)
|
|
||||||
ai = aitextgen(model_folder="trained_model",
|
|
||||||
tokenizer_file="aitextgen.tokenizer.json", prompt=prompt)
|
|
||||||
return ai.generate()
|
|
||||||
|
|
||||||
# ? train model using text file
|
|
||||||
|
|
||||||
|
|
||||||
def train_ai():
|
|
||||||
train_tokenizer(file_name)
|
|
||||||
tokenizer_file = "aitextgen.tokenizer.json"
|
|
||||||
config = GPT2ConfigCPU()
|
|
||||||
ai = aitextgen(tokenizer_file=tokenizer_file, config=config)
|
|
||||||
data = TokenDataset(
|
|
||||||
file_name, tokenizer_file=tokenizer_file, block_size=64)
|
|
||||||
ai.train(data, batch_size=8, num_steps=50000,
|
|
||||||
generate_every=5000, save_every=5000)
|
|
||||||
|
|
||||||
|
|
||||||
match sys.argv[1]:
|
|
||||||
case "generate":
|
|
||||||
# ? send message to parent JS process
|
|
||||||
try:
|
|
||||||
prompt = ' '.join(map(str, sys.argv[2:]))
|
|
||||||
generate_message(prompt)
|
|
||||||
except IndexError:
|
|
||||||
generate_message()
|
|
||||||
finally:
|
|
||||||
sys.stdout.flush()
|
|
||||||
case "train":
|
|
||||||
train_ai()
|
|
||||||
case _:
|
|
||||||
raise NameError("Argument not provided.")
|
|
20
train.py
Normal file
20
train.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
from aitextgen.TokenDataset import TokenDataset
|
||||||
|
from aitextgen.tokenizers import train_tokenizer
|
||||||
|
from aitextgen.utils import GPT2ConfigCPU
|
||||||
|
from aitextgen import aitextgen
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open('config.json', 'r') as file:
|
||||||
|
json_object = json.load(file)
|
||||||
|
|
||||||
|
file_name = json_object['file']
|
||||||
|
|
||||||
|
train_tokenizer(file_name)
|
||||||
|
tokenizer_file = "aitextgen.tokenizer.json"
|
||||||
|
config = GPT2ConfigCPU()
|
||||||
|
|
||||||
|
ai = aitextgen(tokenizer_file=tokenizer_file, config=config)
|
||||||
|
data = TokenDataset(
|
||||||
|
file_name, tokenizer_file=tokenizer_file, block_size=64)
|
||||||
|
ai.train(data, batch_size=8, num_steps=50000,
|
||||||
|
generate_every=5000, save_every=5000)
|
Loading…
Add table
Reference in a new issue