pull/5190/head
Xuanlei Zhao 2023-12-29 17:28:46 +08:00
parent ccad7014c6
commit 0bb317d9e6
1 changed files with 22 additions and 1 deletions

View File

@ -10,6 +10,7 @@ from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
@ -100,6 +101,10 @@ def parse_args():
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
# lr scheduler
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
# zero stage for all plugins # zero stage for all plugins
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
# hybrid plugin # hybrid plugin
@ -204,9 +209,24 @@ def main():
# Set optimizer # Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Set lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * len(dataloader),
warmup_steps=args.warmup_steps
if args.warmup_steps is not None
else int(args.num_epochs * len(dataloader) * 0.025),
eta_min=0.1 * args.lr,
)
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
pp_print_rank = is_pp_last_stage and (coordinator.local_rank == "0") pp_print_rank = is_pp_last_stage and (coordinator.local_rank == "0")
@ -253,6 +273,7 @@ def main():
pbar.set_postfix({"loss": loss.item()}) pbar.set_postfix({"loss": loss.item()})
optimizer.step() optimizer.step()
lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
# Apply load balance # Apply load balance