ColossalAI/examples/language/gpt/titans/train_gpt.py

122 lines
4.2 KiB
Python
Raw Normal View History

import argparse
2023-01-16 07:55:41 +00:00
import contextlib
import os
import torch
import torch.nn as nn
from dataset.webtext import WebtextDataset
2023-01-16 07:55:41 +00:00
from titans.model.gpt import GPTLMLoss
import colossalai
import colossalai.utils as utils
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.trainer import Trainer, hooks
from colossalai.legacy.zero.init_ctx import ZeroInitContext
2023-01-16 07:55:41 +00:00
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import LinearWarmupLR
from colossalai.utils import is_using_pp
2023-01-16 07:55:41 +00:00
from colossalai.utils.timer import MultiTimer
def calc_local_model_size(model: torch.nn.Module):
numel_per_device = 0
for p in model.parameters():
numel_per_device += p.numel()
return numel_per_device
VOCAB_SIZE = 50257
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--from_torch", default=False, action="store_true")
parser.add_argument("--use_dummy_dataset", default=False, action="store_true")
2023-01-16 07:55:41 +00:00
args = parser.parse_args()
disable_existing_loggers()
if args.from_torch:
colossalai.launch_from_torch(config=args.config)
else:
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
logger = get_dist_logger()
data_path = None if args.use_dummy_dataset else os.environ["DATA"]
logger.info(f"Build data loader from path {data_path}", ranks=[0])
2023-01-16 07:55:41 +00:00
train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN)
train_dataloader = utils.get_dataloader(
train_ds, seed=42, batch_size=gpc.config.BATCH_SIZE, pin_memory=True, shuffle=True, drop_last=True
)
logger.info("Build model", ranks=[0])
2023-01-16 07:55:41 +00:00
use_pipeline = is_using_pp()
use_interleaved = hasattr(gpc.config.model, "num_chunks")
use_zero3 = hasattr(gpc.config, "zero")
2023-01-16 07:55:41 +00:00
ctx = contextlib.nullcontext()
if use_zero3:
ctx = ZeroInitContext(
target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True,
)
2023-01-16 07:55:41 +00:00
with ctx:
model = gpc.config.model.pop("type")(**gpc.config.model)
2023-01-16 07:55:41 +00:00
if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList):
model = nn.ModuleList([model])
if use_zero3:
numel = ctx.model_numel_tensor.item()
else:
numel = calc_local_model_size(model)
tflop = (
numel
* gpc.config.BATCH_SIZE
* gpc.config.SEQ_LEN
* gpc.get_world_size(ParallelMode.MODEL)
* gpc.get_world_size(ParallelMode.DATA)
* 8
/ (1024**4)
)
criterion = getattr(gpc.config, "loss_fn", None)
2023-01-16 07:55:41 +00:00
if criterion is not None:
criterion = criterion.type()
else:
criterion = GPTLMLoss()
logger.info("Build optimizer", ranks=[0])
optimizer = gpc.config.optimizer.pop("type")(model.parameters(), **gpc.config.optimizer)
2023-01-16 07:55:41 +00:00
lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(
model, optimizer, criterion, train_dataloader=train_dataloader, lr_scheduler=lr_scheduler
)
global_batch_size = (
gpc.config.BATCH_SIZE * gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
)
logger.info(f"Init done, global batch size = {global_batch_size}", ranks=[0])
2023-01-16 07:55:41 +00:00
timier = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timier)
hook_list = [
hooks.LossHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
hooks.LogMetricByEpochHook(logger),
hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop),
hooks.LogMetricByStepHook(),
hooks.LogMemoryByEpochHook(logger),
# hooks.LogMemoryByEpochHook(logger),
# hooks.LogTimingByEpochHook(timer, logger),
2023-01-16 07:55:41 +00:00
]
trainer.fit(
train_dataloader=train_dataloader,
epochs=gpc.config.NUM_EPOCHS,
test_interval=1,
hooks=hook_list,
display_progress=True,
return_output_label=False,
)
2023-01-16 07:55:41 +00:00
if __name__ == "__main__":
2023-01-16 07:55:41 +00:00
main()