mirror of https://github.com/hpcaitech/ColossalAI
[zero] solve hang
parent
0fad23c691
commit
46c069b0db
|
@ -1058,17 +1058,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
if dp_outside:
|
||||
(
|
||||
self.dp_axis,
|
||||
self.pp_axis,
|
||||
self.tp_axis,
|
||||
self.sp_axis,
|
||||
) = (
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
)
|
||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
else:
|
||||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import random
|
||||
import warnings
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, OrderedDict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
@ -11,7 +9,6 @@ from torch.nn import Module
|
|||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||
HybridParallelAMPOptimizer,
|
||||
|
@ -22,13 +19,8 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|||
reinitialize_optimizer,
|
||||
)
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
|
@ -39,6 +31,8 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
optimizer: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
dp_process_group: ProcessGroup, # the dp pg for comm
|
||||
moe_dp_group: ProcessGroup, # the moe dp pg for gomm
|
||||
param_info: OrderedDict,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
|
@ -54,30 +48,20 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
overlap_communication: bool = True,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None,
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
|
||||
):
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.dp_pg = dp_process_group
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
|
||||
if use_pipeline:
|
||||
reinitialize_optimizer(optimizer, model)
|
||||
|
||||
pg_param_list = {
|
||||
dp_process_group: [],
|
||||
moe_extra_dp_process_group: [],
|
||||
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
||||
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
||||
}
|
||||
for param in model.parameters():
|
||||
if is_moe_tensor(param):
|
||||
pg_param_list[moe_extra_dp_process_group].append(param)
|
||||
else:
|
||||
pg_param_list[dp_process_group].append(param)
|
||||
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
|
@ -102,285 +86,43 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
|
||||
class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
"""
|
||||
Plugin for Moe Hybrid Parallel Training.
|
||||
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
|
||||
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
|
||||
|
||||
Example:
|
||||
>>> from colossalai.booster import Booster
|
||||
>>> from colossalai.booster.plugin import HybridParallelPlugin
|
||||
|
||||
>>> model, train_dataset, optimizer, criterion = ...
|
||||
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
|
||||
|
||||
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||
>>> booster = Booster(plugin=plugin)
|
||||
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
|
||||
|
||||
Args:
|
||||
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
|
||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||
precision (str, optional): Specifies the precision of parameters during training.
|
||||
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
|
||||
Defaults to 'fp16'.
|
||||
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
|
||||
When set to 0, ZeRO will not be used. Defaults to 0.
|
||||
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
|
||||
Currently all the optimization methods include fused normalization, flash attention and JIT.
|
||||
Defaults to False.
|
||||
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
|
||||
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
|
||||
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
|
||||
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
|
||||
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
|
||||
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
|
||||
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
|
||||
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
|
||||
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
|
||||
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
|
||||
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
|
||||
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
|
||||
ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
|
||||
find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
|
||||
check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
|
||||
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
|
||||
static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
|
||||
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
|
||||
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
|
||||
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
|
||||
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
|
||||
use_ep_inside (bool, Optional): Whether to use ep inside dp (intra-node) for moe params.
|
||||
TODO: add docstring
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pp_size: int,
|
||||
ep_size: int,
|
||||
tp_size: int = 1,
|
||||
sp_size: int = 1,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
broadcast_buffers: bool = True,
|
||||
ddp_bucket_cap_mb: int = 25,
|
||||
find_unused_parameters: bool = False,
|
||||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
zero_bucket_size_in_m: int = 12,
|
||||
cpu_offload: bool = False,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
use_ep_inside: bool = True,
|
||||
custom_policy: Policy = None,
|
||||
checkpoint_io: Optional[MoECheckpointIO] = None,
|
||||
) -> None:
|
||||
world_size = dist.get_world_size()
|
||||
assert tp_size == 1, "Tensor parallel is not supported in MoE yet"
|
||||
assert sp_size == 1 and enable_sequence_parallelism is False, "Sequence parallelism it not supported in MoE yet"
|
||||
def __init__(self, ep_size: int, ep_tp_size: int = 1, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
assert (
|
||||
world_size % (tp_size * pp_size) == 0
|
||||
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
assert (
|
||||
world_size % (tp_size * pp_size * ep_size) == 0
|
||||
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
|
||||
|
||||
self.dp_size = world_size // (tp_size * pp_size)
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.ep_size = ep_size
|
||||
self.sp_size = sp_size
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
self.enable_all_optimization = enable_all_optimization
|
||||
self.enable_fused_normalization = enable_fused_normalization
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.checkpoint_io = checkpoint_io
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# NOTE: Two process meshes: global dp for non-moe param; dp + ep for moe param
|
||||
# See https://hpc-ai.com/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient
|
||||
# we change pg mesh to (pp, dp, tp) for better moe performance
|
||||
assert (
|
||||
self.ep_size <= self.dp_size
|
||||
), f"Not enough devices({self.dp_size}) for expert parallelism size({self.ep_size})."
|
||||
|
||||
self.moe_dp_size = self.dp_size // self.ep_size
|
||||
self.use_ep_inside = use_ep_inside
|
||||
if self.use_ep_inside:
|
||||
logger.info(f"MoE Parallel use ep inside dp.", ranks=[0])
|
||||
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size)
|
||||
else:
|
||||
logger.info(f"MoE Parallel use ep outside dp.", ranks=[0])
|
||||
warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.")
|
||||
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size)
|
||||
|
||||
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
||||
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
|
||||
logger.info(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}", ranks=[0])
|
||||
logger.info(
|
||||
f"MoE Parallel: pp {self.pp_size}, ep {ep_size}, moe dp {self.moe_dp_size}, tp {tp_size}", ranks=[0]
|
||||
)
|
||||
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(
|
||||
self.tp_axis
|
||||
) # TODO: support custom tp size for mixtral lm head
|
||||
self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis))
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
|
||||
# TODO: Currently moe only support partially sequence parallel
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
|
||||
self.custom_policy = custom_policy
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis)
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
|
||||
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
if self.use_ddp:
|
||||
warnings.warn(
|
||||
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
|
||||
)
|
||||
self.ddp_config["find_unused_parameters"] = True
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
enable_tensor_parallelism=self.tp_size > 1,
|
||||
enable_all_optimization=self.enable_all_optimization,
|
||||
enable_fused_normalization=self.enable_fused_normalization,
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
ep_group=self.ep_group,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
if ep_tp_size != 1:
|
||||
raise NotImplementedError
|
||||
|
||||
self.ddp_config = dict(
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
bucket_cap_mb=ddp_bucket_cap_mb,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
check_reduction=check_reduction,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
self.zero_config = dict(
|
||||
reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload,
|
||||
partition_grad=(self.zero_stage == 2),
|
||||
)
|
||||
self.moe_dp_size = world_size // (ep_size * ep_tp_size)
|
||||
self.ep_size = ep_size
|
||||
self.moe_tp_size = ep_tp_size
|
||||
|
||||
self.max_norm = max_norm
|
||||
self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
||||
self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
|
||||
self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
|
||||
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
|
||||
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
|
||||
|
||||
|
||||
Args:
|
||||
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=self.dp_size,
|
||||
rank=dist.get_rank(self.global_dp_group),
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
# set ep_group after super init
|
||||
# TODO do it in a better way
|
||||
self.shard_config.ep_group = self.ep_group
|
||||
|
||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||
if self.checkpoint_io is None:
|
||||
self.checkpoint_io = MoECheckpointIO(
|
||||
self.global_dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
||||
)
|
||||
else:
|
||||
self.checkpoint_io = self.checkpoint_io(
|
||||
self.global_dp_group,
|
||||
self.pp_group,
|
||||
self.tp_group,
|
||||
ep_group=self.ep_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
zero_stage=self.zero_stage,
|
||||
)
|
||||
if hasattr(self.checkpoint_io, "moe_info"):
|
||||
self.checkpoint_io.moe_info = self.moe_info
|
||||
return self.checkpoint_io
|
||||
return MoECheckpointIO(
|
||||
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
||||
)
|
||||
|
||||
def configure(
|
||||
self,
|
||||
|
@ -392,15 +134,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
param_info = get_param_info(optimizer)
|
||||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
model = HybridParallelModule(
|
||||
module=model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=self.global_dp_group,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=use_ddp, # TODO fix why this failed
|
||||
use_ddp=self.use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
)
|
||||
|
@ -411,8 +152,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
reinitialize_optimizer(optimizer, model)
|
||||
|
||||
if self.zero_stage == 0:
|
||||
# assert self.ep_size > 1
|
||||
|
||||
if self.precision in ["fp16", "bf16"]:
|
||||
optimizer = HybridParallelAMPOptimizer(
|
||||
optimizer,
|
||||
|
@ -435,10 +174,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.global_dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
moe_extra_dp_process_group=self.moe_dp_group,
|
||||
dp_process_group=self.dp_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.zero_config,
|
||||
|
|
|
@ -137,7 +137,7 @@ class ProcessGroupMesh:
|
|||
assert mode in ["raise", "wrap", "clip"]
|
||||
return int(np.ravel_multi_index(coord, shape, mode))
|
||||
|
||||
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
|
||||
def _get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
|
||||
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
|
||||
|
||||
Args:
|
||||
|
@ -240,7 +240,7 @@ class ProcessGroupMesh:
|
|||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
|
||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||
group = self.get_group(ranks_in_group, backend=backend)
|
||||
group = self._get_group(ranks_in_group, backend=backend)
|
||||
if self._rank in ranks_in_group:
|
||||
target_group = group
|
||||
return target_group
|
||||
|
|
|
@ -393,4 +393,7 @@ def all_to_all_uneven(
|
|||
group=None,
|
||||
overlap: bool = False,
|
||||
):
|
||||
assert (
|
||||
inputs.requires_grad
|
||||
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
|
|
|
@ -101,20 +101,18 @@ class MixtralPolicy(Policy):
|
|||
# )
|
||||
|
||||
if getattr(self.shard_config, "ep_group", None) is None:
|
||||
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
|
||||
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
kwargs={"ep_group": self.shard_config.ep_group},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
kwargs={"ep_group": self.shard_config.ep_group},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
|
@ -144,6 +142,7 @@ class MixtralPolicy(Policy):
|
|||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.")
|
||||
self.shard_config.enable_flash_attention = False
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ class BucketStore(BaseStore):
|
|||
|
||||
return self._grad_in_bucket
|
||||
|
||||
def get_flatten_grad(self) -> Tensor:
|
||||
def get_flatten_grad(self, dtype=None) -> Tensor:
|
||||
"""Return the flattened gradients slices in the bucket, the data organization of the flattened tensor:
|
||||
[grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....]
|
||||
|
||||
|
@ -110,8 +110,12 @@ class BucketStore(BaseStore):
|
|||
|
||||
flat_grad = []
|
||||
for grad_list in self._grad_in_bucket.values():
|
||||
flat_grad.append(_flatten_dense_tensors(grad_list))
|
||||
flat_grad = _flatten_dense_tensors(flat_grad)
|
||||
if len(grad_list) > 0:
|
||||
flat_grad.append(_flatten_dense_tensors(grad_list))
|
||||
if len(flat_grad) > 0:
|
||||
flat_grad = _flatten_dense_tensors(flat_grad)
|
||||
else:
|
||||
flat_grad = torch.tensor([], device=self.comm_stream.device, dtype=dtype)
|
||||
return flat_grad
|
||||
|
||||
def get_param_id_of_grad(self, grad: Tensor) -> int:
|
||||
|
|
|
@ -91,7 +91,7 @@ class GradientStore(BaseStore):
|
|||
|
||||
return grad_list
|
||||
|
||||
def get_working_grad_by_param_id(self, param_id) -> Tensor:
|
||||
def get_working_grad_by_param_id(self, param_id) -> Optional[Tensor]:
|
||||
"""
|
||||
Return the working gradient for the specified parameter.
|
||||
|
||||
|
|
|
@ -301,12 +301,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
def _run_reduction(self):
|
||||
for bucket_store in self.pg_to_bucket_store.values():
|
||||
if bucket_store.num_elements_in_bucket() <= 0:
|
||||
continue
|
||||
|
||||
bucket_store.build_grad_in_bucket()
|
||||
|
||||
flat_grads = bucket_store.get_flatten_grad()
|
||||
flat_grads = bucket_store.get_flatten_grad(self._dtype)
|
||||
flat_grads /= bucket_store.world_size
|
||||
|
||||
# ready to add other tensors to bucket
|
||||
|
@ -353,6 +350,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
self, bucket_store: BucketStore, origin_grad_list: List, flat_grad_list: List, group_id: int
|
||||
) -> None:
|
||||
for rank, grad_list in enumerate(origin_grad_list):
|
||||
if len(grad_list) == 0:
|
||||
continue
|
||||
sync_tensor(flat_grad_list[rank], grad_list)
|
||||
for grad in grad_list:
|
||||
param_id = bucket_store.get_param_id_of_grad(grad)
|
||||
|
@ -869,12 +868,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
def get_param_grad(self, working_param: nn.Parameter) -> Tensor:
|
||||
grad_store = self.pid_to_grad_store[id(working_param)]
|
||||
partial_grad = grad_store.get_working_grad_by_param_id(id(working_param))
|
||||
if partial_grad is None:
|
||||
grad = grad_store.get_working_grad_by_param_id(id(working_param))
|
||||
if grad is None:
|
||||
return None
|
||||
tensor_list = [torch.empty_like(partial_grad) for _ in range(grad_store.world_size)]
|
||||
dist.all_gather(tensor_list, partial_grad, group=grad_store.torch_pg)
|
||||
grad_flat = torch.cat(tensor_list, dim=0)
|
||||
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
|
||||
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
|
||||
return grad_flat[: working_param.numel()].reshape_as(working_param)
|
||||
|
||||
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
|
||||
|
|
|
@ -19,7 +19,7 @@ def data_gen():
|
|||
# tokenized_input = tokenizer([input], return_tensors="pt")
|
||||
# input_ids = tokenized_input['input_ids']
|
||||
# attention_mask = tokenized_input['attention_mask']
|
||||
input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64)
|
||||
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
@ -43,7 +43,7 @@ def data_gen_for_sequence_classification():
|
|||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss function
|
||||
loss_fn_for_mixtral_model = lambda x: torch.nn.functional.mse_loss(x[0], torch.ones_like(x[0]))
|
||||
loss_fn_for_mixtral_model = lambda x: x[0].mean()
|
||||
loss_fn = lambda x: x.loss
|
||||
loss_fn_for_seq_classification = lambda output: output.logits.mean()
|
||||
|
||||
|
@ -52,7 +52,7 @@ config = MixtralConfig(
|
|||
intermediate_size=256,
|
||||
num_attention_heads=64,
|
||||
num_hidden_layers=2,
|
||||
vocab_size=50258,
|
||||
vocab_size=1000,
|
||||
output_router_logits=True,
|
||||
)
|
||||
|
||||
|
|
|
@ -141,7 +141,6 @@ def check_moe_checkpoint(test_config):
|
|||
if dist.get_rank() == 0:
|
||||
saved_model = model_cls.from_pretrained(model_dir).cuda()
|
||||
check_model_equal(orig_model, saved_model)
|
||||
# check_model_equal(model, saved_model)
|
||||
saved_model.save_pretrained(hf_model_dir)
|
||||
dist.barrier()
|
||||
# check load model
|
||||
|
|
|
@ -31,16 +31,17 @@ def split_grad(grad, world_size):
|
|||
return splited_grad
|
||||
|
||||
|
||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||
@parameterize("master_weights", [True, False])
|
||||
@parameterize("stage", [1, 2])
|
||||
def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.dtype, stage: int):
|
||||
@parameterize("ep_size", [1, 2, 4])
|
||||
def run_zero_with_original_model(stage: int, ep_size: int):
|
||||
dtype = torch.float16
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=dist.get_world_size() // 2,
|
||||
ep_size=ep_size,
|
||||
)
|
||||
|
||||
seed_all(10086)
|
||||
|
@ -53,26 +54,30 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.
|
|||
|
||||
orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda()
|
||||
|
||||
ori_model = DDP(orig_model.cuda(), static_graph=True).cuda()
|
||||
ori_model = DDP(
|
||||
orig_model.cuda(),
|
||||
process_group=plugin.dp_group,
|
||||
find_unused_parameters=True, # important for torch ddp, not all experts are routed
|
||||
).cuda()
|
||||
|
||||
zero_model = deepcopy(orig_model).to(dtype)
|
||||
zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group)
|
||||
|
||||
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
|
||||
pg_param_list = {plugin.global_dp_group: [], plugin.moe_dp_group: []}
|
||||
pg_param_list = {plugin.dp_group: [], plugin.moe_dp_group: []}
|
||||
for p in zero_model.parameters():
|
||||
if is_moe_tensor(p):
|
||||
pg_param_list[plugin.moe_dp_group].append(p)
|
||||
else:
|
||||
pg_param_list[plugin.global_dp_group].append(p)
|
||||
pg_param_list[plugin.dp_group].append(p)
|
||||
|
||||
zero_optimizer = LowLevelZeroOptimizer(
|
||||
zero_optimizer,
|
||||
pg_to_param_list=pg_param_list,
|
||||
master_weights=master_weights,
|
||||
master_weights=False,
|
||||
initial_scale=1,
|
||||
overlap_communication=False,
|
||||
partition_grad=True,
|
||||
overlap_communication=True,
|
||||
partition_grad=stage == 2,
|
||||
)
|
||||
|
||||
ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1)
|
||||
|
@ -82,11 +87,11 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.
|
|||
|
||||
for _ in range(2):
|
||||
# zero-dp forward
|
||||
input_data = torch.rand(1, tokens, hidden_size).cuda()
|
||||
zero_output, zero_logits = zero_model(input_data.to(dtype))
|
||||
input_data = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
||||
zero_output, _ = zero_model(input_data.to(dtype))
|
||||
|
||||
# torch-ddp forward
|
||||
ori_output, ori_logits = ori_model(input_data.to(dtype))
|
||||
ori_output, _ = ori_model(input_data.to(dtype))
|
||||
loose_close(zero_output, ori_output, dtype=dtype)
|
||||
|
||||
# zero-dp backward
|
||||
|
@ -115,14 +120,16 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.
|
|||
for n, p in zero_model.named_parameters():
|
||||
loose_close(p.data, name_to_p[n].data, dtype=dtype)
|
||||
|
||||
print(f"{dist.get_rank()} test passed")
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_zero_with_original_model(world_size=world_size)
|
||||
run_zero_with_original_model()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_model(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
|
|
@ -25,13 +25,14 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
|||
|
||||
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
# TODO: SGD failed for full dp
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
|
||||
)
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
|
||||
)
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
)
|
||||
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
@ -73,6 +74,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
@ -103,9 +107,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
verbose=False,
|
||||
)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
@ -114,37 +115,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
[
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 4,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"ep_size": 1,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 0,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
# {
|
||||
"precision": "fp32",
|
||||
}, # pp + ep
|
||||
# {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 1, "precision": "fp16"}, # full dp for moe and non-moe
|
||||
# { # moe_dp = 2, non_moe_dp = 4
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 1,
|
||||
# "ep_size": 4,
|
||||
# "num_microbatches": 2,
|
||||
# "ep_size": 2,
|
||||
# "zero_stage": 1,
|
||||
# "enable_all_optimization": True,
|
||||
# "use_lazy_init": False,
|
||||
# "precision": "fp16",
|
||||
# "initial_scale": 1,
|
||||
# },
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 1,
|
||||
# "ep_size": 4,
|
||||
# "num_microbatches": 2,
|
||||
# "zero_stage": 2,
|
||||
# "enable_all_optimization": True,
|
||||
# "use_lazy_init": False,
|
||||
# "precision": "fp16",
|
||||
# "initial_scale": 1,
|
||||
# },
|
||||
# }, # moe_dp = 1, non_moe_dp = 4
|
||||
# {"tp_size": 1, "pp_size": 1, "ep_size": 4, "zero_stage": 1, "precision": "fp16"},
|
||||
# {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 0, "precision": "fp32"}, # full dp for moe and non-moe
|
||||
],
|
||||
)
|
||||
def run_mixtral_test(test_config):
|
||||
|
|
Loading…
Reference in New Issue