mirror of https://github.com/hpcaitech/ColossalAI
[utils] Impl clip_grad_norm for ColoTensor and ZeroOptimizer (#1442)
parent
74bee5f7e8
commit
821c6172e2
|
@ -1,11 +1,13 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import os
|
||||
from pprint import pp
|
||||
import random
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Union
|
||||
from typing import Callable, List, Union, Dict, Optional
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch.nn.parameter import Parameter
|
||||
|
@ -22,9 +24,11 @@ from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PAR
|
|||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def print_rank_0(msg: str, logger=None):
|
||||
"""Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
|
||||
|
@ -162,6 +166,121 @@ def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Te
|
|||
# ======== Gradient Clipping =========
|
||||
|
||||
|
||||
def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float:
|
||||
if len(params) == 0:
|
||||
return 0.0
|
||||
grads = [p.grad for p in params]
|
||||
use_cuda_kernel = grads[0].device.type == 'cuda'
|
||||
if norm_type == inf:
|
||||
local_lp = max([g.abs().max() for g in grads])
|
||||
elif norm_type == 2.0 and use_cuda_kernel:
|
||||
local_lp = _calc_l2_norm(grads)**norm_type
|
||||
else:
|
||||
local_lp = _calc_lp(grads, norm_type)
|
||||
if isinstance(local_lp, torch.Tensor):
|
||||
return local_lp.item()
|
||||
return local_lp
|
||||
|
||||
|
||||
def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float:
|
||||
if len(params) == 0:
|
||||
return 0.0
|
||||
buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list)
|
||||
for p in params:
|
||||
if p.is_replicate():
|
||||
buckets[None].append(p)
|
||||
else:
|
||||
buckets[p.get_process_group().tp_process_group()].append(p)
|
||||
total_lp = 0.0
|
||||
for group, bucket in buckets.items():
|
||||
local_lp = _compute_local_lp(bucket, norm_type)
|
||||
if group is not None:
|
||||
local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device())
|
||||
if norm_type == inf:
|
||||
dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group)
|
||||
else:
|
||||
dist.all_reduce(local_lp_tensor, group=group)
|
||||
local_lp = local_lp_tensor.item()
|
||||
if norm_type == inf:
|
||||
total_lp = max(total_lp, local_lp)
|
||||
else:
|
||||
total_lp += local_lp
|
||||
return total_lp
|
||||
|
||||
|
||||
def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float:
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device())
|
||||
if norm_type == inf:
|
||||
dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
else:
|
||||
dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_lp = total_lp_tensor.item()
|
||||
return total_lp
|
||||
|
||||
|
||||
def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float:
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
grad_dtype = None
|
||||
cpu_grad_params: List[ColoParameter] = []
|
||||
cuda_grad_params: List[ColoParameter] = []
|
||||
for p in parameters:
|
||||
if p.grad is None:
|
||||
continue
|
||||
assert isinstance(p, ColoParameter)
|
||||
if grad_dtype is None:
|
||||
grad_dtype = p.grad.dtype
|
||||
assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}'
|
||||
if p.grad.device.type == 'cuda':
|
||||
cuda_grad_params.append(p)
|
||||
else:
|
||||
cpu_grad_params.append(p)
|
||||
norm_type = float(norm_type)
|
||||
cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type)
|
||||
cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type)
|
||||
if norm_type == inf:
|
||||
total_lp = max(cpu_lp, cuda_lp)
|
||||
else:
|
||||
total_lp = cpu_lp + cuda_lp
|
||||
return _compute_pp_grad_lp(total_lp, norm_type)
|
||||
|
||||
|
||||
def compute_grad_norm(parameters, norm_type: float = 2.0) -> float:
|
||||
norm_type = float(norm_type)
|
||||
total_norm = _compute_grad_lp(parameters, norm_type)
|
||||
if norm_type != inf:
|
||||
total_norm = total_norm**(1 / norm_type)
|
||||
return total_norm
|
||||
|
||||
|
||||
def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef < 1.0:
|
||||
cuda_grads: List[torch.Tensor] = []
|
||||
cpu_grads: List[torch.Tensor] = []
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
for p in parameters:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if p.grad.device.type == 'cuda':
|
||||
cuda_grads.append(p.grad.detach())
|
||||
else:
|
||||
cpu_grads.append(p.grad.detach())
|
||||
if len(cuda_grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef)
|
||||
for g in cpu_grads:
|
||||
g.mul_(clip_coef)
|
||||
|
||||
|
||||
def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float:
|
||||
total_norm = compute_grad_norm(parameters, norm_type)
|
||||
_clip_grad_norm(parameters, max_norm, total_norm)
|
||||
return total_norm
|
||||
|
||||
|
||||
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
"""Clips gradient norm of an iterable of parameters whose gradients are in fp32.
|
||||
|
||||
|
|
|
@ -8,9 +8,11 @@ from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
from colossalai.utils.common import _compute_grad_lp, compute_grad_norm, _clip_grad_norm
|
||||
from collections import defaultdict, abc as container_abcs
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from torch._six import inf
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
|
@ -143,11 +145,38 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
self._update_fp16_params()
|
||||
return ret
|
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float):
|
||||
def compute_grad_norm(self, norm_type: float = 2.0) -> float:
|
||||
norm_type = float(norm_type)
|
||||
if not self.chunk_manager.enable_distributed_storage:
|
||||
return compute_grad_norm(self.module.parameters(), norm_type)
|
||||
|
||||
non_distributed_params = []
|
||||
distributed_params = []
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
non_distributed_params.append(p)
|
||||
else:
|
||||
distributed_params.append(p)
|
||||
non_distributed_norm = _compute_grad_lp(non_distributed_params, norm_type)
|
||||
distributed_norm_tensor = torch.tensor([_compute_grad_lp(distributed_params, norm_type)],
|
||||
device=get_current_device())
|
||||
if norm_type == inf:
|
||||
dist.all_reduce(distributed_norm_tensor,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=self.chunk_manager.process_group.dp_process_group())
|
||||
total_norm = max(non_distributed_norm, distributed_norm_tensor.item())
|
||||
else:
|
||||
dist.all_reduce(distributed_norm_tensor, group=self.chunk_manager.process_group.dp_process_group())
|
||||
total_norm = non_distributed_norm + distributed_norm_tensor.item()
|
||||
total_norm = total_norm**(1 / norm_type)
|
||||
return total_norm
|
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
# TODO(ver217): fix zero clip grad norm
|
||||
return super().clip_grad_norm(model, max_norm)
|
||||
total_norm = self.compute_grad_norm(norm_type)
|
||||
_clip_grad_norm(self.module.parameters(), max_norm, total_norm)
|
||||
return total_norm
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss = self.loss_scale * loss
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
from colossalai.tensor import distspec, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from functools import partial
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils.common import clip_grad_norm
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
|
||||
return abs(num - other) <= atol + rtol * other
|
||||
|
||||
|
||||
def shard_param(p: ColoParameter) -> None:
|
||||
pg = p.get_process_group()
|
||||
p._redistribute(distspec.shard([0], [pg.tp_world_size()]))
|
||||
p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()
|
||||
|
||||
|
||||
def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None:
|
||||
pg = colo_p.get_process_group()
|
||||
if p.shape != colo_p.shape:
|
||||
grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()]
|
||||
else:
|
||||
grad = p.grad
|
||||
assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}'
|
||||
|
||||
|
||||
@parameterize('dtype', [torch.float])
|
||||
@parameterize('device', ['mixed', 'cuda', 'cpu'])
|
||||
@parameterize('norm_type', [2.0, 3.0, float('inf')])
|
||||
def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float):
|
||||
print(f'{world_size}, {dtype}, {device}, {norm_type}')
|
||||
cuda_device = get_current_device()
|
||||
devices = [cuda_device] * 4
|
||||
if device == 'cpu':
|
||||
devices = [torch.device('cpu')] * 4
|
||||
elif device == 'mixed':
|
||||
devices = [cuda_device] * 2 + [torch.device('cpu')] * 2
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)]
|
||||
colo_params = [
|
||||
ColoParameter(torch.empty(4, 4, dtype=dtype, device=devices[i]), spec=ColoTensorSpec(pg)) for i in range(4)
|
||||
]
|
||||
for p, colo_p in zip(params, colo_params):
|
||||
grad = torch.rand_like(p)
|
||||
p.grad = grad
|
||||
colo_p.grad = grad.clone().detach()
|
||||
shard_param(colo_params[0])
|
||||
shard_param(colo_params[2])
|
||||
torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type)
|
||||
colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type)
|
||||
assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}'
|
||||
for p, colo_p in zip(params, colo_params):
|
||||
check_grad_equal(p, colo_p)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_grad_clip_norm(world_size=world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_clip_grad(world_size: int):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_clip_grad(2)
|
Loading…
Reference in New Issue