|
|
|
@ -170,7 +170,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
|
if ctx.async_grad_allreduce: |
|
|
|
|
handle.wait() |
|
|
|
|
|
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None, None |
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): |
|
|
|
@ -261,7 +261,7 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
dist.reduce_scatter(output, grad_list, group=process_group) |
|
|
|
|
|
|
|
|
|
return output, None, None, None |
|
|
|
|
return output, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): |
|
|
|
@ -729,7 +729,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
|
|
|
|
grad_output = grad_output * ctx.grad_scale |
|
|
|
|
|
|
|
|
|
# to_cast.append(grad_output.cpu().detach().numpy()) |
|
|
|
|
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, "e4m3"), None, None, None, None |
|
|
|
|
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ReduceForward(torch.autograd.Function): |
|
|
|
@ -786,7 +786,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
|
|
|
|
ctx.dim = dim |
|
|
|
|
ctx.grad_scale = grad_scale |
|
|
|
|
|
|
|
|
|
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3") |
|
|
|
|
return _gather(input_, dim, process_group, fp8_communication=fp8_communication) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
@ -806,67 +806,26 @@ class _AllToAll(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication): |
|
|
|
|
def forward(ctx, input_, process_group, scatter_dim, gather_dim): |
|
|
|
|
ctx.process_group = process_group |
|
|
|
|
ctx.scatter_dim = scatter_dim |
|
|
|
|
ctx.gather_dim = gather_dim |
|
|
|
|
ctx.fp8_communication = fp8_communication |
|
|
|
|
world_size = dist.get_world_size(process_group) |
|
|
|
|
bsz, _, _ = input_.shape |
|
|
|
|
|
|
|
|
|
# using all_to_all_single when batch size is 1 |
|
|
|
|
if bsz == 1: |
|
|
|
|
return _all_to_all_single( |
|
|
|
|
input_, |
|
|
|
|
world_size, |
|
|
|
|
process_group, |
|
|
|
|
scatter_dim, |
|
|
|
|
gather_dim, |
|
|
|
|
fp8_communication=fp8_communication, |
|
|
|
|
fp8_format="e5m2", |
|
|
|
|
) |
|
|
|
|
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim) |
|
|
|
|
else: |
|
|
|
|
return _all_to_all( |
|
|
|
|
input_, |
|
|
|
|
world_size, |
|
|
|
|
process_group, |
|
|
|
|
scatter_dim, |
|
|
|
|
gather_dim, |
|
|
|
|
fp8_communication=fp8_communication, |
|
|
|
|
fp8_format="e5m2", |
|
|
|
|
) |
|
|
|
|
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
def backward(ctx, *grad_output): |
|
|
|
|
process_group = ctx.process_group |
|
|
|
|
scatter_dim = ctx.gather_dim |
|
|
|
|
gather_dim = ctx.scatter_dim |
|
|
|
|
fp8_communication = ctx.fp8_communication |
|
|
|
|
world_size = dist.get_world_size(process_group) |
|
|
|
|
bsz, _, _ = grad_output.shape |
|
|
|
|
|
|
|
|
|
if bsz == 1: |
|
|
|
|
return_grad = _all_to_all_single( |
|
|
|
|
grad_output, |
|
|
|
|
world_size, |
|
|
|
|
process_group, |
|
|
|
|
scatter_dim, |
|
|
|
|
gather_dim, |
|
|
|
|
fp8_communication=fp8_communication, |
|
|
|
|
fp8_format="e5m2", |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
return_grad = _all_to_all( |
|
|
|
|
grad_output, |
|
|
|
|
world_size, |
|
|
|
|
process_group, |
|
|
|
|
scatter_dim, |
|
|
|
|
gather_dim, |
|
|
|
|
fp8_communication=fp8_communication, |
|
|
|
|
fp8_format="e5m2", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return (return_grad, None, None, None, None) |
|
|
|
|
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) |
|
|
|
|
return (return_grad, None, None, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HookParameter(torch.autograd.Function): |
|
|
|
@ -924,41 +883,20 @@ def _split(input_, dim=-1, process_group=None):
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e4m3"): |
|
|
|
|
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"): |
|
|
|
|
# skip if only one rank involved |
|
|
|
|
world_size = dist.get_world_size(process_group) |
|
|
|
|
if world_size == 1: |
|
|
|
|
return input_ |
|
|
|
|
|
|
|
|
|
# all gather |
|
|
|
|
import torch.distributed as dista |
|
|
|
|
|
|
|
|
|
from colossalai.zero.low_level._utils import has_inf_or_nan |
|
|
|
|
|
|
|
|
|
if fp8_communication: |
|
|
|
|
# if False: |
|
|
|
|
if has_inf_or_nan(input_): |
|
|
|
|
print("input has nan") |
|
|
|
|
exit(0) |
|
|
|
|
input_type = input_.dtype |
|
|
|
|
ret, scale = cast_to_fp8(input_, fp8_format="e5m2") |
|
|
|
|
if has_inf_or_nan(ret): |
|
|
|
|
import pdb |
|
|
|
|
|
|
|
|
|
pdb.set_trace() |
|
|
|
|
print("cast has nan") |
|
|
|
|
# exit(0) |
|
|
|
|
dista.barrier() |
|
|
|
|
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) |
|
|
|
|
fp8_type = ret.dtype |
|
|
|
|
input_ = ret.view(torch.uint8) |
|
|
|
|
input_ = input_.contiguous() |
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
|
|
|
scale = torch.tensor(scale, dtype=torch.float32).to(input_.device) |
|
|
|
|
# import torch.distributed as dista |
|
|
|
|
# if dista.get_rank()==0: |
|
|
|
|
# import pdb |
|
|
|
|
# pdb.set_trace() |
|
|
|
|
# dista.barrier() |
|
|
|
|
scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)] |
|
|
|
|
|
|
|
|
|
scale = torch.tensor(scale).to(input_.device) |
|
|
|
@ -969,24 +907,10 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for
|
|
|
|
|
for output, scale in zip(tensor_list, scale_list): |
|
|
|
|
output = output.view(fp8_type) |
|
|
|
|
output = cast_from_fp8(output, scale, input_type) |
|
|
|
|
if has_inf_or_nan(output) and dista.get_rank() == 0: |
|
|
|
|
print("casted_output has nan") |
|
|
|
|
import pdb |
|
|
|
|
|
|
|
|
|
pdb.set_trace() |
|
|
|
|
dista.barrier() |
|
|
|
|
|
|
|
|
|
cast_tensor_list.append(output) |
|
|
|
|
|
|
|
|
|
output = torch.cat(cast_tensor_list, dim=dim).contiguous() |
|
|
|
|
|
|
|
|
|
if has_inf_or_nan(output): |
|
|
|
|
print("output has nan") |
|
|
|
|
exit(0) |
|
|
|
|
# import pdb |
|
|
|
|
# pdb.set_trace() |
|
|
|
|
dista.barrier() |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
input_ = input_.contiguous() |
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
|
|
@ -1020,33 +944,14 @@ def _reduce_scatter(input_, dim=1, process_group=None):
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"): |
|
|
|
|
if fp8_communication: |
|
|
|
|
input_type = input_.dtype |
|
|
|
|
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) |
|
|
|
|
fp8_type = ret.dtype |
|
|
|
|
input_ = ret.view(torch.uint8) |
|
|
|
|
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] |
|
|
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
|
|
|
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] |
|
|
|
|
dist.all_to_all(output_list, input_list, group=group) |
|
|
|
|
dist.all_gather(scale_list, scale, group=group) |
|
|
|
|
cast_tensor_list = [] |
|
|
|
|
for output, scale in zip(output_list, scale_list): |
|
|
|
|
output = output.view(fp8_type) |
|
|
|
|
output = cast_from_fp8(output, scale, input_type) |
|
|
|
|
cast_tensor_list.append(output) |
|
|
|
|
output_list = cast_tensor_list |
|
|
|
|
else: |
|
|
|
|
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] |
|
|
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
|
|
|
|
dist.all_to_all(output_list, input_list, group=group) |
|
|
|
|
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): |
|
|
|
|
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] |
|
|
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
|
|
|
|
dist.all_to_all(output_list, input_list, group=group) |
|
|
|
|
return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _all_to_all_single( |
|
|
|
|
input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2" |
|
|
|
|
): |
|
|
|
|
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): |
|
|
|
|
inp_shape = list(input_.shape) |
|
|
|
|
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size |
|
|
|
|
if scatter_dim < 2: |
|
|
|
@ -1058,24 +963,8 @@ def _all_to_all_single(
|
|
|
|
|
.contiguous() |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if fp8_communication: |
|
|
|
|
input_type = input_t.dtype |
|
|
|
|
ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format) |
|
|
|
|
fp8_type = ret.dtype |
|
|
|
|
input_t = ret.view(torch.uint8) |
|
|
|
|
output = torch.empty_like(input_t) |
|
|
|
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(seq_world_size)] |
|
|
|
|
dist.all_to_all_single(output, input_t, group=group) |
|
|
|
|
dist.all_gather(scale_list, scale, group=group) |
|
|
|
|
cast_tensor_list = [] |
|
|
|
|
for output_part, scale in zip(output, scale_list): |
|
|
|
|
output_part = output_part.view(fp8_type) |
|
|
|
|
output_part = cast_from_fp8(output_part, scale, input_type) |
|
|
|
|
cast_tensor_list.append(output_part) |
|
|
|
|
output = torch.stack(cast_tensor_list, dim=0) |
|
|
|
|
else: |
|
|
|
|
output = torch.empty_like(input_t) |
|
|
|
|
dist.all_to_all_single(output, input_t, group=group) |
|
|
|
|
output = torch.empty_like(input_t) |
|
|
|
|
dist.all_to_all_single(output, input_t, group=group) |
|
|
|
|
|
|
|
|
|
if scatter_dim < 2: |
|
|
|
|
output = output.transpose(0, 1).contiguous() |
|
|
|
@ -1143,5 +1032,5 @@ def reduce_backward(input_, process_group, fp8_communication=False):
|
|
|
|
|
return _ReduceBackward.apply(input_, process_group, fp8_communication) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False): |
|
|
|
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) |
|
|
|
|
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): |
|
|
|
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) |
|
|
|
|