ColossalAI/tests/test_utils/test_norm_gradient_clipping.py

79 lines
2.9 KiB
Python
Raw Normal View History

import pytest
import torch
from torch.nn.parameter import Parameter
from torch.nn.utils import clip_grad_norm_
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.utils.common import clip_grad_norm
def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
return abs(num - other) <= atol + rtol * other
def shard_param(p: ColoParameter) -> None:
pg = p.get_process_group()
p._redistribute(distspec.ShardSpec([0], [pg.tp_world_size()]))
p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()
def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None:
pg = colo_p.get_process_group()
if p.shape != colo_p.shape:
grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()]
else:
grad = p.grad
assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}'
@parameterize('dtype', [torch.float])
@parameterize('device', ['mixed', 'cuda', 'cpu'])
@parameterize('norm_type', [2.0, 3.0, float('inf')])
def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float):
print(f'{world_size}, {dtype}, {device}, {norm_type}')
cuda_device = get_current_device()
devices = [cuda_device] * 4
if device == 'cpu':
devices = [torch.device('cpu')] * 4
elif device == 'mixed':
devices = [cuda_device] * 2 + [torch.device('cpu')] * 2
pg = ProcessGroup(tp_degree=world_size)
params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)]
colo_params = [
ColoParameter(torch.empty(4, 4, dtype=dtype, device=devices[i]), spec=ColoTensorSpec(pg)) for i in range(4)
]
for p, colo_p in zip(params, colo_params):
grad = torch.rand_like(p)
p.grad = grad
colo_p.grad = grad.clone().detach()
shard_param(colo_params[0])
shard_param(colo_params[2])
torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type)
colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type)
assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}'
for p, colo_p in zip(params, colo_params):
check_grad_equal(p, colo_p)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_grad_clip_norm(world_size=world_size)
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
@pytest.mark.skip("this need to be updated")
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_zero_clip_grad(world_size: int):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_zero_clip_grad(2)