|
|
@ -81,7 +81,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|
|
|
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
|
|
|
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
|
|
|
# Delay the start of weight gradient computation shortly (3us) to have
|
|
|
|
# Delay the start of weight gradient computation shortly (3us) to have
|
|
|
|
# all-reduce scheduled first and have GPU resources allocated
|
|
|
|
# all-reduce scheduled first and have GPU resources allocated
|
|
|
|
_ = torch.empty(1, device=grad_output.device) + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
grad_weight = grad_output.t().matmul(total_input)
|
|
|
|
grad_weight = grad_output.t().matmul(total_input)
|
|
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
|
|