mirror of https://github.com/hpcaitech/ColossalAI
[gemini] support gradient accumulation (#4869)
* add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 casepull/4990/head
parent
a41cf88e9b
commit
21ba89cab6
|
@ -245,6 +245,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
chunk_config_dict (dict, optional): chunk configuration dictionary.
|
||||
chunk_init_device (torch.device, optional): device to initialize the chunk.
|
||||
placement_policy (str, optional): "static" and "auto". Defaults to "static".
|
||||
enable_gradient_accumulation (bool, optional): Whether to enable gradient accumulation. When set to True, gradient will be stored after doing backward pass. Defaults to False.
|
||||
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
|
||||
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
|
||||
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
|
||||
|
@ -257,7 +258,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
|
||||
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
|
||||
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
|
||||
master_weights (bool, optional): master weights. Defaults to True.
|
||||
master_weights (bool, optional): Whether to keep fp32 master parameter weights in optimizer. Defaults to True.
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
|
||||
|
@ -291,6 +292,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
chunk_config_dict: Optional[dict] = None,
|
||||
chunk_init_device: Optional[torch.device] = None,
|
||||
placement_policy: str = "static",
|
||||
enable_gradient_accumulation: bool = False,
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
|
@ -323,6 +325,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
chunk_config_dict=chunk_config_dict,
|
||||
chunk_init_device=(chunk_init_device or get_current_device()),
|
||||
placement_policy=placement_policy,
|
||||
enable_gradient_accumulation=enable_gradient_accumulation,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
offload_param_frac=offload_param_frac,
|
||||
|
|
|
@ -335,4 +335,4 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
||||
return optimizer.optim.no_sync()
|
||||
return optimizer.no_sync()
|
||||
|
|
|
@ -434,6 +434,21 @@ class Chunk:
|
|||
if update_ptr:
|
||||
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
|
||||
|
||||
def add_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
||||
"""
|
||||
Add data slice to the memory space indexed by the input tensor in the chunk.
|
||||
Only used when accumulating gradient chunks.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the tensor used to retrieve meta information
|
||||
data_slice (torch.Tensor): the tensor to be added to the chunk
|
||||
"""
|
||||
# sanity check
|
||||
assert self.is_gathered
|
||||
|
||||
tensor_info = self.tensors_info[tensor]
|
||||
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].add_(data_slice.data.flatten())
|
||||
|
||||
def get_valid_length(self) -> int:
|
||||
"""Get the valid length of the chunk's payload."""
|
||||
if self.keep_gathered:
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import free_storage, get_current_device
|
||||
|
||||
from .chunk import Chunk, ChunkFullError, TensorState
|
||||
|
||||
|
@ -255,3 +255,37 @@ class ChunkManager:
|
|||
self.accessed_chunks.add(grad_chunk)
|
||||
self.accessed_mem += grad_chunk.chunk_mem
|
||||
return grad_chunk
|
||||
|
||||
def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
|
||||
"""Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction."""
|
||||
|
||||
assert chunk.grad_chunk is not None
|
||||
|
||||
# Make a backup for gradient accumulated before.
|
||||
# Here backup gradients should be multiplied, since it will be divided after gradient reduction.
|
||||
if chunk.grad_chunk.is_gathered:
|
||||
accumulated_grad = chunk.grad_chunk.cuda_global_chunk.clone().detach().mul_(chunk.pg_size)
|
||||
accumulated_grad_gathered = True
|
||||
else:
|
||||
if chunk.grad_chunk.cuda_shard is not None:
|
||||
accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size)
|
||||
else:
|
||||
accumulated_grad = (
|
||||
chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size)
|
||||
)
|
||||
accumulated_grad_gathered = False
|
||||
|
||||
# Reset grad_chunk, and chunk.grad_chunk will be accessed.
|
||||
grad_chunk = self.init_grad_chunk(chunk)
|
||||
grad_chunk.cuda_global_chunk.zero_()
|
||||
|
||||
# Add backup gradients to grad_chunk.
|
||||
if accumulated_grad_gathered:
|
||||
grad_chunk.cuda_global_chunk.add_(accumulated_grad)
|
||||
else:
|
||||
grad_chunk.cuda_global_chunk[grad_chunk.shard_begin : grad_chunk.shard_end].add_(accumulated_grad)
|
||||
|
||||
# Release accumulated_grad
|
||||
free_storage(accumulated_grad)
|
||||
|
||||
return grad_chunk
|
||||
|
|
|
@ -59,6 +59,7 @@ class GeminiDDP(ModelWrapper):
|
|||
chunk_config_dict: Optional[dict] = None,
|
||||
chunk_init_device: torch.device = torch.device("cpu"),
|
||||
placement_policy: str = "static",
|
||||
enable_gradient_accumulation: bool = False,
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
|
@ -119,6 +120,11 @@ class GeminiDDP(ModelWrapper):
|
|||
self.reuse_fp16_chunk = master_weights
|
||||
self.master_weights = master_weights
|
||||
|
||||
self.enable_gradient_accumulation = enable_gradient_accumulation
|
||||
if self.enable_gradient_accumulation:
|
||||
self.reuse_fp16_chunk = False
|
||||
self.accumulating_grads = False # Whether model is accumulating gradients
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
if self.gemini_manager._premade_memstats_:
|
||||
|
@ -298,6 +304,8 @@ class GeminiDDP(ModelWrapper):
|
|||
f"{error_str}",
|
||||
)
|
||||
self._setup_grads_ptr()
|
||||
if self.enable_gradient_accumulation and not self.accumulating_grads:
|
||||
self.accumulating_grads = True # Turn on the state of gradient accumulation.
|
||||
self._logger.debug(
|
||||
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
|
||||
)
|
||||
|
@ -327,7 +335,15 @@ class GeminiDDP(ModelWrapper):
|
|||
)
|
||||
grad_chunk = chunk
|
||||
if not self.reuse_fp16_chunk:
|
||||
if not self.accumulating_grads:
|
||||
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
|
||||
else:
|
||||
assert chunk.grad_chunk is not None
|
||||
if chunk.grad_chunk not in self.chunk_manager.accessed_chunks:
|
||||
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
|
||||
else:
|
||||
grad_chunk = chunk.grad_chunk
|
||||
|
||||
# hold -> compute -> hold after bwd
|
||||
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
|
||||
grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD)
|
||||
|
@ -336,7 +352,10 @@ class GeminiDDP(ModelWrapper):
|
|||
chunk.tensor_trans_state(p, TensorState.HOLD)
|
||||
|
||||
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
|
||||
if not self.accumulating_grads:
|
||||
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
|
||||
else:
|
||||
grad_chunk.add_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
|
||||
if reduced:
|
||||
if not self.reuse_fp16_chunk:
|
||||
|
@ -354,7 +373,7 @@ class GeminiDDP(ModelWrapper):
|
|||
if chunk.l2_norm_flag:
|
||||
grad_chunk.set_l2_norm()
|
||||
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
|
||||
if not self.master_weights:
|
||||
if not (self.master_weights) or (self.enable_gradient_accumulation):
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
||||
return empty_grad
|
||||
|
||||
|
|
|
@ -263,6 +263,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
self.zero_grad()
|
||||
if self.module.master_weights:
|
||||
self._update_fp16_params()
|
||||
self.module.accumulating_grads = False
|
||||
return ret
|
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Gradient Accumulation
|
||||
|
||||
Author: [Mingyan Jiang](https://github.com/jiangmingyan)
|
||||
Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Baizhou Zhang](https://github.com/Fridge003)
|
||||
|
||||
**Prerequisite**
|
||||
- [Training Booster](../basics/booster_api.md)
|
||||
|
@ -126,6 +126,7 @@ for idx, (img, label) in enumerate(train_dataloader):
|
|||
|
||||
```
|
||||
|
||||
|
||||
### Step 6. Invoke Training Scripts
|
||||
To verify gradient accumulation, we can just check the change of parameter values. When gradient accumulation is set, parameters are only updated in the last step. You can run the script using this command:
|
||||
```shell
|
||||
|
@ -142,4 +143,30 @@ iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0
|
|||
iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=<SliceBackward0>)
|
||||
```
|
||||
|
||||
|
||||
## Gradient Accumulation on GeminiPlugin
|
||||
|
||||
Currently the plugins supporting `no_sync()` method include `TorchDDPPlugin` and `LowLevelZeroPlugin` set to stage 1. `GeminiPlugin` doesn't support `no_sync()` method, but it can enable synchronized gradient accumulation in a torch-like way.
|
||||
|
||||
To enable gradient accumulation feature, the argument `enable_gradient_accumulation` should be set to `True` when initializing `GeminiPlugin`. Following is the pseudocode snippet of enabling gradient accumulation for `GeminiPlugin`:
|
||||
<!--- doc-test-ignore-start -->
|
||||
```python
|
||||
...
|
||||
plugin = GeminiPlugin(..., enable_gradient_accumulation=True)
|
||||
booster = Booster(plugin=plugin)
|
||||
...
|
||||
|
||||
...
|
||||
for idx, (input, label) in enumerate(train_dataloader):
|
||||
output = gemini_model(input.cuda())
|
||||
train_loss = criterion(output, label.cuda())
|
||||
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||
booster.backward(train_loss, gemini_optimizer)
|
||||
|
||||
if idx % (GRADIENT_ACCUMULATION - 1) == 0:
|
||||
gemini_optimizer.step() # zero_grad is automatically done
|
||||
...
|
||||
```
|
||||
<!--- doc-test-ignore-end -->
|
||||
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 gradient_accumulation_with_booster.py -->
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# 梯度累积
|
||||
|
||||
作者: [Mingyan Jiang](https://github.com/jiangmingyan)
|
||||
作者: [Mingyan Jiang](https://github.com/jiangmingyan), [Baizhou Zhang](https://github.com/Fridge003)
|
||||
|
||||
**前置教程**
|
||||
- [训练中使用Booster](../basics/booster_api.md)
|
||||
|
@ -93,6 +93,7 @@ model, optimizer, criterion, train_dataloader, _ = booster.boost(model=model,
|
|||
dataloader=train_dataloader)
|
||||
```
|
||||
|
||||
|
||||
### 步骤 5. 使用booster训练
|
||||
使用booster构建一个普通的训练循环,验证梯度累积。 `param_by_iter` 记录分布训练的信息。
|
||||
```python
|
||||
|
@ -144,4 +145,29 @@ iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0
|
|||
iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=<SliceBackward0>)
|
||||
```
|
||||
|
||||
## 在Gemini插件中使用梯度累积
|
||||
|
||||
目前支持`no_sync()`方法的插件包括 `TorchDDPPlugin` 和 `LowLevelZeroPlugin`(需要设置参数`stage`为1). `GeminiPlugin` 不支持 `no_sync()` 方法, 但是它可以通过和`pytorch`类似的方式来使用同步的梯度累积。
|
||||
|
||||
为了开启梯度累积功能,在初始化`GeminiPlugin`的时候需要将参数`enable_gradient_accumulation`设置为`True`。以下是 `GeminiPlugin` 进行梯度累积的伪代码片段:
|
||||
<!--- doc-test-ignore-start -->
|
||||
```python
|
||||
...
|
||||
plugin = GeminiPlugin(..., enable_gradient_accumulation=True)
|
||||
booster = Booster(plugin=plugin)
|
||||
...
|
||||
|
||||
...
|
||||
for idx, (input, label) in enumerate(train_dataloader):
|
||||
output = gemini_model(input.cuda())
|
||||
train_loss = criterion(output, label.cuda())
|
||||
train_loss = train_loss / GRADIENT_ACCUMULATION
|
||||
booster.backward(train_loss, gemini_optimizer)
|
||||
|
||||
if idx % (GRADIENT_ACCUMULATION - 1) == 0:
|
||||
gemini_optimizer.step() # zero_grad is automatically done
|
||||
...
|
||||
```
|
||||
<!--- doc-test-ignore-end -->
|
||||
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 gradient_accumulation_with_booster.py -->
|
||||
|
|
|
@ -52,7 +52,6 @@ def get_training_components():
|
|||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
)
|
||||
print("building BertForSequenceClassification model")
|
||||
|
||||
# adapting huggingface BertForSequenceClassification for single unittest calling interface
|
||||
class ModelAdaptor(BertForSequenceClassification):
|
||||
|
|
|
@ -14,6 +14,8 @@ from tests.kit.model_zoo import model_zoo
|
|||
_AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"]
|
||||
# These models have no parameters
|
||||
_LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"]
|
||||
# These models will cause stuck, to be fixed
|
||||
_STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
|
||||
|
||||
|
||||
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
|
@ -53,7 +55,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
|||
"""
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS
|
||||
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
|
||||
skipped_models = []
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from apex import amp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import set_seed
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.components_to_test import run_fwd
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
|
||||
{"placement_policy": "static", "shard_param_frac": 1.0}, # zero3
|
||||
{"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half
|
||||
{"placement_policy": "auto"},
|
||||
]
|
||||
|
||||
|
||||
def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
chunk_manager = model.chunk_manager
|
||||
grad_chunk_list = []
|
||||
device_list = []
|
||||
|
||||
# Access gradient chunks.
|
||||
for p in model.parameters():
|
||||
grad_chunk = chunk_manager.get_chunk(p).grad_chunk
|
||||
if grad_chunk not in grad_chunk_list:
|
||||
chunk_manager.access_chunk(grad_chunk)
|
||||
grad_chunk_list.append(grad_chunk)
|
||||
device_list.append(model.grads_device[p])
|
||||
|
||||
# Compare gradients.
|
||||
for p0, p1 in zip(model.parameters(), torch_model.parameters()):
|
||||
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
|
||||
|
||||
# Release gradient chunks and move them to gradient device.
|
||||
for grad_chunk, device in zip(grad_chunk_list, device_list):
|
||||
chunk_manager.release_chunk(grad_chunk)
|
||||
chunk_manager.move_chunk(grad_chunk, device, force_copy=True)
|
||||
|
||||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("keep_gathered", [False, True])
|
||||
@parameterize("model_name", ["gpt2", "bert"])
|
||||
@parameterize("use_grad_checkpoint", [False, True])
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_gemini_grad_acc(
|
||||
placement_config, keep_gathered: bool, model_name: str, use_grad_checkpoint: bool, master_weights: bool
|
||||
):
|
||||
init_device = get_current_device()
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
||||
|
||||
set_seed(42)
|
||||
gemini_model = model_builder(use_grad_checkpoint)
|
||||
|
||||
set_seed(42)
|
||||
torch_model = model_builder(use_grad_checkpoint).cuda()
|
||||
for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = keep_gathered
|
||||
gemini_model = GeminiDDP(
|
||||
gemini_model,
|
||||
config_dict,
|
||||
init_device,
|
||||
pin_memory=True,
|
||||
enable_gradient_accumulation=True,
|
||||
master_weights=master_weights,
|
||||
**placement_config,
|
||||
)
|
||||
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
|
||||
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1)
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
||||
# setting master_weights to False will cause overflow after optimizer.step()
|
||||
amp_config = dict(
|
||||
opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, min_loss_scale=1, max_loss_scale=1, master_weights=True
|
||||
)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = amp.initialize(torch_model, torch_optim, **amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[rank])
|
||||
|
||||
set_seed(rank)
|
||||
accum_iter = 4
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
delay_unscale = False if (i + 1) % accum_iter == 0 else True
|
||||
input_ids, label = input_ids.cuda(), label.cuda()
|
||||
|
||||
set_seed(42 + rank)
|
||||
torch_loss = run_fwd(torch_model, input_ids, label, criterion)
|
||||
torch_loss = torch_loss / accum_iter
|
||||
with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
|
||||
set_seed(42 + rank)
|
||||
gemini_loss = run_fwd(gemini_model, input_ids, label, criterion)
|
||||
gemini_loss = gemini_loss / accum_iter
|
||||
gemini_optim.backward(gemini_loss)
|
||||
|
||||
assert torch.allclose(torch_loss, gemini_loss, rtol=1e-3, atol=1e-5)
|
||||
|
||||
check_grad(gemini_model, torch_model)
|
||||
|
||||
if (i + 1) % accum_iter == 0:
|
||||
torch_optim.step()
|
||||
gemini_optim.step()
|
||||
torch_optim.zero_grad()
|
||||
|
||||
# check updated param
|
||||
torch_dict = torch_model.state_dict()
|
||||
gemini_dict = gemini_model.state_dict(only_rank_0=False)
|
||||
|
||||
for key, value in gemini_dict.items():
|
||||
torch_key = "module." + key
|
||||
torch_value = torch_dict[torch_key].to(value.device).to(value.dtype)
|
||||
assert_close(value, torch_value, rtol=1e-3, atol=2e-3)
|
||||
|
||||
if i == accum_iter:
|
||||
break
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
exam_gemini_grad_acc()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_grad_accumulation():
|
||||
spawn(run_dist, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_grad_accumulation()
|
Loading…
Reference in New Issue