from enum import Enum from typing import Tuple, List from colossalai.context.parallel_mode import ParallelMode class ComputePattern(Enum): TP1DRow_Linear = 1 TP1DCol_Linear = 2 TP1DRow_Embedding = 3 TP1DCol_Embedding = 4 ZeRO = 5 DP = 6 class ShardPattern(Enum): NA = 0 Row = 1 Col = 2 class ParallelAction(object): def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA, gather_out=True) -> None: self.priority = priority self.compute_pattern = compute_pattern self.parallel_mode = parallel_mode self.gather_out = gather_out class TensorSpec(object): """ It contains two aspects of information: First, How are tensors distributed in Heterougenous memory space. Second, if the tensor is a model parameter, the Spec contains the parallel computation pattern of the Operator (Layer). We have to consider the hybrid parallel mode. """ # a list of parallel actions. # For example: On 8 GPUs, a hybrid parallel strategy is applied using # using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2. # parallel_action_list = [ # ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)), # ParallelAction(1, ComputePattern.TP1DRow_Linear, gpc.get_group(ParallelMode.PARALLEL_1D)) # ] # When the ColoTensor is initialized, # we first splitting tensor according to ParallelAction of ZeRO, # then splitting tensor according to ParallelAction of TP1DRow_Linear. # During Linear computation # Before Linear Op, we gather the tensors according to ZeRO. # We perform Linear Op according to compute pattern of TP1DRow_Linear. # After Linear Op, we split the tensors according to ZeRO. def __init__(self, parallel_action_list: List[ParallelAction] = [], shard_pattern: ShardPattern = ShardPattern.NA): self._parallel_action_list = parallel_action_list self._shard_pattern = shard_pattern self.sort() @property def parallel_action_list(self): return self._parallel_action_list @property def num_action(self): return len(self._parallel_action_list) @property def compute_patterns(self): return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list] @property def shard_pattern(self): return self._shard_pattern def sort(self): if len(self._parallel_action_list) > 0: self._parallel_action_list.sort(key=lambda parallel_action: parallel_action.priority) def get_action_by_compute_pattern(self, compute_pattern: ComputePattern): for parallel_action in self._parallel_action_list: if parallel_action.compute_pattern == compute_pattern: return parallel_action return None