mirror of https://github.com/hpcaitech/ColossalAI
205 lines
8.3 KiB
Python
205 lines
8.3 KiB
Python
import warnings
|
|
from types import MethodType
|
|
from typing import Callable, Optional, OrderedDict, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed import ProcessGroup
|
|
from torch.nn import Module
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
from torch.utils.data import DataLoader
|
|
|
|
from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|
HybridParallelAMPOptimizer,
|
|
HybridParallelModule,
|
|
HybridParallelNaiveOptimizer,
|
|
HybridParallelPlugin,
|
|
get_param_info,
|
|
reinitialize_optimizer,
|
|
)
|
|
from colossalai.checkpoint_io import MoECheckpointIO
|
|
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
|
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
|
|
|
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|
def __init__(
|
|
self,
|
|
optimizer: Optimizer,
|
|
model: Module,
|
|
use_pipeline: bool,
|
|
force_overlap_comm: bool, # force overlap comm
|
|
dp_process_group: ProcessGroup, # dp pg for comm
|
|
moe_dp_group: ProcessGroup, # moe dp pg for comm
|
|
param_info: OrderedDict,
|
|
initial_scale: int = 2**16, # grad scaler config
|
|
min_scale: int = 1,
|
|
growth_factor: float = 2.0,
|
|
backoff_factor: float = 0.5,
|
|
growth_interval: int = 2000,
|
|
hysteresis: int = 2,
|
|
max_scale: int = 2**24,
|
|
clip_grad_norm: float = 0.0, # grad clipping
|
|
verbose: bool = False,
|
|
reduce_bucket_size: int = 1024 * 1024, # communication
|
|
communication_dtype: Optional[torch.dtype] = None,
|
|
overlap_communication: bool = False,
|
|
partition_grad: bool = False, # stage 2 flag
|
|
cpu_offload: bool = False, # cpu offload
|
|
forced_dtype: Optional[torch.dtype] = None,
|
|
):
|
|
|
|
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
|
|
if not force_overlap_comm and (overlap_communication or partition_grad):
|
|
raise RuntimeError(WARN_STR + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True")
|
|
|
|
if force_overlap_comm:
|
|
overlap_communication = True
|
|
warnings.warn(WARN_STR + " Please make sure of this.")
|
|
|
|
self.param_info = param_info
|
|
self.stage_manager = model.stage_manager
|
|
self.shared_params = model.shared_params
|
|
self.dp_pg = dp_process_group
|
|
|
|
if use_pipeline:
|
|
reinitialize_optimizer(optimizer, model)
|
|
|
|
pg_param_list = {
|
|
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
|
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
|
}
|
|
|
|
super().__init__(
|
|
optimizer=optimizer,
|
|
pg_to_param_list=pg_param_list,
|
|
initial_scale=initial_scale,
|
|
min_scale=min_scale,
|
|
growth_factor=growth_factor,
|
|
backoff_factor=backoff_factor,
|
|
growth_interval=growth_interval,
|
|
hysteresis=hysteresis,
|
|
max_scale=max_scale,
|
|
clip_grad_norm=clip_grad_norm,
|
|
verbose=verbose,
|
|
reduce_bucket_size=reduce_bucket_size,
|
|
communication_dtype=communication_dtype,
|
|
overlap_communication=overlap_communication,
|
|
partition_grad=partition_grad,
|
|
cpu_offload=cpu_offload,
|
|
forced_dtype=forced_dtype,
|
|
)
|
|
|
|
|
|
class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|
"""
|
|
TODO: add docstring
|
|
"""
|
|
|
|
def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
|
if self.use_ddp:
|
|
warnings.warn(
|
|
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
|
|
)
|
|
self.ddp_config["find_unused_parameters"] = True
|
|
|
|
if moe_tp_size != 1:
|
|
raise NotImplementedError
|
|
|
|
world_size = dist.get_world_size()
|
|
|
|
self.moe_dp_size = world_size // (ep_size * moe_tp_size)
|
|
self.ep_size = ep_size
|
|
self.moe_tp_size = moe_tp_size
|
|
|
|
self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
|
self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
|
|
|
|
self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
|
|
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
|
|
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
|
|
|
|
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
|
|
|
|
# set ep_group after super init
|
|
# TODO do it in a better way
|
|
self.shard_config.ep_group = self.ep_group
|
|
|
|
self.force_overlap_comm = force_overlap_comm
|
|
|
|
def get_checkpoint_io(self) -> MoECheckpointIO:
|
|
return MoECheckpointIO(
|
|
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
|
)
|
|
|
|
def configure(
|
|
self,
|
|
model: Module,
|
|
optimizer: Optional[Optimizer] = None,
|
|
criterion: Optional[Callable] = None,
|
|
dataloader: Optional[DataLoader] = None,
|
|
lr_scheduler: Optional[LRScheduler] = None,
|
|
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
|
param_info = get_param_info(optimizer)
|
|
if not isinstance(model, ModelWrapper):
|
|
model = HybridParallelModule(
|
|
module=model,
|
|
precision=self.precision,
|
|
shard_config=self.shard_config,
|
|
dp_group=self.dp_group,
|
|
tp_group=self.tp_group,
|
|
sp_group=self.sp_group,
|
|
use_ddp=self.use_ddp,
|
|
ddp_config=self.ddp_config,
|
|
custom_policy=self.custom_policy,
|
|
)
|
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
|
if self.ep_size > 1:
|
|
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
|
|
# but the optimizer is not aware of ep, so we need to update the optimizer
|
|
reinitialize_optimizer(optimizer, model)
|
|
|
|
if self.zero_stage == 0:
|
|
if self.precision in ["fp16", "bf16"]:
|
|
optimizer = HybridParallelAMPOptimizer(
|
|
optimizer,
|
|
model,
|
|
use_pipeline=self.enable_pipeline_parallelism,
|
|
param_info=param_info,
|
|
precision=self.precision,
|
|
max_norm=self.max_norm,
|
|
**self.amp_config,
|
|
)
|
|
else:
|
|
optimizer = HybridParallelNaiveOptimizer(
|
|
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
|
)
|
|
else:
|
|
if not(self.dp_size > 1 or self.moe_dp_size > 1):
|
|
warnings.warn(
|
|
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
|
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
|
|
)
|
|
optimizer = MoeHybridParallelZeroOptimizer(
|
|
optimizer,
|
|
model,
|
|
use_pipeline=self.enable_pipeline_parallelism,
|
|
force_overlap_comm=self.force_overlap_comm,
|
|
param_info=param_info,
|
|
dp_process_group=self.dp_group,
|
|
moe_dp_group=self.moe_dp_group,
|
|
verbose=True,
|
|
clip_grad_norm=self.max_norm,
|
|
**self.zero_config,
|
|
**self.amp_config,
|
|
)
|
|
# inject update_master_params
|
|
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
|
|
|
return model, optimizer, criterion, dataloader, lr_scheduler
|