diff --git a/applications/ColossalMoE/train_moe.py b/applications/ColossalMoE/train_moe.py index 243443cec..6bf9ca4e9 100644 --- a/applications/ColossalMoE/train_moe.py +++ b/applications/ColossalMoE/train_moe.py @@ -1,6 +1,7 @@ import argparse import os import json +from typing import Dict, Union, List import torch from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO @@ -8,7 +9,7 @@ from colossal_moe.models.mixtral_layer import replace_moe_layer from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint from tqdm import tqdm -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedModel from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM import torch.distributed as dist @@ -29,6 +30,35 @@ from colossal_moe.dataset.loader import ( DataCollatorForSupervisedDataset, ) + +def get_optimizer_grouped_parameters(model: Union[torch.nn.Module, PreTrainedModel], + weight_decay: float = 0.0, + disable_decay_names: List[str] = None + ) -> List[Dict[str, Union[List[torch.nn.Parameter], float]]]: + if disable_decay_names is None: + disable_decay_names = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in model.named_parameters() + if p.requires_grad is True and not any( + name.lower() in n.lower() for name in disable_decay_names + ) + ], + "weight_decay": weight_decay + }, + { + "params": [ + p for n, p in model.named_parameters() + if p.requires_grad is True and any( + name.lower() in n.lower() for name in disable_decay_names + ) + ], + "weight_decay": 0.0 + } + ] + return optimizer_grouped_parameters + @torch.no_grad() def get_global_loss(loss, booster): global_loss = loss.clone().detach() @@ -84,7 +114,7 @@ def parse_args(): # optim parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") - parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") # zero stage for all plugins parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") @@ -194,7 +224,8 @@ def main(): # Prepare tokenizer and dataloader tokenizer = AutoTokenizer.from_pretrained(args.model_name) - tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = "right" + tokenizer.pad_token = tokenizer.unk_token data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) @@ -205,10 +236,11 @@ def main(): # Set optimizer optimizer = HybridAdam( - model_params=model.parameters(), + model_params=get_optimizer_grouped_parameters( + model=model, weight_decay=args.weight_decay + ), lr=args.lr, betas=(0.9, 0.95), - weight_decay=args.weight_decay, adamw_mode=True, )