diff --git a/colossalai/auto_parallel/solver/op_handler/node_handler.py b/colossalai/auto_parallel/solver/op_handler/node_handler.py new file mode 100644 index 000000000..8deafeb55 --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/node_handler.py @@ -0,0 +1,66 @@ +from abc import ABC, abstractmethod +from torch.fx.node import Node +from colossalai.device.device_mesh import DeviceMesh +from typing import Dict, List +from ..sharding_strategy import StrategiesVector, Operand, StrategyGenerator_V2 + + +class NodeHandler(ABC): + ''' + The NodeHandler is an abstract class used to generate every possible strategies for an operator node. + + Args: + node (Node): the input node in node argument list. + device_mesh (DeviceMesh): A logical view of a physical mesh. + strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. + ''' + + def __init__( + self, + node: Node, + device_mesh: DeviceMesh, + strategies_vector: StrategiesVector, + ) -> None: + self.node = node + self.predecessor_node = list(node._input_nodes.keys()) + self.successor_node = list(node.users.keys()) + self.device_mesh = device_mesh + self.strategies_vector = strategies_vector + self.strategy_generator = self.register_strategy_generator() + + def register_strategy(self) -> StrategiesVector: + """ + Register different sharding strategies for the current node. + """ + operand_mapping = self.get_operand_mapping() + for generator in self.strategy_generator: + strategies = generator.generate(operand_mapping) + self.strategies_vector.extend(strategies) + return self.strategies_vector + + @abstractmethod + def register_strategy_generator(self) -> List[StrategyGenerator_V2]: + """ + Define which generators should be used by this NodeHandler object. + """ + pass + + @abstractmethod + def get_operand_mapping(self) -> Dict[str, Operand]: + """ + Returns the mapping between the logical operand name to its physical operands. + A logical operand is defined by the strategy generator, for example, a matrix multiplication + operation has two operands "input" and "other". For a nn.Linear module, the physical operand for "input" is + the module input and the physical operand for "other" is the module weight. + Note that the operand name is specified by the StrategyGenerator object. + + For example: + + # for a linear layer + mapping = { + "input": Operand(name=str(self.node.args[0]), type=OperandType.ARG), + "other": Operand(name="weight", type=OperandType.PARAM), + "bias": Operand(name="bias", type=OperandType.PARAM) + } + """ + pass diff --git a/colossalai/auto_parallel/solver/op_handler/registry.py b/colossalai/auto_parallel/solver/op_handler/registry.py new file mode 100644 index 000000000..51855e4bf --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/registry.py @@ -0,0 +1,25 @@ +class Registry: + # TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here + + def __init__(self, name): + self.name = name + self.store = {} + + def register(self, source): + + def wrapper(func): + self.store[source] = func + return func + + return wrapper + + def get(self, source): + assert source in self.store + target = self.store[source] + return target + + def has(self, source): + return source in self.store + + +operator_registry = Registry('operator') diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index 598d59fb4..342f6ff37 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -1,4 +1,7 @@ from dataclasses import dataclass +from abc import ABC, abstractmethod +from enum import Enum +from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec from typing import Dict, List, Union, Tuple, Any from torch.fx.node import Node @@ -37,6 +40,20 @@ class ShardingStrategy: input_shardings: List[ShardingSpec] = None +class OperandType(Enum): + """ + An operand can come from the argument list of an operator or the parameter list of a module. + """ + ARG = 0 + PARAM = 1 + + +@dataclass +class Operand: + name: str + type: OperandType + + @dataclass class TrainCycleItem: """ @@ -44,9 +61,8 @@ class TrainCycleItem: in a training iteration. Args: - fwd (Any): the item for the forward pass - bwd (Any): the item for the backward pass - total (Any): the total value for the forward and backward pass + fwd (float): the item for the forward pass + bwd (float): the item for the backward pass """ fwd: Any bwd: Any @@ -74,8 +90,33 @@ class ShardingStrategy_V2: compute_cost: TrainCycleItem = None communication_cost: TrainCycleItem = None memory_cost: TrainCycleItem = None - input_sharding_specs: List[ShardingSpec] = None - input_resharding_costs: Dict[Node, List[float]] = None + input_sharding_specs: Dict[Operand, ShardingSpec] = None + input_resharding_costs: Dict[Operand, List[float]] = None + + +class StrategyGenerator_V2(ABC): + """ + StrategyGenerator is used to generate the same group of sharding strategies. + + TODO: remove the original strategy_generator.py after refactoring + """ + + def __init__(self, device_mesh: DeviceMesh): + self.device_mesh = device_mesh + + @abstractmethod + def generate(self, operand_mapping: Dict[str:Operand]) -> List[ShardingStrategy_V2]: + """ + """ + pass + + @abstractmethod + def validate(self, *args, **kwargs) -> bool: + """ + Validate if the operands are of desired shape. + If True, means this generator can be used for the current operation. + """ + pass class StrategiesVector(list):