From dcfdab6aaf48dfca363960897be472ca7d520dae Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 9 Jan 2024 15:37:26 +0800 Subject: [PATCH] refactor code for all2all to support output_splits --- internlm/moe/utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/internlm/moe/utils.py b/internlm/moe/utils.py index cdb8aed..834e808 100644 --- a/internlm/moe/utils.py +++ b/internlm/moe/utils.py @@ -17,13 +17,23 @@ class _AllToAll(torch.autograd.Function): # TODO: replace with DS process group group: torch.distributed.ProcessGroup, inputs: Tensor, + input_splits=None, + output_splits=None, ) -> Tensor: # type: ignore ctx.group = group + ctx.input_splits = input_splits + ctx.output_splits = output_splits inputs = inputs.contiguous() - output = torch.empty_like(inputs) - dist.all_to_all_single(output, inputs, group=group) + output = ( + torch.empty_like(inputs) + if output_splits is None + else inputs.new_empty(size=[sum(output_splits)] + list(inputs.size()[1:])) + ) + dist.all_to_all_single( + output, inputs, output_split_sizes=output_splits, input_split_sizes=input_splits, group=group + ) return output @staticmethod def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: - return (None, _AllToAll.apply(ctx.group, *grad_output)) + return (None, _AllToAll.apply(ctx.group, *grad_output, ctx.output_splits, ctx.input_splits), None, None)