2023-11-10 02:49:50 +00:00
from dataclasses import dataclass , field
2023-11-19 13:05:05 +00:00
from typing import Any , Dict , Optional
2023-06-19 05:53:17 +00:00
2023-06-30 01:58:08 +00:00
import torch . distributed as dist
from torch . distributed import ProcessGroup
2023-07-05 06:16:55 +00:00
from colossalai . pipeline . stage_manager import PipelineStageManager
2023-09-19 06:20:26 +00:00
__all__ = [ " ShardConfig " ]
2023-05-24 08:01:26 +00:00
2023-05-22 07:02:17 +00:00
@dataclass
class ShardConfig :
2023-06-09 06:36:54 +00:00
r """
The config for sharding the huggingface model
Args :
2023-09-15 02:56:39 +00:00
tensor_parallel_process_group ( Optional [ ProcessGroup ] ) : The process group of tensor parallelism , it ' s necessary when using tensor parallel. Defaults to None, which is the global process group.
pipeline_stage_manager ( Optional [ PipelineStageManager ] ) : If using pipeline parallelism , it ' s necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
enable_tensor_parallelism ( bool ) : Whether to use tensor parallelism . Defaults to True .
enable_fused_normalization ( bool ) : Whether to use fused layernorm . Defaults to False .
enable_flash_attention ( bool , optional ) : Whether to switch on flash attention . Defaults to False .
enable_jit_fused ( bool , optional ) : Whether to switch on JIT fused operators . Defaults to False .
enable_sequence_parallelism ( bool ) : Whether to turn on sequence parallelism , which partitions non - tensor - parallel regions along the sequence dimension . Defaults to False .
2024-01-04 08:21:55 +00:00
enable_sequence_overlap ( bool ) : Whether to turn on sequence overlap , which overlap the computation and communication in sequence parallelism . It can only be used when enable_sequence_parallelism is True . Defaults to False .
enable_all_optimization ( bool ) : Whether to turn on all optimization tools including ' fused normalization ' , ' flash attention ' , ' JIT fused operators ' , ' sequence parallelism ' and ' sequence overlap ' . Defaults to False .
2023-05-22 07:02:17 +00:00
"""
2023-07-05 06:16:55 +00:00
tensor_parallel_process_group : Optional [ ProcessGroup ] = None
pipeline_stage_manager : Optional [ PipelineStageManager ] = None
2023-07-04 01:57:03 +00:00
enable_tensor_parallelism : bool = True
2023-06-30 01:32:37 +00:00
enable_fused_normalization : bool = False
2023-08-07 08:41:07 +00:00
enable_flash_attention : bool = False
enable_jit_fused : bool = False
2023-09-15 02:56:39 +00:00
enable_all_optimization : bool = False
2023-09-22 03:02:50 +00:00
enable_sequence_parallelism : bool = False
enable_sequence_overlap : bool = False
2023-11-19 13:05:05 +00:00
extra_kwargs : Dict [ str , Any ] = field ( default_factory = dict )
2023-09-22 03:02:50 +00:00
# pipeline_parallel_size: int
# data_parallel_size: int
2023-06-19 05:53:17 +00:00
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
2023-06-30 02:56:29 +00:00
@property
def tensor_parallel_size ( self ) :
return self . _tensor_parallel_size
2023-06-19 05:53:17 +00:00
def __post_init__ ( self ) :
2023-08-28 09:16:40 +00:00
if not self . enable_tensor_parallelism and self . enable_sequence_parallelism :
raise ValueError (
2023-09-19 06:20:26 +00:00
" enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True "
)
2023-08-28 09:16:40 +00:00
if not self . enable_sequence_parallelism and self . enable_sequence_overlap :
raise ValueError ( " enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True " )
2023-07-04 01:57:03 +00:00
if not self . enable_tensor_parallelism :
self . _tensor_parallel_size = 1
else :
# get the parallel size
self . _tensor_parallel_size = dist . get_world_size ( self . tensor_parallel_process_group )
2023-06-30 02:56:29 +00:00
# turn on all optimization if all_optimization is set to True
if self . enable_all_optimization :
self . _turn_on_all_optimization ( )
def _turn_on_all_optimization ( self ) :
"""
Turn on all optimization .
"""
# you can add all the optimization flag here
2023-07-03 07:29:11 +00:00
self . enable_fused_normalization = True
2023-08-07 08:41:07 +00:00
self . enable_flash_attention = True
self . enable_jit_fused = True
2023-08-18 07:34:18 +00:00
self . enable_sequence_parallelism = True
2023-08-28 09:16:40 +00:00
self . enable_sequence_overlap = True
2023-09-11 17:22:56 +00:00
def _infer ( self ) :
"""
Set default params for inference .
"""
2023-10-27 08:19:54 +00:00
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"