fix the merge

pull/6023/head
wangbluo 2024-08-19 02:26:52 +00:00
parent 4cf79fa275
commit 02636c5bef
1 changed files with 3 additions and 3 deletions

View File

@ -767,12 +767,12 @@ class _ReduceForward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward(ctx, input_, process_group): def forward(ctx, input_, process_group, fp8_communication=False):
return _reduce(input_, process_group) return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return grad_output, None return grad_output, None, None
class _ReduceBackward(torch.autograd.Function): class _ReduceBackward(torch.autograd.Function):