fix: fix sft (#3568)

pull/3586/head
tingfeng cao 2 years ago committed by GitHub
parent 6e7e43c6fe
commit 7788e0b0a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -53,29 +53,25 @@ class SFTDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None: def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
super().__init__() super().__init__()
# self.prompts = []
self.input_ids = [] self.input_ids = []
for data in tqdm(dataset, disable=not is_rank_0()): for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt'] + data['completion'] + "<|endoftext|>" prompt = data['prompt'] + data['completion'] + tokenizer.eos_token
prompt_token = tokenizer(prompt, prompt_token = tokenizer(prompt,
max_length=max_length, max_length=max_length,
padding="max_length", padding="max_length",
truncation=True, truncation=True,
return_tensors="pt") return_tensors="pt")
# self.prompts.append(prompt_token)s self.input_ids.append(prompt_token['input_ids'][0])
self.input_ids.append(prompt_token) self.labels = copy.deepcopy(self.input_ids)
self.labels = copy.deepcopy(self.input_ids)
def __len__(self): def __len__(self):
length = len(self.prompts) length = len(self.input_ids)
return length return length
def __getitem__(self, idx): def __getitem__(self, idx):
# dict(input_ids=self.input_ids[i], labels=self.labels[i])
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
# return dict(self.prompts[idx], self.prompts[idx])
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict: def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:

@ -96,7 +96,7 @@ class SFTTrainer(ABC):
loss = outputs.loss loss = outputs.loss
prompt_logits = outputs.logits prompt_logits = outputs.logits
if loss >= 2.5: if loss >= 2.5 and is_rank_0():
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}") logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
loss = loss / self.accimulation_steps loss = loss / self.accimulation_steps
@ -110,12 +110,13 @@ class SFTTrainer(ABC):
self.strategy.optimizer_step(self.optimizer) self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.scheduler.step() self.scheduler.step()
wandb.log({ if is_rank_0():
"loss": total_loss / self.accimulation_steps, wandb.log({
"lr": self.scheduler.get_last_lr()[0], "loss": total_loss / self.accimulation_steps,
"epoch": epoch, "lr": self.scheduler.get_last_lr()[0],
"batch_id": batch_id "epoch": epoch,
}) "batch_id": batch_id
})
total_loss = 0 total_loss = 0
step_bar.update() step_bar.update()

@ -111,7 +111,7 @@ def train(args):
max_datasets_size=args.max_datasets_size, max_datasets_size=args.max_datasets_size,
max_length=max_len) max_length=max_len)
eval_dataset = None eval_dataset = None
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset, train_sampler = DistributedSampler(train_dataset,

Loading…
Cancel
Save