[utils] Impl clip_grad_norm for ColoTensor and ZeroOptimizer (#1442)

pull/1439/head
ver217 2022-08-11 22:58:58 +08:00 committed by GitHub
parent 74bee5f7e8
commit 821c6172e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 232 additions and 5 deletions

View File

@ -1,11 +1,13 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from pprint import pp
import random
import socket
from pathlib import Path
from typing import Callable, List, Union
from typing import Callable, List, Union, Dict, Optional
import functools
import torch
from torch._six import inf
from torch.nn.parameter import Parameter
@ -22,9 +24,11 @@ from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PAR
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from .multi_tensor_apply import multi_tensor_applier
from colossalai.tensor import ColoParameter, ProcessGroup
from collections import defaultdict
def print_rank_0(msg: str, logger=None):
"""Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
@ -162,6 +166,121 @@ def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Te
# ======== Gradient Clipping =========
def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float:
if len(params) == 0:
return 0.0
grads = [p.grad for p in params]
use_cuda_kernel = grads[0].device.type == 'cuda'
if norm_type == inf:
local_lp = max([g.abs().max() for g in grads])
elif norm_type == 2.0 and use_cuda_kernel:
local_lp = _calc_l2_norm(grads)**norm_type
else:
local_lp = _calc_lp(grads, norm_type)
if isinstance(local_lp, torch.Tensor):
return local_lp.item()
return local_lp
def _compute_buckets_lp(params: List[ColoParameter], norm_type: float) -> float:
if len(params) == 0:
return 0.0
buckets: Dict[Optional[ProcessGroup], List[ColoParameter]] = defaultdict(list)
for p in params:
if p.is_replicate():
buckets[None].append(p)
else:
buckets[p.get_process_group().tp_process_group()].append(p)
total_lp = 0.0
for group, bucket in buckets.items():
local_lp = _compute_local_lp(bucket, norm_type)
if group is not None:
local_lp_tensor = torch.tensor([local_lp], device=torch.cuda.current_device())
if norm_type == inf:
dist.all_reduce(local_lp_tensor, op=dist.ReduceOp.MAX, group=group)
else:
dist.all_reduce(local_lp_tensor, group=group)
local_lp = local_lp_tensor.item()
if norm_type == inf:
total_lp = max(total_lp, local_lp)
else:
total_lp += local_lp
return total_lp
def _compute_pp_grad_lp(total_lp: float, norm_type: float) -> float:
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
total_lp_tensor = torch.tensor([total_lp], device=torch.cuda.current_device())
if norm_type == inf:
dist.all_reduce(total_lp_tensor, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PIPELINE))
else:
dist.all_reduce(total_lp_tensor, group=gpc.get_group(ParallelMode.PIPELINE))
total_lp = total_lp_tensor.item()
return total_lp
def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grad_dtype = None
cpu_grad_params: List[ColoParameter] = []
cuda_grad_params: List[ColoParameter] = []
for p in parameters:
if p.grad is None:
continue
assert isinstance(p, ColoParameter)
if grad_dtype is None:
grad_dtype = p.grad.dtype
assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}'
if p.grad.device.type == 'cuda':
cuda_grad_params.append(p)
else:
cpu_grad_params.append(p)
norm_type = float(norm_type)
cpu_lp = _compute_buckets_lp(cpu_grad_params, norm_type)
cuda_lp = _compute_buckets_lp(cuda_grad_params, norm_type)
if norm_type == inf:
total_lp = max(cpu_lp, cuda_lp)
else:
total_lp = cpu_lp + cuda_lp
return _compute_pp_grad_lp(total_lp, norm_type)
def compute_grad_norm(parameters, norm_type: float = 2.0) -> float:
norm_type = float(norm_type)
total_norm = _compute_grad_lp(parameters, norm_type)
if norm_type != inf:
total_norm = total_norm**(1 / norm_type)
return total_norm
def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1.0:
cuda_grads: List[torch.Tensor] = []
cpu_grads: List[torch.Tensor] = []
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
for p in parameters:
if p.grad is None:
continue
if p.grad.device.type == 'cuda':
cuda_grads.append(p.grad.detach())
else:
cpu_grads.append(p.grad.detach())
if len(cuda_grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef)
for g in cpu_grads:
g.mul_(clip_coef)
def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0) -> float:
total_norm = compute_grad_norm(parameters, norm_type)
_clip_grad_norm(parameters, max_norm, total_norm)
return total_norm
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients are in fp32.

View File

@ -8,9 +8,11 @@ from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable
from colossalai.utils.common import _compute_grad_lp, compute_grad_norm, _clip_grad_norm
from collections import defaultdict, abc as container_abcs
from copy import deepcopy
from itertools import chain
from torch._six import inf
class OptimState(Enum):
@ -143,11 +145,38 @@ class ZeroOptimizer(ColossalaiOptimizer):
self._update_fp16_params()
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float):
def compute_grad_norm(self, norm_type: float = 2.0) -> float:
norm_type = float(norm_type)
if not self.chunk_manager.enable_distributed_storage:
return compute_grad_norm(self.module.parameters(), norm_type)
non_distributed_params = []
distributed_params = []
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
non_distributed_params.append(p)
else:
distributed_params.append(p)
non_distributed_norm = _compute_grad_lp(non_distributed_params, norm_type)
distributed_norm_tensor = torch.tensor([_compute_grad_lp(distributed_params, norm_type)],
device=get_current_device())
if norm_type == inf:
dist.all_reduce(distributed_norm_tensor,
op=dist.ReduceOp.MAX,
group=self.chunk_manager.process_group.dp_process_group())
total_norm = max(non_distributed_norm, distributed_norm_tensor.item())
else:
dist.all_reduce(distributed_norm_tensor, group=self.chunk_manager.process_group.dp_process_group())
total_norm = non_distributed_norm + distributed_norm_tensor.item()
total_norm = total_norm**(1 / norm_type)
return total_norm
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
# TODO(ver217): fix zero clip grad norm
return super().clip_grad_norm(model, max_norm)
total_norm = self.compute_grad_norm(norm_type)
_clip_grad_norm(self.module.parameters(), max_norm, total_norm)
return total_norm
def backward(self, loss: torch.Tensor):
loss = self.loss_scale * loss

View File

@ -0,0 +1,79 @@
from colossalai.tensor import distspec, ColoTensorSpec, ProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port, get_current_device
from torch.nn.utils import clip_grad_norm_
from functools import partial
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils.common import clip_grad_norm
from torch.nn.parameter import Parameter
def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
return abs(num - other) <= atol + rtol * other
def shard_param(p: ColoParameter) -> None:
pg = p.get_process_group()
p._redistribute(distspec.shard([0], [pg.tp_world_size()]))
p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()
def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None:
pg = colo_p.get_process_group()
if p.shape != colo_p.shape:
grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()]
else:
grad = p.grad
assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}'
@parameterize('dtype', [torch.float])
@parameterize('device', ['mixed', 'cuda', 'cpu'])
@parameterize('norm_type', [2.0, 3.0, float('inf')])
def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float):
print(f'{world_size}, {dtype}, {device}, {norm_type}')
cuda_device = get_current_device()
devices = [cuda_device] * 4
if device == 'cpu':
devices = [torch.device('cpu')] * 4
elif device == 'mixed':
devices = [cuda_device] * 2 + [torch.device('cpu')] * 2
pg = ProcessGroup(tp_degree=world_size)
params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)]
colo_params = [
ColoParameter(torch.empty(4, 4, dtype=dtype, device=devices[i]), spec=ColoTensorSpec(pg)) for i in range(4)
]
for p, colo_p in zip(params, colo_params):
grad = torch.rand_like(p)
p.grad = grad
colo_p.grad = grad.clone().detach()
shard_param(colo_params[0])
shard_param(colo_params[2])
torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type)
colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type)
assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}'
for p, colo_p in zip(params, colo_params):
check_grad_equal(p, colo_p)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_grad_clip_norm(world_size=world_size)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_zero_clip_grad(world_size: int):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_clip_grad(2)