diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py index b2bdd4252..701a495c2 100644 --- a/applications/ColossalMoE/train.py +++ b/applications/ColossalMoE/train.py @@ -10,6 +10,7 @@ from torch.utils.data import Dataset from tqdm import tqdm from transformers import AutoTokenizer from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR import colossalai from colossalai.booster import Booster @@ -99,6 +100,10 @@ def parse_args(): # optim parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + + # lr scheduler + parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") # zero stage for all plugins parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") @@ -203,10 +208,25 @@ def main(): # Set optimizer optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + # Set lr scheduler + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, + total_steps=args.num_epochs * len(dataloader), + warmup_steps=args.warmup_steps + if args.warmup_steps is not None + else int(args.num_epochs * len(dataloader) * 0.025), + eta_min=0.1 * args.lr, + ) # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dataloader=dataloader, + ) use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() pp_print_rank = is_pp_last_stage and (coordinator.local_rank == "0") @@ -253,6 +273,7 @@ def main(): pbar.set_postfix({"loss": loss.item()}) optimizer.step() + lr_scheduler.step() optimizer.zero_grad() # Apply load balance