fix: haru's sft being incompatible with the ColossalAI fine-tune script
This commit is contained in:
parent
5dbde00d27
commit
b79ac657a4
|
@ -4,9 +4,9 @@ import accelerate
|
||||||
import tqdm
|
import tqdm
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
import wandb
|
# import wandb
|
||||||
|
|
||||||
from dataset import TokenizedDataset, FeedbackDataset, SFTDataset
|
from .dataset import TokenizedDataset, FeedbackDataset, SFTDataset
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from transformers.modeling_outputs import CausalLMOutput
|
from transformers.modeling_outputs import CausalLMOutput
|
||||||
|
@ -33,6 +33,7 @@ def sft_forward(
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return_dict = True
|
return_dict = True
|
||||||
|
|
||||||
|
'''
|
||||||
outputs = self.transformer(
|
outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
@ -48,6 +49,21 @@ def sft_forward(
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
logits = self.lm_head(sequence_output)
|
logits = self.lm_head(sequence_output)
|
||||||
|
'''
|
||||||
|
outputs = self(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
use_cache=False,
|
||||||
|
# token_type_ids=token_type_ids,
|
||||||
|
# position_ids=position_ids,
|
||||||
|
# head_mask=head_mask,
|
||||||
|
# inputs_embeds=inputs_embeds,
|
||||||
|
# output_attentions=output_attentions,
|
||||||
|
# output_hidden_states=output_hidden_states,
|
||||||
|
# return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = outputs["logits"]
|
||||||
|
|
||||||
answer_logits = logits[:, start_positions[0]:end_positions[0]+1]
|
answer_logits = logits[:, start_positions[0]:end_positions[0]+1]
|
||||||
answer_input_ids = input_ids[:, start_positions[0]:end_positions[0]+1]
|
answer_input_ids = input_ids[:, start_positions[0]:end_positions[0]+1]
|
||||||
|
|
Loading…
Reference in New Issue