import argparse import contextlib import os import torch import torch.nn as nn from dataset.webtext import WebtextDataset from titans.model.gpt import GPTLMLoss import colossalai import colossalai.utils as utils from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.trainer import Trainer, hooks from colossalai.legacy.zero.init_ctx import ZeroInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn import LinearWarmupLR from colossalai.utils import is_using_pp from colossalai.utils.timer import MultiTimer def calc_local_model_size(model: torch.nn.Module): numel_per_device = 0 for p in model.parameters(): numel_per_device += p.numel() return numel_per_device VOCAB_SIZE = 50257 def main(): parser = argparse.ArgumentParser() parser.add_argument("--from_torch", default=False, action="store_true") parser.add_argument("--use_dummy_dataset", default=False, action="store_true") args = parser.parse_args() disable_existing_loggers() if args.from_torch: colossalai.launch_from_torch(config=args.config) else: colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42) logger = get_dist_logger() data_path = None if args.use_dummy_dataset else os.environ["DATA"] logger.info(f"Build data loader from path {data_path}", ranks=[0]) train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN) train_dataloader = utils.get_dataloader( train_ds, seed=42, batch_size=gpc.config.BATCH_SIZE, pin_memory=True, shuffle=True, drop_last=True ) logger.info("Build model", ranks=[0]) use_pipeline = is_using_pp() use_interleaved = hasattr(gpc.config.model, "num_chunks") use_zero3 = hasattr(gpc.config, "zero") ctx = contextlib.nullcontext() if use_zero3: ctx = ZeroInitContext( target_device=torch.cuda.current_device(), shard_strategy=gpc.config.zero.model_config.shard_strategy, shard_param=True, ) with ctx: model = gpc.config.model.pop("type")(**gpc.config.model) if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList): model = nn.ModuleList([model]) if use_zero3: numel = ctx.model_numel_tensor.item() else: numel = calc_local_model_size(model) tflop = ( numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN * gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024**4) ) criterion = getattr(gpc.config, "loss_fn", None) if criterion is not None: criterion = criterion.type() else: criterion = GPTLMLoss() logger.info("Build optimizer", ranks=[0]) optimizer = gpc.config.optimizer.pop("type")(model.parameters(), **gpc.config.optimizer) lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5) engine, train_dataloader, _, lr_scheduler = colossalai.initialize( model, optimizer, criterion, train_dataloader=train_dataloader, lr_scheduler=lr_scheduler ) global_batch_size = ( gpc.config.BATCH_SIZE * gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) ) logger.info(f"Init done, global batch size = {global_batch_size}", ranks=[0]) timier = MultiTimer() trainer = Trainer(engine=engine, logger=logger, timer=timier) hook_list = [ hooks.LossHook(), hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), hooks.LogMetricByEpochHook(logger), hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop), hooks.LogMetricByStepHook(), hooks.LogMemoryByEpochHook(logger), # hooks.LogMemoryByEpochHook(logger), # hooks.LogTimingByEpochHook(timer, logger), ] trainer.fit( train_dataloader=train_dataloader, epochs=gpc.config.NUM_EPOCHS, test_interval=1, hooks=hook_list, display_progress=True, return_output_label=False, ) if __name__ == "__main__": main()