|
|
|
@ -94,7 +94,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
grad_output = grad_output.view(-1, grad_output.shape[-1]) |
|
|
|
|
total_input = total_input.view(-1, total_input.shape[-1]) |
|
|
|
|
|
|
|
|
|
if fp8_communication and ctx.async_grad_allreduce: |
|
|
|
|
if ctx.async_grad_allreduce and fp8_communication: |
|
|
|
|
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication) |
|
|
|
|
elif ctx.async_grad_allreduce: |
|
|
|
|
# Asynchronous all-reduce |
|
|
|
@ -117,12 +117,11 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): |
|
|
|
|
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): |
|
|
|
|
ctx.save_for_backward(input_, weight, bias) |
|
|
|
|
ctx.use_bias = bias is not None |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.async_grad_allreduce = async_grad_allreduce |
|
|
|
|
ctx.fp8_communication = fp8_communication |
|
|
|
|
if bias is not None: |
|
|
|
|
output = F.linear(input_, weight, bias) |
|
|
|
|
else: |
|
|
|
@ -134,7 +133,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
input, weight, bias = ctx.saved_tensors |
|
|
|
|
use_bias = ctx.use_bias |
|
|
|
|
fp8_communication = ctx.fp8_communication |
|
|
|
|
|
|
|
|
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. |
|
|
|
|
if use_bias: |
|
|
|
@ -150,10 +148,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
if ctx.async_grad_allreduce: |
|
|
|
|
# Asynchronous all-reduce |
|
|
|
|
if fp8_communication: |
|
|
|
|
all_reduce_fp8(grad_input, group=ctx.process_group) |
|
|
|
|
else: |
|
|
|
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) |
|
|
|
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) |
|
|
|
|
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have |
|
|
|
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py |
|
|
|
|
|
|
|
|
@ -172,7 +167,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None |
|
|
|
|
|
|
|
|
|
if ctx.async_grad_allreduce and not fp8_communication: |
|
|
|
|
if ctx.async_grad_allreduce: |
|
|
|
|
handle.wait() |
|
|
|
|
|
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None, None |
|
|
|
@ -243,18 +238,16 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, process_group, dim, fp8_communication=False): |
|
|
|
|
def forward(ctx, input_, process_group, dim): |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.dim = dim |
|
|
|
|
ctx.fp8_communication = fp8_communication |
|
|
|
|
|
|
|
|
|
return _gather(input_, dim, process_group, fp8_communication) |
|
|
|
|
return _gather(input_, dim, process_group) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
dim = ctx.dim |
|
|
|
|
process_group = ctx.process_group |
|
|
|
|
fp8_communication = ctx.fp8_communication |
|
|
|
|
# do reduce-scatter |
|
|
|
|
new_shape = list(grad_output.shape) |
|
|
|
|
assert ( |
|
|
|
@ -266,10 +259,7 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
] |
|
|
|
|
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) |
|
|
|
|
|
|
|
|
|
if fp8_communication: |
|
|
|
|
reduce_scatter_fp8(output, grad_list, group=process_group) |
|
|
|
|
else: |
|
|
|
|
dist.reduce_scatter(output, grad_list, group=process_group) |
|
|
|
|
dist.reduce_scatter(output, grad_list, group=process_group) |
|
|
|
|
|
|
|
|
|
return output, None, None, None |
|
|
|
|
|
|
|
|
@ -576,7 +566,6 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|
|
|
|
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] |
|
|
|
|
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) |
|
|
|
|
if fp8_communication: |
|
|
|
|
# if False: |
|
|
|
|
reduce_scatter_fp8(output, input_list, group=process_group) |
|
|
|
|
else: |
|
|
|
|
dist.reduce_scatter(output, input_list, group=process_group) |
|
|
|
@ -588,8 +577,7 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|
|
|
|
dim = ctx.dim |
|
|
|
|
process_group = ctx.process_group |
|
|
|
|
fp8_communication = ctx.fp8_communication |
|
|
|
|
|
|
|
|
|
return _gather(grad_output, dim, process_group, fp8_communication), None, None, None |
|
|
|
|
return _gather(grad_output, dim, process_group, fp8_communication=fp8_communication), None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): |
|
|
|
@ -793,12 +781,12 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_comm=False): |
|
|
|
|
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False): |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.dim = dim |
|
|
|
|
ctx.grad_scale = grad_scale |
|
|
|
|
|
|
|
|
|
return _gather(input_, dim, process_group, fp8_comm=fp8_comm, fp8_format="e4m3") |
|
|
|
|
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3") |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
@ -829,11 +817,23 @@ class _AllToAll(torch.autograd.Function):
|
|
|
|
|
# using all_to_all_single when batch size is 1 |
|
|
|
|
if bsz == 1: |
|
|
|
|
return _all_to_all_single( |
|
|
|
|
input_, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" |
|
|
|
|
input_, |
|
|
|
|
world_size, |
|
|
|
|
process_group, |
|
|
|
|
scatter_dim, |
|
|
|
|
gather_dim, |
|
|
|
|
fp8_communication=fp8_communication, |
|
|
|
|
fp8_format="e5m2", |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
return _all_to_all( |
|
|
|
|
input_, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" |
|
|
|
|
input_, |
|
|
|
|
world_size, |
|
|
|
|
process_group, |
|
|
|
|
scatter_dim, |
|
|
|
|
gather_dim, |
|
|
|
|
fp8_communication=fp8_communication, |
|
|
|
|
fp8_format="e5m2", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
@ -841,17 +841,29 @@ class _AllToAll(torch.autograd.Function):
|
|
|
|
|
process_group = ctx.process_group |
|
|
|
|
scatter_dim = ctx.gather_dim |
|
|
|
|
gather_dim = ctx.scatter_dim |
|
|
|
|
ctx.fp8_communication |
|
|
|
|
fp8_communication = ctx.fp8_communication |
|
|
|
|
world_size = dist.get_world_size(process_group) |
|
|
|
|
bsz, _, _ = grad_output.shape |
|
|
|
|
|
|
|
|
|
if bsz == 1: |
|
|
|
|
return_grad = _all_to_all_single( |
|
|
|
|
grad_output, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" |
|
|
|
|
grad_output, |
|
|
|
|
world_size, |
|
|
|
|
process_group, |
|
|
|
|
scatter_dim, |
|
|
|
|
gather_dim, |
|
|
|
|
fp8_communication=fp8_communication, |
|
|
|
|
fp8_format="e5m2", |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
return_grad = _all_to_all( |
|
|
|
|
grad_output, world_size, process_group, scatter_dim, gather_dim, fp8_comm=fp8_comm, fp8_format="e5m2" |
|
|
|
|
grad_output, |
|
|
|
|
world_size, |
|
|
|
|
process_group, |
|
|
|
|
scatter_dim, |
|
|
|
|
gather_dim, |
|
|
|
|
fp8_communication=fp8_communication, |
|
|
|
|
fp8_format="e5m2", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return (return_grad, None, None, None, None) |
|
|
|
@ -912,10 +924,7 @@ def _split(input_, dim=-1, process_group=None):
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from colossalai.params import to_cast |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gather(input_, dim=-1, process_group=None, fp8_comm=False, fp8_format="e4m3"): |
|
|
|
|
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e4m3"): |
|
|
|
|
# skip if only one rank involved |
|
|
|
|
world_size = dist.get_world_size(process_group) |
|
|
|
|
if world_size == 1: |
|
|
|
@ -926,13 +935,12 @@ def _gather(input_, dim=-1, process_group=None, fp8_comm=False, fp8_format="e4m3
|
|
|
|
|
|
|
|
|
|
from colossalai.zero.low_level._utils import has_inf_or_nan |
|
|
|
|
|
|
|
|
|
if fp8_comm: |
|
|
|
|
if fp8_communication: |
|
|
|
|
# if False: |
|
|
|
|
if has_inf_or_nan(input_): |
|
|
|
|
print("input has nan") |
|
|
|
|
exit(0) |
|
|
|
|
input_type = input_.dtype |
|
|
|
|
to_cast.append(input_) |
|
|
|
|
ret, scale = cast_to_fp8(input_, fp8_format="e5m2") |
|
|
|
|
if has_inf_or_nan(ret): |
|
|
|
|
import pdb |
|
|
|
@ -1012,8 +1020,8 @@ def _reduce_scatter(input_, dim=1, process_group=None):
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_comm=False, fp8_format="e5m2"): |
|
|
|
|
if fp8_comm: |
|
|
|
|
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"): |
|
|
|
|
if fp8_communication: |
|
|
|
|
input_type = input_.dtype |
|
|
|
|
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) |
|
|
|
|
fp8_type = ret.dtype |
|
|
|
@ -1036,7 +1044,9 @@ def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_comm=Fal
|
|
|
|
|
return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim, fp8_comm=False, fp8_format="e5m2"): |
|
|
|
|
def _all_to_all_single( |
|
|
|
|
input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2" |
|
|
|
|
): |
|
|
|
|
inp_shape = list(input_.shape) |
|
|
|
|
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size |
|
|
|
|
if scatter_dim < 2: |
|
|
|
@ -1048,7 +1058,7 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim, f
|
|
|
|
|
.contiguous() |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if fp8_comm: |
|
|
|
|
if fp8_communication: |
|
|
|
|
input_type = input_t.dtype |
|
|
|
|
ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format) |
|
|
|
|
fp8_type = ret.dtype |
|
|
|
@ -1085,10 +1095,8 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): |
|
|
|
|
return LinearWithAsyncCommunication.apply( |
|
|
|
|
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication |
|
|
|
|
) |
|
|
|
|
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): |
|
|
|
|
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_gather_forward_reducescatter_backward( |
|
|
|
@ -1099,8 +1107,8 @@ def linear_gather_forward_reducescatter_backward(
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False): |
|
|
|
|
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication) |
|
|
|
|
def gather_forward_reducescatter_backward(input_, process_group, dim): |
|
|
|
|
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False): |
|
|
|
@ -1132,8 +1140,8 @@ def reduce_forward(input_, process_group, fp8_communication=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_backward(input_, process_group, fp8_communication=False): |
|
|
|
|
return _ReduceBackward.apply(input_, process_group, fp8_communication=fp8_communication) |
|
|
|
|
return _ReduceBackward.apply(input_, process_group, fp8_communication) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_comm=False): |
|
|
|
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_comm) |
|
|
|
|
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False): |
|
|
|
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) |
|
|
|
|