[amp] add gradient clipping for unit tests (#2283)

* [amp] add gradient clipping in unit tests

* fix bugs
pull/2312/head
HELSON 2023-01-04 11:59:56 +08:00 committed by GitHub
parent e00cedd181
commit 5d3a2be3af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 64 additions and 44 deletions

View File

@ -147,6 +147,12 @@ class FP16Optimizer(Optimizer):
f"==========================================", f"==========================================",
ranks=[0]) ranks=[0])
@property
def max_norm(self):
"""Returns the maximum norm of gradient clipping.
"""
return self._clip_grad_max_norm
@property @property
def grad_scaler(self): def grad_scaler(self):
"""Returns the gradient scaler. """Returns the gradient scaler.

View File

@ -1,17 +1,20 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor
from typing import Any from typing import Any
from torch.optim import Optimizer
from torch.distributed import ReduceOp import torch
from colossalai.core import global_context as gpc import torch.distributed as dist
from colossalai.context import ParallelMode import torch.nn as nn
from colossalai.nn.optimizer import ColossalaiOptimizer from torch import Tensor
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ReduceOp
from torch.optim import Optimizer
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer
from ._fp16_optimizer import FP16Optimizer from ._fp16_optimizer import FP16Optimizer
@ -40,7 +43,11 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
return self.optim.step() return self.optim.step()
def clip_grad_norm(self, model: nn.Module, max_norm: float): def clip_grad_norm(self, model: nn.Module, max_norm: float):
pass if self.optim.max_norm == max_norm:
return
raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). "
"If you have supplied clip_grad_norm in the amp_config, "
"executing the method clip_grad_norm is not allowed.")
class NaiveAMPModel(nn.Module): class NaiveAMPModel(nn.Module):

View File

@ -2,6 +2,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.testing import assert_close
def assert_equal(a: Tensor, b: Tensor): def assert_equal(a: Tensor, b: Tensor):
@ -12,12 +13,8 @@ def assert_not_equal(a: Tensor, b: Tensor):
assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}' assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}'
def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8):
assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}'
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3): def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
assert_close(a, b, rtol, atol) assert_close(a, b, rtol=rtol, atol=atol)
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
@ -30,4 +27,4 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
for i in range(world_size - 1): for i in range(world_size - 1):
a = tensor_list[i] a = tensor_list[i]
b = tensor_list[i + 1] b = tensor_list[i + 1]
assert torch.all(a == b), f'expected tensors on rank {i} and {i+1} to be equal but they are not, {a} vs {b}' assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}'

View File

@ -1,18 +1,16 @@
import copy
from functools import partial
import pytest
import torch import torch
import colossalai
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs import colossalai
from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
import copy
import pytest
from functools import partial
def check_equal(a, b): def check_equal(a, b):
""" """
@ -23,7 +21,7 @@ def check_equal(a, b):
def run_naive_amp(): def run_naive_amp():
""" """
In this test, we compare the naive fp16 optimizer implemented in colossalai In this test, we compare the naive fp16 optimizer implemented in colossalai
and fp32 torch optimizer and fp32 torch optimizer
""" """
@ -41,11 +39,12 @@ def run_naive_amp():
apex_amp_model = copy.deepcopy(naive_amp_model) apex_amp_model = copy.deepcopy(naive_amp_model)
# create optimizer # create optimizer
naive_amp_optimizer = optim_class(naive_amp_model.parameters(), lr=1e-3) # we use SGD here, since the correctness of gradient clipping can't be tested with Adam
apex_amp_optimizer = optim_class(apex_amp_model.parameters(), lr=1e-3) naive_amp_optimizer = torch.optim.SGD(naive_amp_model.parameters(), lr=1e-3)
apex_amp_optimizer = torch.optim.SGD(apex_amp_model.parameters(), lr=1e-3)
# inject naive and apex amp # inject naive and apex amp
naive_amp_config = dict(initial_scale=128) naive_amp_config = dict(initial_scale=128, clip_grad_norm=1.0)
naive_amp_model, naive_amp_optimizer = convert_to_naive_amp(naive_amp_model, naive_amp_optimizer, naive_amp_model, naive_amp_optimizer = convert_to_naive_amp(naive_amp_model, naive_amp_optimizer,
naive_amp_config) naive_amp_config)
apex_amp_config = dict(opt_level='O2', loss_scale=128, keep_batchnorm_fp32=False) apex_amp_config = dict(opt_level='O2', loss_scale=128, keep_batchnorm_fp32=False)
@ -62,13 +61,17 @@ def run_naive_amp():
assert_close_loose(naive_amp_output, apex_amp_output) assert_close_loose(naive_amp_output, apex_amp_output)
# backward # backward
naive_amp_optimizer.backward(naive_amp_output.mean()) # use sum() to get big gradient
apex_amp_optimizer.backward(apex_amp_output.mean()) naive_amp_optimizer.backward(naive_amp_output.sum())
apex_amp_optimizer.backward(apex_amp_output.sum())
# check grad # check grad
for naive_amp_param, apex_amp_param in zip(naive_amp_model.parameters(), apex_amp_model.parameters()): for naive_amp_param, apex_amp_param in zip(naive_amp_model.parameters(), apex_amp_model.parameters()):
assert_close_loose(naive_amp_param.grad, apex_amp_param.grad) assert_close_loose(naive_amp_param.grad, apex_amp_param.grad)
# clip gradient
apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0)
# step # step
naive_amp_optimizer.step() naive_amp_optimizer.step()
apex_amp_optimizer.step() apex_amp_optimizer.step()

View File

@ -1,14 +1,15 @@
import copy
from functools import partial
import pytest
import torch import torch
import colossalai
import torch.multiprocessing as mp import torch.multiprocessing as mp
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.amp import convert_to_torch_amp, convert_to_apex_amp from tests.components_to_test.registry import non_distributed_component_funcs
import copy
import pytest
from functools import partial
def run_torch_amp(): def run_torch_amp():
@ -30,15 +31,16 @@ def run_torch_amp():
apex_amp_model = copy.deepcopy(torch_amp_model) apex_amp_model = copy.deepcopy(torch_amp_model)
# create optimizer # create optimizer
torch_amp_optimizer = optim_class(torch_amp_model.parameters(), lr=1e-3) # we use SGD here, since the correctness of gradient clipping can't be tested with Adam
apex_amp_optimizer = optim_class(apex_amp_model.parameters(), lr=1e-3) torch_amp_optimizer = torch.optim.SGD(torch_amp_model.parameters(), lr=1e-3)
apex_amp_optimizer = torch.optim.SGD(apex_amp_model.parameters(), lr=1e-3)
# inject torch and apex amp # inject torch and apex amp
torch_amp_config = dict(init_scale=1280, enabled=True) torch_amp_config = dict(init_scale=128, enabled=True)
torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp(torch_amp_model, torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp(torch_amp_model,
torch_amp_optimizer, torch_amp_optimizer,
amp_config=torch_amp_config) amp_config=torch_amp_config)
apex_amp_config = dict(opt_level='O1', loss_scale=1280) apex_amp_config = dict(opt_level='O1', loss_scale=128)
apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config)
# create data # create data
@ -55,14 +57,19 @@ def run_torch_amp():
assert_close_loose(torch_amp_param, apex_amp_param) assert_close_loose(torch_amp_param, apex_amp_param)
# backward # backward
torch_amp_optimizer.backward(torch_amp_output.mean()) # use sum() to get big gradient
apex_amp_optimizer.backward(apex_amp_output.mean()) torch_amp_optimizer.backward(torch_amp_output.sum())
apex_amp_optimizer.backward(apex_amp_output.sum())
# check grad # check grad
# In apex amp, grad is not scaled before backward, but torch amp does # In apex amp, grad is not scaled before backward, but torch amp does
for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()):
assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config['loss_scale']) assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config['loss_scale'])
# clip gradient
apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0)
torch_amp_optimizer.clip_grad_norm(model=torch_amp_model, max_norm=1.0)
# step # step
torch_amp_optimizer.step() torch_amp_optimizer.step()
apex_amp_optimizer.step() apex_amp_optimizer.step()