mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] refactored the data structure for sharding strategy (#1610)
parent
933b6c6367
commit
edb67cb378
|
@ -1,6 +1,6 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from typing import Dict, List, Union, Tuple
|
from typing import Dict, List, Union, Tuple, Any
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
from .constants import *
|
from .constants import *
|
||||||
|
|
||||||
|
@ -37,6 +37,47 @@ class ShardingStrategy:
|
||||||
input_shardings: List[ShardingSpec] = None
|
input_shardings: List[ShardingSpec] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainCycleItem:
|
||||||
|
"""
|
||||||
|
TrainCycleItem is a dataclass to store the items which have different values for the forward and backward pass
|
||||||
|
in a training iteration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fwd (Any): the item for the forward pass
|
||||||
|
bwd (Any): the item for the backward pass
|
||||||
|
total (Any): the total value for the forward and backward pass
|
||||||
|
"""
|
||||||
|
fwd: Any
|
||||||
|
bwd: Any
|
||||||
|
total: Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ShardingStrategy_V2:
|
||||||
|
"""
|
||||||
|
ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
|
||||||
|
output_sharding_spec (ShardingSpec): ShardingSpec of the output node.
|
||||||
|
compute_cost (TrainCycleItem): Computation cost to complete this strategy. (default to None)
|
||||||
|
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
|
||||||
|
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
|
||||||
|
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
||||||
|
input_resharding_costs (Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
||||||
|
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
||||||
|
strategy.(default to None)
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
output_sharding_spec: ShardingSpec
|
||||||
|
compute_cost: TrainCycleItem = None
|
||||||
|
communication_cost: TrainCycleItem = None
|
||||||
|
memory_cost: TrainCycleItem = None
|
||||||
|
input_sharding_specs: List[ShardingSpec] = None
|
||||||
|
input_resharding_costs: Dict[Node, List[float]] = None
|
||||||
|
|
||||||
|
|
||||||
class StrategiesVector(list):
|
class StrategiesVector(list):
|
||||||
'''
|
'''
|
||||||
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
||||||
|
|
Loading…
Reference in New Issue