mirror of https://github.com/hpcaitech/ColossalAI
135 lines
6.2 KiB
Python
135 lines
6.2 KiB
Python
from abc import ABC, abstractmethod
|
|
from torch.fx.node import Node
|
|
from colossalai.device.device_mesh import DeviceMesh
|
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|
from typing import Dict, List
|
|
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData, TrainCycleItem
|
|
from ..strategy import StrategyGenerator_V2
|
|
|
|
|
|
class NodeHandler(ABC):
|
|
'''
|
|
The NodeHandler is an abstract class used to generate every possible strategies for an operator node.
|
|
|
|
Args:
|
|
node (Node): the input node in node argument list.
|
|
device_mesh (DeviceMesh): A logical view of a physical mesh.
|
|
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
|
|
'''
|
|
|
|
def __init__(
|
|
self,
|
|
node: Node,
|
|
device_mesh: DeviceMesh,
|
|
strategies_vector: StrategiesVector,
|
|
) -> None:
|
|
self.node = node
|
|
self.predecessor_node = list(node._input_nodes.keys())
|
|
self.successor_node = list(node.users.keys())
|
|
self.device_mesh = device_mesh
|
|
self.strategies_vector = strategies_vector
|
|
|
|
def update_resharding_cost(self, strategy: ShardingStrategy_V2) -> None:
|
|
"""
|
|
Compute the resharding costs and save the costs in the ShardingStrategy object.
|
|
"""
|
|
# TODO: test this function when other handlers are ready
|
|
resharding_costs = {}
|
|
shape_consistency_manager = ShapeConsistencyManager()
|
|
for node in self.predecessor_node:
|
|
node_name = str(node)
|
|
|
|
# get the sharding specs for this node generated
|
|
# in its own node handler
|
|
assert hasattr(node, 'strategies_vector'), \
|
|
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
|
|
prev_strategy_vector = node.strategies_vector
|
|
prev_sharding_specs = [strategy.get_sharding_spec_by_name(node_name) for strategy in prev_strategy_vector]
|
|
|
|
# get the current sharding spec generated by this node handler
|
|
op_data = strategy.get_op_data_by_name(node_name)
|
|
current_sharding_spec = strategy.sharding_specs[op_data]
|
|
|
|
# create data structrure to store costs
|
|
if op_data not in resharding_costs:
|
|
resharding_costs[op_data] = {}
|
|
|
|
# for each sharding spec generated by the predecessor's node handler
|
|
# compute the resharding cost to switch to the sharding spec generated
|
|
# by the current node handler
|
|
for prev_sharding_spec in prev_sharding_specs:
|
|
fwd_cost = shape_consistency_manager.shape_consistency(prev_sharding_spec, current_sharding_spec)
|
|
bwd_cost = shape_consistency_manager.shape_consistency(current_sharding_spec, prev_sharding_spec)
|
|
resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=fwd_cost + bwd_cost)
|
|
resharding_costs[op_data][prev_sharding_spec] = resharding_cost
|
|
strategy.resharding_costs = resharding_costs
|
|
|
|
def register_strategy(self, compute_resharding_cost: bool = False) -> StrategiesVector:
|
|
"""
|
|
Register different sharding strategies for the current node.
|
|
"""
|
|
strategy_generators = self.get_strategy_generator()
|
|
for generator in strategy_generators:
|
|
strategies = generator.generate()
|
|
|
|
# compute the resharding costs based on the previous node
|
|
# strategies if specified
|
|
if compute_resharding_cost:
|
|
strategies = list(map(self.update_resharding_cost, strategies))
|
|
self.strategies_vector.extend(strategies)
|
|
|
|
strategies_vector = map(self.post_process, self.strategies_vector)
|
|
self.strategies_vector = list(strategies_vector)
|
|
return self.strategies_vector
|
|
|
|
def post_process(self, strategy: ShardingStrategy_V2):
|
|
# tranform the strategy generated
|
|
# e.g. to process the sharding strategy for the transposed weights
|
|
return strategy
|
|
|
|
@abstractmethod
|
|
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
|
"""
|
|
Define which generators should be used by this NodeHandler object.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
|
"""
|
|
Returns the mapping between the logical operation data to its physical data.
|
|
A logical operation data is a data associated with an operation, which can be input and output. It is
|
|
defined by the strategy generator, for example, a matrix multiplication operation has two operands "input"
|
|
and "other" and one result "output". For a nn.Linear module, the physical operand for "input" is
|
|
the module input, the physical operand for "other" is the module weight, and the physical result for "output"
|
|
is the module output.
|
|
Note that the operand name is specified by the StrategyGenerator object.
|
|
|
|
For example:
|
|
|
|
# for a linear layer
|
|
mapping = {
|
|
"input": Operand(name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data),
|
|
"other": Operand(name="weight", type=OperationDataType.PARAM, data=self.named_parameters['weight']),
|
|
"bias": Operand(name="bias", type=OperationDataType.PARAM, data=self.named_parameters['bias']),
|
|
"output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data),
|
|
}
|
|
"""
|
|
pass
|
|
|
|
|
|
class ModuleHandler(NodeHandler):
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# set attributes to access module parameters for convenience
|
|
assert self.node.graph.owning_module is not None, \
|
|
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
|
|
module = self.node.graph.owning_module.get_submodule(self.node.target)
|
|
named_parameters = list(module.named_parameters(recurse=False))
|
|
# convert named parameters from list to dict
|
|
named_parameters = {k: v for k, v in named_parameters}
|
|
self.module = module
|
|
self.named_parameters = named_parameters
|