mirror of https://github.com/hpcaitech/ColossalAI
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
|
from dataclasses import dataclass
|
||
|
from abc import ABC, abstractmethod
|
||
|
from typing import List, Dict
|
||
|
from colossalai.device.device_mesh import DeviceMesh
|
||
|
|
||
|
__all__ = ['IntermediateStrategy', 'StrategyGenerator']
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class IntermediateStrategy:
|
||
|
"""
|
||
|
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
|
||
|
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
|
||
|
|
||
|
Args:
|
||
|
name (str): name of the sharding strategy.
|
||
|
dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping.
|
||
|
all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation.
|
||
|
"""
|
||
|
name: str
|
||
|
dim_partition_dict: Dict[str, Dict[int, List[int]]]
|
||
|
all_reduce_axis: List[int] = None
|
||
|
|
||
|
|
||
|
class StrategyGenerator(ABC):
|
||
|
"""
|
||
|
StrategyGenerator is used to generate the same group of sharding strategies.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, device_mesh: DeviceMesh):
|
||
|
self.device_mesh = device_mesh
|
||
|
|
||
|
@abstractmethod
|
||
|
def generate(self) -> List[IntermediateStrategy]:
|
||
|
pass
|