From ea088b5f75e9c9a79d67b370286da2a1508688c8 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Wed, 10 Jan 2024 10:42:37 +0800 Subject: [PATCH] update train code --- applications/ColossalMoE/train_moe.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/applications/ColossalMoE/train_moe.py b/applications/ColossalMoE/train_moe.py index 1ce975e0a..b05d93fe0 100644 --- a/applications/ColossalMoE/train_moe.py +++ b/applications/ColossalMoE/train_moe.py @@ -78,13 +78,13 @@ def parse_args(): choices=["fp32", "bf16", "fp16"], help="The mixed precision training.", ) - parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.") + parser.add_argument("--max_length", type=int, default=4096, help="Max sequence length.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint") # optim parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") - parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.") # zero stage for all plugins parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") @@ -93,6 +93,7 @@ def parse_args(): parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin") parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin") parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin") + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory") parser.add_argument("--config_file", type=str, default="config_file", help="Config file") @@ -158,6 +159,7 @@ def main(): if args.plugin == "hybrid": plugin = MoeHybridParallelPlugin( pp_size=args.pp_size, + max_norm=args.grad_clip, microbatch_size=args.microbatch_size, **hybrid_dict, ) @@ -195,7 +197,7 @@ def main(): tokenizer.pad_token_id = tokenizer.eos_token_id - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=4096) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") dataloader = plugin.prepare_dataloader( dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=data_collator @@ -280,23 +282,6 @@ def main(): scalar_value=lr_scheduler.get_last_lr()[0], global_step=step, ) - - # import torch.distributed as dist - # new_loss = loss.clone().detach() - # new_aux_loss = aux_loss.clone().detach() - # dist.all_reduce(tensor=new_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group) - # dist.all_reduce(tensor=new_aux_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group) - # new_loss.div_(booster.plugin.dp_size) - # new_aux_loss.div_(booster.plugin.dp_size) - # if coordinator._local_rank == '0': - # pbar.set_postfix({"Loss": new_loss.item()}) - # writer.add_scalar(tag="Loss", scalar_value=new_loss.item(), global_step=step) - # writer.add_scalar(tag="Aux Loss", scalar_value=new_aux_loss.item(), global_step=step) - # writer.add_scalar( - # tag="Learning Rate", - # scalar_value=lr_scheduler.get_last_lr()[0], - # global_step=step, - # ) else: # Forward pass data = next(train_dataloader_iter)