#!/usr/bin/env python # -*- encoding: utf-8 -*- import os import random import socket from pathlib import Path from typing import Callable, List, Union import functools import torch from torch._six import inf from torch.nn.parameter import Parameter try: import colossal_C except: pass from contextlib import contextmanager import torch.distributed as dist from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES) from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env from .multi_tensor_apply import multi_tensor_applier def print_rank_0(msg: str, logger=None): """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu. Args: msg (str): A string message to output. logger (:class:`colossalai.logging.DistributedLogger`, optional): The logger to record the message, defaults to None. """ if gpc.get_global_rank() == 0: if logger is None: print(msg, flush=True) else: logger.info(msg) def ensure_path_exists(filename: str): # ensure the path exists dirpath = os.path.dirname(filename) if not os.path.exists(dirpath): Path(dirpath).mkdir(parents=True, exist_ok=True) def free_port(): while True: try: sock = socket.socket() sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) port = random.randint(20000, 65000) sock.bind(('localhost', port)) sock.close() return port except Exception: continue def sync_model_param(model, parallel_mode): r"""Make sure data parameters are consistent during Data Parallel Mode. Args: model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel mode to be checked. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: for param in model.parameters(): ranks = gpc.get_ranks_in_group(parallel_mode) dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) def is_dp_rank_0(): return not gpc.is_initialized(ParallelMode.DATA) or gpc.is_first_rank(ParallelMode.DATA) def is_tp_rank_0(): return not gpc.is_initialized(ParallelMode.TENSOR) or gpc.is_first_rank(ParallelMode.TENSOR) def is_no_pp_or_last_stage(): return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) def is_using_ddp(): return gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1 def is_using_pp(): return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1 def is_using_sequence(): return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1 @contextmanager def conditional_context(context_manager, enable=True): if enable: with context_manager: yield else: yield class model_branch_context(object): def __enter__(self): self.env_status = env.save() def __exit__(self, *exc_info): env.load(**self.env_status) def is_model_parallel_parameter(p): return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) def _calc_l2_norm(grads): norm = 0.0 if len(grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) norm, _ = multi_tensor_applier( colossal_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm ) return norm def _calc_lp(grads, norm_type): norm = 0.0 for grad in grads: grad_norm = torch.norm(grad, norm_type) norm += grad_norm**norm_type return norm def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: if torch.is_tensor(norm) and norm.device.type != 'cuda': norm = norm.to(torch.cuda.current_device()) return norm def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor: if isinstance(norm, float): norm = torch.Tensor([norm]) if move_to_cuda: norm = norm.to(torch.cuda.current_device()) return norm # ======== Gradient Clipping ========= def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): """Clips gradient norm of an iterable of parameters whose gradients are in fp32. This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and added functionality to handle model parallel parameters. Note: the gradients are modified in place. Args: parameters (Iterable[:class:`torch.tensor`] or :class:`torch.tensor`): An iterable of Tensors or a single Tensor that will have gradients normalized. max_norm (Union[float, int]): Max norm of the gradients. norm_type (Union[float, int, 'inf']): Type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: float: Total norm of the parameters. """ if isinstance(parameters, torch.Tensor): parameters = [parameters] # Filter parameters based on: # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism params: List[Parameter] = [] has_zero_shared_param: bool = False for param in parameters: if param.grad is not None: # Make sure the grads are in fp32 assert param.grad.dtype == torch.float, \ f'expected gradient to be dtype torch.float, but got {param.grad.type()}' if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded: has_zero_shared_param = True params.append(param) if len(params) == 0: enable_cuda_kernels = False else: enable_cuda_kernels = params[0].grad.device.type == 'cuda' # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) # Parameters can be on CPU or CUDA # If parameters are on CPU, disable CUDA kernerls # Calculate norm. if norm_type == inf: total_norm = max(p.grad.data.abs().max() for p in params) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL), async_op=False) if has_zero_shared_param: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.DATA), async_op=False) total_norm = total_norm_cuda[0].item() else: tensor_parallel_grads = [] no_tensor_parallel_grads = [] zero_sharded_grads = [] for p in params: if is_model_parallel_parameter(p): reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type) tensor_parallel_grads.append(p.grad.data / reductor) elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded: zero_sharded_grads.append(p.grad.data) else: no_tensor_parallel_grads.append(p.grad.data) if norm_type == 2.0 and enable_cuda_kernels: tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type else: tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type) # If norm is type of float, then we convert them into torch.Tensor. tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels) zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels) # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors if not enable_cuda_kernels: tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm) no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm) zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm) # Sum across all model-parallel GPUs. if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) # Sum across all zero sharded GPUs if len(zero_sharded_grads) > 0: dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA)) no_tensor_parallel_norm += zero_sharded_norm total_norm = tensor_parallel_norm + no_tensor_parallel_norm if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE)) total_norm = total_norm**(1.0 / norm_type) if torch.is_tensor(total_norm): total_norm = total_norm.item() # Scale. clip_coeff = max_norm / (total_norm + 1.0e-6) if clip_coeff < 1.0: if enable_cuda_kernels: grads = [p.grad.detach() for p in params] dummy_overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) else: for p in params: p.grad.detach().mul_(clip_coeff) return total_norm def count_zeros_fp32(parameters): if isinstance(parameters, torch.Tensor): parameters = [parameters] # Filter parameters based on: # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism total_num_zeros = 0.0 for param in parameters: grad_not_none = param.grad is not None is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) if grad_not_none and is_not_tp_duplicate: grad = param.grad.detach() num_zeros = grad.numel() - torch.count_nonzero(grad) total_num_zeros = num_zeros + total_num_zeros total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda() # Sum across all model-parallel GPUs. ops = [] ops.append( dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True)) if gpc.is_initialized(ParallelMode.PIPELINE): ops.append( dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE), async_op=True)) for req in ops: req.wait() total_num_zeros = total_num_zeros.item() return total_num_zeros def copy_tensor_parallel_attributes(src_tensor, dst_tensor): for attr in TENSOR_PARALLEL_ATTRIBUTES: if hasattr(src_tensor, attr): val = getattr(src_tensor, attr) setattr(dst_tensor, attr, val) def param_is_not_tensor_parallel_duplicate(param): return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank( ParallelMode.TENSOR) == 0) @contextmanager def switch_virtual_pipeline_parallel_rank(rank): prev_rank = gpc.virtual_pipeline_parallel_rank try: gpc.set_virtual_pipeline_parallel_rank(rank) yield finally: gpc.set_virtual_pipeline_parallel_rank(prev_rank) def disposable(func: Callable) -> Callable: executed = False @functools.wraps(func) def wrapper(*args, **kwargs): nonlocal executed if not executed: executed = True return func(*args, **kwargs) return wrapper