mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] standardize the code structure (#1469)
parent
26a37b5cd5
commit
3a54e1c9b7
|
@ -1,3 +1,9 @@
|
|||
from dataclasses import dataclass
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardingStrategy:
|
||||
'''
|
||||
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
|
||||
|
@ -15,21 +21,13 @@ class ShardingStrategy:
|
|||
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
name,
|
||||
output_sharding_spec,
|
||||
compute_cost=0,
|
||||
communication_cost=0,
|
||||
memory_cost=0,
|
||||
resharding_costs=None,
|
||||
input_shardings=None):
|
||||
self.name = name
|
||||
self.output_sharding_spec = output_sharding_spec
|
||||
self.compute_cost = compute_cost
|
||||
self.communication_cost = communication_cost
|
||||
self.memory_cost = memory_cost
|
||||
self.resharding_costs = resharding_costs
|
||||
self.input_shardings = input_shardings
|
||||
name: str
|
||||
output_sharding_spec: ShardingSpec
|
||||
compute_cost: float = 0.
|
||||
communication_cost: float = 0.
|
||||
memory_cost: float = 0.
|
||||
resharding_costs: Dict[int, List[float]] = None
|
||||
input_shardings: ShardingSpec = None
|
||||
|
||||
|
||||
class StrategiesVector:
|
||||
|
|
Loading…
Reference in New Issue