Browse Source

update

feat/moe
Tong Li 11 months ago
parent
commit
1d96a562bb
  1. 42
      applications/ColossalMoE/train_moe.py

42
applications/ColossalMoE/train_moe.py

@ -1,6 +1,7 @@
import argparse import argparse
import os import os
import json import json
from typing import Dict, Union, List
import torch import torch
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
@ -8,7 +9,7 @@ from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer, PreTrainedModel
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import torch.distributed as dist import torch.distributed as dist
@ -29,6 +30,35 @@ from colossal_moe.dataset.loader import (
DataCollatorForSupervisedDataset, DataCollatorForSupervisedDataset,
) )
def get_optimizer_grouped_parameters(model: Union[torch.nn.Module, PreTrainedModel],
weight_decay: float = 0.0,
disable_decay_names: List[str] = None
) -> List[Dict[str, Union[List[torch.nn.Parameter], float]]]:
if disable_decay_names is None:
disable_decay_names = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in model.named_parameters()
if p.requires_grad is True and not any(
name.lower() in n.lower() for name in disable_decay_names
)
],
"weight_decay": weight_decay
},
{
"params": [
p for n, p in model.named_parameters()
if p.requires_grad is True and any(
name.lower() in n.lower() for name in disable_decay_names
)
],
"weight_decay": 0.0
}
]
return optimizer_grouped_parameters
@torch.no_grad() @torch.no_grad()
def get_global_loss(loss, booster): def get_global_loss(loss, booster):
global_loss = loss.clone().detach() global_loss = loss.clone().detach()
@ -84,7 +114,7 @@ def parse_args():
# optim # optim
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.")
# zero stage for all plugins # zero stage for all plugins
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
@ -194,7 +224,8 @@ def main():
# Prepare tokenizer and dataloader # Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.unk_token
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
@ -205,10 +236,11 @@ def main():
# Set optimizer # Set optimizer
optimizer = HybridAdam( optimizer = HybridAdam(
model_params=model.parameters(), model_params=get_optimizer_grouped_parameters(
model=model, weight_decay=args.weight_decay
),
lr=args.lr, lr=args.lr,
betas=(0.9, 0.95), betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True, adamw_mode=True,
) )

Loading…
Cancel
Save