Browse Source

[zero] solve hang

colossalchat
hxwang 5 months ago committed by Hongxin Liu
parent
commit
46c069b0db
  1. 12
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  2. 333
      colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
  3. 4
      colossalai/cluster/process_group_mesh.py
  4. 3
      colossalai/moe/_operation.py
  5. 27
      colossalai/shardformer/policies/mixtral.py
  6. 10
      colossalai/zero/low_level/bookkeeping/bucket_store.py
  7. 2
      colossalai/zero/low_level/bookkeeping/gradient_store.py
  8. 16
      colossalai/zero/low_level/low_level_optim.py
  9. 6
      tests/kit/model_zoo/transformers/mixtral.py
  10. 1
      tests/test_moe/test_moe_checkpoint.py
  11. 37
      tests/test_moe/test_moe_zero_fwd_bwd_optim.py
  12. 52
      tests/test_shardformer/test_model/test_shard_mixtral.py

12
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -1058,17 +1058,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
if dp_outside:
(
self.dp_axis,
self.pp_axis,
self.tp_axis,
self.sp_axis,
) = (
0,
1,
2,
3,
)
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3

333
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

@ -1,9 +1,7 @@
import random
import warnings
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@ -11,7 +9,6 @@ from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelAMPOptimizer,
@ -22,13 +19,8 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
reinitialize_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
@ -39,6 +31,8 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
dp_process_group: ProcessGroup, # the dp pg for comm
moe_dp_group: ProcessGroup, # the moe dp pg for gomm
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
@ -54,30 +48,20 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
):
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
reinitialize_optimizer(optimizer, model)
pg_param_list = {
dp_process_group: [],
moe_extra_dp_process_group: [],
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
}
for param in model.parameters():
if is_moe_tensor(param):
pg_param_list[moe_extra_dp_process_group].append(param)
else:
pg_param_list[dp_process_group].append(param)
super().__init__(
optimizer=optimizer,
@ -102,285 +86,43 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
Plugin for Moe Hybrid Parallel Training.
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import HybridParallelPlugin
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
Args:
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
When set to 0, ZeRO will not be used. Defaults to 0.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
use_ep_inside (bool, Optional): Whether to use ep inside dp (intra-node) for moe params.
TODO: add docstring
"""
def __init__(
self,
pp_size: int,
ep_size: int,
tp_size: int = 1,
sp_size: int = 1,
precision: str = "fp16",
zero_stage: int = 0,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
broadcast_buffers: bool = True,
ddp_bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
use_ep_inside: bool = True,
custom_policy: Policy = None,
checkpoint_io: Optional[MoECheckpointIO] = None,
) -> None:
world_size = dist.get_world_size()
assert tp_size == 1, "Tensor parallel is not supported in MoE yet"
assert sp_size == 1 and enable_sequence_parallelism is False, "Sequence parallelism it not supported in MoE yet"
def __init__(self, ep_size: int, ep_tp_size: int = 1, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
assert (
world_size % (tp_size * pp_size) == 0
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
world_size % (tp_size * pp_size * ep_size) == 0
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
self.dp_size = world_size // (tp_size * pp_size)
self.tp_size = tp_size
self.pp_size = pp_size
self.ep_size = ep_size
self.sp_size = sp_size
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.checkpoint_io = checkpoint_io
logger = get_dist_logger()
# NOTE: Two process meshes: global dp for non-moe param; dp + ep for moe param
# See https://hpc-ai.com/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient
# we change pg mesh to (pp, dp, tp) for better moe performance
assert (
self.ep_size <= self.dp_size
), f"Not enough devices({self.dp_size}) for expert parallelism size({self.ep_size})."
self.moe_dp_size = self.dp_size // self.ep_size
self.use_ep_inside = use_ep_inside
if self.use_ep_inside:
logger.info(f"MoE Parallel use ep inside dp.", ranks=[0])
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size)
else:
logger.info(f"MoE Parallel use ep outside dp.", ranks=[0])
warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.")
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size)
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
logger.info(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}", ranks=[0])
logger.info(
f"MoE Parallel: pp {self.pp_size}, ep {ep_size}, moe dp {self.moe_dp_size}, tp {tp_size}", ranks=[0]
)
self.tp_group = self.pg_mesh.get_group_along_axis(
self.tp_axis
) # TODO: support custom tp size for mixtral lm head
self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis))
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
# TODO: Currently moe only support partially sequence parallel
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.custom_policy = custom_policy
self.stage_manager = None
self.schedule = None
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
if self.use_ddp:
warnings.warn(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
)
self.ddp_config["find_unused_parameters"] = True
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
ep_group=self.ep_group,
)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
)
self.ddp_config = dict(
broadcast_buffers=broadcast_buffers,
bucket_cap_mb=ddp_bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
self.zero_config = dict(
reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
)
self.max_norm = max_norm
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
if ep_tp_size != 1:
raise NotImplementedError
world_size = dist.get_world_size()
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
self.moe_dp_size = world_size // (ep_size * ep_tp_size)
self.ep_size = ep_size
self.moe_tp_size = ep_tp_size
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(
dataset,
num_replicas=self.dp_size,
rank=dist.get_rank(self.global_dp_group),
shuffle=shuffle,
)
self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
# set ep_group after super init
# TODO do it in a better way
self.shard_config.ep_group = self.ep_group
def get_checkpoint_io(self) -> MoECheckpointIO:
if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpointIO(
self.global_dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
)
else:
self.checkpoint_io = self.checkpoint_io(
self.global_dp_group,
self.pp_group,
self.tp_group,
ep_group=self.ep_group,
moe_dp_group=self.moe_dp_group,
zero_stage=self.zero_stage,
)
if hasattr(self.checkpoint_io, "moe_info"):
self.checkpoint_io.moe_info = self.moe_info
return self.checkpoint_io
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
)
def configure(
self,
@ -392,15 +134,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.global_dp_group,
dp_group=self.dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp, # TODO fix why this failed
use_ddp=self.use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
@ -411,8 +152,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
reinitialize_optimizer(optimizer, model)
if self.zero_stage == 0:
# assert self.ep_size > 1
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
@ -435,10 +174,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.global_dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_extra_dp_process_group=self.moe_dp_group,
dp_process_group=self.dp_group,
moe_dp_group=self.moe_dp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,

4
colossalai/cluster/process_group_mesh.py

@ -137,7 +137,7 @@ class ProcessGroupMesh:
assert mode in ["raise", "wrap", "clip"]
return int(np.ravel_multi_index(coord, shape, mode))
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
def _get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
Args:
@ -240,7 +240,7 @@ class ProcessGroupMesh:
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
group = self.get_group(ranks_in_group, backend=backend)
group = self._get_group(ranks_in_group, backend=backend)
if self._rank in ranks_in_group:
target_group = group
return target_group

3
colossalai/moe/_operation.py

@ -393,4 +393,7 @@ def all_to_all_uneven(
group=None,
overlap: bool = False,
):
assert (
inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)

27
colossalai/shardformer/policies/mixtral.py

@ -101,20 +101,18 @@ class MixtralPolicy(Policy):
# )
if getattr(self.shard_config, "ep_group", None) is None:
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
@ -144,6 +142,7 @@ class MixtralPolicy(Policy):
if self.shard_config.enable_flash_attention:
warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.")
self.shard_config.enable_flash_attention = False
return policy

10
colossalai/zero/low_level/bookkeeping/bucket_store.py

@ -100,7 +100,7 @@ class BucketStore(BaseStore):
return self._grad_in_bucket
def get_flatten_grad(self) -> Tensor:
def get_flatten_grad(self, dtype=None) -> Tensor:
"""Return the flattened gradients slices in the bucket, the data organization of the flattened tensor:
[grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....]
@ -110,8 +110,12 @@ class BucketStore(BaseStore):
flat_grad = []
for grad_list in self._grad_in_bucket.values():
flat_grad.append(_flatten_dense_tensors(grad_list))
flat_grad = _flatten_dense_tensors(flat_grad)
if len(grad_list) > 0:
flat_grad.append(_flatten_dense_tensors(grad_list))
if len(flat_grad) > 0:
flat_grad = _flatten_dense_tensors(flat_grad)
else:
flat_grad = torch.tensor([], device=self.comm_stream.device, dtype=dtype)
return flat_grad
def get_param_id_of_grad(self, grad: Tensor) -> int:

2
colossalai/zero/low_level/bookkeeping/gradient_store.py

@ -91,7 +91,7 @@ class GradientStore(BaseStore):
return grad_list
def get_working_grad_by_param_id(self, param_id) -> Tensor:
def get_working_grad_by_param_id(self, param_id) -> Optional[Tensor]:
"""
Return the working gradient for the specified parameter.

16
colossalai/zero/low_level/low_level_optim.py

@ -301,12 +301,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def _run_reduction(self):
for bucket_store in self.pg_to_bucket_store.values():
if bucket_store.num_elements_in_bucket() <= 0:
continue
bucket_store.build_grad_in_bucket()
flat_grads = bucket_store.get_flatten_grad()
flat_grads = bucket_store.get_flatten_grad(self._dtype)
flat_grads /= bucket_store.world_size
# ready to add other tensors to bucket
@ -353,6 +350,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self, bucket_store: BucketStore, origin_grad_list: List, flat_grad_list: List, group_id: int
) -> None:
for rank, grad_list in enumerate(origin_grad_list):
if len(grad_list) == 0:
continue
sync_tensor(flat_grad_list[rank], grad_list)
for grad in grad_list:
param_id = bucket_store.get_param_id_of_grad(grad)
@ -869,12 +868,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def get_param_grad(self, working_param: nn.Parameter) -> Tensor:
grad_store = self.pid_to_grad_store[id(working_param)]
partial_grad = grad_store.get_working_grad_by_param_id(id(working_param))
if partial_grad is None:
grad = grad_store.get_working_grad_by_param_id(id(working_param))
if grad is None:
return None
tensor_list = [torch.empty_like(partial_grad) for _ in range(grad_store.world_size)]
dist.all_gather(tensor_list, partial_grad, group=grad_store.torch_pg)
grad_flat = torch.cat(tensor_list, dim=0)
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
return grad_flat[: working_param.numel()].reshape_as(working_param)
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:

6
tests/kit/model_zoo/transformers/mixtral.py

@ -19,7 +19,7 @@ def data_gen():
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64)
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
@ -43,7 +43,7 @@ def data_gen_for_sequence_classification():
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_mixtral_model = lambda x: torch.nn.functional.mse_loss(x[0], torch.ones_like(x[0]))
loss_fn_for_mixtral_model = lambda x: x[0].mean()
loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
@ -52,7 +52,7 @@ config = MixtralConfig(
intermediate_size=256,
num_attention_heads=64,
num_hidden_layers=2,
vocab_size=50258,
vocab_size=1000,
output_router_logits=True,
)

1
tests/test_moe/test_moe_checkpoint.py

@ -141,7 +141,6 @@ def check_moe_checkpoint(test_config):
if dist.get_rank() == 0:
saved_model = model_cls.from_pretrained(model_dir).cuda()
check_model_equal(orig_model, saved_model)
# check_model_equal(model, saved_model)
saved_model.save_pretrained(hf_model_dir)
dist.barrier()
# check load model

37
tests/test_moe/test_moe_zero_fwd_bwd_optim.py

@ -31,16 +31,17 @@ def split_grad(grad, world_size):
return splited_grad
@parameterize("dtype", [torch.float16, torch.bfloat16])
@parameterize("master_weights", [True, False])
@parameterize("stage", [1, 2])
def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.dtype, stage: int):
@parameterize("ep_size", [1, 2, 4])
def run_zero_with_original_model(stage: int, ep_size: int):
dtype = torch.float16
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=1,
ep_size=dist.get_world_size() // 2,
ep_size=ep_size,
)
seed_all(10086)
@ -53,26 +54,30 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.
orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda()
ori_model = DDP(orig_model.cuda(), static_graph=True).cuda()
ori_model = DDP(
orig_model.cuda(),
process_group=plugin.dp_group,
find_unused_parameters=True, # important for torch ddp, not all experts are routed
).cuda()
zero_model = deepcopy(orig_model).to(dtype)
zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
pg_param_list = {plugin.global_dp_group: [], plugin.moe_dp_group: []}
pg_param_list = {plugin.dp_group: [], plugin.moe_dp_group: []}
for p in zero_model.parameters():
if is_moe_tensor(p):
pg_param_list[plugin.moe_dp_group].append(p)
else:
pg_param_list[plugin.global_dp_group].append(p)
pg_param_list[plugin.dp_group].append(p)
zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer,
pg_to_param_list=pg_param_list,
master_weights=master_weights,
master_weights=False,
initial_scale=1,
overlap_communication=False,
partition_grad=True,
overlap_communication=True,
partition_grad=stage == 2,
)
ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1)
@ -82,11 +87,11 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.
for _ in range(2):
# zero-dp forward
input_data = torch.rand(1, tokens, hidden_size).cuda()
zero_output, zero_logits = zero_model(input_data.to(dtype))
input_data = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
zero_output, _ = zero_model(input_data.to(dtype))
# torch-ddp forward
ori_output, ori_logits = ori_model(input_data.to(dtype))
ori_output, _ = ori_model(input_data.to(dtype))
loose_close(zero_output, ori_output, dtype=dtype)
# zero-dp backward
@ -115,14 +120,16 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.
for n, p in zero_model.named_parameters():
loose_close(p.data, name_to_p[n].data, dtype=dtype)
print(f"{dist.get_rank()} test passed")
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model(world_size=world_size)
run_zero_with_original_model()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size):
spawn(run_dist, world_size)

52
tests/test_shardformer/test_model/test_shard_mixtral.py

@ -25,13 +25,14 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# TODO: SGD failed for full dp
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.SGD
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
)
with torch.autograd.set_detect_anomaly(True):
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
@ -73,6 +74,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check grads
check_all_grad_tensors(grads_to_check)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
@ -103,9 +107,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
verbose=False,
)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@ -114,37 +115,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[
{
"tp_size": 1,
"pp_size": 4,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 1,
"num_microbatches": 4,
"zero_stage": 0,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"initial_scale": 1,
},
# {
"precision": "fp32",
}, # pp + ep
# {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 1, "precision": "fp16"}, # full dp for moe and non-moe
# { # moe_dp = 2, non_moe_dp = 4
# "tp_size": 1,
# "pp_size": 1,
# "ep_size": 4,
# "num_microbatches": 2,
# "ep_size": 2,
# "zero_stage": 1,
# "enable_all_optimization": True,
# "use_lazy_init": False,
# "precision": "fp16",
# "initial_scale": 1,
# },
# {
# "tp_size": 1,
# "pp_size": 1,
# "ep_size": 4,
# "num_microbatches": 2,
# "zero_stage": 2,
# "enable_all_optimization": True,
# "use_lazy_init": False,
# "precision": "fp16",
# "initial_scale": 1,
# },
# }, # moe_dp = 1, non_moe_dp = 4
# {"tp_size": 1, "pp_size": 1, "ep_size": 4, "zero_stage": 1, "precision": "fp16"},
# {"tp_size": 1, "pp_size": 1, "ep_size": 1, "zero_stage": 0, "precision": "fp32"}, # full dp for moe and non-moe
],
)
def run_mixtral_test(test_config):

Loading…
Cancel
Save