diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py index 4e39fcd8e..5f6cc69ba 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py @@ -1,6 +1,7 @@ -from dataclasses import dataclass from abc import ABC, abstractmethod -from typing import List, Dict +from dataclasses import dataclass +from typing import Dict, List + from colossalai.device.device_mesh import DeviceMesh __all__ = ['IntermediateStrategy', 'StrategyGenerator'] @@ -9,7 +10,7 @@ __all__ = ['IntermediateStrategy', 'StrategyGenerator'] @dataclass class IntermediateStrategy: """ - IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is + IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler. Args: @@ -24,7 +25,7 @@ class IntermediateStrategy: class StrategyGenerator(ABC): """ - StrategyGenerator is used to generate the same group of sharding strategies. + StrategyGenerator is used to generate the same group of sharding strategies. """ def __init__(self, device_mesh: DeviceMesh): @@ -39,7 +40,7 @@ class StrategyGenerator(ABC): @abstractmethod def validate(self, *args, **kwargs) -> bool: """ - Validate if the operands are of desired shape. + Validate if the operands are of desired shape. If True, means this generator can be used for the current operation. """ pass