import torch import torch.distributed as dist import torch.nn.functional as F try: import fused_mix_prec_layer_norm_cuda except: fused_mix_prec_layer_norm_cuda = None class FusedLayerNormAffineFunction1D(torch.autograd.Function): r"""Layernorm Args: input: input matrix. weight: weight matrix. bias: bias matrix. normalized_shape: input shape from an expected input of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps: a value added to the denominator for numerical stability """ @staticmethod def forward(ctx, input, weight, bias, normalized_shape, eps): ctx.normalized_shape = normalized_shape ctx.eps = eps input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output @staticmethod def backward(ctx, grad_output): input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None grad_input, grad_weight, grad_bias \ = fused_mix_prec_layer_norm_cuda.backward_affine( grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps) return grad_input, grad_weight, grad_bias, None, None class MatmulWithAsyncCommunication(torch.autograd.Function): """ Linear layer execution with asynchronous communication in backprop. """ @staticmethod def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): ctx.save_for_backward(input_, weight) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce output = torch.matmul(input_, weight) if bias is not None: output = output + bias return output @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors use_bias = ctx.use_bias total_input = input grad_input = grad_output.matmul(weight.T) grad_output = grad_output.contiguous() # Convert the tensor shapes to 2D for execution compatibility if len(grad_output.shape) > 2: grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1]) if ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # all-reduce scheduled first and have GPU resources allocated _ = torch.empty(1, device=grad_output.device) + 1 grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_allreduce: handle.wait() return grad_input, grad_weight, grad_bias, None, None, None class LinearWithAsyncCommunication(torch.autograd.Function): """ Linear layer execution with asynchronous communication in backprop. """ @staticmethod def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): ctx.save_for_backward(input_, weight) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce if bias is not None: output = F.linear(input_, weight, bias) else: output = F.linear(input_, weight) return output @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors use_bias = ctx.use_bias total_input = input grad_input = grad_output.matmul(weight) grad_output = grad_output.contiguous() # Convert the tensor shapes to 2D for execution compatibility if len(grad_output.shape) > 2: grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1]) if ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # all-reduce scheduled first and have GPU resources allocated _ = torch.empty(1, device=grad_output.device) + 1 grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_allreduce: handle.wait() return grad_input, grad_weight, grad_bias, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. Args: input_ (`torch.Tensor`): input matrix. dim (int): the dimension to perform split and gather process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication """ @staticmethod def forward(ctx, input_, dim, process_group): ctx.process_group = process_group ctx.dim = dim return _split(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): return _gather(grad_output, ctx.dim, ctx.process_group), None, None class _ReduceForward(torch.autograd.Function): """ All-reduce the input from the model parallel region. Args: input_: input matrix. parallel_mode: parallel mode. """ @staticmethod def forward(ctx, input_, process_group): return _reduce(input_, process_group) @staticmethod def backward(ctx, grad_output): return grad_output, None class _ReduceBackward(torch.autograd.Function): """ All-reduce the input from the model parallel region. Args: input_: input matrix. parallel_mode: parallel mode. """ @staticmethod def forward(ctx, input_, process_group): ctx.process_group = process_group return input_ @staticmethod def backward(ctx, grad_output): return _reduce(grad_output, ctx.process_group), None def _reduce(input_, process_group): # skip if only one rank involved if dist.get_world_size(process_group) == 1: return input_ else: dist.all_reduce(input_, group=process_group) return input_ def _split(input_, dim=-1, process_group=None): # skip if only one rank involved world_size = dist.get_world_size(process_group) if world_size == 1: return input_ # Split along last dimension. dim_size = input_.size(dim) assert dim_size % world_size == 0, \ f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ f'cannot split tensor evenly' tensor_list = torch.split(input_, dim_size // world_size, dim=dim) rank = dist.get_rank(process_group) output = tensor_list[rank].contiguous() return output def _gather(input_, dim=-1, process_group=None): # skip if only one rank involved world_size = dist.get_world_size(process_group) if world_size == 1: return input_ # all gather rank = dist.get_rank(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ torch.distributed.all_gather(tensor_list, input_, group=process_group) # concat output = torch.cat(tensor_list, dim=dim).contiguous() return output class _GatherForwardSplitBackward(torch.autograd.Function): """Gather the input from model parallel region and concatenate. Args: input_: input matrix. parallel_mode: parallel mode. dim: dimension """ @staticmethod def forward(ctx, input_, dim, process_group): ctx.process_group = process_group ctx.dim = dim return _gather(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): return _split(grad_output, ctx.dim, ctx.process_group), None, None def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) def gather_forward_split_backward(input_, dim, process_group): return _GatherForwardSplitBackward.apply(input_, dim, process_group) def split_forward_gather_backward(input_, dim, process_group): return _SplitForwardGatherBackward.apply(input_, dim, process_group) def reduce_forward(input_, process_group): return _ReduceForward.apply(input_, process_group) def reduce_backward(input_, process_group): return _ReduceBackward.apply(input_, process_group)