[gemini] support gradient accumulation (#4869)

* add test

* fix no_sync bug in low level zero plugin

* fix test

* add argument for grad accum

* add grad accum in backward hook for gemini

* finish implementation, rewrite tests

* fix test

* skip stuck model in low level zero test

* update doc

* optimize communication & fix gradient checkpoint

* modify doc

* cleaning codes

* update cpu adam fp16 case
pull/4990/head
Baizhou Zhang 2023-10-17 14:07:21 +08:00 committed by GitHub
parent a41cf88e9b
commit 21ba89cab6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 283 additions and 10 deletions

View File

@ -245,6 +245,7 @@ class GeminiPlugin(DPPluginBase):
chunk_config_dict (dict, optional): chunk configuration dictionary.
chunk_init_device (torch.device, optional): device to initialize the chunk.
placement_policy (str, optional): "static" and "auto". Defaults to "static".
enable_gradient_accumulation (bool, optional): Whether to enable gradient accumulation. When set to True, gradient will be stored after doing backward pass. Defaults to False.
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
@ -257,7 +258,7 @@ class GeminiPlugin(DPPluginBase):
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
master_weights (bool, optional): master weights. Defaults to True.
master_weights (bool, optional): Whether to keep fp32 master parameter weights in optimizer. Defaults to True.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
@ -291,6 +292,7 @@ class GeminiPlugin(DPPluginBase):
chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
@ -323,6 +325,7 @@ class GeminiPlugin(DPPluginBase):
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy,
enable_gradient_accumulation=enable_gradient_accumulation,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,

View File

@ -335,4 +335,4 @@ class LowLevelZeroPlugin(DPPluginBase):
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.optim.no_sync()
return optimizer.no_sync()

View File

@ -434,6 +434,21 @@ class Chunk:
if update_ptr:
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
def add_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
"""
Add data slice to the memory space indexed by the input tensor in the chunk.
Only used when accumulating gradient chunks.
Args:
tensor (torch.Tensor): the tensor used to retrieve meta information
data_slice (torch.Tensor): the tensor to be added to the chunk
"""
# sanity check
assert self.is_gathered
tensor_info = self.tensors_info[tensor]
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].add_(data_slice.data.flatten())
def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload."""
if self.keep_gathered:

View File

@ -5,7 +5,7 @@ import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.utils import get_current_device
from colossalai.utils import free_storage, get_current_device
from .chunk import Chunk, ChunkFullError, TensorState
@ -255,3 +255,37 @@ class ChunkManager:
self.accessed_chunks.add(grad_chunk)
self.accessed_mem += grad_chunk.chunk_mem
return grad_chunk
def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
"""Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction."""
assert chunk.grad_chunk is not None
# Make a backup for gradient accumulated before.
# Here backup gradients should be multiplied, since it will be divided after gradient reduction.
if chunk.grad_chunk.is_gathered:
accumulated_grad = chunk.grad_chunk.cuda_global_chunk.clone().detach().mul_(chunk.pg_size)
accumulated_grad_gathered = True
else:
if chunk.grad_chunk.cuda_shard is not None:
accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size)
else:
accumulated_grad = (
chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size)
)
accumulated_grad_gathered = False
# Reset grad_chunk, and chunk.grad_chunk will be accessed.
grad_chunk = self.init_grad_chunk(chunk)
grad_chunk.cuda_global_chunk.zero_()
# Add backup gradients to grad_chunk.
if accumulated_grad_gathered:
grad_chunk.cuda_global_chunk.add_(accumulated_grad)
else:
grad_chunk.cuda_global_chunk[grad_chunk.shard_begin : grad_chunk.shard_end].add_(accumulated_grad)
# Release accumulated_grad
free_storage(accumulated_grad)
return grad_chunk

View File

@ -59,6 +59,7 @@ class GeminiDDP(ModelWrapper):
chunk_config_dict: Optional[dict] = None,
chunk_init_device: torch.device = torch.device("cpu"),
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
@ -119,6 +120,11 @@ class GeminiDDP(ModelWrapper):
self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights
self.enable_gradient_accumulation = enable_gradient_accumulation
if self.enable_gradient_accumulation:
self.reuse_fp16_chunk = False
self.accumulating_grads = False # Whether model is accumulating gradients
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
@ -298,6 +304,8 @@ class GeminiDDP(ModelWrapper):
f"{error_str}",
)
self._setup_grads_ptr()
if self.enable_gradient_accumulation and not self.accumulating_grads:
self.accumulating_grads = True # Turn on the state of gradient accumulation.
self._logger.debug(
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
)
@ -327,7 +335,15 @@ class GeminiDDP(ModelWrapper):
)
grad_chunk = chunk
if not self.reuse_fp16_chunk:
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
if not self.accumulating_grads:
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
else:
assert chunk.grad_chunk is not None
if chunk.grad_chunk not in self.chunk_manager.accessed_chunks:
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
else:
grad_chunk = chunk.grad_chunk
# hold -> compute -> hold after bwd
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD)
@ -336,7 +352,10 @@ class GeminiDDP(ModelWrapper):
chunk.tensor_trans_state(p, TensorState.HOLD)
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
if not self.accumulating_grads:
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
else:
grad_chunk.add_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
if reduced:
if not self.reuse_fp16_chunk:
@ -354,7 +373,7 @@ class GeminiDDP(ModelWrapper):
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
if not self.master_weights:
if not (self.master_weights) or (self.enable_gradient_accumulation):
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad

View File

@ -263,6 +263,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.zero_grad()
if self.module.master_weights:
self._update_fp16_params()
self.module.accumulating_grads = False
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):

View File

@ -1,6 +1,6 @@
# Gradient Accumulation
Author: [Mingyan Jiang](https://github.com/jiangmingyan)
Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Baizhou Zhang](https://github.com/Fridge003)
**Prerequisite**
- [Training Booster](../basics/booster_api.md)
@ -126,6 +126,7 @@ for idx, (img, label) in enumerate(train_dataloader):
```
### Step 6. Invoke Training Scripts
To verify gradient accumulation, we can just check the change of parameter values. When gradient accumulation is set, parameters are only updated in the last step. You can run the script using this command:
```shell
@ -142,4 +143,30 @@ iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0
iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=<SliceBackward0>)
```
## Gradient Accumulation on GeminiPlugin
Currently the plugins supporting `no_sync()` method include `TorchDDPPlugin` and `LowLevelZeroPlugin` set to stage 1. `GeminiPlugin` doesn't support `no_sync()` method, but it can enable synchronized gradient accumulation in a torch-like way.
To enable gradient accumulation feature, the argument `enable_gradient_accumulation` should be set to `True` when initializing `GeminiPlugin`. Following is the pseudocode snippet of enabling gradient accumulation for `GeminiPlugin`:
<!--- doc-test-ignore-start -->
```python
...
plugin = GeminiPlugin(..., enable_gradient_accumulation=True)
booster = Booster(plugin=plugin)
...
...
for idx, (input, label) in enumerate(train_dataloader):
output = gemini_model(input.cuda())
train_loss = criterion(output, label.cuda())
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, gemini_optimizer)
if idx % (GRADIENT_ACCUMULATION - 1) == 0:
gemini_optimizer.step() # zero_grad is automatically done
...
```
<!--- doc-test-ignore-end -->
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 gradient_accumulation_with_booster.py -->

View File

@ -1,6 +1,6 @@
# 梯度累积
作者: [Mingyan Jiang](https://github.com/jiangmingyan)
作者: [Mingyan Jiang](https://github.com/jiangmingyan), [Baizhou Zhang](https://github.com/Fridge003)
**前置教程**
- [训练中使用Booster](../basics/booster_api.md)
@ -93,6 +93,7 @@ model, optimizer, criterion, train_dataloader, _ = booster.boost(model=model,
dataloader=train_dataloader)
```
### 步骤 5. 使用booster训练
使用booster构建一个普通的训练循环验证梯度累积。 `param_by_iter` 记录分布训练的信息。
```python
@ -144,4 +145,29 @@ iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0
iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=<SliceBackward0>)
```
## 在Gemini插件中使用梯度累积
目前支持`no_sync()`方法的插件包括 `TorchDDPPlugin``LowLevelZeroPlugin`(需要设置参数`stage`为1. `GeminiPlugin` 不支持 `no_sync()` 方法, 但是它可以通过和`pytorch`类似的方式来使用同步的梯度累积。
为了开启梯度累积功能,在初始化`GeminiPlugin`的时候需要将参数`enable_gradient_accumulation`设置为`True`。以下是 `GeminiPlugin` 进行梯度累积的伪代码片段:
<!--- doc-test-ignore-start -->
```python
...
plugin = GeminiPlugin(..., enable_gradient_accumulation=True)
booster = Booster(plugin=plugin)
...
...
for idx, (input, label) in enumerate(train_dataloader):
output = gemini_model(input.cuda())
train_loss = criterion(output, label.cuda())
train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, gemini_optimizer)
if idx % (GRADIENT_ACCUMULATION - 1) == 0:
gemini_optimizer.step() # zero_grad is automatically done
...
```
<!--- doc-test-ignore-end -->
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 gradient_accumulation_with_booster.py -->

View File

@ -52,7 +52,6 @@ def get_training_components():
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
)
print("building BertForSequenceClassification model")
# adapting huggingface BertForSequenceClassification for single unittest calling interface
class ModelAdaptor(BertForSequenceClassification):

View File

@ -14,6 +14,8 @@ from tests.kit.model_zoo import model_zoo
_AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"]
# These models have no parameters
_LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"]
# These models will cause stuck, to be fixed
_STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
@ -53,7 +55,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
"""
passed_models = []
failed_info = {} # (model_name, error) pair
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
skipped_models = []
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():

View File

@ -0,0 +1,147 @@
import pytest
import torch
import torch.distributed as dist
from apex import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd
from tests.components_to_test.registry import non_distributed_component_funcs
PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
{"placement_policy": "static", "shard_param_frac": 1.0}, # zero3
{"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half
{"placement_policy": "auto"},
]
def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager
grad_chunk_list = []
device_list = []
# Access gradient chunks.
for p in model.parameters():
grad_chunk = chunk_manager.get_chunk(p).grad_chunk
if grad_chunk not in grad_chunk_list:
chunk_manager.access_chunk(grad_chunk)
grad_chunk_list.append(grad_chunk)
device_list.append(model.grads_device[p])
# Compare gradients.
for p0, p1 in zip(model.parameters(), torch_model.parameters()):
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
# Release gradient chunks and move them to gradient device.
for grad_chunk, device in zip(grad_chunk_list, device_list):
chunk_manager.release_chunk(grad_chunk)
chunk_manager.move_chunk(grad_chunk, device, force_copy=True)
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [False, True])
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True])
def exam_gemini_grad_acc(
placement_config, keep_gathered: bool, model_name: str, use_grad_checkpoint: bool, master_weights: bool
):
init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func()
set_seed(42)
gemini_model = model_builder(use_grad_checkpoint)
set_seed(42)
torch_model = model_builder(use_grad_checkpoint).cuda()
for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gathered
gemini_model = GeminiDDP(
gemini_model,
config_dict,
init_device,
pin_memory=True,
enable_gradient_accumulation=True,
master_weights=master_weights,
**placement_config,
)
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1)
rank = dist.get_rank()
# setting master_weights to False will cause overflow after optimizer.step()
amp_config = dict(
opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, min_loss_scale=1, max_loss_scale=1, master_weights=True
)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = amp.initialize(torch_model, torch_optim, **amp_config)
torch_model = DDP(torch_model, device_ids=[rank])
set_seed(rank)
accum_iter = 4
for i, (input_ids, label) in enumerate(train_dataloader):
delay_unscale = False if (i + 1) % accum_iter == 0 else True
input_ids, label = input_ids.cuda(), label.cuda()
set_seed(42 + rank)
torch_loss = run_fwd(torch_model, input_ids, label, criterion)
torch_loss = torch_loss / accum_iter
with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward()
set_seed(42 + rank)
gemini_loss = run_fwd(gemini_model, input_ids, label, criterion)
gemini_loss = gemini_loss / accum_iter
gemini_optim.backward(gemini_loss)
assert torch.allclose(torch_loss, gemini_loss, rtol=1e-3, atol=1e-5)
check_grad(gemini_model, torch_model)
if (i + 1) % accum_iter == 0:
torch_optim.step()
gemini_optim.step()
torch_optim.zero_grad()
# check updated param
torch_dict = torch_model.state_dict()
gemini_dict = gemini_model.state_dict(only_rank_0=False)
for key, value in gemini_dict.items():
torch_key = "module." + key
torch_value = torch_dict[torch_key].to(value.device).to(value.dtype)
assert_close(value, torch_value, rtol=1e-3, atol=2e-3)
if i == accum_iter:
break
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_gemini_grad_acc()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_grad_accumulation():
spawn(run_dist, 2)
if __name__ == "__main__":
test_grad_accumulation()