ColossalAI/colossalai/auto_parallel/solver/op_handler/strategy_generator.py

35 lines
1.1 KiB
Python
Raw Normal View History

from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import List, Dict
from colossalai.device.device_mesh import DeviceMesh
__all__ = ['IntermediateStrategy', 'StrategyGenerator']
@dataclass
class IntermediateStrategy:
"""
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
Args:
name (str): name of the sharding strategy.
dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping.
all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation.
"""
name: str
dim_partition_dict: Dict[str, Dict[int, List[int]]]
all_reduce_axis: List[int] = None
class StrategyGenerator(ABC):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
"""
def __init__(self, device_mesh: DeviceMesh):
self.device_mesh = device_mesh
@abstractmethod
def generate(self) -> List[IntermediateStrategy]:
pass