toolbox/training/harubaru-convogpt/sft.py

293 lines
10 KiB
Python

import os
import torch
import accelerate
import tqdm
import time
import argparse
# import wandb
from .dataset import TokenizedDataset, FeedbackDataset, SFTDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutput
from typing import Union, Optional
# Supervised Finetuning: Compute loss between model output and target using start_positions and end_positions
def sft_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[torch.Tensor, CausalLMOutput]:
try:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
except AttributeError:
return_dict = True
'''
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.lm_head(sequence_output)
'''
outputs = self(
input_ids,
attention_mask=attention_mask,
use_cache=False,
# token_type_ids=token_type_ids,
# position_ids=position_ids,
# head_mask=head_mask,
# inputs_embeds=inputs_embeds,
# output_attentions=output_attentions,
# output_hidden_states=output_hidden_states,
# return_dict=return_dict,
)
logits = outputs["logits"]
answer_logits = logits[:, start_positions[0]:end_positions[0]+1]
answer_input_ids = input_ids[:, start_positions[0]:end_positions[0]+1]
# compute loss for prompt and answer
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1)
shift_answer_logits = answer_logits[..., :-1, :].contiguous()
shift_answer_labels = answer_input_ids[..., 1:].contiguous()
answer_loss = loss_fct(shift_answer_logits.view(-1, answer_logits.size(-1)), shift_answer_labels.view(-1))
loss = answer_loss
if not return_dict:
output = (loss,) + outputs[2:]
return ((loss,) + outputs[2:]) if return_dict else output
return CausalLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class SFT_Trainer:
def __init__(
self,
accelerator: accelerate.Accelerator,
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
train_dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
weight_dtype: torch.dtype,
args: argparse.Namespace,
) -> None:
self.accelerator = accelerator
self.model = model
self.tokenizer = tokenizer
self.train_dataloader = train_dataloader
self.optimizer = optimizer
self.weight_dtype = weight_dtype
self.args = args
if accelerator.is_main_process:
self.progress_bar = tqdm.tqdm(
total=self.args.epochs*len(train_dataloader),
desc="Total Steps",
leave=False,
)
self.run = wandb.init(
project="convogpt-sftlm",
name=f'{self.args.model}-{self.args.epochs}-{self.args.batch_size}-{self.args.learning_rate}--{int(time.time())}',
config=self.args,
)
self.global_step = 0
def save_model(self) -> None:
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
path = f'{self.args.output_dir}/{self.run.name}'
os.makedirs(path, exist_ok=True)
unwrapped_model = self.accelerator.unwrap_model(self.model)
unwrapped_model.save_pretrained(path, save_function=self.accelerator.save)
def step(self, batch: dict) -> None:
with self.accelerator.accumulate(self.model):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
start_positions = batch['start_positions']
end_positions = batch['end_positions']
try:
outputs = sft_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
start_positions=start_positions,
end_positions=end_positions,
)
loss = outputs.loss
self.accelerator.backward(loss)
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.optimizer.zero_grad()
except RuntimeError as e:
print(f"RuntimeError: {e}")
print(f"input_ids: {input_ids}")
print(f"attention_mask: {attention_mask}")
print(f"start_positions: {start_positions}")
print(f"end_positions: {end_positions}")
print('Skipping batch...')
loss = torch.tensor(float('nan'), device=self.accelerator.device)
return {
"train/loss": loss.detach().item(),
}
def train(self) -> None:
self.model.train()
for epoch in range(self.args.epochs):
for _, batch in enumerate(self.train_dataloader):
step_start = time.perf_counter()
#print(f"####\n{self.tokenizer.decode(batch['input_ids'][0])}\n#{batch['start_positions'][0]}:{batch['end_positions'][0]}\n####")
metrics = self.step(batch)
step_end = time.perf_counter()
if self.accelerator.is_main_process:
rank_samples_per_second = self.args.batch_size / (step_end - step_start)
world_samples_per_second = rank_samples_per_second * self.accelerator.num_processes
metrics.update({
"perf/rank_samples_per_second": rank_samples_per_second,
"perf/world_samples_per_second": world_samples_per_second,
"train/epoch": epoch,
"train/step": self.global_step,
"train/samples_seen": self.global_step * self.args.batch_size,
})
self.global_step += 1
self.progress_bar.update(1)
self.progress_bar.set_postfix(**metrics)
self.run.log(metrics, step=self.global_step)
if self.global_step % self.args.save_steps == 0:
self.save_model()
self.accelerator.wait_for_everyone()
self.save_model()
def main() -> None:
parser = argparse.ArgumentParser(description="Supervised GPT finetuning")
parser.add_argument("--model", type=str, default="hakurei/gpt-j-random-tinier", help="Model name")
parser.add_argument("--dataset", type=str, default="train.jsonl", help="Training file")
parser.add_argument("--output_dir", type=str, default="output", help="Output directory")
parser.add_argument("--epochs", type=int, default=1, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--save_steps", type=int, default=1000, help="Save model every x steps")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
args = parser.parse_args()
accelerator = accelerate.Accelerator()
accelerate.utils.set_seed(42)
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer.pad_token = tokenizer.eos_token
def collate_fn(batches):
input_ids = [
batch["input_ids"].squeeze(0) for batch in batches
]
padded_tokens = tokenizer.pad(
{"input_ids": input_ids}, return_tensors="pt", padding=True
)
start_positions = torch.stack(
[batch["start_positions"] for batch in batches]
)
end_positions = torch.stack(
[batch["end_positions"] for batch in batches]
)
return {
"input_ids": padded_tokens["input_ids"],
"attention_mask": padded_tokens["attention_mask"],
"start_positions": start_positions,
"end_positions": end_positions,
}
train_dataset = SFTDataset(args.dataset, tokenizer)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=collate_fn,
)
model = AutoModelForCausalLM.from_pretrained(args.model)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
)
trainer = SFT_Trainer(
accelerator=accelerator,
model=model,
tokenizer=tokenizer,
train_dataloader=train_dataloader,
optimizer=optimizer,
weight_dtype=None,
args=args,
)
trainer.train()
if __name__ == '__main__':
"""
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained('distilgpt2')
tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
# Add supervised finetuning forward method to model
model.forward = sft_forward.__get__(model)
# Create input tensors
question = 'What is the capital of France?'
answer = 'The capital of France is Paris.'
question_tokens = tokenizer.encode(question, return_tensors='pt')
answer_tokens = tokenizer.encode(answer, return_tensors='pt')
input_ids = torch.cat([question_tokens, answer_tokens], dim=-1)
start_positions = torch.tensor([len(question_tokens[0])])
end_positions = torch.tensor([len(question_tokens[0]) + len(answer_tokens[0]) - 1])
# Compute loss
loss = model(input_ids, start_positions=start_positions, end_positions=end_positions).loss
print(loss)
"""
main()