fix: haru's sft being incompatible with the ColossalAI fine-tune script
This commit is contained in:
parent
5dbde00d27
commit
b79ac657a4
|
@ -4,9 +4,9 @@ import accelerate
|
|||
import tqdm
|
||||
import time
|
||||
import argparse
|
||||
import wandb
|
||||
# import wandb
|
||||
|
||||
from dataset import TokenizedDataset, FeedbackDataset, SFTDataset
|
||||
from .dataset import TokenizedDataset, FeedbackDataset, SFTDataset
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.modeling_outputs import CausalLMOutput
|
||||
|
@ -33,6 +33,7 @@ def sft_forward(
|
|||
except AttributeError:
|
||||
return_dict = True
|
||||
|
||||
'''
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
@ -48,6 +49,21 @@ def sft_forward(
|
|||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.lm_head(sequence_output)
|
||||
'''
|
||||
outputs = self(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=False,
|
||||
# token_type_ids=token_type_ids,
|
||||
# position_ids=position_ids,
|
||||
# head_mask=head_mask,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# output_attentions=output_attentions,
|
||||
# output_hidden_states=output_hidden_states,
|
||||
# return_dict=return_dict,
|
||||
)
|
||||
|
||||
logits = outputs["logits"]
|
||||
|
||||
answer_logits = logits[:, start_positions[0]:end_positions[0]+1]
|
||||
answer_input_ids = input_ids[:, start_positions[0]:end_positions[0]+1]
|
||||
|
|
Loading…
Reference in New Issue