|
|
|
@ -211,43 +211,36 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|
|
|
|
handle.wait() |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
# create new stream for calculate the gradient |
|
|
|
|
calculate_stream = torch.cuda.Stream() |
|
|
|
|
|
|
|
|
|
# do all gather in default stream |
|
|
|
|
input_ = input_.contiguous() |
|
|
|
|
world_size = dist.get_world_size(process_group) |
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
|
|
|
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) |
|
|
|
|
|
|
|
|
|
# calculate gradient in calculate_stream |
|
|
|
|
with torch.cuda.stream(calculate_stream): |
|
|
|
|
# calculate |
|
|
|
|
grad_input = grad_output.matmul(weight) |
|
|
|
|
grad_output = grad_output.contiguous() |
|
|
|
|
# Convert the tensor shapes to 2D for execution compatibility |
|
|
|
|
if len(grad_output.shape) > 2: |
|
|
|
|
grad_output = grad_output.view(-1, grad_output.shape[-1]) |
|
|
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None |
|
|
|
|
|
|
|
|
|
# prepare data |
|
|
|
|
input_list = [ |
|
|
|
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) |
|
|
|
|
] |
|
|
|
|
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() |
|
|
|
|
|
|
|
|
|
torch.cuda.current_stream().wait_stream(calculate_stream) |
|
|
|
|
# do all gather in is async way |
|
|
|
|
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) |
|
|
|
|
# calculate gradient and prepare data asynchronously with all-gather |
|
|
|
|
# calculate |
|
|
|
|
grad_input = grad_output.matmul(weight) |
|
|
|
|
grad_output = grad_output.contiguous() |
|
|
|
|
# Convert the tensor shapes to 2D for execution compatibility |
|
|
|
|
if len(grad_output.shape) > 2: |
|
|
|
|
grad_output = grad_output.view(-1, grad_output.shape[-1]) |
|
|
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None |
|
|
|
|
# prepare data |
|
|
|
|
input_list = [ |
|
|
|
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) |
|
|
|
|
] |
|
|
|
|
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() |
|
|
|
|
# wait until all-gather finished |
|
|
|
|
gather_handle.wait() |
|
|
|
|
|
|
|
|
|
# do reduce-scatter in async way |
|
|
|
|
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) |
|
|
|
|
with torch.cuda.stream(calculate_stream): |
|
|
|
|
input_parallel = torch.cat(tensor_list, dim=dim).contiguous() |
|
|
|
|
if len(input_parallel.shape) > 2: |
|
|
|
|
input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) |
|
|
|
|
print(grad_output.shape, input_parallel.shape) |
|
|
|
|
grad_weight = grad_output.t().matmul(input_parallel) |
|
|
|
|
|
|
|
|
|
torch.cuda.current_stream().wait_stream(calculate_stream) |
|
|
|
|
input_parallel = torch.cat(tensor_list, dim=dim).contiguous() |
|
|
|
|
# calculate gradient |
|
|
|
|
if len(input_parallel.shape) > 2: |
|
|
|
|
input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) |
|
|
|
|
grad_weight = grad_output.t().matmul(input_parallel) |
|
|
|
|
# wait until reduce-scatter finished |
|
|
|
|
reducescatter_handle.wait() |
|
|
|
|
|
|
|
|
|
return output, grad_weight, grad_bias, None, None, None, None |
|
|
|
|