toolbox/training/harubaru_convogpt/dataset.py

138 lines
5.1 KiB
Python

import os
import struct
import torch
import argparse
import numpy as np
import transformers
import json
from typing import Tuple
def decode(in_file: str, out_file: str, tokenizer: transformers.AutoTokenizer) -> int:
mem = np.memmap(in_file, mode="r", dtype="uint16")
tokens = len(mem)
with open(out_file, "a") as f:
for token in mem:
f.write(tokenizer.decode([token]))
return tokens
def encode(in_file: str, out_file: str, tokenizer: transformers.AutoTokenizer) -> int:
with open(in_file, "r", encoding="utf-8") as f:
text = f.read()
tokens = tokenizer.encode(text)
with open(out_file, "wb") as f:
for token in tokens:
f.write(np.uint16(token))
return len(tokens)
class TokenizedDataset(torch.utils.data.Dataset):
"""
Consumes a flat binary file containing 16-bit token serialization, aligned
along `context_length` chunks.
"""
def __init__(self, path: str, context_length: int = 2048):
file_stat = os.stat(path)
self.file = open(path, 'rb')
self.length = int(file_stat.st_size / 2 / context_length)
self.formatstr = '%sH' % context_length
self.context_length = context_length
length_mb = os.stat(path).st_size / 1024.0 / 1024.0
num_tokens = self.length * context_length
print(f"DATASET: {path}")
print(f"DATASET SIZE: {length_mb:,.2f}mb, {num_tokens:,} tokens, "
f"{self.length:,} contexts")
def __len__(self) -> int:
return self.length
def load(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
self.seek(idx)
input_ids = torch.tensor(
struct.unpack(self.formatstr,
self.file.read(self.context_length * 2)))
mask = torch.zeros(self.context_length)
return input_ids, mask
def seek(self, idx):
self.file.seek(self.context_length * idx * 2)
def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
return self.load(idx)
class FeedbackDataset(torch.utils.data.Dataset):
def __init__(self, feedback_file: str, tokenizer: transformers.AutoTokenizer, max_length: int = 512):
self.tokenizer = tokenizer
self.max_length = max_length
self.feedback_file = feedback_file
with open(feedback_file) as f:
self.feedback = [json.loads(line) for line in f]
def __len__(self):
return len(self.feedback)
def __getitem__(self, idx):
feedback = self.feedback[idx]
feedback_input = '\n'.join(feedback["input"].split("\n")[-2:])
feedback_str = f'{feedback_input} {feedback["output"].lstrip().rstrip()}'
seq = self.tokenizer(
feedback_str,
padding="max_length",
truncation=True,
return_tensors="pt"
)
reward = torch.tensor([feedback["reward"]]).unsqueeze(0)
return seq, reward
# sft file example
# {
# "input": "Anonymous: Hi, how are you?\nGPT:",
# "output": " I'm good, how are you?\n",
# "reward": 0.0
# }
import tqdm
class SFTDataset(torch.utils.data.Dataset):
def __init__(self, sft_file: str, tokenizer: transformers.AutoTokenizer, max_length: int = 2048):
self.tokenizer = tokenizer
self.max_length = max_length
self.sft_file = sft_file
with open(sft_file) as f:
self.sft = [json.loads(line) for line in f]
# iterate over sft, removing any that have a reward of 0
self.sft = [sft for sft in self.sft if sft["reward"] != 0.0]
# iterate over sft, removing any that have too many tokens
for feedback in tqdm.tqdm(self.sft, desc="Validating SFT"):
inputs = feedback["input"] + f' {feedback["output"].lstrip().rstrip()}\n'
if len(self.tokenizer(inputs).input_ids) > self.max_length:
self.sft.remove(feedback)
print(f"Removed {feedback['output']} due to length")
def __len__(self):
return len(self.sft)
def __getitem__(self, idx):
sft = self.sft[idx]
sft_input_tokens = self.tokenizer(sft["input"], return_tensors="pt").input_ids
sft_output_tokens = self.tokenizer(f' {sft["output"].lstrip().rstrip()}\n', return_tensors="pt").input_ids
input_ids = torch.cat([sft_input_tokens, sft_output_tokens], dim=-1)
start_positions = torch.tensor([len(sft_input_tokens[0])])
end_positions = torch.tensor([len(sft_input_tokens[0]) + len(sft_output_tokens[0]) - 1])
return {
"input_ids": input_ids,
"start_positions": start_positions,
"end_positions": end_positions,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Dataset Creator')
parser.add_argument('--in_file', type=str, help='input file to use', required=True)
parser.add_argument('--out_file', type=str, help='output file to use', required=True)
parser.add_argument('--model', type=str, help='model tokenizer to use', required=True)
args = parser.parse_args()
encode(args.in_file, args.out_file, transformers.AutoTokenizer.from_pretrained(args.model))