|
|
|
@ -14,6 +14,8 @@ try:
|
|
|
|
|
except ImportError: |
|
|
|
|
_grad_accum_fusion_available = False |
|
|
|
|
|
|
|
|
|
from colossalai.quantization.fp8 import all_reduce_fp8, cast_from_fp8, cast_to_fp8, reduce_scatter_fp8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FusedLayerNormAffineFunction1D(torch.autograd.Function): |
|
|
|
|
r"""Layernorm |
|
|
|
@ -59,11 +61,12 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): |
|
|
|
|
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): |
|
|
|
|
ctx.save_for_backward(input_, weight, bias) |
|
|
|
|
ctx.use_bias = bias is not None |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.async_grad_allreduce = async_grad_allreduce |
|
|
|
|
ctx.fp8_communication = fp8_communication |
|
|
|
|
|
|
|
|
|
output = torch.matmul(input_, weight) |
|
|
|
|
|
|
|
|
@ -76,6 +79,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
input, weight, bias = ctx.saved_tensors |
|
|
|
|
use_bias = ctx.use_bias |
|
|
|
|
fp8_communication = ctx.fp8_communication |
|
|
|
|
|
|
|
|
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. |
|
|
|
|
weight = weight.view(weight.shape) |
|
|
|
@ -90,7 +94,9 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
grad_output = grad_output.view(-1, grad_output.shape[-1]) |
|
|
|
|
total_input = total_input.view(-1, total_input.shape[-1]) |
|
|
|
|
|
|
|
|
|
if ctx.async_grad_allreduce: |
|
|
|
|
if fp8_communication and ctx.async_grad_allreduce: |
|
|
|
|
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication) |
|
|
|
|
elif ctx.async_grad_allreduce: |
|
|
|
|
# Asynchronous all-reduce |
|
|
|
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) |
|
|
|
|
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have |
|
|
|
@ -99,10 +105,10 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
grad_weight = total_input.t().matmul(grad_output) |
|
|
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None |
|
|
|
|
|
|
|
|
|
if ctx.async_grad_allreduce: |
|
|
|
|
if ctx.async_grad_allreduce and not fp8_communication: |
|
|
|
|
handle.wait() |
|
|
|
|
|
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None |
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LinearWithAsyncCommunication(torch.autograd.Function): |
|
|
|
@ -111,11 +117,12 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): |
|
|
|
|
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): |
|
|
|
|
ctx.save_for_backward(input_, weight, bias) |
|
|
|
|
ctx.use_bias = bias is not None |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.async_grad_allreduce = async_grad_allreduce |
|
|
|
|
ctx.fp8_communication = fp8_communication |
|
|
|
|
if bias is not None: |
|
|
|
|
output = F.linear(input_, weight, bias) |
|
|
|
|
else: |
|
|
|
@ -127,6 +134,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
input, weight, bias = ctx.saved_tensors |
|
|
|
|
use_bias = ctx.use_bias |
|
|
|
|
fp8_communication = ctx.fp8_communication |
|
|
|
|
|
|
|
|
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. |
|
|
|
|
if use_bias: |
|
|
|
@ -142,7 +150,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
if ctx.async_grad_allreduce: |
|
|
|
|
# Asynchronous all-reduce |
|
|
|
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) |
|
|
|
|
if fp8_communication: |
|
|
|
|
all_reduce_fp8(grad_input, group=ctx.process_group) |
|
|
|
|
else: |
|
|
|
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) |
|
|
|
|
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have |
|
|
|
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py |
|
|
|
|
|
|
|
|
@ -161,10 +172,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None |
|
|
|
|
|
|
|
|
|
if ctx.async_grad_allreduce: |
|
|
|
|
if ctx.async_grad_allreduce and not fp8_communication: |
|
|
|
|
handle.wait() |
|
|
|
|
|
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None |
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): |
|
|
|
@ -232,17 +243,18 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, process_group, dim): |
|
|
|
|
def forward(ctx, input_, process_group, dim, fp8_communication=False): |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.dim = dim |
|
|
|
|
ctx.fp8_communication = fp8_communication |
|
|
|
|
|
|
|
|
|
return _gather(input_, dim, process_group) |
|
|
|
|
return _gather(input_, dim, process_group, fp8_communication) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
dim = ctx.dim |
|
|
|
|
process_group = ctx.process_group |
|
|
|
|
|
|
|
|
|
fp8_communication = ctx.fp8_communication |
|
|
|
|
# do reduce-scatter |
|
|
|
|
new_shape = list(grad_output.shape) |
|
|
|
|
assert ( |
|
|
|
@ -253,9 +265,13 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim) |
|
|
|
|
] |
|
|
|
|
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) |
|
|
|
|
dist.reduce_scatter(output, grad_list, group=process_group) |
|
|
|
|
|
|
|
|
|
return output, None, None |
|
|
|
|
if fp8_communication: |
|
|
|
|
reduce_scatter_fp8(output, grad_list, group=process_group) |
|
|
|
|
else: |
|
|
|
|
dist.reduce_scatter(output, grad_list, group=process_group) |
|
|
|
|
|
|
|
|
|
return output, None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): |
|
|
|
@ -546,9 +562,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, process_group, dim): |
|
|
|
|
def forward(ctx, input_, process_group, dim, fp8_communication=False): |
|
|
|
|
ctx.dim = dim |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.fp8_communication = fp8_communication |
|
|
|
|
|
|
|
|
|
# do reduce-scatter |
|
|
|
|
new_shape = list(input_.shape) |
|
|
|
@ -558,7 +575,11 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|
|
|
|
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) |
|
|
|
|
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] |
|
|
|
|
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) |
|
|
|
|
dist.reduce_scatter(output, input_list, group=process_group) |
|
|
|
|
if fp8_communication: |
|
|
|
|
# if False: |
|
|
|
|
reduce_scatter_fp8(output, input_list, group=process_group) |
|
|
|
|
else: |
|
|
|
|
dist.reduce_scatter(output, input_list, group=process_group) |
|
|
|
|
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
@ -566,8 +587,9 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
dim = ctx.dim |
|
|
|
|
process_group = ctx.process_group |
|
|
|
|
fp8_communication = ctx.fp8_communication |
|
|
|
|
|
|
|
|
|
return _gather(grad_output, dim, process_group), None, None |
|
|
|
|
return _gather(grad_output, dim, process_group, fp8_communication), None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): |
|
|
|
@ -582,13 +604,16 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring): |
|
|
|
|
def forward( |
|
|
|
|
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication |
|
|
|
|
): |
|
|
|
|
ctx.save_for_backward(input_, weight, bias) |
|
|
|
|
ctx.use_bias = bias is not None |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter |
|
|
|
|
ctx.dim = dim |
|
|
|
|
ctx.overlap = overlap |
|
|
|
|
ctx.fp8_communication = fp8_communication |
|
|
|
|
|
|
|
|
|
if ring is True: |
|
|
|
|
input_to_gather = {} |
|
|
|
@ -605,7 +630,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
input_parallel = _gather(input_, dim, process_group) |
|
|
|
|
input_parallel = _gather(input_, dim, process_group, fp8_communication) |
|
|
|
|
|
|
|
|
|
output = torch.matmul(input_parallel, weight) |
|
|
|
|
|
|
|
|
@ -620,6 +645,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
dim = ctx.dim |
|
|
|
|
process_group = ctx.process_group |
|
|
|
|
overlap = ctx.overlap |
|
|
|
|
fp8_communication = ctx.fp8_communication |
|
|
|
|
|
|
|
|
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm |
|
|
|
|
weight = weight.view(weight.shape) |
|
|
|
@ -627,7 +653,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
bias = bias.view(bias.shape) |
|
|
|
|
|
|
|
|
|
if not overlap: |
|
|
|
|
input_parallel = _gather(input_, dim, process_group) |
|
|
|
|
input_parallel = _gather(input_, dim, process_group, fp8_communication) |
|
|
|
|
|
|
|
|
|
total_input = input_parallel |
|
|
|
|
grad_input = grad_output.matmul(weight.T) |
|
|
|
@ -687,7 +713,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
# wait until reduce-scatter finished |
|
|
|
|
reducescatter_handle.wait() |
|
|
|
|
|
|
|
|
|
return output, grad_weight, grad_bias, None, None, None, None, None |
|
|
|
|
return output, grad_weight, grad_bias, None, None, None, None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _SplitForwardGatherBackward(torch.autograd.Function): |
|
|
|
@ -702,17 +728,20 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, dim, process_group, grad_scale=None): |
|
|
|
|
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False): |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.dim = dim |
|
|
|
|
ctx.grad_scale = grad_scale |
|
|
|
|
ctx.fp8_communication = fp8_communication |
|
|
|
|
return _split(input_, dim, process_group) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
if ctx.grad_scale is not None: |
|
|
|
|
grad_output = grad_output * ctx.grad_scale |
|
|
|
|
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None |
|
|
|
|
|
|
|
|
|
# to_cast.append(grad_output.cpu().detach().numpy()) |
|
|
|
|
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, "e4m3"), None, None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ReduceForward(torch.autograd.Function): |
|
|
|
@ -725,12 +754,12 @@ class _ReduceForward(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, process_group): |
|
|
|
|
return _reduce(input_, process_group) |
|
|
|
|
def forward(ctx, input_, process_group, fp8_communication=False): |
|
|
|
|
return _reduce(input_, process_group, fp8_communication) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
return grad_output, None |
|
|
|
|
return grad_output, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ReduceBackward(torch.autograd.Function): |
|
|
|
@ -743,13 +772,15 @@ class _ReduceBackward(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, process_group): |
|
|
|
|
def forward(ctx, input_, process_group, fp8_communication=False): |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.fp8_communication = fp8_communication |
|
|
|
|
return input_ |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
return _reduce(grad_output, ctx.process_group), None |
|
|
|
|
fp8_communication = ctx.fp8_communication |
|
|
|
|
return _reduce(grad_output, ctx.process_group, fp8_communication), None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _GatherForwardSplitBackward(torch.autograd.Function): |
|
|
|
@ -762,17 +793,18 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, dim, process_group, grad_scale=None): |
|
|
|
|
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_comm=False): |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.dim = dim |
|
|
|
|
ctx.grad_scale = grad_scale |
|
|
|
|
return _gather(input_, dim, process_group) |
|
|
|
|
|
|
|
|
|
return _gather(input_, dim, process_group, fp8_comm=fp8_comm, fp8_format="e4m3") |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
if ctx.grad_scale is not None: |
|
|
|
|
grad_output = grad_output * ctx.grad_scale |
|
|
|
|
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None |
|
|
|
|
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _AllToAll(torch.autograd.Function): |
|
|
|
@ -786,26 +818,43 @@ class _AllToAll(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, process_group, scatter_dim, gather_dim): |
|
|
|
|
def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication): |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.scatter_dim = scatter_dim |
|
|
|
|
ctx.gather_dim = gather_dim |
|
|
|
|
ctx.fp8_communication = fp8_communication |
|
|
|
|
world_size = dist.get_world_size(process_group) |
|
|
|
|
bsz, _, _ = input_.shape |
|
|
|
|
|
|
|
|
|
# using all_to_all_single when batch size is 1 |
|
|
|
|
if bsz == 1: |
|
|
|
|
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim) |
|
|
|
|
return _all_to_all_single( |
|
|
|
|
input_, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) |
|
|
|
|
return _all_to_all( |
|
|
|
|
input_, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, *grad_output): |
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
process_group = ctx.process_group |
|
|
|
|
scatter_dim = ctx.gather_dim |
|
|
|
|
gather_dim = ctx.scatter_dim |
|
|
|
|
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) |
|
|
|
|
return (return_grad, None, None, None) |
|
|
|
|
ctx.fp8_communication |
|
|
|
|
world_size = dist.get_world_size(process_group) |
|
|
|
|
bsz, _, _ = grad_output.shape |
|
|
|
|
|
|
|
|
|
if bsz == 1: |
|
|
|
|
return_grad = _all_to_all_single( |
|
|
|
|
grad_output, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
return_grad = _all_to_all( |
|
|
|
|
grad_output, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return (return_grad, None, None, None, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HookParameter(torch.autograd.Function): |
|
|
|
@ -831,12 +880,15 @@ def hook_parameter_in_backward(input, weight=None, bias=None):
|
|
|
|
|
return HookParameter.apply(input, weight, bias) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _reduce(input_, process_group): |
|
|
|
|
def _reduce(input_, process_group, fp8_communication=False): |
|
|
|
|
# skip if only one rank involved |
|
|
|
|
if dist.get_world_size(process_group) == 1: |
|
|
|
|
return input_ |
|
|
|
|
else: |
|
|
|
|
dist.all_reduce(input_, group=process_group) |
|
|
|
|
if fp8_communication: |
|
|
|
|
all_reduce_fp8(input_, group=process_group) |
|
|
|
|
else: |
|
|
|
|
dist.all_reduce(input_, group=process_group) |
|
|
|
|
return input_ |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -860,19 +912,78 @@ def _split(input_, dim=-1, process_group=None):
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gather(input_, dim=-1, process_group=None): |
|
|
|
|
from colossalai.params import to_cast |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gather(input_, dim=-1, process_group=None, fp8_comm=False, fp8_format="e4m3"): |
|
|
|
|
# skip if only one rank involved |
|
|
|
|
world_size = dist.get_world_size(process_group) |
|
|
|
|
if world_size == 1: |
|
|
|
|
return input_ |
|
|
|
|
|
|
|
|
|
# all gather |
|
|
|
|
input_ = input_.contiguous() |
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
|
|
|
torch.distributed.all_gather(tensor_list, input_, group=process_group) |
|
|
|
|
import torch.distributed as dista |
|
|
|
|
|
|
|
|
|
from colossalai.zero.low_level._utils import has_inf_or_nan |
|
|
|
|
|
|
|
|
|
if fp8_comm: |
|
|
|
|
# if False: |
|
|
|
|
if has_inf_or_nan(input_): |
|
|
|
|
print("input has nan") |
|
|
|
|
exit(0) |
|
|
|
|
input_type = input_.dtype |
|
|
|
|
to_cast.append(input_) |
|
|
|
|
ret, scale = cast_to_fp8(input_, fp8_format="e5m2") |
|
|
|
|
if has_inf_or_nan(ret): |
|
|
|
|
import pdb |
|
|
|
|
|
|
|
|
|
pdb.set_trace() |
|
|
|
|
print("cast has nan") |
|
|
|
|
# exit(0) |
|
|
|
|
dista.barrier() |
|
|
|
|
fp8_type = ret.dtype |
|
|
|
|
input_ = ret.view(torch.uint8) |
|
|
|
|
input_ = input_.contiguous() |
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
|
|
|
scale = torch.tensor(scale, dtype=torch.float32).to(input_.device) |
|
|
|
|
# import torch.distributed as dista |
|
|
|
|
# if dista.get_rank()==0: |
|
|
|
|
# import pdb |
|
|
|
|
# pdb.set_trace() |
|
|
|
|
# dista.barrier() |
|
|
|
|
scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)] |
|
|
|
|
|
|
|
|
|
scale = torch.tensor(scale).to(input_.device) |
|
|
|
|
torch.distributed.all_gather(tensor_list, input_, group=process_group) |
|
|
|
|
torch.distributed.all_gather(scale_list, scale, group=process_group) |
|
|
|
|
|
|
|
|
|
cast_tensor_list = [] |
|
|
|
|
for output, scale in zip(tensor_list, scale_list): |
|
|
|
|
output = output.view(fp8_type) |
|
|
|
|
output = cast_from_fp8(output, scale, input_type) |
|
|
|
|
if has_inf_or_nan(output) and dista.get_rank() == 0: |
|
|
|
|
print("casted_output has nan") |
|
|
|
|
import pdb |
|
|
|
|
|
|
|
|
|
pdb.set_trace() |
|
|
|
|
dista.barrier() |
|
|
|
|
|
|
|
|
|
cast_tensor_list.append(output) |
|
|
|
|
|
|
|
|
|
output = torch.cat(cast_tensor_list, dim=dim).contiguous() |
|
|
|
|
|
|
|
|
|
if has_inf_or_nan(output): |
|
|
|
|
print("output has nan") |
|
|
|
|
exit(0) |
|
|
|
|
# import pdb |
|
|
|
|
# pdb.set_trace() |
|
|
|
|
dista.barrier() |
|
|
|
|
|
|
|
|
|
# concat |
|
|
|
|
output = torch.cat(tensor_list, dim=dim).contiguous() |
|
|
|
|
else: |
|
|
|
|
input_ = input_.contiguous() |
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
|
|
|
torch.distributed.all_gather(tensor_list, input_, group=process_group) |
|
|
|
|
output = torch.cat(tensor_list, dim=dim).contiguous() |
|
|
|
|
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
@ -901,14 +1012,31 @@ def _reduce_scatter(input_, dim=1, process_group=None):
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): |
|
|
|
|
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] |
|
|
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
|
|
|
|
dist.all_to_all(output_list, input_list, group=group) |
|
|
|
|
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_comm=False, fp8_format="e5m2"): |
|
|
|
|
if fp8_comm: |
|
|
|
|
input_type = input_.dtype |
|
|
|
|
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) |
|
|
|
|
fp8_type = ret.dtype |
|
|
|
|
input_ = ret.view(torch.uint8) |
|
|
|
|
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] |
|
|
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
|
|
|
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] |
|
|
|
|
dist.all_to_all(output_list, input_list, group=group) |
|
|
|
|
dist.all_gather(scale_list, scale, group=group) |
|
|
|
|
cast_tensor_list = [] |
|
|
|
|
for output, scale in zip(output_list, scale_list): |
|
|
|
|
output = output.view(fp8_type) |
|
|
|
|
output = cast_from_fp8(output, scale, input_type) |
|
|
|
|
cast_tensor_list.append(output) |
|
|
|
|
output_list = cast_tensor_list |
|
|
|
|
else: |
|
|
|
|
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] |
|
|
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
|
|
|
|
dist.all_to_all(output_list, input_list, group=group) |
|
|
|
|
return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): |
|
|
|
|
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim, fp8_comm=False, fp8_format="e5m2"): |
|
|
|
|
inp_shape = list(input_.shape) |
|
|
|
|
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size |
|
|
|
|
if scatter_dim < 2: |
|
|
|
@ -920,8 +1048,24 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
|
|
|
|
.contiguous() |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
output = torch.empty_like(input_t) |
|
|
|
|
dist.all_to_all_single(output, input_t, group=group) |
|
|
|
|
if fp8_comm: |
|
|
|
|
input_type = input_t.dtype |
|
|
|
|
ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format) |
|
|
|
|
fp8_type = ret.dtype |
|
|
|
|
input_t = ret.view(torch.uint8) |
|
|
|
|
output = torch.empty_like(input_t) |
|
|
|
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(seq_world_size)] |
|
|
|
|
dist.all_to_all_single(output, input_t, group=group) |
|
|
|
|
dist.all_gather(scale_list, scale, group=group) |
|
|
|
|
cast_tensor_list = [] |
|
|
|
|
for output_part, scale in zip(output, scale_list): |
|
|
|
|
output_part = output_part.view(fp8_type) |
|
|
|
|
output_part = cast_from_fp8(output_part, scale, input_type) |
|
|
|
|
cast_tensor_list.append(output_part) |
|
|
|
|
output = torch.stack(cast_tensor_list, dim=0) |
|
|
|
|
else: |
|
|
|
|
output = torch.empty_like(input_t) |
|
|
|
|
dist.all_to_all_single(output, input_t, group=group) |
|
|
|
|
|
|
|
|
|
if scatter_dim < 2: |
|
|
|
|
output = output.transpose(0, 1).contiguous() |
|
|
|
@ -935,12 +1079,16 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
|
|
|
|
).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): |
|
|
|
|
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) |
|
|
|
|
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): |
|
|
|
|
return MatmulWithAsyncCommunication.apply( |
|
|
|
|
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): |
|
|
|
|
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) |
|
|
|
|
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): |
|
|
|
|
return LinearWithAsyncCommunication.apply( |
|
|
|
|
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_gather_forward_reducescatter_backward( |
|
|
|
@ -951,12 +1099,12 @@ def linear_gather_forward_reducescatter_backward(
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gather_forward_reducescatter_backward(input_, process_group, dim): |
|
|
|
|
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim) |
|
|
|
|
def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False): |
|
|
|
|
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reducescatter_forward_gather_backward(input_, process_group, dim): |
|
|
|
|
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim) |
|
|
|
|
def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False): |
|
|
|
|
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False): |
|
|
|
@ -964,28 +1112,28 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def matmul_gather_forward_reducescatter_backward( |
|
|
|
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False |
|
|
|
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False |
|
|
|
|
): |
|
|
|
|
return _MatmulWithGatherForwardReduceScatterBackward.apply( |
|
|
|
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring |
|
|
|
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None): |
|
|
|
|
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale) |
|
|
|
|
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False): |
|
|
|
|
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None): |
|
|
|
|
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale) |
|
|
|
|
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False): |
|
|
|
|
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_forward(input_, process_group): |
|
|
|
|
return _ReduceForward.apply(input_, process_group) |
|
|
|
|
def reduce_forward(input_, process_group, fp8_communication=False): |
|
|
|
|
return _ReduceForward.apply(input_, process_group, fp8_communication) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_backward(input_, process_group): |
|
|
|
|
return _ReduceBackward.apply(input_, process_group) |
|
|
|
|
def reduce_backward(input_, process_group, fp8_communication=False): |
|
|
|
|
return _ReduceBackward.apply(input_, process_group, fp8_communication=fp8_communication) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): |
|
|
|
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) |
|
|
|
|
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_comm=False): |
|
|
|
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_comm) |
|
|
|
|