mirror of https://github.com/hpcaitech/ColossalAI
[doc] add MoeHybridParallelPlugin docstring
parent
7bedd03739
commit
65daa87627
|
@ -101,9 +101,71 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||
|
||||
class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
"""
|
||||
Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin
|
||||
Extra Args:
|
||||
Plugin for MoE Hybrid Parallel Training, which is similar to HybridParallelPlugin
|
||||
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
|
||||
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
|
||||
|
||||
```python
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import MoeHybridParallelPlugin
|
||||
|
||||
model, train_dataset, optimizer, criterion = ...
|
||||
plugin = MoeHybridParallelPlugin(tp_size=2, pp_size=2, ep_size=2)
|
||||
|
||||
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
|
||||
```
|
||||
|
||||
Args:
|
||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
|
||||
ep_size (int): The size of expert parallelism
|
||||
sp_size (int): The size of sequence parallelism.
|
||||
precision (str, optional): Specifies the precision of parameters during training.
|
||||
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
|
||||
Defaults to 'fp16'.
|
||||
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
|
||||
When set to 0, ZeRO will not be used. Defaults to 0.
|
||||
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
|
||||
Currently all the optimization methods include fused normalization, flash attention and JIT.
|
||||
Defaults to False.
|
||||
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
|
||||
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
|
||||
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
|
||||
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
|
||||
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
|
||||
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
|
||||
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
|
||||
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
|
||||
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
|
||||
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
|
||||
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
|
||||
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
|
||||
ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
|
||||
find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
|
||||
check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
|
||||
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
|
||||
static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
|
||||
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
|
||||
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
|
||||
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
|
||||
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
|
||||
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
|
||||
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
|
||||
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
|
||||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
Loading…
Reference in New Issue