[autoparallel] standardize the code structure (#1469)

pull/1472/head
Frank Lee 2022-08-19 15:51:54 +08:00 committed by GitHub
parent 26a37b5cd5
commit 3a54e1c9b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 15 deletions

View File

View File

@ -1,3 +1,9 @@
from dataclasses import dataclass
from colossalai.tensor.sharding_spec import ShardingSpec
from typing import Dict, List
@dataclass
class ShardingStrategy:
'''
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
@ -15,21 +21,13 @@ class ShardingStrategy:
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
'''
def __init__(self,
name,
output_sharding_spec,
compute_cost=0,
communication_cost=0,
memory_cost=0,
resharding_costs=None,
input_shardings=None):
self.name = name
self.output_sharding_spec = output_sharding_spec
self.compute_cost = compute_cost
self.communication_cost = communication_cost
self.memory_cost = memory_cost
self.resharding_costs = resharding_costs
self.input_shardings = input_shardings
name: str
output_sharding_spec: ShardingSpec
compute_cost: float = 0.
communication_cost: float = 0.
memory_cost: float = 0.
resharding_costs: Dict[int, List[float]] = None
input_shardings: ShardingSpec = None
class StrategiesVector: