InternLM/internlm/initialize/initialize_trainer.py

85 lines
3.6 KiB
Python
Raw Normal View History

2023-07-06 04:55:23 +00:00
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize
from typing import Callable, Iterable, Optional, Tuple
from torch import nn
from torch.nn.modules.loss import _Loss
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from internlm.core.context import global_context as gpc
from internlm.core.engine import Engine
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
from internlm.core.no_pipeline_scheduler import NonPipelineScheduler
from internlm.core.trainer import Trainer
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
from internlm.utils.common import get_current_device
def initialize_trainer(
model: nn.Module,
optimizer: Optimizer,
criterion: Optional[_Loss] = None,
train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Iterable] = None,
lr_scheduler: Optional[_LRScheduler] = None,
beta2_scheduler: Optional[Beta2Scheduler] = None,
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
Args:
model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
optimizer (:class:`BaseOptimizer`.
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
Returns:
Tuple (trainer, train_dataloader, test_dataloader, lr_scheduler):
A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``
where only ``trainer`` could not be None.
"""
if isinstance(model, nn.Module):
# first sync model across dp ranks
model.to(get_current_device())
elif isinstance(model, Callable):
model = model().to(get_current_device())
# clip grad norm
clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)
assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
# gradient handler, only support PipelineSharedModuleGradientHandler now
gradient_handler_cfg = gpc.config.get("gradient_handler", [])
gradient_handlers = []
assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
for config in gradient_handler_cfg:
if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
gradient_handlers.append(handler)
scheduler = NonPipelineScheduler(gradient_accumulation_size=gpc.config.data.gradient_accumulation)
engine = Engine(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
criterion=criterion,
gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm,
)
trainer = Trainer(engine, scheduler)
return trainer, train_dataloader, test_dataloader, lr_scheduler