|
|
|
@ -102,7 +102,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
|
|
|
|
total_input = total_input.view(-1, total_input.shape[-1])
|
|
|
|
|
|
|
|
|
|
if ctx.async_grad_allreduce and fp8_communication:
|
|
|
|
|
if fp8_communication or not ctx.async_grad_allreduce:
|
|
|
|
|
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
|
|
|
|
|
elif ctx.async_grad_allreduce:
|
|
|
|
|
# Asynchronous all-reduce
|
|
|
|
@ -216,10 +216,12 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
|
|
|
|
|
for k in recv_tensors:
|
|
|
|
|
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
|
|
|
|
|
|
|
|
|
|
input_tensors = []
|
|
|
|
|
output_tensors = []
|
|
|
|
|
|
|
|
|
|
handles = communicate_step()
|
|
|
|
|
# first round: special case, retrive from local tensor
|
|
|
|
|
input_tensors.append(input_to_gather)
|
|
|
|
|
output_tensors.append(func(**input_to_gather, **input_local))
|
|
|
|
|
for i in range(group_size - 2):
|
|
|
|
|
for handle in handles:
|
|
|
|
@ -230,14 +232,25 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
|
|
|
|
|
handles = communicate_step()
|
|
|
|
|
|
|
|
|
|
# actual computation
|
|
|
|
|
input_tensors.append(send_tensors)
|
|
|
|
|
output_tensors.append(func(**send_tensors, **input_local))
|
|
|
|
|
|
|
|
|
|
# final round: special case, no need to send/recv again
|
|
|
|
|
for handle in handles:
|
|
|
|
|
handle.wait()
|
|
|
|
|
input_tensors.append(send_tensors)
|
|
|
|
|
output_tensors.append(func(**recv_tensors, **input_local))
|
|
|
|
|
|
|
|
|
|
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
|
|
|
|
|
gathered_input = {}
|
|
|
|
|
for k in input_to_gather:
|
|
|
|
|
input_shards = [d[k] for d in input_tensors[group_size - cur_rank :] + input_tensors[: group_size - cur_rank]]
|
|
|
|
|
gathered_input[k] = torch.cat(input_shards, dim=gather_dim)
|
|
|
|
|
|
|
|
|
|
gathered_output = torch.cat(
|
|
|
|
|
output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return gathered_output, gathered_input
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
@ -293,29 +306,30 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
|
|
|
|
|
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False):
|
|
|
|
|
ctx.save_for_backward(input_, weight, bias)
|
|
|
|
|
ctx.use_bias = bias is not None
|
|
|
|
|
ctx.process_group = process_group
|
|
|
|
|
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
|
|
|
|
ctx.dim = dim
|
|
|
|
|
ctx.overlap = overlap
|
|
|
|
|
|
|
|
|
|
if ring is True:
|
|
|
|
|
input_to_gather = {"input": input_}
|
|
|
|
|
input_local = {"weight": weight}
|
|
|
|
|
|
|
|
|
|
output = _ring_as_gather(
|
|
|
|
|
output, input_dict = _ring_as_gather(
|
|
|
|
|
F.linear,
|
|
|
|
|
input_to_gather=input_to_gather,
|
|
|
|
|
input_local=input_local,
|
|
|
|
|
process_group=process_group,
|
|
|
|
|
)
|
|
|
|
|
ctx.gathered_input = input_dict["input"]
|
|
|
|
|
|
|
|
|
|
if bias is not None:
|
|
|
|
|
output += bias
|
|
|
|
|
else:
|
|
|
|
|
input_parallel = _gather(input_, dim, process_group)
|
|
|
|
|
ctx.gathered_input = input_parallel
|
|
|
|
|
if bias is not None:
|
|
|
|
|
output = F.linear(input_parallel, weight, bias)
|
|
|
|
|
else:
|
|
|
|
@ -329,14 +343,12 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
use_bias = ctx.use_bias
|
|
|
|
|
dim = ctx.dim
|
|
|
|
|
process_group = ctx.process_group
|
|
|
|
|
overlap = ctx.overlap
|
|
|
|
|
|
|
|
|
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
|
|
|
|
if use_bias:
|
|
|
|
|
bias = bias.view(bias.shape)
|
|
|
|
|
|
|
|
|
|
if not overlap:
|
|
|
|
|
input_parallel = _gather(input_, dim, process_group)
|
|
|
|
|
input_parallel = ctx.gathered_input
|
|
|
|
|
|
|
|
|
|
total_input = input_parallel
|
|
|
|
|
grad_input = grad_output.matmul(weight)
|
|
|
|
@ -351,9 +363,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
input_list = [
|
|
|
|
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
|
|
|
|
]
|
|
|
|
|
output = torch.empty(
|
|
|
|
|
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
|
|
|
|
).contiguous()
|
|
|
|
|
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
|
|
|
|
|
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
|
|
|
|
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
|
|
|
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
|
|
|
@ -376,53 +386,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
if ctx.async_grad_reduce_scatter:
|
|
|
|
|
handle.wait()
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
input_ = input_.contiguous()
|
|
|
|
|
world_size = dist.get_world_size(process_group)
|
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
|
|
|
|
|
|
|
|
# do all gather in is async way
|
|
|
|
|
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
|
|
|
|
# calculate gradient and prepare data asynchronously with all-gather
|
|
|
|
|
# calculate
|
|
|
|
|
grad_input = grad_output.matmul(weight)
|
|
|
|
|
grad_output = grad_output.contiguous()
|
|
|
|
|
# Convert the tensor shapes to 2D for execution compatibility
|
|
|
|
|
if len(grad_output.shape) > 2:
|
|
|
|
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
|
|
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
|
|
|
# prepare data
|
|
|
|
|
input_list = [
|
|
|
|
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
|
|
|
|
]
|
|
|
|
|
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
|
|
|
|
# wait until all-gather finished
|
|
|
|
|
gather_handle.wait()
|
|
|
|
|
|
|
|
|
|
# do reduce-scatter in async way
|
|
|
|
|
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
|
|
|
|
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
|
|
|
|
# calculate gradient
|
|
|
|
|
if len(input_parallel.shape) > 2:
|
|
|
|
|
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
|
|
|
|
|
|
|
|
|
if _grad_accum_fusion_available and weight.grad is not None:
|
|
|
|
|
grad = weight.grad
|
|
|
|
|
if grad.dtype == torch.float32:
|
|
|
|
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
|
|
|
|
|
grad_weight = None
|
|
|
|
|
elif grad.dtype == torch.float16:
|
|
|
|
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
|
|
|
|
|
grad_weight = None
|
|
|
|
|
else:
|
|
|
|
|
grad_weight = grad_output.t().matmul(input_parallel)
|
|
|
|
|
else:
|
|
|
|
|
grad_weight = grad_output.t().matmul(input_parallel)
|
|
|
|
|
# grad_weight = grad_output.t().matmul(input_parallel)
|
|
|
|
|
# wait until reduce-scatter finished
|
|
|
|
|
reducescatter_handle.wait()
|
|
|
|
|
|
|
|
|
|
return output, grad_weight, grad_bias, None, None, None, None, None
|
|
|
|
|
return output, grad_weight, grad_bias, None, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ring_as_reducescatter(
|
|
|
|
@ -553,7 +517,7 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|
|
|
|
# Convert the tensor shapes to 2D for execution compatibility
|
|
|
|
|
if len(grad_output.shape) > 2:
|
|
|
|
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
|
|
|
|
total_input = total_input.view(-1, total_input.shape[-1])
|
|
|
|
|
total_input = total_input.reshape(-1, total_input.shape[-1])
|
|
|
|
|
grad_weight = grad_output.t().matmul(total_input)
|
|
|
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
|
|
|
|
|
|
|
@ -611,34 +575,30 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(
|
|
|
|
|
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
|
|
|
|
|
):
|
|
|
|
|
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
|
|
|
|
|
ctx.save_for_backward(input_, weight, bias)
|
|
|
|
|
ctx.use_bias = bias is not None
|
|
|
|
|
ctx.process_group = process_group
|
|
|
|
|
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
|
|
|
|
ctx.dim = dim
|
|
|
|
|
ctx.overlap = overlap
|
|
|
|
|
ctx.fp8_communication = fp8_communication
|
|
|
|
|
|
|
|
|
|
if ring is True:
|
|
|
|
|
input_to_gather = {}
|
|
|
|
|
input_local = {}
|
|
|
|
|
input_to_gather["input"] = input_
|
|
|
|
|
input_local["other"] = weight
|
|
|
|
|
input_to_gather = {"input": input_}
|
|
|
|
|
input_local = {"other": weight}
|
|
|
|
|
|
|
|
|
|
output = _ring_as_gather(
|
|
|
|
|
output, input_dict = _ring_as_gather(
|
|
|
|
|
torch.matmul,
|
|
|
|
|
input_to_gather=input_to_gather,
|
|
|
|
|
input_local=input_local,
|
|
|
|
|
process_group=process_group,
|
|
|
|
|
gather_dim=dim,
|
|
|
|
|
)
|
|
|
|
|
ctx.gathered_input = input_dict["input"]
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
|
|
|
|
|
|
|
|
|
|
ctx.gathered_input = input_parallel
|
|
|
|
|
output = torch.matmul(input_parallel, weight)
|
|
|
|
|
|
|
|
|
|
if bias is not None:
|
|
|
|
@ -651,16 +611,13 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
use_bias = ctx.use_bias
|
|
|
|
|
dim = ctx.dim
|
|
|
|
|
process_group = ctx.process_group
|
|
|
|
|
overlap = ctx.overlap
|
|
|
|
|
fp8_communication = ctx.fp8_communication
|
|
|
|
|
|
|
|
|
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
|
|
|
|
weight = weight.view(weight.shape)
|
|
|
|
|
if use_bias:
|
|
|
|
|
bias = bias.view(bias.shape)
|
|
|
|
|
|
|
|
|
|
if not overlap:
|
|
|
|
|
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
|
|
|
|
|
input_parallel = ctx.gathered_input
|
|
|
|
|
|
|
|
|
|
total_input = input_parallel
|
|
|
|
|
grad_input = grad_output.matmul(weight.T)
|
|
|
|
@ -675,9 +632,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
input_list = [
|
|
|
|
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
|
|
|
|
]
|
|
|
|
|
output = torch.empty(
|
|
|
|
|
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
|
|
|
|
).contiguous()
|
|
|
|
|
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
|
|
|
|
|
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
|
|
|
|
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
|
|
|
|
# all-reduce scheduled first and have GPU resources allocated
|
|
|
|
@ -688,39 +643,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
if ctx.async_grad_reduce_scatter:
|
|
|
|
|
handle.wait()
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
world_size = dist.get_world_size(process_group)
|
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
|
|
|
|
|
|
|
|
# do all gather in is async way
|
|
|
|
|
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
|
|
|
|
# calculate gradient and prepare data asynchronously with all-gather
|
|
|
|
|
# calculate
|
|
|
|
|
grad_input = grad_output.matmul(weight.T)
|
|
|
|
|
grad_output = grad_output.contiguous()
|
|
|
|
|
# Convert the tensor shapes to 2D for execution compatibility
|
|
|
|
|
if len(grad_output.shape) > 2:
|
|
|
|
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
|
|
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
|
|
|
# prepare data
|
|
|
|
|
input_list = [
|
|
|
|
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
|
|
|
|
]
|
|
|
|
|
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
|
|
|
|
# wait until all-gather finished
|
|
|
|
|
gather_handle.wait()
|
|
|
|
|
|
|
|
|
|
# do reduce-scatter in async way
|
|
|
|
|
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
|
|
|
|
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
|
|
|
|
# calculate gradient
|
|
|
|
|
if len(input_parallel.shape) > 2:
|
|
|
|
|
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
|
|
|
|
grad_weight = input_parallel.t().matmul(grad_output)
|
|
|
|
|
# wait until reduce-scatter finished
|
|
|
|
|
reducescatter_handle.wait()
|
|
|
|
|
|
|
|
|
|
return output, grad_weight, grad_bias, None, None, None, None, None, None
|
|
|
|
|
return output, grad_weight, grad_bias, None, None, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _SplitForwardGatherBackward(torch.autograd.Function):
|
|
|
|
@ -1050,10 +973,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_gather_forward_reducescatter_backward(
|
|
|
|
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
|
|
|
|
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
|
|
|
|
|
):
|
|
|
|
|
return _LinearWithGatherForwardReduceScatterBackward.apply(
|
|
|
|
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
|
|
|
|
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1070,10 +993,10 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def matmul_gather_forward_reducescatter_backward(
|
|
|
|
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False
|
|
|
|
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False
|
|
|
|
|
):
|
|
|
|
|
return _MatmulWithGatherForwardReduceScatterBackward.apply(
|
|
|
|
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
|
|
|
|
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|