mirror of https://github.com/hpcaitech/ColossalAI
[amp] add gradient clipping for unit tests (#2283)
* [amp] add gradient clipping in unit tests * fix bugspull/2312/head
parent
e00cedd181
commit
5d3a2be3af
|
@ -147,6 +147,12 @@ class FP16Optimizer(Optimizer):
|
|||
f"==========================================",
|
||||
ranks=[0])
|
||||
|
||||
@property
|
||||
def max_norm(self):
|
||||
"""Returns the maximum norm of gradient clipping.
|
||||
"""
|
||||
return self._clip_grad_max_norm
|
||||
|
||||
@property
|
||||
def grad_scaler(self):
|
||||
"""Returns the gradient scaler.
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from typing import Any
|
||||
from torch.optim import Optimizer
|
||||
from torch.distributed import ReduceOp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ReduceOp
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
|
||||
from ._fp16_optimizer import FP16Optimizer
|
||||
|
||||
|
||||
|
@ -40,7 +43,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
|
|||
return self.optim.step()
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
pass
|
||||
if self.optim.max_norm == max_norm:
|
||||
return
|
||||
raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). "
|
||||
"If you have supplied clip_grad_norm in the amp_config, "
|
||||
"executing the method clip_grad_norm is not allowed.")
|
||||
|
||||
|
||||
class NaiveAMPModel(nn.Module):
|
||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.testing import assert_close
|
||||
|
||||
|
||||
def assert_equal(a: Tensor, b: Tensor):
|
||||
|
@ -12,12 +13,8 @@ def assert_not_equal(a: Tensor, b: Tensor):
|
|||
assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}'
|
||||
|
||||
|
||||
def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8):
|
||||
assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}'
|
||||
|
||||
|
||||
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
|
||||
assert_close(a, b, rtol, atol)
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
|
||||
|
@ -30,4 +27,4 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
|
|||
for i in range(world_size - 1):
|
||||
a = tensor_list[i]
|
||||
b = tensor_list[i + 1]
|
||||
assert torch.all(a == b), f'expected tensors on rank {i} and {i+1} to be equal but they are not, {a} vs {b}'
|
||||
assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}'
|
||||
|
|
|
@ -1,18 +1,16 @@
|
|||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import colossalai
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp
|
||||
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
import copy
|
||||
import pytest
|
||||
from functools import partial
|
||||
|
||||
|
||||
def check_equal(a, b):
|
||||
"""
|
||||
|
@ -23,7 +21,7 @@ def check_equal(a, b):
|
|||
|
||||
def run_naive_amp():
|
||||
"""
|
||||
In this test, we compare the naive fp16 optimizer implemented in colossalai
|
||||
In this test, we compare the naive fp16 optimizer implemented in colossalai
|
||||
and fp32 torch optimizer
|
||||
"""
|
||||
|
||||
|
@ -41,11 +39,12 @@ def run_naive_amp():
|
|||
apex_amp_model = copy.deepcopy(naive_amp_model)
|
||||
|
||||
# create optimizer
|
||||
naive_amp_optimizer = optim_class(naive_amp_model.parameters(), lr=1e-3)
|
||||
apex_amp_optimizer = optim_class(apex_amp_model.parameters(), lr=1e-3)
|
||||
# we use SGD here, since the correctness of gradient clipping can't be tested with Adam
|
||||
naive_amp_optimizer = torch.optim.SGD(naive_amp_model.parameters(), lr=1e-3)
|
||||
apex_amp_optimizer = torch.optim.SGD(apex_amp_model.parameters(), lr=1e-3)
|
||||
|
||||
# inject naive and apex amp
|
||||
naive_amp_config = dict(initial_scale=128)
|
||||
naive_amp_config = dict(initial_scale=128, clip_grad_norm=1.0)
|
||||
naive_amp_model, naive_amp_optimizer = convert_to_naive_amp(naive_amp_model, naive_amp_optimizer,
|
||||
naive_amp_config)
|
||||
apex_amp_config = dict(opt_level='O2', loss_scale=128, keep_batchnorm_fp32=False)
|
||||
|
@ -62,13 +61,17 @@ def run_naive_amp():
|
|||
assert_close_loose(naive_amp_output, apex_amp_output)
|
||||
|
||||
# backward
|
||||
naive_amp_optimizer.backward(naive_amp_output.mean())
|
||||
apex_amp_optimizer.backward(apex_amp_output.mean())
|
||||
# use sum() to get big gradient
|
||||
naive_amp_optimizer.backward(naive_amp_output.sum())
|
||||
apex_amp_optimizer.backward(apex_amp_output.sum())
|
||||
|
||||
# check grad
|
||||
for naive_amp_param, apex_amp_param in zip(naive_amp_model.parameters(), apex_amp_model.parameters()):
|
||||
assert_close_loose(naive_amp_param.grad, apex_amp_param.grad)
|
||||
|
||||
# clip gradient
|
||||
apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0)
|
||||
|
||||
# step
|
||||
naive_amp_optimizer.step()
|
||||
apex_amp_optimizer.step()
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import colossalai
|
||||
import torch.multiprocessing as mp
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp
|
||||
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.amp import convert_to_torch_amp, convert_to_apex_amp
|
||||
|
||||
import copy
|
||||
import pytest
|
||||
from functools import partial
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def run_torch_amp():
|
||||
|
@ -30,15 +31,16 @@ def run_torch_amp():
|
|||
apex_amp_model = copy.deepcopy(torch_amp_model)
|
||||
|
||||
# create optimizer
|
||||
torch_amp_optimizer = optim_class(torch_amp_model.parameters(), lr=1e-3)
|
||||
apex_amp_optimizer = optim_class(apex_amp_model.parameters(), lr=1e-3)
|
||||
# we use SGD here, since the correctness of gradient clipping can't be tested with Adam
|
||||
torch_amp_optimizer = torch.optim.SGD(torch_amp_model.parameters(), lr=1e-3)
|
||||
apex_amp_optimizer = torch.optim.SGD(apex_amp_model.parameters(), lr=1e-3)
|
||||
|
||||
# inject torch and apex amp
|
||||
torch_amp_config = dict(init_scale=1280, enabled=True)
|
||||
torch_amp_config = dict(init_scale=128, enabled=True)
|
||||
torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp(torch_amp_model,
|
||||
torch_amp_optimizer,
|
||||
amp_config=torch_amp_config)
|
||||
apex_amp_config = dict(opt_level='O1', loss_scale=1280)
|
||||
apex_amp_config = dict(opt_level='O1', loss_scale=128)
|
||||
apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config)
|
||||
|
||||
# create data
|
||||
|
@ -55,14 +57,19 @@ def run_torch_amp():
|
|||
assert_close_loose(torch_amp_param, apex_amp_param)
|
||||
|
||||
# backward
|
||||
torch_amp_optimizer.backward(torch_amp_output.mean())
|
||||
apex_amp_optimizer.backward(apex_amp_output.mean())
|
||||
# use sum() to get big gradient
|
||||
torch_amp_optimizer.backward(torch_amp_output.sum())
|
||||
apex_amp_optimizer.backward(apex_amp_output.sum())
|
||||
|
||||
# check grad
|
||||
# In apex amp, grad is not scaled before backward, but torch amp does
|
||||
for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()):
|
||||
assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config['loss_scale'])
|
||||
|
||||
# clip gradient
|
||||
apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0)
|
||||
torch_amp_optimizer.clip_grad_norm(model=torch_amp_model, max_norm=1.0)
|
||||
|
||||
# step
|
||||
torch_amp_optimizer.step()
|
||||
apex_amp_optimizer.step()
|
||||
|
|
Loading…
Reference in New Issue