mirror of https://github.com/hpcaitech/ColossalAI
67 lines
2.6 KiB
Python
67 lines
2.6 KiB
Python
from abc import ABC, abstractmethod
|
|
from torch.fx.node import Node
|
|
from colossalai.device.device_mesh import DeviceMesh
|
|
from typing import Dict, List
|
|
from ..sharding_strategy import StrategiesVector, Operand, StrategyGenerator_V2
|
|
|
|
|
|
class NodeHandler(ABC):
|
|
'''
|
|
The NodeHandler is an abstract class used to generate every possible strategies for an operator node.
|
|
|
|
Args:
|
|
node (Node): the input node in node argument list.
|
|
device_mesh (DeviceMesh): A logical view of a physical mesh.
|
|
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
|
'''
|
|
|
|
def __init__(
|
|
self,
|
|
node: Node,
|
|
device_mesh: DeviceMesh,
|
|
strategies_vector: StrategiesVector,
|
|
) -> None:
|
|
self.node = node
|
|
self.predecessor_node = list(node._input_nodes.keys())
|
|
self.successor_node = list(node.users.keys())
|
|
self.device_mesh = device_mesh
|
|
self.strategies_vector = strategies_vector
|
|
self.strategy_generator = self.register_strategy_generator()
|
|
|
|
def register_strategy(self) -> StrategiesVector:
|
|
"""
|
|
Register different sharding strategies for the current node.
|
|
"""
|
|
operand_mapping = self.get_operand_mapping()
|
|
for generator in self.strategy_generator:
|
|
strategies = generator.generate(operand_mapping)
|
|
self.strategies_vector.extend(strategies)
|
|
return self.strategies_vector
|
|
|
|
@abstractmethod
|
|
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
|
"""
|
|
Define which generators should be used by this NodeHandler object.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_operand_mapping(self) -> Dict[str, Operand]:
|
|
"""
|
|
Returns the mapping between the logical operand name to its physical operands.
|
|
A logical operand is defined by the strategy generator, for example, a matrix multiplication
|
|
operation has two operands "input" and "other". For a nn.Linear module, the physical operand for "input" is
|
|
the module input and the physical operand for "other" is the module weight.
|
|
Note that the operand name is specified by the StrategyGenerator object.
|
|
|
|
For example:
|
|
|
|
# for a linear layer
|
|
mapping = {
|
|
"input": Operand(name=str(self.node.args[0]), type=OperandType.ARG),
|
|
"other": Operand(name="weight", type=OperandType.PARAM),
|
|
"bias": Operand(name="bias", type=OperandType.PARAM)
|
|
}
|
|
"""
|
|
pass
|