mirror of https://github.com/InternLM/InternLM
85 lines
3.6 KiB
Python
85 lines
3.6 KiB
Python
![]() |
#!/usr/bin/env python
|
||
|
# -*- encoding: utf-8 -*-
|
||
|
|
||
|
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize
|
||
|
|
||
|
from typing import Callable, Iterable, Optional, Tuple
|
||
|
|
||
|
from torch import nn
|
||
|
from torch.nn.modules.loss import _Loss
|
||
|
from torch.optim.lr_scheduler import _LRScheduler
|
||
|
from torch.optim.optimizer import Optimizer
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
from internlm.core.context import global_context as gpc
|
||
|
from internlm.core.engine import Engine
|
||
|
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
||
|
from internlm.core.no_pipeline_scheduler import NonPipelineScheduler
|
||
|
from internlm.core.trainer import Trainer
|
||
|
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||
|
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
|
||
|
from internlm.utils.common import get_current_device
|
||
|
|
||
|
|
||
|
def initialize_trainer(
|
||
|
model: nn.Module,
|
||
|
optimizer: Optimizer,
|
||
|
criterion: Optional[_Loss] = None,
|
||
|
train_dataloader: Optional[Iterable] = None,
|
||
|
test_dataloader: Optional[Iterable] = None,
|
||
|
lr_scheduler: Optional[_LRScheduler] = None,
|
||
|
beta2_scheduler: Optional[Beta2Scheduler] = None,
|
||
|
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
|
||
|
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||
|
loaded into gpc.config.
|
||
|
|
||
|
Args:
|
||
|
model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
|
||
|
optimizer (:class:`BaseOptimizer`.
|
||
|
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
|
||
|
train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
|
||
|
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
|
||
|
lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
|
||
|
|
||
|
Returns:
|
||
|
Tuple (trainer, train_dataloader, test_dataloader, lr_scheduler):
|
||
|
A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``
|
||
|
where only ``trainer`` could not be None.
|
||
|
"""
|
||
|
|
||
|
if isinstance(model, nn.Module):
|
||
|
# first sync model across dp ranks
|
||
|
model.to(get_current_device())
|
||
|
elif isinstance(model, Callable):
|
||
|
model = model().to(get_current_device())
|
||
|
|
||
|
# clip grad norm
|
||
|
clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)
|
||
|
|
||
|
assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
|
||
|
|
||
|
# gradient handler, only support PipelineSharedModuleGradientHandler now
|
||
|
gradient_handler_cfg = gpc.config.get("gradient_handler", [])
|
||
|
gradient_handlers = []
|
||
|
assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
|
||
|
for config in gradient_handler_cfg:
|
||
|
if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
|
||
|
handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
|
||
|
gradient_handlers.append(handler)
|
||
|
|
||
|
scheduler = NonPipelineScheduler(gradient_accumulation_size=gpc.config.data.gradient_accumulation)
|
||
|
|
||
|
engine = Engine(
|
||
|
model=model,
|
||
|
optimizer=optimizer,
|
||
|
lr_scheduler=lr_scheduler,
|
||
|
beta2_scheduler=beta2_scheduler,
|
||
|
criterion=criterion,
|
||
|
gradient_handlers=gradient_handlers,
|
||
|
clip_grad_norm=clip_grad_norm,
|
||
|
)
|
||
|
|
||
|
trainer = Trainer(engine, scheduler)
|
||
|
|
||
|
return trainer, train_dataloader, test_dataloader, lr_scheduler
|