ColossalAI/colossalai/shardformer/layer/_operation.py

992 lines
37 KiB
Python

import torch
import torch.distributed as dist
import torch.nn.functional as F
try:
import fused_mix_prec_layer_norm_cuda
except:
fused_mix_prec_layer_norm_cuda = None
try:
import fused_weight_gradient_mlp_cuda
_grad_accum_fusion_available = True
except ImportError:
_grad_accum_fusion_available = False
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
r"""Layernorm
Args:
input: input matrix.
weight: weight matrix.
bias: bias matrix.
normalized_shape: input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability
"""
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
return grad_input, grad_weight, grad_bias, None, None
class MatmulWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input_, weight)
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
weight = weight.view(weight.shape)
if bias is not None:
bias = bias.view(bias.shape)
total_input = input
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
if bias is not None:
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)
total_input = input
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
cur_rank = dist.get_rank(process_group)
# output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
# initialization of ring communication
recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
rank_map = list(dist.get_process_group_ranks(process_group))
recv_rank = rank_map[recv_rank]
send_rank = rank_map[send_rank]
recv_tensors = {}
send_tensors = {}
for k, v in input_to_gather.items():
recv_tensors[k] = torch.empty_like(v)
send_tensors[k] = v.clone()
def communicate_step():
comm_ops = []
for k in recv_tensors:
comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group))
comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group))
return dist.batch_isend_irecv(comm_ops)
def switch_step():
for k in recv_tensors:
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
output_tensors = []
handles = communicate_step()
# first round: special case, retrive from local tensor
output_tensors.append(func(**input_to_gather, **input_local))
for i in range(group_size - 2):
for handle in handles:
handle.wait()
switch_step()
handles = communicate_step()
# actual computation
output_tensors.append(func(**send_tensors, **input_local))
# final round: special case, no need to send/recv again
for handle in handles:
handle.wait()
output_tensors.append(func(**recv_tensors, **input_local))
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
# do reduce-scatter
new_shape = list(grad_output.shape)
assert (
new_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
grad_list = [
item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
dist.reduce_scatter(output, grad_list, group=process_group)
return output, None, None
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whether to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
if ring is True:
input_to_gather = {"input": input_}
input_local = {"weight": weight}
output = _ring_as_gather(
F.linear,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
)
if bias is not None:
output += bias
else:
input_parallel = _gather(input_, dim, process_group)
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
output = F.linear(input_parallel, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
if not overlap:
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input_parallel)
else:
grad_weight = grad_output.t().matmul(input_parallel)
# grad_weight = grad_output.t().matmul(input_parallel)
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None
def _ring_as_reducescatter(
func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1
):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
cur_rank = dist.get_rank(process_group)
# initialization of ring communication
recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
rank_map = list(dist.get_process_group_ranks(process_group))
recv_rank = rank_map[recv_rank]
send_rank = rank_map[send_rank]
input_tensors = []
for _ in range(group_size):
input_tensors.append({})
for k, v in input_to_reducescatter.items():
input_shape = v.shape
assert input_shape[reducescatter_dim] % group_size == 0
_input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim))
for i in range(group_size):
input_tensors[i][k] = _input_tensors[i]
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
input_tensors.reverse()
output_tensor = func(**input_tensors[0], **input_local)
recv_tensor = torch.empty_like(output_tensor)
send_tensor = output_tensor.clone()
def communicate_step():
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
return dist.batch_isend_irecv([recv_op, send_op])
handles = communicate_step()
# first round: special case, retrive from local tensor
for i in range(group_size - 2):
# actual computation
output_tensor = func(**input_tensors[i + 1], **input_local)
for handle in handles:
handle.wait()
output_tensor += recv_tensor
tmp_tensor = send_tensor
send_tensor = output_tensor
output_tensor = tmp_tensor
handles = communicate_step()
# final round: special case, no need to send/recv again
output_tensor = func(**input_tensors[-1], **input_local)
for handle in handles:
handle.wait()
output_tensor += recv_tensor
return output_tensor
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Reduce-scatter input from sequence parallel in forward and gather gradient in backward with ring
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, dim, ring):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.dim = dim
if ring is True:
input_to_reducescatter = {"input": input_}
input_local = {"weight": weight}
if bias is not None:
input_to_reducescatter["bias"] = bias
output = _ring_as_reducescatter(
F.linear,
input_to_reducescatter=input_to_reducescatter,
input_local=input_local,
process_group=process_group,
)
else:
if bias is not None:
partial_output = F.linear(input_, weight, bias)
else:
partial_output = F.linear(input_, weight)
output_shape = list(partial_output.shape)
assert (
output_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
output_list = [
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
dist.reduce_scatter(output, output_list, group=process_group)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
grad_output = _gather(grad_output, dim, process_group)
# TODO Need to fully optimize
total_input = input_
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None
class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Reduce-scatter input from sequence parallel in forward and gather gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
"""
@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.dim = dim
ctx.process_group = process_group
# do reduce-scatter
new_shape = list(input_.shape)
assert (
new_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_list, group=process_group)
return output
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
return _gather(grad_output, dim, process_group), None, None
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
This class is designed for matmul operation with gather forward and reduce-scatter backward.
Args:
input_ (`torch.Tensor`): input matrix.
dim (int): the dimension to perform split and gather
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
if ring is True:
input_to_gather = {}
input_local = {}
input_to_gather["input"] = input_
input_local["other"] = weight
output = _ring_as_gather(
torch.matmul,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
gather_dim=dim,
)
else:
input_parallel = _gather(input_, dim, process_group)
output = torch.matmul(input_parallel, weight)
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape)
if use_bias:
bias = bias.view(bias.shape)
if not overlap:
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_ (`torch.Tensor`): input matrix.
dim (int): the dimension to perform split and gather
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
"""
@staticmethod
def forward(ctx, input_, dim, process_group, grad_scale=None):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
return _split(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None
class _ReduceForward(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
Args:
input_: input matrix.
parallel_mode: parallel mode.
"""
@staticmethod
def forward(ctx, input_, process_group):
return _reduce(input_, process_group)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _ReduceBackward(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
Args:
input_: input matrix.
parallel_mode: parallel mode.
"""
@staticmethod
def forward(ctx, input_, process_group):
ctx.process_group = process_group
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, ctx.process_group), None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""
@staticmethod
def forward(ctx, input_, dim, process_group, grad_scale=None):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
return _gather(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
world_size = dist.get_world_size(process_group)
bsz, _, _ = input_.shape
# using all_to_all_single when batch size is 1
if bsz == 1:
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
else:
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)
@staticmethod
def backward(ctx, *grad_output):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
return (return_grad, None, None, None)
class HookParameter(torch.autograd.Function):
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(weight, bias)
output = input
return output
@staticmethod
def backward(ctx, grad_output):
weight, bias = ctx.saved_tensors
if weight is not None:
weight = weight.view(weight.shape)
if bias is not None:
bias = bias.view(bias.shape)
return grad_output, None, None
def hook_parameter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias)
def _reduce(input_, process_group):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
return input_
else:
dist.all_reduce(input_, group=process_group)
return input_
def _split(input_, dim=-1, process_group=None):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
# Split along last dimension.
dim_size = input_.size(dim)
assert dim_size % world_size == 0, (
f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), "
f"cannot split tensor evenly"
)
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = dist.get_rank(process_group)
output = tensor_list[rank].clone().contiguous()
return output
def _gather(input_, dim=-1, process_group=None):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
# all gather
input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, input_, group=process_group)
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
def _reduce_scatter(input_, dim=1, process_group=None):
"""Do reduce-scatter operation.
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
dim (int): The dimension to perform reduce-scatter.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
"""
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
# reduce-scatter
new_shape = list(input_.shape)
assert (
new_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
new_shape[dim] = new_shape[dim] // world_size
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_, group=process_group)
return output
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
inp_shape = list(input_.shape)
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
if scatter_dim < 2:
input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous()
else:
input_t = (
input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :])
.transpose(0, 1)
.contiguous()
)
output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)
if scatter_dim < 2:
output = output.transpose(0, 1).contiguous()
return output.reshape(
inp_shape[:gather_dim]
+ [
inp_shape[gather_dim] * seq_world_size,
]
+ inp_shape[gather_dim + 1 :]
).contiguous()
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
return _LinearWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
)
def gather_forward_reducescatter_backward(input_, process_group, dim):
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
def reducescatter_forward_gather_backward(input_, process_group, dim):
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)
def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
)
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
def reduce_forward(input_, process_group):
return _ReduceForward.apply(input_, process_group)
def reduce_backward(input_, process_group):
return _ReduceBackward.apply(input_, process_group)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)