[amp] add gradient clipping for unit tests (#2283)

* [amp] add gradient clipping in unit tests

* fix bugs
pull/2312/head
HELSON 2023-01-04 11:59:56 +08:00 committed by GitHub
parent e00cedd181
commit 5d3a2be3af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 64 additions and 44 deletions

View File

@ -147,6 +147,12 @@ class FP16Optimizer(Optimizer):
f"==========================================",
ranks=[0])
@property
def max_norm(self):
"""Returns the maximum norm of gradient clipping.
"""
return self._clip_grad_max_norm
@property
def grad_scaler(self):
"""Returns the gradient scaler.

View File

@ -1,17 +1,20 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import Any
from torch.optim import Optimizer
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.nn.optimizer import ColossalaiOptimizer
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ReduceOp
from torch.optim import Optimizer
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer
from ._fp16_optimizer import FP16Optimizer
@ -40,7 +43,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
return self.optim.step()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
pass
if self.optim.max_norm == max_norm:
return
raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). "
"If you have supplied clip_grad_norm in the amp_config, "
"executing the method clip_grad_norm is not allowed.")
class NaiveAMPModel(nn.Module):

View File

@ -2,6 +2,7 @@ import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.testing import assert_close
def assert_equal(a: Tensor, b: Tensor):
@ -12,12 +13,8 @@ def assert_not_equal(a: Tensor, b: Tensor):
assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}'
def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8):
assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}'
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
assert_close(a, b, rtol, atol)
assert_close(a, b, rtol=rtol, atol=atol)
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
@ -30,4 +27,4 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
for i in range(world_size - 1):
a = tensor_list[i]
b = tensor_list[i + 1]
assert torch.all(a == b), f'expected tensors on rank {i} and {i+1} to be equal but they are not, {a} vs {b}'
assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}'

View File

@ -1,18 +1,16 @@
import copy
from functools import partial
import pytest
import torch
import colossalai
import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
import copy
import pytest
from functools import partial
def check_equal(a, b):
"""
@ -23,7 +21,7 @@ def check_equal(a, b):
def run_naive_amp():
"""
In this test, we compare the naive fp16 optimizer implemented in colossalai
In this test, we compare the naive fp16 optimizer implemented in colossalai
and fp32 torch optimizer
"""
@ -41,11 +39,12 @@ def run_naive_amp():
apex_amp_model = copy.deepcopy(naive_amp_model)
# create optimizer
naive_amp_optimizer = optim_class(naive_amp_model.parameters(), lr=1e-3)
apex_amp_optimizer = optim_class(apex_amp_model.parameters(), lr=1e-3)
# we use SGD here, since the correctness of gradient clipping can't be tested with Adam
naive_amp_optimizer = torch.optim.SGD(naive_amp_model.parameters(), lr=1e-3)
apex_amp_optimizer = torch.optim.SGD(apex_amp_model.parameters(), lr=1e-3)
# inject naive and apex amp
naive_amp_config = dict(initial_scale=128)
naive_amp_config = dict(initial_scale=128, clip_grad_norm=1.0)
naive_amp_model, naive_amp_optimizer = convert_to_naive_amp(naive_amp_model, naive_amp_optimizer,
naive_amp_config)
apex_amp_config = dict(opt_level='O2', loss_scale=128, keep_batchnorm_fp32=False)
@ -62,13 +61,17 @@ def run_naive_amp():
assert_close_loose(naive_amp_output, apex_amp_output)
# backward
naive_amp_optimizer.backward(naive_amp_output.mean())
apex_amp_optimizer.backward(apex_amp_output.mean())
# use sum() to get big gradient
naive_amp_optimizer.backward(naive_amp_output.sum())
apex_amp_optimizer.backward(apex_amp_output.sum())
# check grad
for naive_amp_param, apex_amp_param in zip(naive_amp_model.parameters(), apex_amp_model.parameters()):
assert_close_loose(naive_amp_param.grad, apex_amp_param.grad)
# clip gradient
apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0)
# step
naive_amp_optimizer.step()
apex_amp_optimizer.step()

View File

@ -1,14 +1,15 @@
import copy
from functools import partial
import pytest
import torch
import colossalai
import torch.multiprocessing as mp
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.amp import convert_to_torch_amp, convert_to_apex_amp
import copy
import pytest
from functools import partial
from tests.components_to_test.registry import non_distributed_component_funcs
def run_torch_amp():
@ -30,15 +31,16 @@ def run_torch_amp():
apex_amp_model = copy.deepcopy(torch_amp_model)
# create optimizer
torch_amp_optimizer = optim_class(torch_amp_model.parameters(), lr=1e-3)
apex_amp_optimizer = optim_class(apex_amp_model.parameters(), lr=1e-3)
# we use SGD here, since the correctness of gradient clipping can't be tested with Adam
torch_amp_optimizer = torch.optim.SGD(torch_amp_model.parameters(), lr=1e-3)
apex_amp_optimizer = torch.optim.SGD(apex_amp_model.parameters(), lr=1e-3)
# inject torch and apex amp
torch_amp_config = dict(init_scale=1280, enabled=True)
torch_amp_config = dict(init_scale=128, enabled=True)
torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp(torch_amp_model,
torch_amp_optimizer,
amp_config=torch_amp_config)
apex_amp_config = dict(opt_level='O1', loss_scale=1280)
apex_amp_config = dict(opt_level='O1', loss_scale=128)
apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config)
# create data
@ -55,14 +57,19 @@ def run_torch_amp():
assert_close_loose(torch_amp_param, apex_amp_param)
# backward
torch_amp_optimizer.backward(torch_amp_output.mean())
apex_amp_optimizer.backward(apex_amp_output.mean())
# use sum() to get big gradient
torch_amp_optimizer.backward(torch_amp_output.sum())
apex_amp_optimizer.backward(apex_amp_output.sum())
# check grad
# In apex amp, grad is not scaled before backward, but torch amp does
for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()):
assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config['loss_scale'])
# clip gradient
apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0)
torch_amp_optimizer.clip_grad_norm(model=torch_amp_model, max_norm=1.0)
# step
torch_amp_optimizer.step()
apex_amp_optimizer.step()