pull/6023/head
wangbluo 2024-08-20 05:07:58 +00:00
parent 53823118f2
commit f7acfa1bd5
2 changed files with 5 additions and 2 deletions

View File

@ -767,11 +767,14 @@ class _ReduceForward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, process_group, fp8_communication=False):
def forward(ctx, input_, process_group, grad_scale=None, fp8_communication=False):
ctx.grad_scale = grad_scale
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return grad_output, None, None

View File

@ -555,7 +555,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
else:
if self.seq_parallel_mode is None:
output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, self.fp8_communication)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(