|
|
|
@ -10,6 +10,7 @@ from torch.utils.data import Dataset
|
|
|
|
|
from tqdm import tqdm |
|
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM |
|
|
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
from colossalai.booster import Booster |
|
|
|
@ -99,6 +100,10 @@ def parse_args():
|
|
|
|
|
# optim |
|
|
|
|
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") |
|
|
|
|
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") |
|
|
|
|
|
|
|
|
|
# lr scheduler |
|
|
|
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") |
|
|
|
|
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") |
|
|
|
|
|
|
|
|
|
# zero stage for all plugins |
|
|
|
|
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") |
|
|
|
@ -203,10 +208,25 @@ def main():
|
|
|
|
|
|
|
|
|
|
# Set optimizer |
|
|
|
|
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
|
|
|
|
|
|
|
|
|
# Set lr scheduler |
|
|
|
|
lr_scheduler = CosineAnnealingWarmupLR( |
|
|
|
|
optimizer=optimizer, |
|
|
|
|
total_steps=args.num_epochs * len(dataloader), |
|
|
|
|
warmup_steps=args.warmup_steps |
|
|
|
|
if args.warmup_steps is not None |
|
|
|
|
else int(args.num_epochs * len(dataloader) * 0.025), |
|
|
|
|
eta_min=0.1 * args.lr, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Set booster |
|
|
|
|
booster = Booster(plugin=plugin, **booster_kwargs) |
|
|
|
|
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) |
|
|
|
|
model, optimizer, _, dataloader, lr_scheduler = booster.boost( |
|
|
|
|
model=model, |
|
|
|
|
optimizer=optimizer, |
|
|
|
|
lr_scheduler=lr_scheduler, |
|
|
|
|
dataloader=dataloader, |
|
|
|
|
) |
|
|
|
|
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 |
|
|
|
|
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() |
|
|
|
|
pp_print_rank = is_pp_last_stage and (coordinator.local_rank == "0") |
|
|
|
@ -253,6 +273,7 @@ def main():
|
|
|
|
|
pbar.set_postfix({"loss": loss.item()}) |
|
|
|
|
|
|
|
|
|
optimizer.step() |
|
|
|
|
lr_scheduler.step() |
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
|
# Apply load balance |
|
|
|
|