Browse Source

update

pull/5190/head
Xuanlei Zhao 11 months ago
parent
commit
0bb317d9e6
  1. 23
      applications/ColossalMoE/train.py

23
applications/ColossalMoE/train.py

@ -10,6 +10,7 @@ from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
import colossalai
from colossalai.booster import Booster
@ -99,6 +100,10 @@ def parse_args():
# optim
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
# lr scheduler
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
# zero stage for all plugins
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
@ -203,10 +208,25 @@ def main():
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Set lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * len(dataloader),
warmup_steps=args.warmup_steps
if args.warmup_steps is not None
else int(args.num_epochs * len(dataloader) * 0.025),
eta_min=0.1 * args.lr,
)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
pp_print_rank = is_pp_last_stage and (coordinator.local_rank == "0")
@ -253,6 +273,7 @@ def main():
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Apply load balance

Loading…
Cancel
Save