fix: haru's sft being incompatible with the ColossalAI fine-tune script

This commit is contained in:
11b 2022-12-26 20:42:48 -03:00
parent 5dbde00d27
commit b79ac657a4
1 changed files with 18 additions and 2 deletions

View File

@ -4,9 +4,9 @@ import accelerate
import tqdm
import time
import argparse
import wandb
# import wandb
from dataset import TokenizedDataset, FeedbackDataset, SFTDataset
from .dataset import TokenizedDataset, FeedbackDataset, SFTDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutput
@ -33,6 +33,7 @@ def sft_forward(
except AttributeError:
return_dict = True
'''
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
@ -48,6 +49,21 @@ def sft_forward(
sequence_output = outputs[0]
logits = self.lm_head(sequence_output)
'''
outputs = self(
input_ids,
attention_mask=attention_mask,
use_cache=False,
# token_type_ids=token_type_ids,
# position_ids=position_ids,
# head_mask=head_mask,
# inputs_embeds=inputs_embeds,
# output_attentions=output_attentions,
# output_hidden_states=output_hidden_states,
# return_dict=return_dict,
)
logits = outputs["logits"]
answer_logits = logits[:, start_positions[0]:end_positions[0]+1]
answer_input_ids = input_ids[:, start_positions[0]:end_positions[0]+1]