ColossalAI/colossalai/auto_parallel/solver/op_handler/node_handler.py

67 lines
2.6 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List
from ..sharding_strategy import StrategiesVector, Operand, StrategyGenerator_V2
class NodeHandler(ABC):
'''
The NodeHandler is an abstract class used to generate every possible strategies for an operator node.
Args:
node (Node): the input node in node argument list.
device_mesh (DeviceMesh): A logical view of a physical mesh.
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
'''
def __init__(
self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
self.strategy_generator = self.register_strategy_generator()
def register_strategy(self) -> StrategiesVector:
"""
Register different sharding strategies for the current node.
"""
operand_mapping = self.get_operand_mapping()
for generator in self.strategy_generator:
strategies = generator.generate(operand_mapping)
self.strategies_vector.extend(strategies)
return self.strategies_vector
@abstractmethod
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
"""
Define which generators should be used by this NodeHandler object.
"""
pass
@abstractmethod
def get_operand_mapping(self) -> Dict[str, Operand]:
"""
Returns the mapping between the logical operand name to its physical operands.
A logical operand is defined by the strategy generator, for example, a matrix multiplication
operation has two operands "input" and "other". For a nn.Linear module, the physical operand for "input" is
the module input and the physical operand for "other" is the module weight.
Note that the operand name is specified by the StrategyGenerator object.
For example:
# for a linear layer
mapping = {
"input": Operand(name=str(self.node.args[0]), type=OperandType.ARG),
"other": Operand(name="weight", type=OperandType.PARAM),
"bias": Operand(name="bias", type=OperandType.PARAM)
}
"""
pass