mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] refactored the data structure for sharding strategy (#1610)
parent
933b6c6367
commit
edb67cb378
|
@ -1,6 +1,6 @@
|
|||
from dataclasses import dataclass
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from typing import Dict, List, Union, Tuple
|
||||
from typing import Dict, List, Union, Tuple, Any
|
||||
from torch.fx.node import Node
|
||||
from .constants import *
|
||||
|
||||
|
@ -37,6 +37,47 @@ class ShardingStrategy:
|
|||
input_shardings: List[ShardingSpec] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainCycleItem:
|
||||
"""
|
||||
TrainCycleItem is a dataclass to store the items which have different values for the forward and backward pass
|
||||
in a training iteration.
|
||||
|
||||
Args:
|
||||
fwd (Any): the item for the forward pass
|
||||
bwd (Any): the item for the backward pass
|
||||
total (Any): the total value for the forward and backward pass
|
||||
"""
|
||||
fwd: Any
|
||||
bwd: Any
|
||||
total: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardingStrategy_V2:
|
||||
"""
|
||||
ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node.
|
||||
|
||||
Args:
|
||||
name (str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
|
||||
output_sharding_spec (ShardingSpec): ShardingSpec of the output node.
|
||||
compute_cost (TrainCycleItem): Computation cost to complete this strategy. (default to None)
|
||||
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
|
||||
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
|
||||
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
||||
input_resharding_costs (Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
||||
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
||||
strategy.(default to None)
|
||||
"""
|
||||
name: str
|
||||
output_sharding_spec: ShardingSpec
|
||||
compute_cost: TrainCycleItem = None
|
||||
communication_cost: TrainCycleItem = None
|
||||
memory_cost: TrainCycleItem = None
|
||||
input_sharding_specs: List[ShardingSpec] = None
|
||||
input_resharding_costs: Dict[Node, List[float]] = None
|
||||
|
||||
|
||||
class StrategiesVector(list):
|
||||
'''
|
||||
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
||||
|
|
Loading…
Reference in New Issue