|
|
|
@ -473,16 +473,17 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
return _split(grad_output, ctx.dim, ctx.process_group), None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HookParameter(torch.autograd.Function): |
|
|
|
|
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx, input, weight, bias): |
|
|
|
|
ctx.save_for_backward(weight, bias) |
|
|
|
|
output = input |
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx, grad_output): |
|
|
|
|
weight, bias = ctx.saved_tensors |
|
|
|
@ -491,13 +492,12 @@ class HookParameter(torch.autograd.Function):
|
|
|
|
|
if bias is not None: |
|
|
|
|
bias = bias.view(bias.shape) |
|
|
|
|
return grad_output, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def hook_paramter_in_backward(input, weight=None, bias=None): |
|
|
|
|
return HookParameter.apply(input, weight, bias) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _reduce(input_, process_group): |
|
|
|
|
# skip if only one rank involved |
|
|
|
|
if dist.get_world_size(process_group) == 1: |
|
|
|
@ -522,7 +522,7 @@ def _split(input_, dim=-1, process_group=None):
|
|
|
|
|
|
|
|
|
|
tensor_list = torch.split(input_, dim_size // world_size, dim=dim) |
|
|
|
|
rank = dist.get_rank(process_group) |
|
|
|
|
output = tensor_list[rank].contiguous() |
|
|
|
|
output = tensor_list[rank].clone().contiguous() |
|
|
|
|
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|