mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
53823118f2
commit
f7acfa1bd5
|
@ -767,11 +767,14 @@ class _ReduceForward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, fp8_communication=False):
|
||||
def forward(ctx, input_, process_group, grad_scale=None, fp8_communication=False):
|
||||
ctx.grad_scale = grad_scale
|
||||
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale is not None:
|
||||
grad_output = grad_output * ctx.grad_scale
|
||||
return grad_output, None, None
|
||||
|
||||
|
||||
|
|
|
@ -555,7 +555,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
else:
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reduce_forward(output_parallel, self.process_group, self.fp8_communication)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
|
|
Loading…
Reference in New Issue