mirror of https://github.com/hpcaitech/ColossalAI
update
parent
dac240563c
commit
1d96a562bb
|
@ -1,6 +1,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
from typing import Dict, Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||||
|
@ -8,7 +9,7 @@ from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||||
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint
|
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer, PreTrainedModel
|
||||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
@ -29,6 +30,35 @@ from colossal_moe.dataset.loader import (
|
||||||
DataCollatorForSupervisedDataset,
|
DataCollatorForSupervisedDataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer_grouped_parameters(model: Union[torch.nn.Module, PreTrainedModel],
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
disable_decay_names: List[str] = None
|
||||||
|
) -> List[Dict[str, Union[List[torch.nn.Parameter], float]]]:
|
||||||
|
if disable_decay_names is None:
|
||||||
|
disable_decay_names = ["bias", "LayerNorm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p for n, p in model.named_parameters()
|
||||||
|
if p.requires_grad is True and not any(
|
||||||
|
name.lower() in n.lower() for name in disable_decay_names
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"weight_decay": weight_decay
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p for n, p in model.named_parameters()
|
||||||
|
if p.requires_grad is True and any(
|
||||||
|
name.lower() in n.lower() for name in disable_decay_names
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"weight_decay": 0.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_global_loss(loss, booster):
|
def get_global_loss(loss, booster):
|
||||||
global_loss = loss.clone().detach()
|
global_loss = loss.clone().detach()
|
||||||
|
@ -84,7 +114,7 @@ def parse_args():
|
||||||
|
|
||||||
# optim
|
# optim
|
||||||
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
|
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
|
||||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.")
|
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.")
|
||||||
|
|
||||||
# zero stage for all plugins
|
# zero stage for all plugins
|
||||||
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
|
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
|
||||||
|
@ -194,7 +224,8 @@ def main():
|
||||||
|
|
||||||
# Prepare tokenizer and dataloader
|
# Prepare tokenizer and dataloader
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.padding_side = "right"
|
||||||
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
|
|
||||||
|
|
||||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||||
|
@ -205,10 +236,11 @@ def main():
|
||||||
|
|
||||||
# Set optimizer
|
# Set optimizer
|
||||||
optimizer = HybridAdam(
|
optimizer = HybridAdam(
|
||||||
model_params=model.parameters(),
|
model_params=get_optimizer_grouped_parameters(
|
||||||
|
model=model, weight_decay=args.weight_decay
|
||||||
|
),
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
betas=(0.9, 0.95),
|
betas=(0.9, 0.95),
|
||||||
weight_decay=args.weight_decay,
|
|
||||||
adamw_mode=True,
|
adamw_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue