mirror of https://github.com/hpcaitech/ColossalAI
fix shardformer fp8 communication training degradation
parent
6a20f07b80
commit
5b969fd831
|
@ -95,7 +95,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
||||||
total_input = total_input.view(-1, total_input.shape[-1])
|
total_input = total_input.view(-1, total_input.shape[-1])
|
||||||
|
|
||||||
if ctx.async_grad_allreduce and fp8_communication:
|
if ctx.async_grad_allreduce and fp8_communication:
|
||||||
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication)
|
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
|
||||||
elif ctx.async_grad_allreduce:
|
elif ctx.async_grad_allreduce:
|
||||||
# Asynchronous all-reduce
|
# Asynchronous all-reduce
|
||||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||||
|
@ -566,7 +566,7 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||||
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
|
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
|
||||||
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
||||||
if fp8_communication:
|
if fp8_communication:
|
||||||
reduce_scatter_fp8(output, input_list, group=process_group)
|
reduce_scatter_fp8(output, input_list, group=process_group, fp8_format="e4m3")
|
||||||
else:
|
else:
|
||||||
dist.reduce_scatter(output, input_list, group=process_group)
|
dist.reduce_scatter(output, input_list, group=process_group)
|
||||||
|
|
||||||
|
@ -577,7 +577,12 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||||
dim = ctx.dim
|
dim = ctx.dim
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
fp8_communication = ctx.fp8_communication
|
fp8_communication = ctx.fp8_communication
|
||||||
return _gather(grad_output, dim, process_group, fp8_communication=fp8_communication), None, None, None
|
return (
|
||||||
|
_gather(grad_output, dim, process_group, fp8_communication=fp8_communication, fp8_format="e5m2"),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
|
@ -618,7 +623,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
input_parallel = _gather(input_, dim, process_group, fp8_communication)
|
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
|
||||||
|
|
||||||
output = torch.matmul(input_parallel, weight)
|
output = torch.matmul(input_parallel, weight)
|
||||||
|
|
||||||
|
@ -641,7 +646,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
bias = bias.view(bias.shape)
|
bias = bias.view(bias.shape)
|
||||||
|
|
||||||
if not overlap:
|
if not overlap:
|
||||||
input_parallel = _gather(input_, dim, process_group, fp8_communication)
|
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
|
||||||
|
|
||||||
total_input = input_parallel
|
total_input = input_parallel
|
||||||
grad_input = grad_output.matmul(weight.T)
|
grad_input = grad_output.matmul(weight.T)
|
||||||
|
@ -728,8 +733,13 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
if ctx.grad_scale is not None:
|
if ctx.grad_scale is not None:
|
||||||
grad_output = grad_output * ctx.grad_scale
|
grad_output = grad_output * ctx.grad_scale
|
||||||
|
|
||||||
# to_cast.append(grad_output.cpu().detach().numpy())
|
return (
|
||||||
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None
|
_gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, fp8_format="e5m2"),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _ReduceForward(torch.autograd.Function):
|
class _ReduceForward(torch.autograd.Function):
|
||||||
|
@ -743,7 +753,7 @@ class _ReduceForward(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, process_group, fp8_communication=False):
|
def forward(ctx, input_, process_group, fp8_communication=False):
|
||||||
return _reduce(input_, process_group, fp8_communication)
|
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
|
@ -768,7 +778,7 @@ class _ReduceBackward(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
fp8_communication = ctx.fp8_communication
|
fp8_communication = ctx.fp8_communication
|
||||||
return _reduce(grad_output, ctx.process_group, fp8_communication), None, None
|
return _reduce(grad_output, ctx.process_group, fp8_communication, fp8_format="e5m2"), None, None
|
||||||
|
|
||||||
|
|
||||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||||
|
@ -786,7 +796,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
ctx.grad_scale = grad_scale
|
ctx.grad_scale = grad_scale
|
||||||
|
|
||||||
return _gather(input_, dim, process_group, fp8_communication=fp8_communication)
|
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
|
@ -851,13 +861,13 @@ def hook_parameter_in_backward(input, weight=None, bias=None):
|
||||||
return HookParameter.apply(input, weight, bias)
|
return HookParameter.apply(input, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
def _reduce(input_, process_group, fp8_communication=False):
|
def _reduce(input_, process_group, fp8_communication=False, fp8_format="e5m2"):
|
||||||
# skip if only one rank involved
|
# skip if only one rank involved
|
||||||
if dist.get_world_size(process_group) == 1:
|
if dist.get_world_size(process_group) == 1:
|
||||||
return input_
|
return input_
|
||||||
else:
|
else:
|
||||||
if fp8_communication:
|
if fp8_communication:
|
||||||
all_reduce_fp8(input_, group=process_group)
|
all_reduce_fp8(input_, group=process_group, fp8_format=fp8_format)
|
||||||
else:
|
else:
|
||||||
dist.all_reduce(input_, group=process_group)
|
dist.all_reduce(input_, group=process_group)
|
||||||
return input_
|
return input_
|
||||||
|
|
Loading…
Reference in New Issue