toolbox/training/harubaru_convogpt/dataset.py

138 lines
5.1 KiB
Python
Raw Normal View History

import os
import struct
import torch
import argparse
import numpy as np
import transformers
import json
from typing import Tuple
def decode(in_file: str, out_file: str, tokenizer: transformers.AutoTokenizer) -> int:
mem = np.memmap(in_file, mode="r", dtype="uint16")
tokens = len(mem)
with open(out_file, "a") as f:
for token in mem:
f.write(tokenizer.decode([token]))
return tokens
def encode(in_file: str, out_file: str, tokenizer: transformers.AutoTokenizer) -> int:
with open(in_file, "r", encoding="utf-8") as f:
text = f.read()
tokens = tokenizer.encode(text)
with open(out_file, "wb") as f:
for token in tokens:
f.write(np.uint16(token))
return len(tokens)
class TokenizedDataset(torch.utils.data.Dataset):
"""
Consumes a flat binary file containing 16-bit token serialization, aligned
along `context_length` chunks.
"""
def __init__(self, path: str, context_length: int = 2048):
file_stat = os.stat(path)
self.file = open(path, 'rb')
self.length = int(file_stat.st_size / 2 / context_length)
self.formatstr = '%sH' % context_length
self.context_length = context_length
length_mb = os.stat(path).st_size / 1024.0 / 1024.0
num_tokens = self.length * context_length
print(f"DATASET: {path}")
print(f"DATASET SIZE: {length_mb:,.2f}mb, {num_tokens:,} tokens, "
f"{self.length:,} contexts")
def __len__(self) -> int:
return self.length
def load(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
self.seek(idx)
input_ids = torch.tensor(
struct.unpack(self.formatstr,
self.file.read(self.context_length * 2)))
mask = torch.zeros(self.context_length)
return input_ids, mask
def seek(self, idx):
self.file.seek(self.context_length * idx * 2)
def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
return self.load(idx)
class FeedbackDataset(torch.utils.data.Dataset):
def __init__(self, feedback_file: str, tokenizer: transformers.AutoTokenizer, max_length: int = 512):
self.tokenizer = tokenizer
self.max_length = max_length
self.feedback_file = feedback_file
with open(feedback_file) as f:
self.feedback = [json.loads(line) for line in f]
def __len__(self):
return len(self.feedback)
def __getitem__(self, idx):
feedback = self.feedback[idx]
feedback_input = '\n'.join(feedback["input"].split("\n")[-2:])
feedback_str = f'{feedback_input} {feedback["output"].lstrip().rstrip()}'
seq = self.tokenizer(
feedback_str,
padding="max_length",
truncation=True,
return_tensors="pt"
)
reward = torch.tensor([feedback["reward"]]).unsqueeze(0)
return seq, reward
# sft file example
# {
# "input": "Anonymous: Hi, how are you?\nGPT:",
# "output": " I'm good, how are you?\n",
# "reward": 0.0
# }
import tqdm
class SFTDataset(torch.utils.data.Dataset):
def __init__(self, sft_file: str, tokenizer: transformers.AutoTokenizer, max_length: int = 2048):
self.tokenizer = tokenizer
self.max_length = max_length
self.sft_file = sft_file
with open(sft_file) as f:
self.sft = [json.loads(line) for line in f]
# iterate over sft, removing any that have a reward of 0
self.sft = [sft for sft in self.sft if sft["reward"] != 0.0]
# iterate over sft, removing any that have too many tokens
for feedback in tqdm.tqdm(self.sft, desc="Validating SFT"):
inputs = feedback["input"] + f' {feedback["output"].lstrip().rstrip()}\n'
if len(self.tokenizer(inputs).input_ids) > self.max_length:
self.sft.remove(feedback)
print(f"Removed {feedback['output']} due to length")
def __len__(self):
return len(self.sft)
def __getitem__(self, idx):
sft = self.sft[idx]
sft_input_tokens = self.tokenizer(sft["input"], return_tensors="pt").input_ids
sft_output_tokens = self.tokenizer(f' {sft["output"].lstrip().rstrip()}\n', return_tensors="pt").input_ids
input_ids = torch.cat([sft_input_tokens, sft_output_tokens], dim=-1)
start_positions = torch.tensor([len(sft_input_tokens[0])])
end_positions = torch.tensor([len(sft_input_tokens[0]) + len(sft_output_tokens[0]) - 1])
return {
"input_ids": input_ids,
"start_positions": start_positions,
"end_positions": end_positions,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Dataset Creator')
parser.add_argument('--in_file', type=str, help='input file to use', required=True)
parser.add_argument('--out_file', type=str, help='output file to use', required=True)
parser.add_argument('--model', type=str, help='model tokenizer to use', required=True)
args = parser.parse_args()
encode(args.in_file, args.out_file, transformers.AutoTokenizer.from_pretrained(args.model))