mirror of https://github.com/hpcaitech/ColossalAI
update
parent
ccad7014c6
commit
0bb317d9e6
|
@ -10,6 +10,7 @@ from torch.utils.data import Dataset
|
|||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
|
@ -100,6 +101,10 @@ def parse_args():
|
|||
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
|
||||
# lr scheduler
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
|
||||
# zero stage for all plugins
|
||||
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
|
||||
# hybrid plugin
|
||||
|
@ -204,9 +209,24 @@ def main():
|
|||
# Set optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
|
||||
# Set lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer,
|
||||
total_steps=args.num_epochs * len(dataloader),
|
||||
warmup_steps=args.warmup_steps
|
||||
if args.warmup_steps is not None
|
||||
else int(args.num_epochs * len(dataloader) * 0.025),
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=dataloader,
|
||||
)
|
||||
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
pp_print_rank = is_pp_last_stage and (coordinator.local_rank == "0")
|
||||
|
@ -253,6 +273,7 @@ def main():
|
|||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Apply load balance
|
||||
|
|
Loading…
Reference in New Issue