2024-04-03 09:15:47 +00:00
import warnings
2023-11-10 02:49:50 +00:00
from dataclasses import dataclass , field
2023-11-19 13:05:05 +00:00
from typing import Any , Dict , Optional
2023-06-19 05:53:17 +00:00
2023-06-30 01:58:08 +00:00
import torch . distributed as dist
from torch . distributed import ProcessGroup
2023-07-05 06:16:55 +00:00
from colossalai . pipeline . stage_manager import PipelineStageManager
2024-04-01 03:34:58 +00:00
from . grad_ckpt_config import GradientCheckpointConfig
2023-09-19 06:20:26 +00:00
__all__ = [ " ShardConfig " ]
2024-04-03 09:15:47 +00:00
SUPPORT_SP_MODE = [ " split_gather " , " ring " , " all_to_all " ]
2023-05-24 08:01:26 +00:00
2023-05-22 07:02:17 +00:00
@dataclass
class ShardConfig :
2023-06-09 06:36:54 +00:00
r """
The config for sharding the huggingface model
Args :
2023-09-15 02:56:39 +00:00
tensor_parallel_process_group ( Optional [ ProcessGroup ] ) : The process group of tensor parallelism , it ' s necessary when using tensor parallel. Defaults to None, which is the global process group.
pipeline_stage_manager ( Optional [ PipelineStageManager ] ) : If using pipeline parallelism , it ' s necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
enable_tensor_parallelism ( bool ) : Whether to use tensor parallelism . Defaults to True .
enable_fused_normalization ( bool ) : Whether to use fused layernorm . Defaults to False .
enable_flash_attention ( bool , optional ) : Whether to switch on flash attention . Defaults to False .
enable_jit_fused ( bool , optional ) : Whether to switch on JIT fused operators . Defaults to False .
enable_sequence_parallelism ( bool ) : Whether to turn on sequence parallelism , which partitions non - tensor - parallel regions along the sequence dimension . Defaults to False .
2024-01-04 08:21:55 +00:00
enable_sequence_overlap ( bool ) : Whether to turn on sequence overlap , which overlap the computation and communication in sequence parallelism . It can only be used when enable_sequence_parallelism is True . Defaults to False .
2024-04-01 03:34:58 +00:00
gradient_checkpoint_config ( Optional [ GradientCheckpointConfig ] ) : The gradient checkpoint config . Defaults to None .
2024-01-04 08:21:55 +00:00
enable_all_optimization ( bool ) : Whether to turn on all optimization tools including ' fused normalization ' , ' flash attention ' , ' JIT fused operators ' , ' sequence parallelism ' and ' sequence overlap ' . Defaults to False .
2023-05-22 07:02:17 +00:00
"""
2023-07-05 06:16:55 +00:00
tensor_parallel_process_group : Optional [ ProcessGroup ] = None
2024-04-03 09:15:47 +00:00
sequence_parallel_process_group : Optional [ ProcessGroup ] = None
2023-07-05 06:16:55 +00:00
pipeline_stage_manager : Optional [ PipelineStageManager ] = None
2023-07-04 01:57:03 +00:00
enable_tensor_parallelism : bool = True
2024-04-03 09:15:47 +00:00
enable_all_optimization : bool = False
2023-06-30 01:32:37 +00:00
enable_fused_normalization : bool = False
2023-08-07 08:41:07 +00:00
enable_flash_attention : bool = False
enable_jit_fused : bool = False
2023-09-22 03:02:50 +00:00
enable_sequence_parallelism : bool = False
2024-04-03 09:15:47 +00:00
sequence_parallelism_mode : str = None
2023-09-22 03:02:50 +00:00
enable_sequence_overlap : bool = False
2024-03-18 07:55:11 +00:00
parallel_output : bool = True
2024-04-01 03:34:58 +00:00
gradient_checkpoint_config : Optional [ GradientCheckpointConfig ] = None
2023-11-19 13:05:05 +00:00
extra_kwargs : Dict [ str , Any ] = field ( default_factory = dict )
2024-03-18 07:55:11 +00:00
# TODO padding vocab
# make_vocab_size_divisible_by: int = 128
2023-09-22 03:02:50 +00:00
# pipeline_parallel_size: int
# data_parallel_size: int
2023-06-19 05:53:17 +00:00
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
2023-06-30 02:56:29 +00:00
@property
def tensor_parallel_size ( self ) :
return self . _tensor_parallel_size
2024-04-03 09:15:47 +00:00
@property
def sequence_parallel_size ( self ) :
return self . _sequence_parallel_size
2023-06-19 05:53:17 +00:00
def __post_init__ ( self ) :
2024-04-03 09:15:47 +00:00
# turn on all optimization if all_optimization is set to True
if self . enable_all_optimization :
self . _turn_on_all_optimization ( )
if self . enable_sequence_parallelism :
self . sequence_parallelism_mode = (
" split_gather " if self . sequence_parallelism_mode is None else self . sequence_parallelism_mode
2023-09-19 06:20:26 +00:00
)
2024-04-03 09:15:47 +00:00
assert (
self . sequence_parallelism_mode in SUPPORT_SP_MODE
) , f " Sequence parallelism mode { self . sequence_parallelism_mode } is not in the supported list { SUPPORT_SP_MODE } "
if self . sequence_parallelism_mode in [ " split_gather " , " ring " ] :
assert (
self . enable_tensor_parallelism
) , f " sequence parallelism mode { self . sequence_parallelism_mode } can only be used when enable_tensor_parallelism is True "
elif self . sequence_parallelism_mode in [ " all_to_all " ] :
assert (
not self . enable_tensor_parallelism
) , f " sequence parallelism mode { self . sequence_parallelism_mode } can only be used when enable_tensor_parallelism is False "
if self . enable_sequence_overlap :
self . enable_sequence_overlap = False
warnings . warn (
f " The enable_sequence_overlap flag will be ignored in sequence parallelism mode { self . sequence_parallelism_mode } "
)
else :
if self . sequence_parallelism_mode :
self . sequence_parallelism_mode = None
warnings . warn (
f " The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False "
)
assert (
not self . enable_sequence_overlap
) , f " enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True "
# get the tensor parallel size
2023-07-04 01:57:03 +00:00
if not self . enable_tensor_parallelism :
self . _tensor_parallel_size = 1
else :
self . _tensor_parallel_size = dist . get_world_size ( self . tensor_parallel_process_group )
2024-04-03 09:15:47 +00:00
# get the sequence parallel size
if not self . enable_sequence_parallelism :
self . _sequence_parallel_size = 1
else :
self . _sequence_parallel_size = dist . get_world_size ( self . sequence_parallel_process_group )
2023-06-30 02:56:29 +00:00
def _turn_on_all_optimization ( self ) :
"""
Turn on all optimization .
"""
# you can add all the optimization flag here
2023-07-03 07:29:11 +00:00
self . enable_fused_normalization = True
2023-08-07 08:41:07 +00:00
self . enable_flash_attention = True
self . enable_jit_fused = True
2024-04-02 12:11:18 +00:00
# This can cause non-in-place param sharding when used without ZeRO.
# It may also slow down training when seq len is small. Plz enable manually.
# self.enable_sequence_parallelism = True
# self.enable_sequence_overlap = True
2023-09-11 17:22:56 +00:00
def _infer ( self ) :
"""
Set default params for inference .
"""
2023-10-27 08:19:54 +00:00
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"