[Feature] auto-cast optimizers to distributed version (#5746)

* auto-cast optimizers to distributed

* fix galore casting

* logger

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/5668/head
Edenzzzz 6 months ago committed by GitHub
parent 2fc85abf43
commit 5f8c0a0ac3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -27,7 +27,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
@ -1179,6 +1179,10 @@ class HybridParallelPlugin(PipelinePluginBase):
# TODO: Support Galore + ZeRO # TODO: Support Galore + ZeRO
zero_stage = self.zero_stage zero_stage = self.zero_stage
zero_config = deepcopy(self.zero_config) zero_config = deepcopy(self.zero_config)
# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_config["partition_grad"] = False zero_config["partition_grad"] = False

@ -32,7 +32,7 @@ from colossalai.checkpoint_io.utils import (
) )
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
@ -437,6 +437,10 @@ class LowLevelZeroPlugin(DPPluginBase):
zero_stage = self.stage zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs} zero_optim_kwargs = {**self.zero_optim_kwargs}
dp_size = dist.get_world_size() dp_size = dist.get_world_size()
# Replace with the distributed implementation if exists
optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_optim_kwargs["partition_grad"] = False zero_optim_kwargs["partition_grad"] = False

@ -1,5 +1,7 @@
from galore_torch import GaLoreAdafactor, GaLoreAdamW from galore_torch import GaLoreAdafactor, GaLoreAdamW
from colossalai.logging import get_dist_logger
from .came import CAME from .came import CAME
from .cpu_adam import CPUAdam from .cpu_adam import CPUAdam
from .distributed_adafactor import DistributedAdaFactor from .distributed_adafactor import DistributedAdaFactor
@ -34,3 +36,22 @@ __all__ = [
"Adafactor", "Adafactor",
"DistributedAdaFactor", "DistributedAdaFactor",
] ]
optim2DistOptim = {
GaLoreAdamW8bit: DistGaloreAwamW,
Lamb: DistributedLamb,
CAME: DistributedCAME,
Adafactor: DistributedAdaFactor,
}
_logger = get_dist_logger()
def cast_to_distributed(optim):
if optim.__class__ in optim2DistOptim:
_logger.info(f"Converting optimizer {optim.__class__.__name__} to its distributed version.", ranks=[0])
if isinstance(optim, GaLoreAdamW8bit):
return optim2DistOptim[GaLoreAdamW8bit](optim.param_groups, args=optim.args)
return optim2DistOptim[optim.__class__](optim.param_groups)
return optim

@ -34,9 +34,6 @@ class DistributedCAME(DistributedOptim):
betas=(0.9, 0.999, 0.9999), betas=(0.9, 0.999, 0.9999),
weight_decay=0.0, weight_decay=0.0,
): ):
assert lr > 0.0
assert all([0.0 <= beta <= 1.0 for beta in betas])
defaults = dict( defaults = dict(
lr=lr, lr=lr,
eps=eps, eps=eps,

@ -43,12 +43,13 @@ class DistGaloreAwamW(DistributedOptim, Optimizer2State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not. Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
args (dict, optional): quantization-related arguments. If passed, will override all quantization args above.
""" """
def __init__( def __init__(
self, self,
params, params,
lr=1e-3, lr=1e-2,
betas=(0.9, 0.999), betas=(0.9, 0.999),
eps=1e-8, eps=1e-8,
weight_decay=1e-2, weight_decay=1e-2,
@ -57,6 +58,7 @@ class DistGaloreAwamW(DistributedOptim, Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
is_paged=False, is_paged=False,
args=None,
): ):
super().__init__( super().__init__(
"adam", "adam",
@ -65,13 +67,14 @@ class DistGaloreAwamW(DistributedOptim, Optimizer2State):
betas, betas,
eps, eps,
weight_decay, weight_decay,
nbits, optim_bits=nbits,
None, args=args,
min_8bit_size, min_8bit_size=min_8bit_size,
percentile_clipping, percentile_clipping=percentile_clipping,
block_wise, block_wise=block_wise,
is_paged=is_paged, is_paged=is_paged,
) )
self.tp_size = 1 self.tp_size = 1
self.dp_size = 1 self.dp_size = 1
self.is_dist = {} self.is_dist = {}

@ -184,6 +184,7 @@ class GaLoreAdamW8bit(Optimizer2State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`): is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not. Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
args (dict, optional): quantization-related arguments. If passed, will override all quantization args above.
Example: Example:
""" """
@ -200,6 +201,7 @@ class GaLoreAdamW8bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
is_paged=False, is_paged=False,
args=None,
): ):
super().__init__( super().__init__(
"adam", "adam",
@ -208,11 +210,11 @@ class GaLoreAdamW8bit(Optimizer2State):
betas, betas,
eps, eps,
weight_decay, weight_decay,
nbits, optim_bits=nbits,
None, args=args,
min_8bit_size, min_8bit_size=min_8bit_size,
percentile_clipping, percentile_clipping=percentile_clipping,
block_wise, block_wise=block_wise,
is_paged=is_paged, is_paged=is_paged,
) )

@ -9,7 +9,8 @@ Author: [Wenxuan Tan](https://github.com/Edenzzzz), [Junwen Duan](https://github
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962) - [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962)
## Introduction ## Introduction
Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to efficiently update parameters, and are thus not directly applicable to parallel settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO using plugins. Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to update parameters, and thus aren't directly applicable to settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO plugins, which automatically uses distributed optimizers with 0 code change.
## Optimizers ## Optimizers
Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant. Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant.
@ -21,7 +22,7 @@ Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }} {{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
## Hands-On Practice ## Hands-On Practice
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. **Note that even if you're not aware of distributed optimizers, the plugins automatically casts yours to the distributed version for convenience.**
### step 1. Import libraries ### step 1. Import libraries
```python ```python

@ -9,7 +9,7 @@ Author: Wenxuan Tan, Junwen Duan, Renjie Mao
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962) - [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962)
## 介绍 ## 介绍
除了广泛采用的Adam和SGD外许多现代优化器需要逐层统计信息以有效更新参数因此无法直接应用于模型层在多个设备上分片的并行设置。我们以提供了优化的分布式实现并且通过插件与Tensor Parallel、DDP和ZeRO无缝集成。 除了广泛采用的Adam和SGD外许多现代优化器需要逐层统计信息以有效更新参数因此无法直接应用于模型层在多个设备上分片的并行设置。我们以提供了优化的分布式实现并且通过plugin与Tensor Parallel、DDP和ZeRO无缝集成。
## 优化器 ## 优化器
Adafactor 是一种首次采用非负矩阵分解NMF的 Adam 变体用于减少内存占用。CAME 通过引入一个置信度矩阵来改进 NMF 的效果。GaLore 通过将梯度投影到低秩空间,并使用 8 位块状量化进一步减少内存占用。Lamb 允许使用巨大的批量大小而不失准确性,通过按其 Lipschitz 常数的倒数界定的逐层自适应更新实现 Adafactor 是一种首次采用非负矩阵分解NMF的 Adam 变体用于减少内存占用。CAME 通过引入一个置信度矩阵来改进 NMF 的效果。GaLore 通过将梯度投影到低秩空间,并使用 8 位块状量化进一步减少内存占用。Lamb 允许使用巨大的批量大小而不失准确性,通过按其 Lipschitz 常数的倒数界定的逐层自适应更新实现
@ -21,7 +21,7 @@ Adafactor 是一种首次采用非负矩阵分解NMF的 Adam 变体,用
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }} {{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
## 使用 ## 使用
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. 现在我们展示如何使用分布式 Adafactor 与 booster API 结合 Tensor Parallel 和 ZeRO 2。即使您不使用distributed optimizerplugin 也会自动将optimizer转换为分布式版本以方便使用。
### step 1. 导包 ### step 1. 导包
```python ```python
@ -34,15 +34,13 @@ import torch
``` ```
### step 2. 初始化分布式 ### step 2. 初始化分布式
We need to initialize distributed environment. For demo purpose, we use `colossal run --nproc_per_node 4`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) 我们需要先初始化分布式环境. 为了展示, 我们使用 `colossal run --nproc_per_node 4`. 更多初始化方式请参考 [Launch Colossal-AI](../basics/launch_colossalai.md)
```python ```python
colossalai.launch_from_torch() colossalai.launch_from_torch()
``` ```
### step 3. 初始化模型和优化器 ### step 3. 初始化模型和优化器
Build our model. We created an MLP using two Linear Layer.
```python ```python
configuration = LlamaConfig() configuration = LlamaConfig()
model = LlamaModel(configuration).cuda() model = LlamaModel(configuration).cuda()

@ -552,7 +552,7 @@ def exam_bert_test_on_lowlevelzero_plugin(test_config):
sharded_optimizer, sharded_optimizer,
criterion, criterion,
booster, booster,
) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor) ) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, Adafactor)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_low_level_zero_plugin( org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_low_level_zero_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster

@ -416,7 +416,7 @@ def exam_bert_test_on_hybrid_plugin(test_config):
sharded_optimizer, sharded_optimizer,
criterion, criterion,
booster, booster,
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, CAME, DistributedCAME) ) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, CAME, CAME)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster

@ -306,8 +306,8 @@ def check_dist_galore(rank, world_size, port):
global coordinator global coordinator
coordinator = DistCoordinator() coordinator = DistCoordinator()
run_dist_galore_basic() # run_dist_galore_basic()
coordinator.print_on_master("Basic backward tests passed") # coordinator.print_on_master("Basic backward tests passed")
coordinator.print_on_master("Skipping forward-backward tests due to SVD instability") coordinator.print_on_master("Skipping forward-backward tests due to SVD instability")
# run_dist_galore_fwd_bwd() # run_dist_galore_fwd_bwd()
@ -319,7 +319,7 @@ def check_dist_galore(rank, world_size, port):
) )
for config in test_config: for config in test_config:
try: try:
run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=DistGaloreAwamW) run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=GaLoreAdamW8bit)
except Exception as e: except Exception as e:
print(e) print(e)
dist.barrier() dist.barrier()

@ -289,7 +289,7 @@ def check_dist_lamb(rank, world_size, port):
run_dist_lamb_fwd_bwd() run_dist_lamb_fwd_bwd()
coordinator.print_on_master("Forward-backward tests passed") coordinator.print_on_master("Forward-backward tests passed")
run_bert_test(optim_class=Lamb, sharded_optim_class=DistributedLamb) run_bert_test(optim_class=Lamb, sharded_optim_class=Lamb)
print(f"rank {rank} tests passed :)") print(f"rank {rank} tests passed :)")

@ -17,7 +17,7 @@ from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.checkpoint_io.utils import gather_distributed_param from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import DistGaloreAwamW from colossalai.nn.optimizer import GaLoreAdamW8bit
from colossalai.nn.optimizer.galore import get_galore_param_groups from colossalai.nn.optimizer.galore import get_galore_param_groups
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
@ -130,7 +130,7 @@ def build_model_from_hybrid_plugin(
if use_lazy_init: if use_lazy_init:
ctx.materialize(org_model) ctx.materialize(org_model)
org_model = org_model.cuda() org_model = org_model.cuda()
if sharded_optim_class == DistGaloreAwamW: if optim_class == GaLoreAdamW8bit:
# Disable clipping and block-wise quantization # Disable clipping and block-wise quantization
org_optimizer = optim_class( org_optimizer = optim_class(
get_galore_param_groups(org_model, weight_decay=0, rank=4), get_galore_param_groups(org_model, weight_decay=0, rank=4),

Loading…
Cancel
Save