diff --git a/training/harubaru-convogpt/sft.py b/training/harubaru-convogpt/sft.py index 3ed583d..9c4e909 100644 --- a/training/harubaru-convogpt/sft.py +++ b/training/harubaru-convogpt/sft.py @@ -4,9 +4,9 @@ import accelerate import tqdm import time import argparse -import wandb +# import wandb -from dataset import TokenizedDataset, FeedbackDataset, SFTDataset +from .dataset import TokenizedDataset, FeedbackDataset, SFTDataset from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.modeling_outputs import CausalLMOutput @@ -33,6 +33,7 @@ def sft_forward( except AttributeError: return_dict = True + ''' outputs = self.transformer( input_ids, attention_mask=attention_mask, @@ -48,6 +49,21 @@ def sft_forward( sequence_output = outputs[0] logits = self.lm_head(sequence_output) + ''' + outputs = self( + input_ids, + attention_mask=attention_mask, + use_cache=False, + # token_type_ids=token_type_ids, + # position_ids=position_ids, + # head_mask=head_mask, + # inputs_embeds=inputs_embeds, + # output_attentions=output_attentions, + # output_hidden_states=output_hidden_states, + # return_dict=return_dict, + ) + + logits = outputs["logits"] answer_logits = logits[:, start_positions[0]:end_positions[0]+1] answer_input_ids = input_ids[:, start_positions[0]:end_positions[0]+1]