mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
4.2 KiB
121 lines
4.2 KiB
import argparse |
|
import contextlib |
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
from dataset.webtext import WebtextDataset |
|
from titans.model.gpt import GPTLMLoss |
|
|
|
import colossalai |
|
import colossalai.utils as utils |
|
from colossalai.legacy.context.parallel_mode import ParallelMode |
|
from colossalai.legacy.core import global_context as gpc |
|
from colossalai.legacy.trainer import Trainer, hooks |
|
from colossalai.legacy.zero.init_ctx import ZeroInitContext |
|
from colossalai.logging import disable_existing_loggers, get_dist_logger |
|
from colossalai.nn import LinearWarmupLR |
|
from colossalai.utils import is_using_pp |
|
from colossalai.utils.timer import MultiTimer |
|
|
|
|
|
def calc_local_model_size(model: torch.nn.Module): |
|
numel_per_device = 0 |
|
for p in model.parameters(): |
|
numel_per_device += p.numel() |
|
return numel_per_device |
|
|
|
|
|
VOCAB_SIZE = 50257 |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--from_torch", default=False, action="store_true") |
|
parser.add_argument("--use_dummy_dataset", default=False, action="store_true") |
|
args = parser.parse_args() |
|
disable_existing_loggers() |
|
if args.from_torch: |
|
colossalai.launch_from_torch(config=args.config) |
|
else: |
|
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42) |
|
logger = get_dist_logger() |
|
|
|
data_path = None if args.use_dummy_dataset else os.environ["DATA"] |
|
logger.info(f"Build data loader from path {data_path}", ranks=[0]) |
|
|
|
train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN) |
|
train_dataloader = utils.get_dataloader( |
|
train_ds, seed=42, batch_size=gpc.config.BATCH_SIZE, pin_memory=True, shuffle=True, drop_last=True |
|
) |
|
|
|
logger.info("Build model", ranks=[0]) |
|
use_pipeline = is_using_pp() |
|
use_interleaved = hasattr(gpc.config.model, "num_chunks") |
|
use_zero3 = hasattr(gpc.config, "zero") |
|
ctx = contextlib.nullcontext() |
|
if use_zero3: |
|
ctx = ZeroInitContext( |
|
target_device=torch.cuda.current_device(), |
|
shard_strategy=gpc.config.zero.model_config.shard_strategy, |
|
shard_param=True, |
|
) |
|
with ctx: |
|
model = gpc.config.model.pop("type")(**gpc.config.model) |
|
if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList): |
|
model = nn.ModuleList([model]) |
|
|
|
if use_zero3: |
|
numel = ctx.model_numel_tensor.item() |
|
else: |
|
numel = calc_local_model_size(model) |
|
|
|
tflop = ( |
|
numel |
|
* gpc.config.BATCH_SIZE |
|
* gpc.config.SEQ_LEN |
|
* gpc.get_world_size(ParallelMode.MODEL) |
|
* gpc.get_world_size(ParallelMode.DATA) |
|
* 8 |
|
/ (1024**4) |
|
) |
|
|
|
criterion = getattr(gpc.config, "loss_fn", None) |
|
if criterion is not None: |
|
criterion = criterion.type() |
|
else: |
|
criterion = GPTLMLoss() |
|
logger.info("Build optimizer", ranks=[0]) |
|
optimizer = gpc.config.optimizer.pop("type")(model.parameters(), **gpc.config.optimizer) |
|
lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5) |
|
engine, train_dataloader, _, lr_scheduler = colossalai.initialize( |
|
model, optimizer, criterion, train_dataloader=train_dataloader, lr_scheduler=lr_scheduler |
|
) |
|
global_batch_size = ( |
|
gpc.config.BATCH_SIZE * gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) |
|
) |
|
logger.info(f"Init done, global batch size = {global_batch_size}", ranks=[0]) |
|
timier = MultiTimer() |
|
trainer = Trainer(engine=engine, logger=logger, timer=timier) |
|
hook_list = [ |
|
hooks.LossHook(), |
|
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True), |
|
hooks.LogMetricByEpochHook(logger), |
|
hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop), |
|
hooks.LogMetricByStepHook(), |
|
hooks.LogMemoryByEpochHook(logger), |
|
# hooks.LogMemoryByEpochHook(logger), |
|
# hooks.LogTimingByEpochHook(timer, logger), |
|
] |
|
trainer.fit( |
|
train_dataloader=train_dataloader, |
|
epochs=gpc.config.NUM_EPOCHS, |
|
test_interval=1, |
|
hooks=hook_list, |
|
display_progress=True, |
|
return_output_label=False, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main()
|
|
|