fix: haru's sft being incompatible with the ColossalAI fine-tune script

This commit is contained in:
11b 2022-12-26 20:42:48 -03:00
parent 5dbde00d27
commit b79ac657a4
1 changed files with 18 additions and 2 deletions

View File

@ -4,9 +4,9 @@ import accelerate
import tqdm import tqdm
import time import time
import argparse import argparse
import wandb # import wandb
from dataset import TokenizedDataset, FeedbackDataset, SFTDataset from .dataset import TokenizedDataset, FeedbackDataset, SFTDataset
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutput from transformers.modeling_outputs import CausalLMOutput
@ -33,6 +33,7 @@ def sft_forward(
except AttributeError: except AttributeError:
return_dict = True return_dict = True
'''
outputs = self.transformer( outputs = self.transformer(
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
@ -48,6 +49,21 @@ def sft_forward(
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.lm_head(sequence_output) logits = self.lm_head(sequence_output)
'''
outputs = self(
input_ids,
attention_mask=attention_mask,
use_cache=False,
# token_type_ids=token_type_ids,
# position_ids=position_ids,
# head_mask=head_mask,
# inputs_embeds=inputs_embeds,
# output_attentions=output_attentions,
# output_hidden_states=output_hidden_states,
# return_dict=return_dict,
)
logits = outputs["logits"]
answer_logits = logits[:, start_positions[0]:end_positions[0]+1] answer_logits = logits[:, start_positions[0]:end_positions[0]+1]
answer_input_ids = input_ids[:, start_positions[0]:end_positions[0]+1] answer_input_ids = input_ids[:, start_positions[0]:end_positions[0]+1]