fix shardformer fp8 communication training degradation

pull/5921/head
GuangyaoZhang 2024-07-18 07:16:36 +00:00
parent 6a20f07b80
commit 5b969fd831
1 changed files with 22 additions and 12 deletions

View File

@ -95,7 +95,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce and fp8_communication:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication)
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
elif ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
@ -566,7 +566,7 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
if fp8_communication:
reduce_scatter_fp8(output, input_list, group=process_group)
reduce_scatter_fp8(output, input_list, group=process_group, fp8_format="e4m3")
else:
dist.reduce_scatter(output, input_list, group=process_group)
@ -577,7 +577,12 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
dim = ctx.dim
process_group = ctx.process_group
fp8_communication = ctx.fp8_communication
return _gather(grad_output, dim, process_group, fp8_communication=fp8_communication), None, None, None
return (
_gather(grad_output, dim, process_group, fp8_communication=fp8_communication, fp8_format="e5m2"),
None,
None,
None,
)
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
@ -618,7 +623,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
)
else:
input_parallel = _gather(input_, dim, process_group, fp8_communication)
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
output = torch.matmul(input_parallel, weight)
@ -641,7 +646,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
bias = bias.view(bias.shape)
if not overlap:
input_parallel = _gather(input_, dim, process_group, fp8_communication)
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
@ -728,8 +733,13 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
# to_cast.append(grad_output.cpu().detach().numpy())
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None
return (
_gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, fp8_format="e5m2"),
None,
None,
None,
None,
)
class _ReduceForward(torch.autograd.Function):
@ -743,7 +753,7 @@ class _ReduceForward(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, process_group, fp8_communication=False):
return _reduce(input_, process_group, fp8_communication)
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
@staticmethod
def backward(ctx, grad_output):
@ -768,7 +778,7 @@ class _ReduceBackward(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
fp8_communication = ctx.fp8_communication
return _reduce(grad_output, ctx.process_group, fp8_communication), None, None
return _reduce(grad_output, ctx.process_group, fp8_communication, fp8_format="e5m2"), None, None
class _GatherForwardSplitBackward(torch.autograd.Function):
@ -786,7 +796,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
ctx.dim = dim
ctx.grad_scale = grad_scale
return _gather(input_, dim, process_group, fp8_communication=fp8_communication)
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")
@staticmethod
def backward(ctx, grad_output):
@ -851,13 +861,13 @@ def hook_parameter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias)
def _reduce(input_, process_group, fp8_communication=False):
def _reduce(input_, process_group, fp8_communication=False, fp8_format="e5m2"):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
return input_
else:
if fp8_communication:
all_reduce_fp8(input_, group=process_group)
all_reduce_fp8(input_, group=process_group, fp8_format=fp8_format)
else:
dist.all_reduce(input_, group=process_group)
return input_