mirror of https://github.com/InternLM/InternLM
refactor code for all2all to support output_splits
parent
fe0c342f9d
commit
dcfdab6aaf
|
@ -17,13 +17,23 @@ class _AllToAll(torch.autograd.Function):
|
|||
# TODO: replace with DS process group
|
||||
group: torch.distributed.ProcessGroup,
|
||||
inputs: Tensor,
|
||||
input_splits=None,
|
||||
output_splits=None,
|
||||
) -> Tensor: # type: ignore
|
||||
ctx.group = group
|
||||
ctx.input_splits = input_splits
|
||||
ctx.output_splits = output_splits
|
||||
inputs = inputs.contiguous()
|
||||
output = torch.empty_like(inputs)
|
||||
dist.all_to_all_single(output, inputs, group=group)
|
||||
output = (
|
||||
torch.empty_like(inputs)
|
||||
if output_splits is None
|
||||
else inputs.new_empty(size=[sum(output_splits)] + list(inputs.size()[1:]))
|
||||
)
|
||||
dist.all_to_all_single(
|
||||
output, inputs, output_split_sizes=output_splits, input_split_sizes=input_splits, group=group
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
||||
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
||||
return (None, _AllToAll.apply(ctx.group, *grad_output, ctx.output_splits, ctx.input_splits), None, None)
|
||||
|
|
Loading…
Reference in New Issue