mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
57 lines
1.5 KiB
57 lines
1.5 KiB
from galore_torch import GaLoreAdafactor, GaLoreAdamW |
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
from .came import CAME |
|
from .cpu_adam import CPUAdam |
|
from .distributed_adafactor import DistributedAdaFactor |
|
from .distributed_came import DistributedCAME |
|
from .distributed_galore import DistGaloreAwamW |
|
from .distributed_lamb import DistributedLamb |
|
from .fused_adam import FusedAdam |
|
from .fused_lamb import FusedLAMB |
|
from .fused_sgd import FusedSGD |
|
from .galore import GaLoreAdamW8bit |
|
from .hybrid_adam import HybridAdam |
|
from .lamb import Lamb |
|
from .lars import Lars |
|
|
|
from .adafactor import Adafactor # noqa |
|
|
|
__all__ = [ |
|
"FusedLAMB", |
|
"FusedAdam", |
|
"FusedSGD", |
|
"Lamb", |
|
"Lars", |
|
"CPUAdam", |
|
"HybridAdam", |
|
"DistributedLamb", |
|
"DistGaloreAwamW", |
|
"GaLoreAdamW", |
|
"GaLoreAdafactor", |
|
"GaLoreAdamW8bit", |
|
"CAME", |
|
"DistributedCAME", |
|
"Adafactor", |
|
"DistributedAdaFactor", |
|
] |
|
|
|
optim2DistOptim = { |
|
GaLoreAdamW8bit: DistGaloreAwamW, |
|
Lamb: DistributedLamb, |
|
CAME: DistributedCAME, |
|
Adafactor: DistributedAdaFactor, |
|
} |
|
|
|
|
|
def cast_to_distributed(optim): |
|
if optim.__class__ in optim2DistOptim: |
|
_logger = get_dist_logger() |
|
_logger.info(f"Converting optimizer {optim.__class__.__name__} to its distributed version.", ranks=[0]) |
|
|
|
if isinstance(optim, GaLoreAdamW8bit): |
|
return optim2DistOptim[GaLoreAdamW8bit](optim.param_groups, args=optim.args) |
|
return optim2DistOptim[optim.__class__](optim.param_groups) |
|
|
|
return optim
|
|
|