You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/auto_parallel/tensor_shard/options.py

54 lines
1.5 KiB

from dataclasses import dataclass
from enum import Enum
__all__ = ["SolverOptions", "SolverPerference", "DataloaderOption", "ShardOption"]
class SolverPerference(Enum):
"""
This enum class is to define the solver preference.
"""
STANDARD = 0
DP = 1
TP = 2
class ShardOption(Enum):
"""
This enum class is to define the shard level required in node strategies.
Notes:
STANDARD: We do not add any extra shard requirements.
SHARD: We require the node to be shard using at least one device mesh axis.
SHARD_ONE_AXIS: We require the node to be shard using the last device mesh axis.
FULL_SHARD: We require the node to be shard using all device mesh axes.
TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
"""
STANDARD = 0
SHARD = 1
SHARD_LAST_AXIS = 2
FULL_SHARD = 3
class DataloaderOption(Enum):
"""
This enum class is to define the dataloader option.
"""
REPLICATED = 0
DISTRIBUTED = 1
@dataclass
class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
solver_perference: SolverPerference = SolverPerference.STANDARD
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
shard_option: ShardOption = ShardOption.STANDARD