|
|
|
@ -10,6 +10,7 @@ from torch.utils.data import Dataset
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
|
|
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
|
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
|
from colossalai.booster import Booster
|
|
|
|
@ -100,6 +101,10 @@ def parse_args():
|
|
|
|
|
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
|
|
|
|
|
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
|
|
|
|
|
|
|
|
|
# lr scheduler
|
|
|
|
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
|
|
|
|
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
|
|
|
|
|
|
|
|
|
# zero stage for all plugins
|
|
|
|
|
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
|
|
|
|
|
# hybrid plugin
|
|
|
|
@ -204,9 +209,24 @@ def main():
|
|
|
|
|
# Set optimizer
|
|
|
|
|
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
|
|
|
|
|
|
|
|
|
# Set lr scheduler
|
|
|
|
|
lr_scheduler = CosineAnnealingWarmupLR(
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
total_steps=args.num_epochs * len(dataloader),
|
|
|
|
|
warmup_steps=args.warmup_steps
|
|
|
|
|
if args.warmup_steps is not None
|
|
|
|
|
else int(args.num_epochs * len(dataloader) * 0.025),
|
|
|
|
|
eta_min=0.1 * args.lr,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Set booster
|
|
|
|
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
|
|
|
|
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
|
|
|
|
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
|
|
|
|
model=model,
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
lr_scheduler=lr_scheduler,
|
|
|
|
|
dataloader=dataloader,
|
|
|
|
|
)
|
|
|
|
|
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
|
|
|
|
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
|
|
|
|
pp_print_rank = is_pp_last_stage and (coordinator.local_rank == "0")
|
|
|
|
@ -253,6 +273,7 @@ def main():
|
|
|
|
|
pbar.set_postfix({"loss": loss.item()})
|
|
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
lr_scheduler.step()
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
# Apply load balance
|
|
|
|
|