diff --git a/colossalai/auto_parallel/solver/op_handler/node_handler.py b/colossalai/auto_parallel/solver/op_handler/node_handler.py index c539c5a2b..e137f6fad 100644 --- a/colossalai/auto_parallel/solver/op_handler/node_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/node_handler.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod from torch.fx.node import Node from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import ShapeConsistencyManager from typing import Dict, List -from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData +from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData, TrainCycleItem from ..strategy import StrategyGenerator_V2 @@ -28,13 +29,53 @@ class NodeHandler(ABC): self.device_mesh = device_mesh self.strategies_vector = strategies_vector - def register_strategy(self) -> StrategiesVector: + def update_resharding_cost(self, strategy: ShardingStrategy_V2) -> None: + """ + Compute the resharding costs and save the costs in the ShardingStrategy object. + """ + # TODO: test this function when other handlers are ready + resharding_costs = {} + shape_consistency_manager = ShapeConsistencyManager() + for node in self.predecessor_node: + node_name = str(node) + + # get the sharding specs for this node generated + # in its own node handler + assert hasattr(node, 'strategies_vector'), \ + f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.' + prev_strategy_vector = node.strategies_vector + prev_sharding_specs = [strategy.get_sharding_spec_by_name(node_name) for strategy in prev_strategy_vector] + + # get the current sharding spec generated by this node handler + op_data = strategy.get_op_data_by_name(node_name) + current_sharding_spec = strategy.sharding_specs[op_data] + + # create data structrure to store costs + if op_data not in resharding_costs: + resharding_costs[op_data] = {} + + # for each sharding spec generated by the predecessor's node handler + # compute the resharding cost to switch to the sharding spec generated + # by the current node handler + for prev_sharding_spec in prev_sharding_specs: + fwd_cost = shape_consistency_manager.shape_consistency(prev_sharding_spec, current_sharding_spec) + bwd_cost = shape_consistency_manager.shape_consistency(current_sharding_spec, prev_sharding_spec) + resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=fwd_cost + bwd_cost) + resharding_costs[op_data][prev_sharding_spec] = resharding_cost + strategy.resharding_costs = resharding_costs + + def register_strategy(self, compute_resharding_cost: bool = False) -> StrategiesVector: """ Register different sharding strategies for the current node. """ strategy_generators = self.get_strategy_generator() for generator in strategy_generators: strategies = generator.generate() + + # compute the resharding costs based on the previous node + # strategies if specified + if compute_resharding_cost: + strategies = list(map(self.update_resharding_cost, strategies)) self.strategies_vector.extend(strategies) strategies_vector = map(self.post_process, self.strategies_vector) diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index 4c1a390ce..e73a7281e 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -129,6 +129,7 @@ class ShardingStrategy_V2: memory_cost: TrainCycleItem = None input_resharding_costs: Dict[OperationData, List[float]] = None communication_actions: Dict[OperationData, CommSpec] = None + resharding_costs: Dict[OperationData, Dict[ShardingSpec, TrainCycleItem]] = None @property def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: @@ -153,6 +154,18 @@ class ShardingStrategy_V2: specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type} return specs + def get_op_data_by_name(self, name: str): + for op_data in self.sharding_specs.keys(): + if op_data.name == name: + return op_data + raise KeyError(f"Could not find the OperationData with name {name}") + + def get_sharding_spec_by_name(self, name: str): + for op_data, sharding_spec in self.sharding_specs.items(): + if op_data.name == name: + return sharding_spec + raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}") + class StrategiesVector(list): '''