mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.5 KiB
50 lines
1.5 KiB
2 years ago
|
from dataclasses import dataclass
|
||
|
from enum import Enum
|
||
|
|
||
|
__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption']
|
||
|
|
||
|
|
||
|
class SolverPerference(Enum):
|
||
|
"""
|
||
|
This enum class is to define the solver preference.
|
||
|
"""
|
||
|
STANDARD = 0
|
||
|
DP = 1
|
||
|
TP = 2
|
||
|
|
||
|
|
||
|
class ShardOption(Enum):
|
||
|
"""
|
||
|
This enum class is to define the shard level required in node strategies.
|
||
|
|
||
|
Notes:
|
||
|
STANDARD: We do not add any extra shard requirements.
|
||
|
SHARD: We require the node to be shard using at least one device mesh axis.
|
||
|
SHARD_ONE_AXIS: We require the node to be shard using the last device mesh axis.
|
||
|
FULL_SHARD: We require the node to be shard using all device mesh axes.
|
||
|
TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
|
||
|
TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
|
||
|
"""
|
||
|
STANDARD = 0
|
||
|
SHARD = 1
|
||
|
SHARD_LAST_AXIS = 2
|
||
|
FULL_SHARD = 3
|
||
|
|
||
|
|
||
|
class DataloaderOption(Enum):
|
||
|
"""
|
||
|
This enum class is to define the dataloader option.
|
||
|
"""
|
||
|
REPLICATED = 0
|
||
|
DISTRIBUTED = 1
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class SolverOptions:
|
||
|
"""
|
||
|
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
|
||
|
"""
|
||
|
solver_perference: SolverPerference = SolverPerference.STANDARD
|
||
|
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
|
||
|
shard_option: ShardOption = ShardOption.STANDARD
|