mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py code style (#2695)
Co-authored-by: shenggan <csg19971016@gmail.com>pull/2708/head
parent
534f68c83c
commit
6427c406cf
|
@ -1,6 +1,7 @@
|
||||||
from dataclasses import dataclass
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
__all__ = ['IntermediateStrategy', 'StrategyGenerator']
|
__all__ = ['IntermediateStrategy', 'StrategyGenerator']
|
||||||
|
@ -9,7 +10,7 @@ __all__ = ['IntermediateStrategy', 'StrategyGenerator']
|
||||||
@dataclass
|
@dataclass
|
||||||
class IntermediateStrategy:
|
class IntermediateStrategy:
|
||||||
"""
|
"""
|
||||||
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
|
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
|
||||||
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
|
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -24,7 +25,7 @@ class IntermediateStrategy:
|
||||||
|
|
||||||
class StrategyGenerator(ABC):
|
class StrategyGenerator(ABC):
|
||||||
"""
|
"""
|
||||||
StrategyGenerator is used to generate the same group of sharding strategies.
|
StrategyGenerator is used to generate the same group of sharding strategies.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device_mesh: DeviceMesh):
|
def __init__(self, device_mesh: DeviceMesh):
|
||||||
|
@ -39,7 +40,7 @@ class StrategyGenerator(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def validate(self, *args, **kwargs) -> bool:
|
def validate(self, *args, **kwargs) -> bool:
|
||||||
"""
|
"""
|
||||||
Validate if the operands are of desired shape.
|
Validate if the operands are of desired shape.
|
||||||
If True, means this generator can be used for the current operation.
|
If True, means this generator can be used for the current operation.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue