mirror of https://github.com/hpcaitech/ColossalAI
fix shardformer fp8 communication training degradation
parent
6a20f07b80
commit
5b969fd831
|
@ -95,7 +95,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
if ctx.async_grad_allreduce and fp8_communication:
|
||||
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication)
|
||||
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
|
||||
elif ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
|
@ -566,7 +566,7 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|||
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
|
||||
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
||||
if fp8_communication:
|
||||
reduce_scatter_fp8(output, input_list, group=process_group)
|
||||
reduce_scatter_fp8(output, input_list, group=process_group, fp8_format="e4m3")
|
||||
else:
|
||||
dist.reduce_scatter(output, input_list, group=process_group)
|
||||
|
||||
|
@ -577,7 +577,12 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
fp8_communication = ctx.fp8_communication
|
||||
return _gather(grad_output, dim, process_group, fp8_communication=fp8_communication), None, None, None
|
||||
return (
|
||||
_gather(grad_output, dim, process_group, fp8_communication=fp8_communication, fp8_format="e5m2"),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
|
@ -618,7 +623,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
)
|
||||
|
||||
else:
|
||||
input_parallel = _gather(input_, dim, process_group, fp8_communication)
|
||||
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
output = torch.matmul(input_parallel, weight)
|
||||
|
||||
|
@ -641,7 +646,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
bias = bias.view(bias.shape)
|
||||
|
||||
if not overlap:
|
||||
input_parallel = _gather(input_, dim, process_group, fp8_communication)
|
||||
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
|
||||
|
||||
total_input = input_parallel
|
||||
grad_input = grad_output.matmul(weight.T)
|
||||
|
@ -728,8 +733,13 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
|||
if ctx.grad_scale is not None:
|
||||
grad_output = grad_output * ctx.grad_scale
|
||||
|
||||
# to_cast.append(grad_output.cpu().detach().numpy())
|
||||
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None
|
||||
return (
|
||||
_gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, fp8_format="e5m2"),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class _ReduceForward(torch.autograd.Function):
|
||||
|
@ -743,7 +753,7 @@ class _ReduceForward(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, fp8_communication=False):
|
||||
return _reduce(input_, process_group, fp8_communication)
|
||||
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
|
@ -768,7 +778,7 @@ class _ReduceBackward(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
fp8_communication = ctx.fp8_communication
|
||||
return _reduce(grad_output, ctx.process_group, fp8_communication), None, None
|
||||
return _reduce(grad_output, ctx.process_group, fp8_communication, fp8_format="e5m2"), None, None
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
|
@ -786,7 +796,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
|||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
|
||||
return _gather(input_, dim, process_group, fp8_communication=fp8_communication)
|
||||
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
|
@ -851,13 +861,13 @@ def hook_parameter_in_backward(input, weight=None, bias=None):
|
|||
return HookParameter.apply(input, weight, bias)
|
||||
|
||||
|
||||
def _reduce(input_, process_group, fp8_communication=False):
|
||||
def _reduce(input_, process_group, fp8_communication=False, fp8_format="e5m2"):
|
||||
# skip if only one rank involved
|
||||
if dist.get_world_size(process_group) == 1:
|
||||
return input_
|
||||
else:
|
||||
if fp8_communication:
|
||||
all_reduce_fp8(input_, group=process_group)
|
||||
all_reduce_fp8(input_, group=process_group, fp8_format=fp8_format)
|
||||
else:
|
||||
dist.all_reduce(input_, group=process_group)
|
||||
return input_
|
||||
|
|
Loading…
Reference in New Issue