refactor code for all2all to support output_splits

pull/567/head
Wenwen Qu 2024-01-09 15:37:26 +08:00
parent fe0c342f9d
commit dcfdab6aaf
1 changed files with 13 additions and 3 deletions

View File

@ -17,13 +17,23 @@ class _AllToAll(torch.autograd.Function):
# TODO: replace with DS process group
group: torch.distributed.ProcessGroup,
inputs: Tensor,
input_splits=None,
output_splits=None,
) -> Tensor: # type: ignore
ctx.group = group
ctx.input_splits = input_splits
ctx.output_splits = output_splits
inputs = inputs.contiguous()
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=group)
output = (
torch.empty_like(inputs)
if output_splits is None
else inputs.new_empty(size=[sum(output_splits)] + list(inputs.size()[1:]))
)
dist.all_to_all_single(
output, inputs, output_split_sizes=output_splits, input_split_sizes=input_splits, group=group
)
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))
return (None, _AllToAll.apply(ctx.group, *grad_output, ctx.output_splits, ctx.input_splits), None, None)