|
|
@ -1,6 +1,7 @@ |
|
|
|
import argparse |
|
|
|
import argparse |
|
|
|
import os |
|
|
|
import os |
|
|
|
import json |
|
|
|
import json |
|
|
|
|
|
|
|
from typing import Dict, Union, List |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
import torch |
|
|
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO |
|
|
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO |
|
|
@ -8,7 +9,7 @@ from colossal_moe.models.mixtral_layer import replace_moe_layer |
|
|
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy |
|
|
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy |
|
|
|
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint |
|
|
|
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint |
|
|
|
from tqdm import tqdm |
|
|
|
from tqdm import tqdm |
|
|
|
from transformers import AutoTokenizer |
|
|
|
from transformers import AutoTokenizer, PreTrainedModel |
|
|
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM |
|
|
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM |
|
|
|
import torch.distributed as dist |
|
|
|
import torch.distributed as dist |
|
|
|
|
|
|
|
|
|
|
@ -29,6 +30,35 @@ from colossal_moe.dataset.loader import ( |
|
|
|
DataCollatorForSupervisedDataset, |
|
|
|
DataCollatorForSupervisedDataset, |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_optimizer_grouped_parameters(model: Union[torch.nn.Module, PreTrainedModel], |
|
|
|
|
|
|
|
weight_decay: float = 0.0, |
|
|
|
|
|
|
|
disable_decay_names: List[str] = None |
|
|
|
|
|
|
|
) -> List[Dict[str, Union[List[torch.nn.Parameter], float]]]: |
|
|
|
|
|
|
|
if disable_decay_names is None: |
|
|
|
|
|
|
|
disable_decay_names = ["bias", "LayerNorm.weight"] |
|
|
|
|
|
|
|
optimizer_grouped_parameters = [ |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
"params": [ |
|
|
|
|
|
|
|
p for n, p in model.named_parameters() |
|
|
|
|
|
|
|
if p.requires_grad is True and not any( |
|
|
|
|
|
|
|
name.lower() in n.lower() for name in disable_decay_names |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
"weight_decay": weight_decay |
|
|
|
|
|
|
|
}, |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
"params": [ |
|
|
|
|
|
|
|
p for n, p in model.named_parameters() |
|
|
|
|
|
|
|
if p.requires_grad is True and any( |
|
|
|
|
|
|
|
name.lower() in n.lower() for name in disable_decay_names |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
], |
|
|
|
|
|
|
|
"weight_decay": 0.0 |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
] |
|
|
|
|
|
|
|
return optimizer_grouped_parameters |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
@torch.no_grad() |
|
|
|
def get_global_loss(loss, booster): |
|
|
|
def get_global_loss(loss, booster): |
|
|
|
global_loss = loss.clone().detach() |
|
|
|
global_loss = loss.clone().detach() |
|
|
@ -84,7 +114,7 @@ def parse_args(): |
|
|
|
|
|
|
|
|
|
|
|
# optim |
|
|
|
# optim |
|
|
|
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") |
|
|
|
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") |
|
|
|
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.") |
|
|
|
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") |
|
|
|
|
|
|
|
|
|
|
|
# zero stage for all plugins |
|
|
|
# zero stage for all plugins |
|
|
|
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") |
|
|
|
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") |
|
|
@ -194,7 +224,8 @@ def main(): |
|
|
|
|
|
|
|
|
|
|
|
# Prepare tokenizer and dataloader |
|
|
|
# Prepare tokenizer and dataloader |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
tokenizer.padding_side = "right" |
|
|
|
|
|
|
|
tokenizer.pad_token = tokenizer.unk_token |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) |
|
|
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) |
|
|
@ -205,10 +236,11 @@ def main(): |
|
|
|
|
|
|
|
|
|
|
|
# Set optimizer |
|
|
|
# Set optimizer |
|
|
|
optimizer = HybridAdam( |
|
|
|
optimizer = HybridAdam( |
|
|
|
model_params=model.parameters(), |
|
|
|
model_params=get_optimizer_grouped_parameters( |
|
|
|
|
|
|
|
model=model, weight_decay=args.weight_decay |
|
|
|
|
|
|
|
), |
|
|
|
lr=args.lr, |
|
|
|
lr=args.lr, |
|
|
|
betas=(0.9, 0.95), |
|
|
|
betas=(0.9, 0.95), |
|
|
|
weight_decay=args.weight_decay, |
|
|
|
|
|
|
|
adamw_mode=True, |
|
|
|
adamw_mode=True, |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|