|
|
|
@ -93,25 +93,27 @@ def train(args):
|
|
|
|
|
elif 'alpaca' in args.dataset: |
|
|
|
|
train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.dataset) |
|
|
|
|
eval_dataset = None |
|
|
|
|
eval_dataset |
|
|
|
|
data_collator = AlpacaDataCollator(tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
|
if dist.is_initialized() and dist.get_world_size() > 1: |
|
|
|
|
sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True) |
|
|
|
|
logger.info("Using Distributed Sampler") |
|
|
|
|
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True) |
|
|
|
|
if eval_dataset is not None: |
|
|
|
|
eval_sampler = DistributedSampler(eval_dataset, shuffle=False, seed=42, drop_last=False) |
|
|
|
|
else: |
|
|
|
|
sampler = None |
|
|
|
|
train_sampler = None |
|
|
|
|
eval_sampler = None |
|
|
|
|
|
|
|
|
|
train_dataloader = DataLoader(train_dataset, shuffle=(sampler is None), sampler=sampler, batch_size=args.batch_size) |
|
|
|
|
train_dataloader = DataLoader(train_dataset, shuffle=(train_sampler is None), sampler=train_sampler, batch_size=args.batch_size, collate_fn=data_collator) |
|
|
|
|
if eval_dataset is not None: |
|
|
|
|
eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size) |
|
|
|
|
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, collate_fn=data_collator) |
|
|
|
|
else: |
|
|
|
|
eval_dataloader = None |
|
|
|
|
|
|
|
|
|
trainer = SFTTrainer(model=model, |
|
|
|
|
strategy=strategy, |
|
|
|
|
optim=optim, |
|
|
|
|
train_dataloader=train_dataloader, |
|
|
|
|
eval_dataloader=eval_dataloader, |
|
|
|
|
sampler=sampler, |
|
|
|
|
batch_size=args.batch_size, |
|
|
|
|
max_epochs=args.max_epochs) |
|
|
|
|
|
|
|
|
@ -128,7 +130,7 @@ if __name__ == '__main__':
|
|
|
|
|
parser.add_argument('--strategy', |
|
|
|
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], |
|
|
|
|
default='naive') |
|
|
|
|
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom') |
|
|
|
|
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') |
|
|
|
|
parser.add_argument('--pretrain', type=str, default=None) |
|
|
|
|
parser.add_argument('--dataset', type=str, default='yizhongw/self_instruct') |
|
|
|
|
parser.add_argument('--save_path', type=str, default='sft_ckpt.pth') |
|
|
|
|