diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index f24191c16..ac422a4da 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -440,114 +440,3 @@ def all_to_all_uneven( inputs.requires_grad ), "Input must require grad to assure that backward is executed, otherwise it might hang the program." return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) - - -# =========================================================== -# This code section was modified from -# https://github.com/microsoft/DeepSpeed/blob/3d347276ce80e1a29e777c839d1d7fabe8e5f034/deepspeed/moe/mappings.py - -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -# The file has been adapted from the following Megatron-LM file: -# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/mappings.py -# Git commit hash: 9dc3c42a84aa656f583703cf8b6b4f79f712b796 -# We retain the following copyright from the original files: - -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO: used when non-moe are tp but moe are not - - -def _gather_tokens(input_, dim: int, tp_group: ProcessGroup): - """Gather tensors and concatenate them along a dimension""" - - input_ = input_.contiguous() - # Size and dimension. - rank = tp_group.rank() - - tensor_list = [torch.empty_like(input_) for _ in range(tp_group.size())] - tensor_list[rank] = input_ - dist.all_gather(tensor_list, input_, group=tp_group) - - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=dim).contiguous() - - return output - - -def _drop_tokens(input_, dim: int, tp_group: ProcessGroup): - """Divide a tensor among the tensor parallel ranks""" - - total_chunks = tp_group.size() - this_chunk = tp_group.rank() - assert ( - input_.shape[dim] % total_chunks == 0 - ), f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})" - chunk_size = input_.shape[dim] // total_chunks - - return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size) - - -class _GatherTokens(torch.autograd.Function): - """All gather tokens among the tensor parallel ranks""" - - @staticmethod - def forward(ctx, input_: torch.Tensor, dim: int, tp_group: ProcessGroup) -> torch.Tensor: - ctx.dim = dim - ctx.tp_group = tp_group - return _gather_tokens(input_, dim, tp_group) - - @staticmethod - def backward(ctx, grad_output): - return _drop_tokens(grad_output, ctx.dim, ctx.tp_group), None, None - - -class _DropTokens(torch.autograd.Function): - "Divide tokens equally among the tensor parallel ranks" - - @staticmethod - def forward(ctx, input_: torch.Tensor, dim: int, tp_group: ProcessGroup) -> torch.Tensor: - ctx.dim = dim - ctx.tp_group = tp_group - return _drop_tokens(input_, dim, tp_group) - - @staticmethod - def backward(ctx, input_: torch.Tensor) -> Tuple[torch.Tensor, None]: - return _gather_tokens(input_, ctx.dim, ctx.tp_group), None, None - - -def gather_tokens(input_, dim: int, tp_group: ProcessGroup): - if tp_group.size() == 1: - # no tensor parallelism for non-experts - return input_ - assert ( - input_.requires_grad - ), "Input must require grad to assure that backward is executed, otherwise it might hang the program." - return _GatherTokens.apply(input_, dim, tp_group) - - -def drop_tokens(input_, dim: int, tp_group: ProcessGroup): - if tp_group.size() == 1: - # no tensor parallelism for non-experts - return input_ - assert ( - input_.requires_grad - ), "Input must require grad to assure that backward is executed, otherwise it might hang the program." - return _DropTokens.apply(input_, dim, tp_group) - - -# ===========================================================