[autoparallel] refactored the data structure for sharding strategy (#1610)

pull/1611/head
Frank Lee 2022-09-20 11:20:54 +08:00 committed by GitHub
parent 933b6c6367
commit edb67cb378
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 42 additions and 1 deletions

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from colossalai.tensor.sharding_spec import ShardingSpec
from typing import Dict, List, Union, Tuple
from typing import Dict, List, Union, Tuple, Any
from torch.fx.node import Node
from .constants import *
@ -37,6 +37,47 @@ class ShardingStrategy:
input_shardings: List[ShardingSpec] = None
@dataclass
class TrainCycleItem:
"""
TrainCycleItem is a dataclass to store the items which have different values for the forward and backward pass
in a training iteration.
Args:
fwd (Any): the item for the forward pass
bwd (Any): the item for the backward pass
total (Any): the total value for the forward and backward pass
"""
fwd: Any
bwd: Any
total: Any
@dataclass
class ShardingStrategy_V2:
"""
ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node.
Args:
name (str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
output_sharding_spec (ShardingSpec): ShardingSpec of the output node.
compute_cost (TrainCycleItem): Computation cost to complete this strategy. (default to None)
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
input_resharding_costs (Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.(default to None)
"""
name: str
output_sharding_spec: ShardingSpec
compute_cost: TrainCycleItem = None
communication_cost: TrainCycleItem = None
memory_cost: TrainCycleItem = None
input_sharding_specs: List[ShardingSpec] = None
input_resharding_costs: Dict[Node, List[float]] = None
class StrategiesVector(list):
'''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible