Browse Source

update

feat/moe
Tong Li 11 months ago
parent
commit
1d96a562bb
  1. 42
      applications/ColossalMoE/train_moe.py

42
applications/ColossalMoE/train_moe.py

@ -1,6 +1,7 @@
import argparse
import os
import json
from typing import Dict, Union, List
import torch
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
@ -8,7 +9,7 @@ from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoTokenizer, PreTrainedModel
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import torch.distributed as dist
@ -29,6 +30,35 @@ from colossal_moe.dataset.loader import (
DataCollatorForSupervisedDataset,
)
def get_optimizer_grouped_parameters(model: Union[torch.nn.Module, PreTrainedModel],
weight_decay: float = 0.0,
disable_decay_names: List[str] = None
) -> List[Dict[str, Union[List[torch.nn.Parameter], float]]]:
if disable_decay_names is None:
disable_decay_names = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in model.named_parameters()
if p.requires_grad is True and not any(
name.lower() in n.lower() for name in disable_decay_names
)
],
"weight_decay": weight_decay
},
{
"params": [
p for n, p in model.named_parameters()
if p.requires_grad is True and any(
name.lower() in n.lower() for name in disable_decay_names
)
],
"weight_decay": 0.0
}
]
return optimizer_grouped_parameters
@torch.no_grad()
def get_global_loss(loss, booster):
global_loss = loss.clone().detach()
@ -84,7 +114,7 @@ def parse_args():
# optim
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.")
# zero stage for all plugins
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
@ -194,7 +224,8 @@ def main():
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.unk_token
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
@ -205,10 +236,11 @@ def main():
# Set optimizer
optimizer = HybridAdam(
model_params=model.parameters(),
model_params=get_optimizer_grouped_parameters(
model=model, weight_decay=args.weight_decay
),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)

Loading…
Cancel
Save