|
|
|
@ -1,8 +1,9 @@
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
|
|
from torch.fx.node import Node |
|
|
|
|
from colossalai.device.device_mesh import DeviceMesh |
|
|
|
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager |
|
|
|
|
from typing import Dict, List |
|
|
|
|
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData |
|
|
|
|
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData, TrainCycleItem |
|
|
|
|
from ..strategy import StrategyGenerator_V2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -28,13 +29,53 @@ class NodeHandler(ABC):
|
|
|
|
|
self.device_mesh = device_mesh |
|
|
|
|
self.strategies_vector = strategies_vector |
|
|
|
|
|
|
|
|
|
def register_strategy(self) -> StrategiesVector: |
|
|
|
|
def update_resharding_cost(self, strategy: ShardingStrategy_V2) -> None: |
|
|
|
|
""" |
|
|
|
|
Compute the resharding costs and save the costs in the ShardingStrategy object. |
|
|
|
|
""" |
|
|
|
|
# TODO: test this function when other handlers are ready |
|
|
|
|
resharding_costs = {} |
|
|
|
|
shape_consistency_manager = ShapeConsistencyManager() |
|
|
|
|
for node in self.predecessor_node: |
|
|
|
|
node_name = str(node) |
|
|
|
|
|
|
|
|
|
# get the sharding specs for this node generated |
|
|
|
|
# in its own node handler |
|
|
|
|
assert hasattr(node, 'strategies_vector'), \ |
|
|
|
|
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.' |
|
|
|
|
prev_strategy_vector = node.strategies_vector |
|
|
|
|
prev_sharding_specs = [strategy.get_sharding_spec_by_name(node_name) for strategy in prev_strategy_vector] |
|
|
|
|
|
|
|
|
|
# get the current sharding spec generated by this node handler |
|
|
|
|
op_data = strategy.get_op_data_by_name(node_name) |
|
|
|
|
current_sharding_spec = strategy.sharding_specs[op_data] |
|
|
|
|
|
|
|
|
|
# create data structrure to store costs |
|
|
|
|
if op_data not in resharding_costs: |
|
|
|
|
resharding_costs[op_data] = {} |
|
|
|
|
|
|
|
|
|
# for each sharding spec generated by the predecessor's node handler |
|
|
|
|
# compute the resharding cost to switch to the sharding spec generated |
|
|
|
|
# by the current node handler |
|
|
|
|
for prev_sharding_spec in prev_sharding_specs: |
|
|
|
|
fwd_cost = shape_consistency_manager.shape_consistency(prev_sharding_spec, current_sharding_spec) |
|
|
|
|
bwd_cost = shape_consistency_manager.shape_consistency(current_sharding_spec, prev_sharding_spec) |
|
|
|
|
resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=fwd_cost + bwd_cost) |
|
|
|
|
resharding_costs[op_data][prev_sharding_spec] = resharding_cost |
|
|
|
|
strategy.resharding_costs = resharding_costs |
|
|
|
|
|
|
|
|
|
def register_strategy(self, compute_resharding_cost: bool = False) -> StrategiesVector: |
|
|
|
|
""" |
|
|
|
|
Register different sharding strategies for the current node. |
|
|
|
|
""" |
|
|
|
|
strategy_generators = self.get_strategy_generator() |
|
|
|
|
for generator in strategy_generators: |
|
|
|
|
strategies = generator.generate() |
|
|
|
|
|
|
|
|
|
# compute the resharding costs based on the previous node |
|
|
|
|
# strategies if specified |
|
|
|
|
if compute_resharding_cost: |
|
|
|
|
strategies = list(map(self.update_resharding_cost, strategies)) |
|
|
|
|
self.strategies_vector.extend(strategies) |
|
|
|
|
|
|
|
|
|
strategies_vector = map(self.post_process, self.strategies_vector) |
|
|
|
|