diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index 8e34f6e18..598d59fb4 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from colossalai.tensor.sharding_spec import ShardingSpec -from typing import Dict, List, Union, Tuple +from typing import Dict, List, Union, Tuple, Any from torch.fx.node import Node from .constants import * @@ -37,6 +37,47 @@ class ShardingStrategy: input_shardings: List[ShardingSpec] = None +@dataclass +class TrainCycleItem: + """ + TrainCycleItem is a dataclass to store the items which have different values for the forward and backward pass + in a training iteration. + + Args: + fwd (Any): the item for the forward pass + bwd (Any): the item for the backward pass + total (Any): the total value for the forward and backward pass + """ + fwd: Any + bwd: Any + total: Any + + +@dataclass +class ShardingStrategy_V2: + """ + ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node. + + Args: + name (str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'. + output_sharding_spec (ShardingSpec): ShardingSpec of the output node. + compute_cost (TrainCycleItem): Computation cost to complete this strategy. (default to None) + communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None) + memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None) + input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes. + input_resharding_costs (Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list + with j-th strategy in its strategies_vector transforms to sharding spec wanted in this + strategy.(default to None) + """ + name: str + output_sharding_spec: ShardingSpec + compute_cost: TrainCycleItem = None + communication_cost: TrainCycleItem = None + memory_cost: TrainCycleItem = None + input_sharding_specs: List[ShardingSpec] = None + input_resharding_costs: Dict[Node, List[float]] = None + + class StrategiesVector(list): ''' Each node in fx graph will have a corresponding StrategiesVector, to store all the possible