mirror of https://github.com/hpcaitech/ColossalAI
[MoE/ZeRO] Moe refactor with zero refactor (#5821)
* [moe] removed openmoe-coupled code and rectify mixstral code (#5471) * [Feauture] MoE refractor; Intergration with Mixtral (#5682) * cherry pick from refractor-moe branch * tests passed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support ep + zero --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add mixtral auto policy & move pipeline forward code to modeling folder * [moe refactor] modify kernel test without Route Class * [moe refactor] add moe tensor test path environment variable to github workflow * fix typos * fix moe test bug due to the code rebase * [moe refactor] fix moe zero test, and little bug in low level zero * fix typo * add moe tensor path to github workflow * remove some useless code * fix typo & unify global variable XX_AXIS logic without using -1 * fix typo & prettifier the code * remove print code & support zero 2 test * remove useless code * reanme function * fix typo * fix typo * Further improve the test code * remove print code * [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test * [moe refactor] skip some unit test which will be refactored later * [moe refactor] fix unit import error * [moe refactor] fix circular import issues * [moe refactor] remove debug code * [moe refactor] update github workflow * [moe/zero] refactor low level optimizer (#5767) * [zero] refactor low level optimizer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] MoE refactor with newest version of ZeRO (#5801) * [zero] remove redundant members in BucketStore (#5802) * [zero] align api with previous version * [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * [hotfix]Solve the compatibility issue of zero refactor (#5823) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * Modify function parameter names to resolve compatibility issues * [zero] fix missing hook removal (#5824) * [MoE] Resolve .github conflict (#5829) * [Fix/Example] Fix Llama Inference Loading Data Type (#5763) * [fix/example] fix llama inference loading dtype * revise loading dtype of benchmark llama3 * [release] update version (#5752) * [release] update version * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [test] fix ddp plugin test * [test] fix gptj and rpc test * [devops] fix cuda ext compatibility * [inference] fix flash decoding test * [inference] fix flash decoding test * fix (#5765) * [test] Fix/fix testcase (#5770) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [Hotfix] Add missing init file in inference.executor (#5774) * [CI/tests] simplify some test case to reduce testing time (#5755) * [ci/tests] simplify some test case to reduce testing time * [ci/tests] continue to remove test case to reduce ci time cost * restore some test config * [ci/tests] continue to reduce ci time cost * [misc] update dockerfile (#5776) * [misc] update dockerfile * [misc] update dockerfile * [devops] fix docker ci (#5780) * [Inference]Add Streaming LLM (#5745) * Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolist * [hotfix] fix llama flash attention forward (#5777) * [misc] Accelerate CI for zero and dist optim (#5758) * remove fp16 from lamb * remove d2h copy in checking states --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Test/CI] remove test cases to reduce CI duration (#5753) * [test] smaller gpt2 test case * [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py * [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py * [test] reduce test cases tests/test_zero/test_gemini/test_optim.py * Revert "[test] smaller gpt2 test case" Some tests might depend on the size of model (num of chunks) This reverts commitpull/5874/headdf705a5210
. * [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py * [CI] smaller test model for two mwo the two modifid cases * [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there * [hotfix] fix testcase in test_fx/test_tracer (#5779) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [fix] fix test_deepfm_model & test_dlrf_model; * [fix] fix test_hf_albert & test_hf_gpt; * [gemini] optimize reduce scatter d2h copy (#5760) * [gemini] optimize reduce scatter d2h copy * [fix] fix missing reduce variable * [refactor] remove legacy async reduce scatter code * [gemini] missing sync * Revert "[refactor] remove legacy async reduce scatter code" This reverts commit58ad76d466
. * [gemini] further optimize with async all reduce * [fix] pass flag from manager to chunk * Allow building cuda extension without a device. (#5535) Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are. * [misc] fix dist logger (#5782) * [install]fix setup (#5786) * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] update requirements (#5787) * [shardformer] fix import (#5788) * upgrade colossal-chat support tp_group>1, add sp for sft * upgrade ppo dpo rm script * run pre-commit * moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy * fix training script * fix ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix transformers version * remove duplicated test * fix datasets version * remove models that require huggingface auth from ci * remove local data path * update ci * remove baichuan from template test due to transformer version conflict * merge * Refactor modeling by adding attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix tests and naming Signed-off-by: char-1ee <xingjianli59@gmail.com> * Pass inference model shard configs for module init Signed-off-by: char-1ee <xingjianli59@gmail.com> * Clean up Signed-off-by: char-1ee <xingjianli59@gmail.com> * replace the customized dataloader setup with the build-in one * replace the customized dataloader setup with the build-in one * Remove flash attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * fix readme * Fix test import Signed-off-by: char-1ee <xingjianli59@gmail.com> * update sft trainning script * [Inference]refactor baichuan (#5791) * refactor baichuan * remove unused code and add TODO for lazyinit * [test] fix chatglm test kit (#5793) * [shardformer] fix modeling of bloom and falcon (#5796) * [test] fix qwen2 pytest distLarge (#5797) * [Inference] Fix flash-attn import and add model test (#5794) * Fix torch int32 dtype Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix flash-attn import Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add generalized model test Signed-off-by: char-1ee <xingjianli59@gmail.com> * Remove exposed path to model Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add default value for use_flash_attn Signed-off-by: char-1ee <xingjianli59@gmail.com> * Rename model test Signed-off-by: char-1ee <xingjianli59@gmail.com> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> * [Gemini] Use async stream to prefetch and h2d data moving (#5781) * use async stream to prefetch and h2d data moving * Remove redundant code * [gemini] quick fix on possible async operation (#5803) * [gemini] quick fix on possible async operation * [gemini] quick fix on possible async operation * [shardformer] upgrade transformers to 4.39.3 (#5815) * [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807) * [shardformer] fix modeling of gpt2 and gptj * [shardformer] fix whisper modeling * [misc] update requirements --------- Co-authored-by: ver217 <lhx0217@gmail.com> * [shardformer]upgrade transformers for mistral (#5808) * upgrade transformers for mistral * fix * fix * [shardformer]upgrade transformers for llama (#5809) * update transformers fix * fix * fix * [inference] upgrade transformers (#5810) * update transformers fix * fix * fix * fix * fix * [gemini] update transformers for gemini (#5814) --------- Co-authored-by: ver217 <lhx0217@gmail.com> * Support 4d parallel + flash attention (#5789) * support tp + sp + pp * remove comments --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com> * [zero] fix hook bug * [zero] add low level optimizer back (#5839) * [zero] fix param & refactor * [zero] add back original low level opt * [zero] remove moe related * [zero] pass zero tests * [zero] refactor * [chore] add del func back * [zero] comments and naming (#5840) * [zero] modify api (#5843) * [zero] modify api * [test] remove _grad_store access in tests * [test] fix (#5857) * [CI] skip openmoe CI check * [CI] fox pre-commit * [zero] remove redundant memebr init (#5862) * [misc] remove useless code, modify the pg mesh implementation * [misc] remove useless code, modify the pg mesh implementation * [misc] use tempfile * resolve conflict with main branch * [misc] use tempfile in test_moe_checkpoint.py * [misc] remove useless code, add assertion about sequence parallel, move logger into function * [misc] remove useless code --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
parent
773d9f964a
commit
416580b314
|
@ -90,7 +90,7 @@ jobs:
|
|||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch:/data/scratch
|
||||
timeout-minutes: 90
|
||||
defaults:
|
||||
run:
|
||||
|
@ -165,6 +165,7 @@ jobs:
|
|||
env:
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
||||
- name: Collate artifact
|
||||
env:
|
||||
|
|
|
@ -13,7 +13,7 @@ jobs:
|
|||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Check GPU Availability # ensure all GPUs have enough memory
|
||||
|
@ -69,6 +69,7 @@ jobs:
|
|||
env:
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
||||
- name: Notify Lark
|
||||
id: message-preparation
|
||||
|
|
|
@ -50,7 +50,7 @@ jobs:
|
|||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: ${{ matrix.container }}
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
|
||||
timeout-minutes: 200
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
|
@ -92,3 +92,4 @@ jobs:
|
|||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
|
|
@ -41,7 +41,7 @@ jobs:
|
|||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: ${{ matrix.container }}
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
|
||||
timeout-minutes: 200
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
|
||||
|
@ -87,3 +87,4 @@ jobs:
|
|||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
|
|
@ -38,7 +38,7 @@ jobs:
|
|||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: ${{ matrix.container }}
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
|
||||
timeout-minutes: 200
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
|
@ -85,6 +85,7 @@ jobs:
|
|||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
||||
- name: Notify Lark
|
||||
id: message-preparation
|
||||
|
|
|
@ -1,92 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
|
||||
|
||||
|
||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.setup_ep()
|
||||
|
||||
def setup_ep(self):
|
||||
_, moe_info = MOE_MANAGER.get_info(self.num_experts)
|
||||
ep_group = moe_info.ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_info(p, moe_info)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
module.setup_ep()
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
selected_experts = selected_experts.t().reshape(-1)
|
||||
selected_experts_idx = selected_experts.argsort()
|
||||
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
|
||||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
# compute expert output
|
||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
||||
if output_states.size(0) > 0:
|
||||
if self.num_experts_per_ep == 1:
|
||||
# no need to split
|
||||
expert = self.experts[self.expert_start_idx]
|
||||
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
|
||||
output_states = expert.w2(output_states)
|
||||
else:
|
||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||
output_states_list = []
|
||||
for i, split_states in enumerate(output_states_splits):
|
||||
if split_states.size(0) == 0:
|
||||
continue
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
||||
split_states = expert.w2(split_states)
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||
selected_experts_idx.size(0), device=selected_experts_idx.device
|
||||
)
|
||||
dispatch_states = dispatch_states[recover_experts_idx]
|
||||
k_hidden_states = dispatch_states.chunk(self.top_k)
|
||||
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
|
||||
for i in range(1, self.top_k):
|
||||
output_states += k_hidden_states[i] * routing_weights[:, i, None]
|
||||
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return output_states, router_logits
|
|
@ -2,8 +2,6 @@ import argparse
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
|
||||
|
@ -70,8 +68,6 @@ def main():
|
|||
ep_size=ep_size,
|
||||
zero_stage=1,
|
||||
precision=args.precision,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
enable_fused_normalization=args.use_layernorm_kernel,
|
||||
enable_jit_fused=args.use_kernel,
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
NUM_GPU=2
|
||||
MODEL="mistralai/Mixtral-8x7B-v0.1"
|
||||
# MODEL="mistralai/Mixtral-8x7B-v0.1"
|
||||
MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
|
||||
# ep
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
|
||||
|
|
|
@ -1,146 +0,0 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from torch.optim import Adam
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
top_k = 2
|
||||
|
||||
|
||||
def check_model_equal(model1, model2):
|
||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||
assert torch.equal(p1.half(), p2.half())
|
||||
|
||||
|
||||
def get_optimizer_snapshot(optim):
|
||||
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
|
||||
param_groups = []
|
||||
for group in optim.param_groups:
|
||||
params = [id(p) for p in group["params"]]
|
||||
new_group = {"params": params}
|
||||
for k, v in group.items():
|
||||
if k != "params":
|
||||
new_group[k] = v
|
||||
param_groups.append(new_group)
|
||||
return {
|
||||
"state": state,
|
||||
"param_groups": param_groups,
|
||||
}
|
||||
|
||||
|
||||
def check_optimizer_snapshot_equal(snapshot1, snapshot2):
|
||||
# check param_groups
|
||||
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
|
||||
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
|
||||
assert set(group1.keys()) == set(group2.keys())
|
||||
for k in group1.keys():
|
||||
assert group1[k] == group2[k]
|
||||
# check state
|
||||
assert set(snapshot1["state"].keys()) == set(
|
||||
snapshot2["state"].keys()
|
||||
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
|
||||
for pid in snapshot1["state"].keys():
|
||||
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
|
||||
assert set(state1.keys()) == set(state2.keys())
|
||||
for k in state1.keys():
|
||||
if isinstance(state1[k], torch.Tensor):
|
||||
assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
|
||||
else:
|
||||
assert state1[k] == state2[k]
|
||||
|
||||
|
||||
def check_mixtral_moe_layer():
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
config = MixtralConfig(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
|
||||
orig_model = MixtralForCausalLM(config).cuda()
|
||||
model = deepcopy(orig_model)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=2,
|
||||
ep_size=2,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
microbatch_size=1,
|
||||
zero_stage=1,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
|
||||
# initialize grads
|
||||
data_iter = iter(
|
||||
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
|
||||
)
|
||||
booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
lambda outputs, inputs: outputs.loss,
|
||||
optimizer,
|
||||
)
|
||||
|
||||
# check save model
|
||||
booster.save_model(model, "mixtral_model", shard=True)
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
|
||||
check_model_equal(orig_model, saved_model)
|
||||
saved_model.save_pretrained("mixtral_hf_model")
|
||||
dist.barrier()
|
||||
|
||||
# check load model
|
||||
new_model = MixtralForCausalLM(config).cuda()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
||||
booster.load_model(new_model, "mixtral_hf_model")
|
||||
check_model_equal(model, new_model)
|
||||
|
||||
# check save optimizer
|
||||
optimizer.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = 0.1
|
||||
snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
|
||||
dist.barrier()
|
||||
# reset optimizer state
|
||||
for state in optimizer.unwrap().state.values():
|
||||
for v in state.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v.zero_()
|
||||
booster.load_optimizer(optimizer, "mixtral_optim")
|
||||
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
|
||||
|
||||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch(rank, world_size, "localhost", port)
|
||||
check_mixtral_moe_layer()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
def test_mixtral_moe_layer(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mixtral_moe_layer(4)
|
|
@ -2,13 +2,11 @@ import argparse
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralForCausalLM
|
||||
from utils import load_checkpoint, move_to_cuda, save_checkpoint
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
|
@ -155,12 +153,10 @@ def main():
|
|||
pp_size=args.pp_size,
|
||||
ep_size=args.ep_size,
|
||||
microbatch_size=args.microbatch_size,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
enable_fused_normalization=args.use_layernorm_kernel,
|
||||
enable_jit_fused=args.use_kernel,
|
||||
precision=args.precision,
|
||||
zero_stage=args.zero_stage,
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
|
@ -20,6 +20,7 @@ resp = llm(prompt='What do you call a three-ton kangaroo?', **gen_config)
|
|||
print(resp) # super-heavyweight awesome-natured yawning Australian creature!
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
|
|
@ -655,7 +655,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.dp_pg = dp_process_group
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
|
@ -718,7 +717,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
"""Retrieve all working gradients from different parameter groups."""
|
||||
all_working_grads = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
||||
working_grads = self.get_working_grads_by_group_id(group_id)
|
||||
all_working_grads.extend(working_grads)
|
||||
return all_working_grads
|
||||
|
||||
|
@ -726,7 +725,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
"""Identify gradients to be synchronized in the sequence parallelism."""
|
||||
grads_to_sync = []
|
||||
for grad in all_working_grads:
|
||||
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
|
||||
param_id_for_grad = self.get_param_id_for_grad(grad)
|
||||
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
|
||||
if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
|
||||
grads_to_sync.append(grad)
|
||||
|
@ -739,7 +738,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
# Get all working gradients and gradients to be synchronized.
|
||||
all_working_grads = _get_all_working_grads()
|
||||
grads_to_sync = _get_grads_to_sync(all_working_grads)
|
||||
if self._grad_store.require_grad_sync and grads_to_sync is not None:
|
||||
if self.require_grad_sync and grads_to_sync is not None:
|
||||
# Synchronize sequence parallelism gradients if required.
|
||||
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
|
||||
else:
|
||||
|
@ -763,7 +762,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, retain_graph)
|
||||
|
||||
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self._sync_sp_grads()
|
||||
else:
|
||||
|
@ -788,14 +787,14 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
# Call the superclass backward_by_grad method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
|
||||
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self._sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
|
@ -811,7 +810,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
if len(gradients) == 0:
|
||||
return 0.0
|
||||
|
||||
dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1
|
||||
dp_size = get_world_size(dp_pg) if dp_pg is not None else 1
|
||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
norm_type = float(norm_type)
|
||||
|
@ -842,7 +841,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||
if tp_size > 1:
|
||||
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
|
||||
param_id_for_grad = self.get_param_id_for_grad(grad)
|
||||
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
|
||||
|
||||
if not is_distributed_tensor(param_for_grad):
|
||||
|
@ -856,7 +855,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
for shared_param in self.shared_params:
|
||||
if self.stage_manager.stage in shared_param:
|
||||
stage_shared_param = shared_param[self.stage_manager.stage]
|
||||
working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param))
|
||||
working_grad = self.get_working_grad_by_param_id(id(stage_shared_param))
|
||||
if grad is working_grad:
|
||||
grad_norm_exponentiated /= len(shared_param)
|
||||
|
||||
|
@ -867,7 +866,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
)
|
||||
if dp_size > 1:
|
||||
# compute norm in dp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg)
|
||||
if tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
|
@ -1309,7 +1308,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
|
||||
# run with gradients accumulation
|
||||
if model.require_grad_sync == False or (
|
||||
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False
|
||||
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
|
||||
):
|
||||
return outputs
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import random
|
||||
import warnings
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, OrderedDict, Tuple
|
||||
|
||||
|
@ -20,19 +21,19 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|||
get_param_info,
|
||||
init_pipeline_optimizer,
|
||||
)
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MOE_MANAGER, MoECheckpointIO
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
|
||||
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
|
@ -67,8 +68,20 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
|
||||
pg_param_list = {
|
||||
dp_process_group: [],
|
||||
moe_extra_dp_process_group: [],
|
||||
}
|
||||
for param in model.parameters():
|
||||
if is_moe_tensor(param):
|
||||
pg_param_list[moe_extra_dp_process_group].append(param)
|
||||
else:
|
||||
pg_param_list[dp_process_group].append(param)
|
||||
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
pg_to_param_list=pg_param_list,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
|
@ -83,9 +96,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
overlap_communication=overlap_communication,
|
||||
partition_grad=partition_grad,
|
||||
cpu_offload=cpu_offload,
|
||||
dp_process_group=dp_process_group,
|
||||
forced_dtype=forced_dtype,
|
||||
moe_extra_dp_process_group=moe_extra_dp_process_group,
|
||||
)
|
||||
|
||||
|
||||
|
@ -107,8 +118,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
|
||||
|
||||
Args:
|
||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
|
||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||
precision (str, optional): Specifies the precision of parameters during training.
|
||||
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
|
||||
Defaults to 'fp16'.
|
||||
|
@ -144,14 +155,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
|
||||
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
|
||||
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
|
||||
use_ep_inside (bool, Optional): Whether to use ep inside dp (intra-node) for moe params.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
ep_size: int,
|
||||
extra_dp_size: int = 1,
|
||||
tp_size: int = 1,
|
||||
sp_size: int = 1,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
enable_all_optimization: bool = False,
|
||||
|
@ -184,32 +196,22 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
custom_policy: Policy = None,
|
||||
checkpoint_io: Optional[MoECheckpointIO] = None,
|
||||
) -> None:
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
world_size = dist.get_world_size()
|
||||
assert tp_size == 1, "Tensor parallel is not supported in MoE yet"
|
||||
assert sp_size == 1 and enable_sequence_parallelism is False, "Sequence parallelism it not supported in MoE yet"
|
||||
|
||||
if enable_sequence_parallelism:
|
||||
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
world_size % (tp_size * pp_size) == 0
|
||||
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
|
||||
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=self.real_dp_size,
|
||||
fixed_ep_size=ep_size,
|
||||
fixed_pp_size=pp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
)
|
||||
world_size % (tp_size * pp_size * ep_size) == 0
|
||||
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
|
||||
|
||||
self.dp_size = world_size // (tp_size * pp_size)
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
self.ep_size = ep_size
|
||||
self.moe_info = MOE_MANAGER.get_info(0)[1]
|
||||
self.sp_size = sp_size
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
|
@ -219,43 +221,57 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.checkpoint_io = checkpoint_io
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# NOTE: Two process meshes: global dp for non-moe param; dp + ep for moe param
|
||||
# See https://hpc-ai.com/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient
|
||||
# we change pg mesh to (pp, dp, tp) for better moe performance
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
|
||||
assert (
|
||||
self.ep_size <= self.dp_size
|
||||
), f"Not enough devices({self.dp_size}) for expert parallelism size({self.ep_size})."
|
||||
|
||||
# sync moe in outer dp group, and sync other param in global dp group
|
||||
if extra_dp_size > 1:
|
||||
ep_size = self.dp_size // extra_dp_size
|
||||
if use_ep_inside:
|
||||
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size)
|
||||
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1)
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}")
|
||||
else:
|
||||
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size)
|
||||
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2)
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}")
|
||||
self.moe_dp_size = self.dp_size // self.ep_size
|
||||
self.use_ep_inside = use_ep_inside
|
||||
if self.use_ep_inside:
|
||||
logger.info(f"MoE Parallel use ep inside dp.", ranks=[0])
|
||||
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size)
|
||||
else:
|
||||
self.moe_extra_dp_group = None
|
||||
logger.info(f"MoE Parallel use ep outside dp.", ranks=[0])
|
||||
warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.")
|
||||
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size)
|
||||
|
||||
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
||||
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
|
||||
logger.info(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}", ranks=[0])
|
||||
logger.info(
|
||||
f"MoE Parallel: pp {self.pp_size}, ep {ep_size}, moe dp {self.moe_dp_size}, tp {tp_size}", ranks=[0]
|
||||
)
|
||||
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(
|
||||
self.tp_axis
|
||||
) # TODO: support custom tp size for mixtral lm head
|
||||
self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis))
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
|
||||
# TODO: Currently moe only support partially sequence parallel
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
|
||||
self.custom_policy = custom_policy
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
self.custom_policy = custom_policy
|
||||
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis)
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
|
||||
)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
# TODO: Currently moe only support partially sequence parallel
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
|
@ -267,6 +283,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
ep_group=self.ep_group,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
|
@ -323,7 +340,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(
|
||||
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
||||
dataset,
|
||||
num_replicas=self.dp_size,
|
||||
rank=dist.get_rank(self.global_dp_group),
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
|
@ -346,9 +366,20 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
|
||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||
if self.checkpoint_io is None:
|
||||
self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
self.checkpoint_io = MoECheckpointIO(
|
||||
self.global_dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
||||
)
|
||||
else:
|
||||
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
self.checkpoint_io = self.checkpoint_io(
|
||||
self.global_dp_group,
|
||||
self.pp_group,
|
||||
self.tp_group,
|
||||
ep_group=self.ep_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
zero_stage=self.zero_stage,
|
||||
)
|
||||
if hasattr(self.checkpoint_io, "moe_info"):
|
||||
self.checkpoint_io.moe_info = self.moe_info
|
||||
return self.checkpoint_io
|
||||
|
||||
def configure(
|
||||
|
@ -366,7 +397,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
module=model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=self.dp_group,
|
||||
dp_group=self.global_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=use_ddp,
|
||||
|
@ -392,15 +423,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
else:
|
||||
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
||||
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||
optimizer = HybridParallelZeroOptimizer(
|
||||
optimizer = MoeHybridParallelZeroOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_group,
|
||||
dp_process_group=self.global_dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
moe_extra_dp_process_group=self.moe_extra_dp_group,
|
||||
moe_extra_dp_process_group=self.moe_dp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.zero_config,
|
||||
|
|
|
@ -2,5 +2,12 @@ from .checkpoint_io_base import CheckpointIO
|
|||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .moe_checkpoint import MoECheckpointIO
|
||||
|
||||
__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
|
||||
__all__ = [
|
||||
"CheckpointIO",
|
||||
"CheckpointIndexFile",
|
||||
"GeneralCheckpointIO",
|
||||
"HybridParallelCheckpointIO",
|
||||
"MoECheckpointIO",
|
||||
]
|
||||
|
|
|
@ -70,13 +70,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dp_group = dp_group
|
||||
self.global_dp_group = dp_group
|
||||
self.pp_group = pp_group
|
||||
self.tp_group = tp_group
|
||||
self.dp_rank = dist.get_rank(self.dp_group)
|
||||
self.dp_rank = dist.get_rank(self.global_dp_group)
|
||||
self.tp_rank = dist.get_rank(self.tp_group)
|
||||
self.pp_rank = dist.get_rank(self.pp_group)
|
||||
self.dp_size = dist.get_world_size(dp_group)
|
||||
self.global_dp_size = dist.get_world_size(dp_group)
|
||||
self.pp_size = dist.get_world_size(pp_group)
|
||||
self.tp_size = dist.get_world_size(tp_group)
|
||||
self.use_zero = zero_stage > 0
|
||||
|
@ -433,7 +433,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
|
||||
optimizer,
|
||||
use_zero=self.use_zero,
|
||||
dp_group=self.dp_group,
|
||||
dp_group=self.global_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
size_per_shard=size_per_shard,
|
||||
)
|
||||
|
@ -727,7 +727,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=self.dp_group,
|
||||
dp_group=self.global_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
use_zero=self.use_zero,
|
||||
inplace=False,
|
||||
|
@ -932,12 +932,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
# Shard state along data parallel group when using Zero.
|
||||
if self.use_zero:
|
||||
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
|
||||
padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
slice_size = v.numel() // self.dp_size
|
||||
slice_size = v.numel() // self.global_dp_size
|
||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import get_global_rank
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
|
||||
|
@ -19,15 +20,16 @@ from colossalai.checkpoint_io.utils import (
|
|||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_states_into_optimizer,
|
||||
save_config_file,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
search_tp_partition_dim,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
try:
|
||||
|
@ -36,21 +38,30 @@ except ImportError:
|
|||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
||||
class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
def __init__(
|
||||
self,
|
||||
dp_group: ProcessGroup,
|
||||
global_dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
|
||||
moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
|
||||
self.ep_group = moe_info.ep_group
|
||||
self.ep_size = moe_info.ep_size
|
||||
self.ep_rank = moe_info.ep_rank
|
||||
self.real_dp_rank = moe_info.dp_rank
|
||||
super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose)
|
||||
self.global_dp_group = global_dp_group
|
||||
self.global_dp_rank = dist.get_rank(global_dp_group)
|
||||
self.global_dp_size = dist.get_world_size(global_dp_group)
|
||||
self.pp_group = pp_group
|
||||
self.tp_group = tp_group
|
||||
|
||||
self.moe_dp_group = moe_dp_group
|
||||
self.moe_dp_size = dist.get_world_size(moe_dp_group)
|
||||
self.moe_dp_rank = dist.get_rank(moe_dp_group)
|
||||
self.ep_group = ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group)
|
||||
self.ep_rank = dist.get_rank(ep_group)
|
||||
|
||||
@staticmethod
|
||||
def _model_sharder(
|
||||
|
@ -134,7 +145,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.real_dp_rank != 0:
|
||||
if self.moe_dp_rank != 0:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
|
@ -144,7 +155,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
|
||||
state_dict_shard = MoECheckpointIO._model_sharder(
|
||||
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
|
||||
)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
|
@ -234,11 +245,12 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
state: OrderedDict,
|
||||
param: torch.Tensor,
|
||||
original_shape: torch.Size,
|
||||
dp_group: ProcessGroup,
|
||||
global_dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
use_zero: bool,
|
||||
inplace: bool,
|
||||
is_moe_param: bool,
|
||||
moe_dp_group: ProcessGroup = None,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
|
@ -248,7 +260,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
|
||||
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
dp_group (ProcessGroup): The process group of data parallel.
|
||||
global_dp_group (ProcessGroup): The process group of data parallel.
|
||||
tp_group (ProcessGroup): The process group of tensor parallel.
|
||||
use_zero (bool): Whether Zero is used.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
|
@ -257,27 +269,47 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
Returns:
|
||||
OrderedDict: The complete optimizer state of given parameter.
|
||||
"""
|
||||
dp_size = dist.get_world_size(dp_group)
|
||||
global_dp_size = dist.get_world_size(global_dp_group)
|
||||
tp_size = dist.get_world_size(tp_group)
|
||||
moe_dp_size = dist.get_world_size(moe_dp_group) if moe_dp_group is not None else 1
|
||||
current_shape = param.shape
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
v = v.cuda()
|
||||
|
||||
# First gather Zero shards.
|
||||
if use_zero and not is_moe_param:
|
||||
v = v.cuda()
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
if use_zero and is_moe_param and moe_dp_size > 1:
|
||||
moe_dp_rank = dist.get_rank(moe_dp_group)
|
||||
dst = get_global_rank(moe_dp_group, 0)
|
||||
if moe_dp_rank == 0:
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)]
|
||||
dist.gather(v, gather_tensor, group=moe_dp_group, dst=dst)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
else:
|
||||
dist.gather(v, group=moe_dp_group, dst=dst)
|
||||
|
||||
elif use_zero and not is_moe_param and global_dp_size > 1:
|
||||
dp_rank = dist.get_rank(global_dp_group)
|
||||
dst = get_global_rank(global_dp_group, 0)
|
||||
if dp_rank == 0:
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(global_dp_size)]
|
||||
dist.gather(v, gather_tensor, group=global_dp_group, dst=dst)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
else:
|
||||
dist.gather(v, group=global_dp_group, dst=dst)
|
||||
|
||||
# Then gather TP shards.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
|
||||
if partition_dim is not None:
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=tp_group)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
|
||||
tp_rank = dist.get_rank(tp_group)
|
||||
dst = get_global_rank(tp_group, 0)
|
||||
if tp_rank == 0:
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
|
||||
dist.gather(v, gather_tensor, group=tp_group, dst=dst)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
else:
|
||||
dist.gather(v, group=tp_group, dst=dst)
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
@ -286,8 +318,9 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
def _optimizer_sharder(
|
||||
optimizer: OptimizerWrapper,
|
||||
use_zero: bool,
|
||||
dp_group: ProcessGroup,
|
||||
global_dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
size_per_shard: int = 1024,
|
||||
only_moe_param: bool = False,
|
||||
):
|
||||
|
@ -296,7 +329,6 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
param_info = optimizer.param_info
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
@ -305,22 +337,23 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
param_id = param_info["param2id"][id(working_param)]
|
||||
original_shape = param_info["param2shape"][id(working_param)]
|
||||
state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state_ = MoECheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=dp_group,
|
||||
global_dp_group=global_dp_group,
|
||||
moe_dp_group=moe_dp_group,
|
||||
tp_group=tp_group,
|
||||
use_zero=use_zero,
|
||||
inplace=False,
|
||||
is_moe_param=is_moe_tensor(working_param),
|
||||
is_moe_param=is_moe_tensor(working_param), # TODO: Check correctness here
|
||||
)
|
||||
|
||||
if only_moe_param and not is_moe_tensor(working_param):
|
||||
continue
|
||||
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
@ -359,25 +392,28 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Devices along the same dp_group share the same copies of states when zero is not used.
|
||||
# In this case only let the device with dp_rank == 0 save the model.
|
||||
if not self.use_zero and self.real_dp_rank != 0:
|
||||
# If optim states are not sharded, other ranks don't need to participate in gather.
|
||||
if not self.use_zero and self.moe_dp_rank != 0:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
|
||||
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
|
||||
state_dict_shard = MoECheckpointIO._optimizer_sharder(
|
||||
optimizer,
|
||||
use_zero=self.use_zero,
|
||||
dp_group=self.dp_group,
|
||||
global_dp_group=self.global_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
size_per_shard=size_per_shard,
|
||||
only_moe_param=self.ep_rank != 0,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
|
||||
# e.g. dp_size = 4, moe_dp_size = 2, ep_size = 2 and use gather
|
||||
# rank 0 saves moe & non-moe params; rank 1 only saves moe params
|
||||
# rank 3 & 4 save nothing
|
||||
control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0
|
||||
|
||||
if self.pp_size == 1 and self.ep_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
|
@ -596,7 +632,6 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
OrderedDict: The sharded optimizer state of the given parameter.
|
||||
"""
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# Shard state along tensor parallel group.
|
||||
|
@ -606,24 +641,218 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
|
||||
|
||||
# Shard state along data parallel group when using Zero.
|
||||
if self.use_zero and not is_moe_param:
|
||||
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
|
||||
if self.use_zero and not is_moe_param and self.global_dp_size > 1:
|
||||
padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
slice_size = v.numel() // self.dp_size
|
||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||
slice_size = v.numel() // self.global_dp_size
|
||||
v = v.split(slice_size, dim=0)[self.global_dp_rank]
|
||||
|
||||
elif self.use_zero and is_moe_param and self.moe_dp_size > 1:
|
||||
# LowLevelZeRO pads by global dp size for now.
|
||||
# TODO: update both to use moe dp size
|
||||
padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
slice_size = v.numel() // self.moe_dp_size
|
||||
v = v.split(slice_size, dim=0)[self.moe_dp_rank]
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
raise NotImplementedError
|
||||
"""Migration from MoEHybridParallelCheckpointIO. These functions mostly deals with unsharded saving,
|
||||
and can be savely deleted since large MoE models are often saved in shards.
|
||||
"""
|
||||
|
||||
# Copied from colossalai.moe
|
||||
def pre_save_model(self, model: nn.Module) -> dict:
|
||||
state_dict = model.state_dict()
|
||||
for name, param in model.named_parameters():
|
||||
if ".experts." in name and is_moe_tensor(param):
|
||||
ep_group = param.ep_group
|
||||
ep_rank = dist.get_rank(ep_group)
|
||||
ep_size = dist.get_world_size(ep_group)
|
||||
# TODO: check correctness here
|
||||
# dp_rank = get_dp_rank(param)
|
||||
dp_rank = dist.get_rank(self.global_dp_group)
|
||||
if dp_rank == 0:
|
||||
param = param.data.cuda()
|
||||
if ep_rank == 0:
|
||||
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
|
||||
else:
|
||||
all_param = None
|
||||
# gather param from every ep rank
|
||||
# dist.all_gather(all_param, param, group=ep_group)
|
||||
dist.gather(param, all_param, group=ep_group)
|
||||
if ep_rank == 0:
|
||||
all_param = torch.cat(all_param, dim=0)
|
||||
state_dict[name] = all_param.cpu()
|
||||
|
||||
if self.pp_size > 1:
|
||||
if self.dp_rank == 0:
|
||||
out = [None for _ in range(self.pp_size)]
|
||||
dist.gather_object(state_dict, out, group=self.pp_group)
|
||||
if self.pp_rank == 0:
|
||||
new_state_dict = {}
|
||||
for o in out:
|
||||
new_state_dict.update(o)
|
||||
state_dict = new_state_dict
|
||||
dist.barrier()
|
||||
return state_dict
|
||||
|
||||
def save_unsharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
use_safetensors: bool,
|
||||
):
|
||||
state_dict = self.pre_save_model(model)
|
||||
if dist.get_rank() == 0:
|
||||
torch.save(state_dict, checkpoint)
|
||||
dist.barrier()
|
||||
|
||||
# Copied from colossalai.moe
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
raise NotImplementedError
|
||||
"""
|
||||
Save optimizer state dict to a file with given path.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
|
||||
checkpoint (str): Path to save optimizer state_dict.
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
|
||||
# optimizer states of parameters kept by local device('s pipeline stage)
|
||||
local_states = dict()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
# working param is needed for obtaining correct param_id
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
# gather complete state from tp shards & dp shards
|
||||
param_id = optimizer.param_info["param2id"][id(working_param)]
|
||||
local_states[param_id] = self.pre_save_optim(
|
||||
state,
|
||||
working_param,
|
||||
inplace=False,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states}
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
else:
|
||||
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
||||
states_list = [None for _ in range(self.pp_size)]
|
||||
dist.barrier(self.pp_group)
|
||||
# dist.all_gather_object(states_list, local_states, self.pp_group)
|
||||
dist.gather_object(local_states, states_list, self.pp_group)
|
||||
|
||||
# Only the master rank do the saving.
|
||||
if self.coordinator.is_master():
|
||||
state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()}
|
||||
for _states in states_list:
|
||||
state_dict["state"].update(_states)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
dist.barrier()
|
||||
|
||||
# Copied from colossalai.moe
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
|
||||
raise NotImplementedError
|
||||
"""
|
||||
Load optimizer from a file with given path.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): The optimizer to be loaded.
|
||||
checkpoint_index_file (str): Path to the checkpoint file.
|
||||
"""
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
):
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
if id(working_param) in optimizer.param_info["param2id"]:
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
else:
|
||||
None
|
||||
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
|
||||
# Load param_groups.
|
||||
updated_groups = []
|
||||
saved_groups = state_dict["param_groups"]
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
|
||||
updated_groups.append(new_pg)
|
||||
|
||||
# ep extra group
|
||||
# if MOE_MANAGER.parallel == "EP":
|
||||
if self.ep_size > 1:
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1][
|
||||
"params"
|
||||
] # Only keep the parameters kept by current pipeline stage.
|
||||
for param in new_pg["params"]:
|
||||
param.data = param.data.to(torch.float32)
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
id_map = {}
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
if param_id is not None:
|
||||
id_map[param_id] = param
|
||||
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
device = param.device
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True,
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
dist.barrier()
|
|
@ -242,6 +242,7 @@ def save_state_dict_shards(
|
|||
shard_filenames = []
|
||||
for idx, shard_pair in enumerate(sharded_state_dict):
|
||||
shard, current_size = shard_pair
|
||||
# Just loop over the sharder and gather to other ranks if not master
|
||||
if not is_master:
|
||||
del shard
|
||||
continue
|
||||
|
|
|
@ -244,19 +244,25 @@ class ProcessGroupMesh:
|
|||
return target_group
|
||||
|
||||
def get_group_along_axis(
|
||||
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
||||
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
||||
) -> ProcessGroup:
|
||||
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
|
||||
|
||||
Args:
|
||||
axis (int): Axis along which the process groups are created.
|
||||
axis (int or list of int): Axes along which the process groups are created.
|
||||
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
||||
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||
"""
|
||||
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
|
||||
indices_at_axis = indices_at_axis
|
||||
if indices_at_axis is None:
|
||||
if isinstance(axis, (list, tuple)):
|
||||
indices_at_axis = list(list(range(self._shape[ax])) for ax in axis)
|
||||
else:
|
||||
indices_at_axis = list(range(self._shape[axis]))
|
||||
|
||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
|
||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||
if ranks_in_group not in self._ranks_to_group:
|
||||
|
|
|
@ -1,20 +1,5 @@
|
|||
from .checkpoint import MoECheckpointIO
|
||||
from .experts import MLPExperts
|
||||
from .layers import SparseMLP, apply_load_balance
|
||||
from .manager import MOE_MANAGER
|
||||
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
|
||||
|
||||
__all__ = [
|
||||
"MLPExperts",
|
||||
"MoeRouter",
|
||||
"Top1Router",
|
||||
"Top2Router",
|
||||
"TopKRouter",
|
||||
"NormalNoiseGenerator",
|
||||
"UniformNoiseGenerator",
|
||||
"SparseMLP",
|
||||
"MoECheckpointIO",
|
||||
"MOE_MANAGER",
|
||||
"apply_load_balance",
|
||||
]
|
||||
|
|
|
@ -1,792 +0,0 @@
|
|||
import copy
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
StateDictSharder,
|
||||
gather_distributed_param,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
is_safetensors_available,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_state_dict_into_model,
|
||||
load_states_into_optimizer,
|
||||
save_config_file,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import (
|
||||
get_dp_group,
|
||||
get_dp_rank,
|
||||
get_dp_size,
|
||||
get_ep_group,
|
||||
get_ep_rank,
|
||||
get_ep_size,
|
||||
is_moe_tensor,
|
||||
)
|
||||
|
||||
|
||||
class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
def __init__(
|
||||
self,
|
||||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
) -> None:
|
||||
assert zero_stage in [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
], f"zero_stage should be 0 or 1 or 2, got {zero_stage}"
|
||||
super().__init__(dp_group, pp_group, tp_group, zero_stage)
|
||||
self.parallel = MOE_MANAGER.parallel
|
||||
|
||||
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
|
||||
"""
|
||||
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
|
||||
"""
|
||||
for name, param in state_dict.items():
|
||||
if ".experts." in name:
|
||||
if name in dict(model.named_parameters()):
|
||||
model_param = dict(model.named_parameters())[name]
|
||||
if is_moe_tensor(model_param):
|
||||
ep_rank = get_ep_rank(model_param)
|
||||
ep_size = get_ep_size(model_param)
|
||||
expert_num = param.shape[0] // ep_size
|
||||
assert param.shape[0] % ep_size == 0
|
||||
param = param[ep_rank * expert_num : (ep_rank + 1) * expert_num]
|
||||
state_dict[name] = param
|
||||
dist.barrier()
|
||||
return state_dict
|
||||
|
||||
def _model_sharder(
|
||||
self,
|
||||
state_dict: nn.Module,
|
||||
prefix: str = "",
|
||||
keep_vars: bool = False,
|
||||
size_per_shard: int = 1024,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
# An internel method that breaks state_dict of model into shards within limited size.
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
|
||||
for name, param in state_dict.items():
|
||||
if param is None:
|
||||
continue
|
||||
# Gather tensor pieces when using tensor parallel.
|
||||
param_ = gather_distributed_param(param, keep_vars=False)
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None:
|
||||
state_dict = torch.load(checkpoint)
|
||||
state_dict = self.pre_load_model(model, state_dict)
|
||||
model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False)
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
|
||||
# Check whether the checkpoint uses safetensors.
|
||||
use_safetensors = False
|
||||
if "safetensors" in checkpoint_index_file.name:
|
||||
use_safetensors = True
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
strict = False
|
||||
|
||||
# Load params & buffers to model.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
|
||||
def _load(name: str):
|
||||
if name not in weight_map:
|
||||
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
|
||||
filename = weight_map[name]
|
||||
|
||||
# If this param/buffer has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
return
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
state_dict = self.pre_load_model(model, state_dict)
|
||||
missing_keys = []
|
||||
|
||||
load_state_dict_into_model(
|
||||
model,
|
||||
state_dict,
|
||||
missing_keys=missing_keys,
|
||||
strict=strict,
|
||||
load_sub_module=True,
|
||||
)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Load parameters.
|
||||
for name, _ in model.named_parameters():
|
||||
_load(name)
|
||||
|
||||
if self.verbose:
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def pre_save_model(self, model: nn.Module) -> dict:
|
||||
state_dict = model.state_dict()
|
||||
for name, param in model.named_parameters():
|
||||
if ".experts." in name and is_moe_tensor(param):
|
||||
ep_group = get_ep_group(param)
|
||||
ep_rank = get_ep_rank(param)
|
||||
ep_size = get_ep_size(param)
|
||||
dp_rank = get_dp_rank(param)
|
||||
if dp_rank == 0:
|
||||
param = param.data.cuda()
|
||||
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
|
||||
# gather param from every ep rank
|
||||
dist.all_gather(all_param, param, group=ep_group)
|
||||
if ep_rank == 0:
|
||||
all_param = torch.cat(all_param, dim=0)
|
||||
state_dict[name] = all_param.cpu()
|
||||
if self.pp_size > 1:
|
||||
if self.dp_rank == 0:
|
||||
out = [None for _ in range(self.pp_size)]
|
||||
dist.all_gather_object(out, state_dict, group=self.pp_group)
|
||||
if self.pp_rank == 0:
|
||||
new_state_dict = {}
|
||||
for o in out:
|
||||
new_state_dict.update(o)
|
||||
state_dict = new_state_dict
|
||||
dist.barrier()
|
||||
return state_dict
|
||||
|
||||
def save_unsharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
use_safetensors: bool,
|
||||
):
|
||||
state_dict = self.pre_save_model(model)
|
||||
if dist.get_rank() == 0:
|
||||
torch.save(state_dict, checkpoint)
|
||||
dist.barrier()
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Save sharded model checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
|
||||
- Multiple files that store state tensors of models.
|
||||
The filenames are in the form of "pytorch_model.<prefix>-000XX.bin"
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model on local device to be saved.
|
||||
checkpoint (str): Checkpointing path which should be a directory path.
|
||||
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
|
||||
prefix (str, optional): Perfix of file to save. Defaults to None.
|
||||
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict = self.pre_save_model(model)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard)
|
||||
|
||||
# Devices along the same dp_group share the same copies of model.
|
||||
# So only let the device with dp_rank == 0 save the model.
|
||||
if self.dp_rank != 0:
|
||||
return
|
||||
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.tp_rank == 0
|
||||
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose:
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
dist.barrier()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for optimizer loading/saving implementation
|
||||
# ========================================================
|
||||
|
||||
def pre_load_optim(
|
||||
self,
|
||||
state: OrderedDict,
|
||||
working_param,
|
||||
current_shape: torch.Size,
|
||||
original_shape: torch.Size,
|
||||
device: torch.device,
|
||||
inplace: bool,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With complete optimizer states of a specific parameter loaded from checkpoint,
|
||||
slice out the sharded optimizer states kept by current device.
|
||||
|
||||
Args:
|
||||
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
|
||||
current_shape (torch.Size): The size of parameter after sharding.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
device (torch.device): The destination device of loaded optimizer states.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
|
||||
Returns:
|
||||
OrderedDict: The sharded optimizer state of the given parameter.
|
||||
"""
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
is_moe_tensor_flag = is_moe_tensor(working_param)
|
||||
if is_moe_tensor_flag:
|
||||
ep_rank = get_ep_rank(working_param)
|
||||
ep_size = get_ep_size(working_param)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
if is_moe_tensor_flag:
|
||||
with torch.no_grad():
|
||||
expert_num = v.shape[0] // ep_size
|
||||
assert v.shape[0] % ep_size == 0
|
||||
v = v[ep_rank * expert_num : (ep_rank + 1) * expert_num]
|
||||
else:
|
||||
# Shard state along data parallel group when using Zero.
|
||||
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
slice_size = v.numel() // self.dp_size
|
||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): The optimizer to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
|
||||
):
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
|
||||
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
|
||||
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
|
||||
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
|
||||
id_map = {}
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
|
||||
id_map[param_id] = param
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
|
||||
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(
|
||||
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
|
||||
Lacking param group file under current directory."
|
||||
)
|
||||
saved_groups = torch.load(param_group_path)
|
||||
|
||||
updated_groups = []
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
# obtain updated param group
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
|
||||
updated_groups.append(new_pg)
|
||||
# ep param group
|
||||
if len(optimizer.optim.param_groups) > len(saved_groups):
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
|
||||
if param_id not in weight_map:
|
||||
continue
|
||||
filename = weight_map[param_id]
|
||||
|
||||
# If this param's states has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
continue
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for pid, state in list(state_dict.items()):
|
||||
if pid in id_map:
|
||||
param = id_map[pid]
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif (
|
||||
hasattr(optimizer, "moe_master_to_working_map")
|
||||
and id(param) in optimizer.moe_master_to_working_map
|
||||
):
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
working_param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device="cpu",
|
||||
inplace=True,
|
||||
)
|
||||
state_dict[pid] = sharded_state
|
||||
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
dist.barrier()
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
"""
|
||||
Load optimizer from a file with given path.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): The optimizer to be loaded.
|
||||
checkpoint_index_file (str): Path to the checkpoint file.
|
||||
"""
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
):
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
if id(working_param) in optimizer.param_info["param2id"]:
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
else:
|
||||
None
|
||||
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
|
||||
# Load param_groups.
|
||||
updated_groups = []
|
||||
saved_groups = state_dict["param_groups"]
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
|
||||
updated_groups.append(new_pg)
|
||||
# ep extra group
|
||||
if MOE_MANAGER.parallel == "EP":
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1][
|
||||
"params"
|
||||
] # Only keep the parameters kept by current pipeline stage.
|
||||
for param in new_pg["params"]:
|
||||
param.data = param.data.to(torch.float32)
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
id_map = {}
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
if param_id is not None:
|
||||
id_map[param_id] = param
|
||||
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
device = param.device
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True,
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
dist.barrier()
|
||||
|
||||
def pre_save_optim(
|
||||
self,
|
||||
state: OrderedDict,
|
||||
param: torch.Tensor,
|
||||
inplace: bool,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||
|
||||
Args:
|
||||
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
|
||||
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
dp_group (ProcessGroup): The process group of data parallel.
|
||||
tp_group (ProcessGroup): The process group of tensor parallel.
|
||||
use_zero (bool): Whether Zero is used.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
|
||||
|
||||
Returns:
|
||||
OrderedDict: The complete optimizer state of given parameter.
|
||||
"""
|
||||
if is_moe_tensor(param):
|
||||
moe_dp_group = get_dp_group(param)
|
||||
moe_dp_size = get_dp_size(param)
|
||||
moe_ep_group = get_ep_group(param)
|
||||
moe_ep_size = get_ep_size(param)
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# moe param
|
||||
if is_moe_tensor(param):
|
||||
# dp gather
|
||||
v = v.cuda()
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=moe_dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
# ep gather
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(moe_ep_size)]
|
||||
dist.all_gather(gather_tensor, v, group=moe_ep_group)
|
||||
v = torch.cat(gather_tensor, dim=0)
|
||||
else:
|
||||
# global dp
|
||||
v = v.cuda()
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dist.get_world_size(self.dp_group))]
|
||||
dist.all_gather(gather_tensor, v, group=self.dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
def _optimizer_sharder(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
param_info = optimizer.param_info
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
param_id = param_info["param2id"][id(working_param)]
|
||||
state_ = self.pre_save_optim(
|
||||
state,
|
||||
working_param,
|
||||
inplace=False,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
||||
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||
- Multiple files that store state tensors of optimizers.
|
||||
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
|
||||
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
|
||||
checkpoint (str): Path to save optimizer state_dict
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used
|
||||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Devices along the same dp_group share the same copies of states when zero is not used.
|
||||
# In this case only let the device with dp_rank == 0 save the model.
|
||||
if not self.use_zero and self.dp_rank != 0:
|
||||
return
|
||||
|
||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
|
||||
state_dict_shard = self._optimizer_sharder(
|
||||
optimizer,
|
||||
size_per_shard=size_per_shard,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.dp_rank == 0 and self.tp_rank == 0
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
# Store param groups.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
# Store index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
|
||||
|
||||
final_index_file_path = copy.deepcopy(save_index_file)
|
||||
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
assert (
|
||||
self.dp_rank == 0 and self.tp_rank == 0
|
||||
), "The saving process should have both dp_rank and tp_rank as 0."
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
return
|
||||
|
||||
dist.barrier(self.pp_group)
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.pp_rank == 0:
|
||||
final_index_file = CheckpointIndexFile(checkpoint)
|
||||
final_index_file.append_meta_data("total_size", 0)
|
||||
|
||||
for filename in os.listdir(tmp_index_file_folder):
|
||||
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
|
||||
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
|
||||
for param_id, state_filename in stage_index_file.weight_map.items():
|
||||
final_index_file.append_weight_map(param_id, state_filename)
|
||||
|
||||
# Store param groups.
|
||||
final_index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer state dict to a file with given path.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
|
||||
checkpoint (str): Path to save optimizer state_dict.
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
|
||||
# optimizer states of parameters kept by local device('s pipeline stage)
|
||||
local_states = dict()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
# working param is needed for obtaining correct param_id
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
# gather complete state from tp shards & dp shards
|
||||
param_id = optimizer.param_info["param2id"][id(working_param)]
|
||||
local_states[param_id] = self.pre_save_optim(
|
||||
state,
|
||||
working_param,
|
||||
inplace=False,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states}
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
else:
|
||||
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
||||
states_list = [None for _ in range(self.pp_size)]
|
||||
dist.barrier(self.pp_group)
|
||||
dist.all_gather_object(states_list, local_states, self.pp_group)
|
||||
|
||||
# Only the master rank do the saving.
|
||||
if self.coordinator.is_master():
|
||||
state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()}
|
||||
for _states in states_list:
|
||||
state_dict["state"].update(_states)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
dist.barrier()
|
|
@ -7,8 +7,8 @@ from torch import Tensor, nn
|
|||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.moe.experts import MLPExperts
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.shardformer.layer.moe import MLPExperts
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
|
@ -292,7 +292,7 @@ class LoadBalancer:
|
|||
exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"]
|
||||
exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"]
|
||||
else:
|
||||
master_weight_ptr = optim._param_store.working_to_master_param[id(weight)]
|
||||
master_weight_ptr = optim.working_to_master_param[id(weight)]
|
||||
working_weight_ptr = weight
|
||||
exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"]
|
||||
exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"]
|
||||
|
@ -344,7 +344,7 @@ class LoadBalancer:
|
|||
# gate optim should be obtained first
|
||||
gate_shape = self.gate.shape
|
||||
# get master weight and optim
|
||||
master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)]
|
||||
master_gate_weight = optim.working_to_master_param[id(self.gate)]
|
||||
gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"]
|
||||
gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"]
|
||||
# gather
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
import torch.nn as nn
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
||||
|
||||
class MoeCrossEntropyLoss(_Loss):
|
||||
r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
|
||||
|
||||
Args:
|
||||
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||
aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
|
||||
|
||||
The ``args`` and ``kwargs`` should include parameters below:
|
||||
::
|
||||
|
||||
weight (Tensor, optional)
|
||||
size_average (bool, optional)
|
||||
ignore_index (int, optional)
|
||||
reduce (bool, optional)
|
||||
reduction (str, optional)
|
||||
label_smoothing (float, optional)
|
||||
|
||||
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
|
||||
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.loss = nn.CrossEntropyLoss(*args, **kwargs)
|
||||
self.aux_weight = aux_weight
|
||||
|
||||
def forward(self, *args):
|
||||
"""
|
||||
The ``args`` should at least include parameters below:
|
||||
::
|
||||
|
||||
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||
|
||||
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
|
||||
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
|
||||
"""
|
||||
main_loss = self.loss(*args)
|
||||
aux_loss = MOE_MANAGER.get_loss()
|
||||
return main_loss + self.aux_weight * aux_loss
|
||||
|
||||
|
||||
class MoeLoss(_Loss):
|
||||
"""A wrapper class for any loss module to add with auxiliary loss.
|
||||
|
||||
Args:
|
||||
aux_weight (float): Weight of auxiliary loss in total loss.
|
||||
loss_fn (``Callable``): Loss function.
|
||||
args (list): Args in loss function.
|
||||
kwargs (dict): Kwargs in loss function
|
||||
"""
|
||||
|
||||
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.loss_fn = loss_fn(*args, **kwargs)
|
||||
self.aux_weight = aux_weight
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
The ``args`` and ``kwargs`` should at least include parameters below:
|
||||
::
|
||||
|
||||
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
|
||||
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
|
||||
|
||||
Note:
|
||||
The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
|
||||
"""
|
||||
main_loss = self.loss_fn(*args, **kwargs)
|
||||
aux_loss = MOE_MANAGER.get_loss()
|
||||
return main_loss + self.aux_weight * aux_loss
|
|
@ -1,466 +0,0 @@
|
|||
import math
|
||||
from abc import ABC
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe._operation import moe_cumsum
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
||||
|
||||
class MoeRouter(nn.Module, ABC):
|
||||
"""Base class for all MoE routers.
|
||||
Args:
|
||||
k_value (int): The value of top_k.
|
||||
capacity_factor_train (float): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float): Capacity factor in routing of evaluation.
|
||||
min_capacity (int): The minimum number of the capacity of each expert.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k_value: int,
|
||||
capacity_factor_train: float,
|
||||
capacity_factor_eval: float,
|
||||
min_capacity: int,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
use_kernel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.k_value = k_value
|
||||
self.capacity_factor_train = capacity_factor_train
|
||||
self.capacity_factor_eval = capacity_factor_eval
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_func = noisy_func
|
||||
self.drop_tks = drop_tks
|
||||
self._aux_loss = None
|
||||
self._z_loss = None
|
||||
self.use_kernel = use_kernel
|
||||
|
||||
def get_capacity(self, num_tokens, num_experts, ep_group=None):
|
||||
if ep_group is not None:
|
||||
num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
|
||||
dist.all_reduce(num_tokens_tensor, group=ep_group)
|
||||
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
|
||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||
capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
|
||||
capacity += capacity % 2
|
||||
capacity = max(capacity, self.min_capacity)
|
||||
assert capacity > 0
|
||||
return int(capacity)
|
||||
|
||||
def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None:
|
||||
"""Computes auxiliary load balancing loss as in Switch Transformer.
|
||||
|
||||
See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
|
||||
implements the loss function presented in equations (4) - (6). It aims to
|
||||
penalize those cases where the routing between experts is unbalanced.
|
||||
|
||||
Args:
|
||||
router_probs: Probability assigned to each expert per token. Shape:
|
||||
<float32>[num_groups, tokens_per_group, num_experts].
|
||||
expert_indices: <int>[num_groups, tokens_per_group, num_selected_experts]
|
||||
indices identifying the top num_selected_experts for a given token.
|
||||
"""
|
||||
assert self._aux_loss is None
|
||||
if router_probs.dim() == expert_indices.dim() == 2:
|
||||
router_probs = router_probs.unsqueeze(0)
|
||||
expert_indices = expert_indices.unsqueeze(0)
|
||||
assert (
|
||||
router_probs.dim() == expert_indices.dim() == 3
|
||||
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
|
||||
|
||||
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
|
||||
expert_mask = F.one_hot(expert_indices, num_experts)
|
||||
# For a given token, determine if it was routed to a given expert.
|
||||
# Shape: [num_groups, tokens_per_group, num_experts]
|
||||
expert_mask = expert_mask.max(dim=-2)[0]
|
||||
|
||||
tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2)
|
||||
router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2)
|
||||
aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
|
||||
self._aux_loss = aux_loss
|
||||
|
||||
def set_z_loss(self, router_logits: torch.Tensor):
|
||||
"""Compute router z-loss.
|
||||
|
||||
The router z-loss was introduced in Designing Effective Sparse Expert Models
|
||||
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
|
||||
small in an effort to improve stability.
|
||||
|
||||
Args:
|
||||
router_logits: <float>[num_groups, tokens_per_group, num_experts] router logits.
|
||||
"""
|
||||
assert self._z_loss is None
|
||||
if router_logits.dim() == 2:
|
||||
router_logits = router_logits.unsqueeze(0)
|
||||
assert router_logits.dim() == 3, "router_logits must be 3D tensor"
|
||||
num_groups, tokens_per_group, _ = router_logits.shape
|
||||
log_z = torch.logsumexp(router_logits, dim=-1)
|
||||
z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group)
|
||||
self._z_loss = z_loss
|
||||
|
||||
def pop_router_loss(self) -> torch.Tensor:
|
||||
assert self._aux_loss is not None
|
||||
MOE_MANAGER.add_loss(self._aux_loss, self._z_loss)
|
||||
self._aux_loss = None
|
||||
self._z_loss = None
|
||||
|
||||
|
||||
class Top1Router(MoeRouter):
|
||||
"""Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
|
||||
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
|
||||
function can be found in the paper about Switch Transformer of Google.
|
||||
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert.
|
||||
select_policy (str, optional): The policy about tokens selection.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
select_policy: str = "first",
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
k_value=1,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
self.select_policy = select_policy
|
||||
assert select_policy in {"first", "random"}
|
||||
if select_policy == "random":
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
|
||||
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
|
||||
).rsample
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
use_kernel: bool = False,
|
||||
ep_group: Optional[ProcessGroup] = None,
|
||||
use_loss: bool = False,
|
||||
use_norm: bool = False,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||
|
||||
Returns:
|
||||
1. use_kernel is False:
|
||||
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
|
||||
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
|
||||
2. use_kernel is True:
|
||||
...
|
||||
"""
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
probs = F.softmax(inputs, dim=-1)
|
||||
num_experts = probs.size(-1)
|
||||
num_tokens = inputs.size(0)
|
||||
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
||||
|
||||
top1_idx = torch.argmax(inputs, dim=-1)
|
||||
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
# calculate router loss
|
||||
self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
|
||||
self.set_z_loss(inputs)
|
||||
self.pop_router_loss()
|
||||
|
||||
if not self.training and not self.drop_tks and ep_group is not None:
|
||||
max_num = torch.max(torch.sum(mask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
if self.select_policy == "random":
|
||||
rand_mask = mask * self.uniform(mask.shape)
|
||||
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
||||
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
|
||||
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
|
||||
elif self.select_policy == "first":
|
||||
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
|
||||
mask = mask * torch.lt(ranks, capacity)
|
||||
else:
|
||||
raise NotImplementedError("Not support such select policy yet.")
|
||||
|
||||
ranks = torch.sum(mask * ranks, dim=-1)
|
||||
used_capacity = mask.sum(dim=0)
|
||||
|
||||
if use_kernel:
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
mask = torch.stack([mask], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
||||
return used_capacity, probs, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
ranks = F.one_hot(ranks, num_classes=capacity)
|
||||
weight = mask * probs.type_as(inputs)
|
||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||
sec_mask = combine_weights.bool()
|
||||
return used_capacity, combine_weights, sec_mask, probs
|
||||
|
||||
|
||||
class Top2Router(MoeRouter):
|
||||
"""Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
|
||||
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
|
||||
function can be found in the paper about ViT-MoE.
|
||||
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
k_value=2,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
use_kernel: bool = False,
|
||||
ep_group: Optional[ProcessGroup] = None,
|
||||
use_norm: bool = False,
|
||||
use_loss: bool = True,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||
|
||||
Returns:
|
||||
1. use_kernel is False:
|
||||
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
|
||||
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
|
||||
2. use_kernel is True:
|
||||
...
|
||||
"""
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
probs = F.softmax(inputs, dim=-1)
|
||||
if use_norm:
|
||||
routing_weights, _ = torch.topk(probs, 2, dim=-1)
|
||||
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
num_experts = probs.size(-1)
|
||||
num_tokens = inputs.size(0)
|
||||
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
||||
|
||||
top1_idx = torch.argmax(probs, dim=-1)
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
|
||||
top2_idx = torch.argmax(logits_except1, dim=-1)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
cmask = mask1 + mask2 # loss: [s, e]
|
||||
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
||||
|
||||
# calculate loss
|
||||
if use_loss:
|
||||
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
||||
self.set_aux_loss(probs, expert_indices, num_experts)
|
||||
self.set_z_loss(inputs)
|
||||
self.pop_router_loss()
|
||||
|
||||
if not self.training and not self.drop_tks and ep_group is not None:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
|
||||
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
|
||||
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
||||
|
||||
mask1 *= torch.lt(rank1, capacity)
|
||||
mask2 *= torch.lt(rank2, capacity)
|
||||
used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0)
|
||||
|
||||
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
||||
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask1 = torch.sum(mask1, dim=-1)
|
||||
mask2 = torch.sum(mask2, dim=-1)
|
||||
|
||||
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
|
||||
|
||||
return used_capacity, probs, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
"""
|
||||
The following code is equivalent to:
|
||||
|
||||
```
|
||||
weight1 = mask1 * probs.type_as(inputs)
|
||||
weight2 = mask2 * probs.type_as(inputs)
|
||||
rank1_sc = F.one_hot(rank1, num_classes=capacity)
|
||||
rank2_sc = F.one_hot(rank2, num_classes=capacity)
|
||||
|
||||
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
|
||||
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
|
||||
cb_weight = cb_weight1 + cb_weight2
|
||||
sec_mask = cb_weight.bool()
|
||||
```
|
||||
"""
|
||||
|
||||
weight1 = mask1 * probs.type_as(inputs)
|
||||
weight2 = mask2 * probs.type_as(inputs)
|
||||
|
||||
cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device)
|
||||
sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
|
||||
indices = torch.arange(0, inputs.shape[0], device=inputs.device)
|
||||
cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
|
||||
cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
|
||||
sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
|
||||
sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]
|
||||
|
||||
return used_capacity, cb_weight, sec_mask
|
||||
|
||||
|
||||
class TopKRouter(MoeRouter):
|
||||
"""Masked matmul router using tokens choose top-k experts assignment.
|
||||
|
||||
NOTE: this is modified from flaxformer.
|
||||
This router uses the same mechanism as in Switch Transformer
|
||||
(https://arxiv.org/abs/2101.03961) and V-MoE
|
||||
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are
|
||||
sorted by router_probs and then routed to their choice of expert until the
|
||||
expert's expert_capacity is reached. There is no guarantee that each token is
|
||||
processed by an expert, or that each expert receives at least one token.
|
||||
|
||||
Attributes:
|
||||
num_selected_experts: Maximum number of experts to which each token is
|
||||
routed. Tokens may be routed to fewer experts if particular experts are
|
||||
oversubscribed / reach capacity.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_selected_experts: int,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
router_probs: torch.Tensor,
|
||||
expert_capacity: int,
|
||||
) -> Tuple:
|
||||
"""Computes masks for the top-k experts per token.
|
||||
|
||||
Args:
|
||||
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
|
||||
probabilities used to determine the routing of tokens to the experts.
|
||||
|
||||
Returns:
|
||||
Dispatch and combine arrays for routing with masked matmuls.
|
||||
"""
|
||||
# TODO: FIXME: add parallel group
|
||||
num_groups, _, num_experts = router_probs.shape
|
||||
|
||||
# Top-k router probability and corresponding expert indices for each token.
|
||||
# Shape: [num_groups, tokens_per_group, num_selected_experts].
|
||||
expert_gate, expert_index = torch.topk(router_probs, self.k_value)
|
||||
|
||||
self.set_aux_loss(router_probs, expert_index, num_experts)
|
||||
self.pop_router_loss()
|
||||
|
||||
# Make num_selected_experts the leading axis to ensure that top-1 choices
|
||||
# have priority over top-2 choices, which have priority over top-3 choices,
|
||||
# etc.
|
||||
expert_index = torch.transpose(expert_index, 1, 2)
|
||||
# Shape: [num_groups, num_selected_experts * tokens_per_group]
|
||||
expert_index = expert_index.reshape(num_groups, -1)
|
||||
|
||||
# Create mask out of indices.
|
||||
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
|
||||
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
|
||||
|
||||
# Experts have a fixed capacity that we cannot exceed. A token's priority
|
||||
# within the expert's buffer is given by the masked, cumulative capacity of
|
||||
# its target expert.
|
||||
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
|
||||
token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1
|
||||
# Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts].
|
||||
token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts))
|
||||
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
|
||||
token_priority = torch.transpose(token_priority, 1, 2)
|
||||
# For each token, across all selected experts, select the only non-negative
|
||||
# (unmasked) priority. Now, for group G routing to expert E, token T has
|
||||
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
|
||||
# is its targeted expert.
|
||||
# Shape: [num_groups, tokens_per_group, num_experts].
|
||||
token_priority = torch.max(token_priority, dim=2)[0]
|
||||
|
||||
# Token T can only be routed to expert E if its priority is positive and
|
||||
# less than the expert capacity. One-hot matrix will ignore indices outside
|
||||
# the range [0, expert_capacity).
|
||||
# Shape: [num_groups, tokens_per_group, num_experts, expert_capacity].
|
||||
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
|
||||
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
|
||||
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
|
||||
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity)
|
||||
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
|
||||
|
||||
# The combine array will be used for combining expert outputs, scaled by the
|
||||
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
|
||||
# expert_capacity].
|
||||
combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
|
||||
|
||||
return combine_array, dispatch_mask
|
||||
|
||||
|
||||
def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter:
|
||||
if not grouped:
|
||||
if top_k == 1:
|
||||
return Top1Router
|
||||
elif top_k == 2:
|
||||
return Top2Router
|
||||
else:
|
||||
raise NotImplementedError("top_k > 2 is not supported yet")
|
||||
else:
|
||||
return TopKRouter
|
|
@ -6,10 +6,11 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed.distributed_c10d import get_process_group_ranks
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
|
||||
class ForceFP32Parameter(torch.nn.Parameter):
|
||||
|
@ -145,7 +146,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
|
|||
if not is_moe_tensor(param):
|
||||
ep_size = 1 # set ep_size to 1 for dp parameters
|
||||
else:
|
||||
ep_size = get_ep_size(param)
|
||||
ep_size = dist.get_world_size(param.ep_group)
|
||||
if ep_size not in epsize_param_dict:
|
||||
epsize_param_dict[ep_size] = []
|
||||
epsize_param_dict[ep_size].append(param)
|
||||
|
@ -170,8 +171,8 @@ def sync_moe_model_param(model: nn.Module):
|
|||
# When ep_size = world_size, communication is not needed
|
||||
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
|
||||
for param in param_dict[ep_size]:
|
||||
src_rank = get_dp_group_ranks(param)[0]
|
||||
dist.broadcast(param, src=src_rank, group=get_dp_group(param))
|
||||
src_rank = get_process_group_ranks(param.dp_group)[0]
|
||||
dist.broadcast(param, src=src_rank, group=param.dp_group)
|
||||
|
||||
|
||||
def set_moe_args(config: Any, args: dict):
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .experts import *
|
||||
from .layers import *
|
||||
from .routers import *
|
|
@ -9,7 +9,7 @@ from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
|
|||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import get_activation
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
|
||||
|
||||
if HAS_TRITON:
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
|
||||
|
@ -35,7 +35,7 @@ class MLPExperts(nn.Module):
|
|||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
expert_parallel: Optional[str] = None,
|
||||
expert_parallel: Optional[str] = "EP",
|
||||
activation: Optional[Callable] = None,
|
||||
drop_rate: Optional[float] = 0,
|
||||
gated: Optional[bool] = False,
|
|
@ -8,11 +8,9 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
|
||||
from colossalai.moe.experts import MLPExperts
|
||||
from colossalai.moe.load_balance import LoadBalancer
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.routers import MoeRouter, get_router_cls
|
||||
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
|
||||
from colossalai.shardformer.layer.moe import MLPExperts
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
|
||||
|
||||
|
||||
|
@ -23,6 +21,7 @@ class SparseMLP(nn.Module):
|
|||
dim_model (int): Hidden dimension of training model
|
||||
num_experts (int): The number experts
|
||||
top_k (int, optional): The number of experts for dispatchment of each token
|
||||
parallel (str): parallel mode. Should be "EP", "TP" or None
|
||||
capacity_factor_train (float, optional): Capacity factor in routing during training
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
|
@ -51,6 +50,7 @@ class SparseMLP(nn.Module):
|
|||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
router_top_k: int = 1,
|
||||
parallel: str = "EP",
|
||||
router_loss: bool = True,
|
||||
router_norm: bool = False,
|
||||
router_capacity_factor_train: float = 1.25,
|
||||
|
@ -66,7 +66,7 @@ class SparseMLP(nn.Module):
|
|||
load_balance_group_swap_factor: float = 0.4,
|
||||
enable_kernel: bool = False,
|
||||
enable_comm_overlap: bool = False,
|
||||
enable_hierarchical_comm: bool = False,
|
||||
enable_hierarchical_comm: bool = True,
|
||||
return_gate_logits: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -77,7 +77,9 @@ class SparseMLP(nn.Module):
|
|||
self.return_gate_logits = return_gate_logits
|
||||
self.enable_kernel = enable_kernel
|
||||
self.enable_comm_overlap = enable_comm_overlap
|
||||
self.expert_parallel = MOE_MANAGER.get_parallel()
|
||||
# self.expert_parallel = MOE_MANAGER.get_parallel()
|
||||
assert parallel in ["EP", "TP", None], "parallel mode must be EP, TP or None"
|
||||
self.parallel = parallel
|
||||
self.router_loss = router_loss
|
||||
self.router_norm = router_norm
|
||||
|
||||
|
@ -99,7 +101,7 @@ class SparseMLP(nn.Module):
|
|||
# moe experts
|
||||
self.experts = MLPExperts(
|
||||
num_experts=self.num_experts,
|
||||
expert_parallel=self.expert_parallel,
|
||||
expert_parallel=self.parallel,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
activation=mlp_activation,
|
||||
|
@ -108,11 +110,12 @@ class SparseMLP(nn.Module):
|
|||
)
|
||||
|
||||
# get parallel settings
|
||||
if self.expert_parallel is not None:
|
||||
if self.parallel is not None:
|
||||
self.ep_group = get_ep_group(self.experts)
|
||||
self.ep_size = get_ep_size(self.experts)
|
||||
self.ep_hierarchical_group = None
|
||||
if enable_hierarchical_comm:
|
||||
# TODO: move to plugin
|
||||
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
|
||||
get_ep_group_ranks(self.experts)
|
||||
)
|
||||
|
@ -186,11 +189,11 @@ class SparseMLP(nn.Module):
|
|||
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||
|
||||
# expert_output: (num_groups, num_experts, capacity, hidden_size)
|
||||
if self.expert_parallel == "EP":
|
||||
if self.parallel == "EP":
|
||||
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||
elif self.expert_parallel == "TP":
|
||||
elif self.parallel == "TP":
|
||||
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||
elif self.expert_parallel is None:
|
||||
elif self.parallel is None:
|
||||
expert_output = self._local_process(dispatch_data)
|
||||
else:
|
||||
raise NotImplementedError(
|
|
@ -0,0 +1,161 @@
|
|||
import math
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import get_activation
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
|
||||
|
||||
if HAS_TRITON:
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
|
||||
|
||||
|
||||
class MLPExperts(nn.Module):
|
||||
"""
|
||||
SparseMLP is a multi-layer perceptron with sparse expert parallel layers.
|
||||
|
||||
Args:
|
||||
num_experts (int): The number of experts
|
||||
hidden_size (int): The hidden size of MLP
|
||||
intermediate_size (int): The intermediate size of MLP
|
||||
expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP.
|
||||
activation (optional): The activation function of MLP
|
||||
drop_rate (float, optional): The drop rate of MLP
|
||||
gated (bool, optional): Whether to use gated MLP
|
||||
use_kernel (bool, optional): Whether to use kernel optimization
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
expert_parallel: Optional[str] = "EP",
|
||||
activation: Optional[Callable] = None,
|
||||
drop_rate: Optional[float] = 0,
|
||||
gated: Optional[bool] = False,
|
||||
use_kernel: Optional[bool] = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert expert_parallel in ["EP", "TP", None]
|
||||
self.expert_parallel = expert_parallel
|
||||
self.num_total_experts = num_experts
|
||||
self.gated = gated
|
||||
self.use_kernel = use_kernel
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
# get expert parallel info
|
||||
if expert_parallel is not None:
|
||||
self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
|
||||
num_experts, use_tp=True if expert_parallel == "TP" else False
|
||||
)
|
||||
# get settings for different parallel
|
||||
self.ep_size = get_ep_size(self)
|
||||
if expert_parallel == "TP":
|
||||
intermediate_size = intermediate_size // self.ep_size
|
||||
num_experts = self.num_total_experts
|
||||
else:
|
||||
num_experts = self.num_local_experts
|
||||
else:
|
||||
self.num_local_experts = self.num_total_experts
|
||||
self.ep_size = 1
|
||||
|
||||
if gated:
|
||||
self.wi_gate = nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
|
||||
)
|
||||
)
|
||||
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||
else:
|
||||
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||
self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size))
|
||||
|
||||
self.act_name = activation
|
||||
self.act = get_activation(activation)
|
||||
self.drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
if expert_parallel is not None:
|
||||
for param in self.parameters():
|
||||
set_moe_tensor_info(param, self.moe_info)
|
||||
|
||||
# init param
|
||||
self.reset_parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def reset_parameters(self):
|
||||
# expert param should be different
|
||||
if self.expert_parallel is not None:
|
||||
seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True)
|
||||
else:
|
||||
seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
|
||||
with seed_ctx:
|
||||
if self.gated:
|
||||
torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size))
|
||||
torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size))
|
||||
else:
|
||||
torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size))
|
||||
torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
param_slice: Tuple[slice] = (slice(None),),
|
||||
use_sparse: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
forward: hidden_size --> intermediate_size --> hidden_size
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
|
||||
"""
|
||||
x = MoeInGradScaler.apply(x, self.ep_size)
|
||||
|
||||
e = x.size(1)
|
||||
h = x.size(-1)
|
||||
|
||||
x = x.transpose(0, 1)
|
||||
inshape = x.shape
|
||||
x = x.reshape(e, -1, h)
|
||||
|
||||
if self.use_kernel and use_sparse:
|
||||
seq_len = x.shape[1]
|
||||
with torch.no_grad():
|
||||
mask = x[:, :, 0] != 0.0
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
x_list = []
|
||||
for i in range(e):
|
||||
x_list.append(x[i, : mask[i]])
|
||||
x = x_list
|
||||
|
||||
if self.gated:
|
||||
x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)]
|
||||
x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)]
|
||||
if self.use_kernel and HAS_TRITON and self.act_name == "swiglu":
|
||||
x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)]
|
||||
else:
|
||||
x = [self.act(x_gate[i]) * x_up[i] for i in range(e)]
|
||||
else:
|
||||
x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)]
|
||||
x = [self.act(x[i]) for i in range(e)]
|
||||
x = [self.drop(x[i]) for i in range(e)]
|
||||
x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)]
|
||||
|
||||
if self.use_kernel and use_sparse:
|
||||
for i in range(e):
|
||||
x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0)
|
||||
|
||||
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
|
||||
x = x.reshape(inshape)
|
||||
x = x.transpose(0, 1).contiguous()
|
||||
x = MoeOutGradScaler.apply(x, self.ep_size)
|
||||
return x
|
|
@ -1,222 +1,108 @@
|
|||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import CrossEntropyLoss, Module
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.models.mixtral.modeling_mixtral import (
|
||||
MixtralDecoderLayer,
|
||||
MixtralForCausalLM,
|
||||
MixtralModel,
|
||||
MixtralSparseMoeBlock,
|
||||
MoeCausalLMOutputWithPast,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.utils import is_flash_attn_2_available, logging
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from .mixtral_layer import EPMixtralSparseMoeBlock
|
||||
|
||||
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
|
||||
|
||||
class MixtralPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
def __init__(self, config):
|
||||
self.moe_info = None
|
||||
super().__init__(config)
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
def setup_ep(self, ep_group: ProcessGroup):
|
||||
ep_group = ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
for p in self.experts.parameters():
|
||||
p.ep_group = ep_group
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
@staticmethod
|
||||
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
# if "ep_group" in kwargs:
|
||||
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
|
||||
module.setup_ep(kwargs["ep_group"])
|
||||
return module
|
||||
|
||||
return self.model
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
policy = {}
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
raise NotImplementedError(
|
||||
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
selected_experts = selected_experts.t().reshape(-1)
|
||||
selected_experts_idx = selected_experts.argsort()
|
||||
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
|
||||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MixtralModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "MixtralModel":
|
||||
module = self.model
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
# compute expert output
|
||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
||||
if output_states.size(0) > 0:
|
||||
if self.num_experts_per_ep == 1:
|
||||
# no need to split
|
||||
expert = self.experts[self.expert_start_idx]
|
||||
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
|
||||
output_states = expert.w2(output_states)
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "MixtralModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
||||
class MixtralModelPolicy(MixtralPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MixtralModel,
|
||||
new_forward=MixtralPipelineForwards.mixtral_model_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
return []
|
||||
|
||||
|
||||
class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
MixtralForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MixtralForCausalLM,
|
||||
new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||
output_states_list = []
|
||||
for i, split_states in enumerate(output_states_splits):
|
||||
if split_states.size(0) == 0:
|
||||
continue
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
||||
split_states = expert.w2(split_states)
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||
selected_experts_idx.size(0), device=selected_experts_idx.device
|
||||
)
|
||||
dispatch_states = dispatch_states[recover_experts_idx]
|
||||
k_hidden_states = dispatch_states.chunk(self.top_k)
|
||||
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
|
||||
for i in range(1, self.top_k):
|
||||
output_states += k_hidden_states[i] * routing_weights[:, i, None]
|
||||
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return output_states, router_logits
|
||||
|
||||
|
||||
class MixtralPipelineForwards:
|
||||
|
@ -332,7 +218,7 @@ class MixtralPipelineForwards:
|
|||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if self._use_flash_attention_2:
|
||||
if is_flash_attn_2_available():
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
|
@ -176,6 +176,7 @@ _POLICY_LIST = {
|
|||
"transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation(
|
||||
file_name="falcon", class_name="FalconForQuestionAnsweringPolicy"
|
||||
),
|
||||
# mistral
|
||||
"transformers.models.mistral.modeling_mistral.MistralModel": PolicyLocation(
|
||||
file_name="mistral", class_name="MistralModelPolicy"
|
||||
),
|
||||
|
@ -185,6 +186,13 @@ _POLICY_LIST = {
|
|||
"transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation(
|
||||
file_name="mistral", class_name="MistralForSequenceClassificationPolicy"
|
||||
),
|
||||
# mixtral
|
||||
"transformers.models.mixtral.modeling_mixtral.MixtralModel": PolicyLocation(
|
||||
file_name="mixtral", class_name="MixtralModelPolicy"
|
||||
),
|
||||
"transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation(
|
||||
file_name="mixtral", class_name="MixtralForCausalLMPolicy"
|
||||
),
|
||||
# Qwen2
|
||||
"transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
|
||||
file_name="qwen2", class_name="Qwen2ModelPolicy"
|
||||
|
@ -195,7 +203,7 @@ _POLICY_LIST = {
|
|||
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
|
||||
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
|
||||
),
|
||||
# Command-R
|
||||
# command
|
||||
"transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation(
|
||||
file_name="command", class_name="CommandModelPolicy"
|
||||
),
|
||||
|
|
|
@ -0,0 +1,210 @@
|
|||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
|
||||
|
||||
|
||||
class MixtralPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
raise NotImplementedError(
|
||||
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||
if getattr(self.shard_config, "ep_group", None) is None:
|
||||
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
|
||||
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
kwargs={"ep_group": self.shard_config.ep_group},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MixtralModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "MixtralModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "MixtralModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
||||
class MixtralModelPolicy(MixtralPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MixtralModel,
|
||||
new_forward=MixtralPipelineForwards.mixtral_model_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
return []
|
||||
|
||||
|
||||
class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
# TODO: assign pg mesh from plugin to all modules
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
MixtralForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MixtralForCausalLM,
|
||||
new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
|
@ -46,6 +46,7 @@ class ShardConfig:
|
|||
make_vocab_size_divisible_by: int = 64
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
ep_group: Optional[ProcessGroup] = None
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||
|
|
|
@ -17,10 +17,10 @@ def is_moe_tensor(tensor: torch.Tensor) -> bool:
|
|||
Returns:
|
||||
bool: Whether the given tensor is a moe tensor.
|
||||
"""
|
||||
return hasattr(tensor, "moe_info")
|
||||
return hasattr(tensor, "ep_group")
|
||||
|
||||
|
||||
def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None:
|
||||
def set_moe_tensor_ep_group(tensor: torch.Tensor, ep_group: ProcessGroup) -> None:
|
||||
"""
|
||||
Set moe info for the given tensor.
|
||||
|
||||
|
@ -29,7 +29,7 @@ def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None
|
|||
moe_info (dict): The moe info to be set.
|
||||
|
||||
"""
|
||||
tensor.__setattr__("moe_info", moe_info)
|
||||
tensor.__setattr__("ep_group", ep_group)
|
||||
|
||||
|
||||
def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo:
|
||||
|
@ -58,7 +58,7 @@ def get_ep_group(tensor: torch.Tensor) -> ProcessGroup:
|
|||
Returns:
|
||||
torch.distributed.ProcessGroup: The expert parallel group of the given tensor.
|
||||
"""
|
||||
return tensor.moe_info.ep_group
|
||||
return tensor.ep_group
|
||||
|
||||
|
||||
def get_ep_size(tensor: torch.Tensor) -> int:
|
||||
|
@ -71,7 +71,8 @@ def get_ep_size(tensor: torch.Tensor) -> int:
|
|||
Returns:
|
||||
int: The expert parallel size of the given tensor.
|
||||
"""
|
||||
return tensor.moe_info.ep_size
|
||||
assert getattr(tensor, "ep_group") is not None, "The tensor does not have expert parallel group."
|
||||
return dist.get_world_size(tensor.ep_group)
|
||||
|
||||
|
||||
def get_dp_size(tensor: torch.Tensor) -> int:
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from .bucket_store import BucketStore
|
||||
from .gradient_store import GradientStore
|
||||
from .parameter_store import ParameterStore
|
||||
from .tensor_bucket import TensorBucket
|
||||
|
||||
__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"]
|
||||
__all__ = ["GradientStore", "BucketStore", "TensorBucket"]
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.accelerator.api import get_accelerator
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
@ -16,29 +15,11 @@ class BucketStore(BaseStore):
|
|||
self,
|
||||
torch_pg: ProcessGroup,
|
||||
reduce_bucket_size: int,
|
||||
overlap_communication: bool,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
moe_extra_dp_process_group: ProcessGroup = None,
|
||||
):
|
||||
super().__init__(torch_pg)
|
||||
self.reduce_bucket_size = reduce_bucket_size
|
||||
# communication params
|
||||
self._overlap_communication = overlap_communication
|
||||
self._communication_dtype = communication_dtype
|
||||
if self._overlap_communication:
|
||||
self.comm_stream = get_accelerator().Stream()
|
||||
self.zero_local_rank = dist.get_rank(group=self.torch_pg)
|
||||
self.zero_world_size = dist.get_world_size(group=self.torch_pg)
|
||||
# extra dp
|
||||
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
|
||||
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
|
||||
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
|
||||
# And moe working and master param are split by extra dp pg.
|
||||
self.moe_extra_dp_pg = moe_extra_dp_process_group
|
||||
if self.moe_extra_dp_pg is not None:
|
||||
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
|
||||
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
|
||||
self.reset_all()
|
||||
self.comm_stream = get_accelerator().Stream()
|
||||
|
||||
def reset_all(self) -> None:
|
||||
# init
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -6,7 +6,7 @@ from .base_store import BaseStore
|
|||
|
||||
|
||||
class GradientStore(BaseStore):
|
||||
def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True):
|
||||
def __init__(self, *args, partition_grad: bool = False):
|
||||
super().__init__(*args)
|
||||
"""
|
||||
self._grads_of_params mapping the parameter and its gradient slices
|
||||
|
@ -20,8 +20,6 @@ class GradientStore(BaseStore):
|
|||
self._grads_of_params = dict()
|
||||
# stage 2
|
||||
self._partition_grads = partition_grad
|
||||
# grad accumulation
|
||||
self.require_grad_sync = require_grad_sync
|
||||
self._working_index = 0 if partition_grad else self._local_rank
|
||||
# for zero2, it's `param_id: [grad_local_rank]`
|
||||
self.grad_to_param_mapping = dict()
|
||||
|
@ -107,8 +105,7 @@ class GradientStore(BaseStore):
|
|||
for group in self._grads_of_params.values():
|
||||
if param_id in group.keys():
|
||||
return group[param_id][self._working_index]
|
||||
|
||||
raise KeyError(f"Working gradient for param_id {param_id} not found.")
|
||||
return None
|
||||
|
||||
def reset_grads_by_group_id(self, group_id: int):
|
||||
self._grads_of_params[group_id] = dict()
|
||||
|
@ -116,7 +113,7 @@ class GradientStore(BaseStore):
|
|||
def reset_all_gradients(self):
|
||||
self._grads_of_params = dict()
|
||||
|
||||
def get_param_id_for_grad(self, grad: Tensor) -> int:
|
||||
def get_param_id_for_grad(self, grad: Tensor) -> Optional[int]:
|
||||
"""Return the id of a parameter which the gradient slice belongs to
|
||||
|
||||
Args:
|
||||
|
@ -126,4 +123,4 @@ class GradientStore(BaseStore):
|
|||
int: the id of a parameter which the gradient slice belongs to
|
||||
"""
|
||||
|
||||
return self.grad_to_param_mapping[id(grad)]
|
||||
return self.grad_to_param_mapping.get(id(grad), None)
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
from typing import Dict
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class ParameterStore(BaseStore):
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
|
||||
# record the padding size of each param
|
||||
self._padding_map = dict()
|
||||
|
||||
# mapping working param and master param
|
||||
self.master_to_working_param = dict()
|
||||
self.working_to_master_param = dict()
|
||||
|
||||
def record_param_padding_size(self, param: Tensor, padding_size: int):
|
||||
"""Record the padding size of a param
|
||||
|
||||
Args:
|
||||
param (Tensor): The parameter
|
||||
padding_size (int): The padding size of the parameter
|
||||
"""
|
||||
|
||||
self._padding_map[id(param)] = padding_size
|
||||
|
||||
def get_param_padding_size(self, param: Tensor) -> int:
|
||||
"""Return the padding size of the parameter
|
||||
|
||||
Args:
|
||||
param (Tensor): The parameter
|
||||
|
||||
Returns:
|
||||
int: the padding size of the parameter
|
||||
"""
|
||||
|
||||
return self._padding_map[id(param)]
|
||||
|
||||
def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor):
|
||||
"""Mapping master parameter and working parameter
|
||||
|
||||
Args:
|
||||
master_param (Tensor): The parameter copy in optimizer
|
||||
working_param (Tensor): The parameter of the model
|
||||
"""
|
||||
|
||||
self.master_to_working_param[id(master_param)] = working_param
|
||||
self.working_to_master_param[id(working_param)] = master_param
|
||||
|
||||
def get_padding_map(self) -> Dict[int, Tensor]:
|
||||
"""Return the padding map
|
||||
|
||||
Returns:
|
||||
Dict[int, Tensor]: The padding map
|
||||
"""
|
||||
|
||||
return self._padding_map
|
File diff suppressed because it is too large
Load Diff
|
@ -176,7 +176,7 @@ def main():
|
|||
use_ep_inside = False
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
extra_dp_size=args.extra_dp_size,
|
||||
ep_size=args.ep_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
**hybrid_dict,
|
||||
)
|
||||
|
|
|
@ -50,9 +50,9 @@ try:
|
|||
except:
|
||||
HAS_FLASH_ATTN = False
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||
from colossalai.moe.layers import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import get_activation, set_moe_args
|
||||
from colossalai.shardformer.layer.moe import SparseMLP
|
||||
|
||||
if HAS_TRITON:
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
|
||||
|
@ -83,7 +83,7 @@ def set_openmoe_args(
|
|||
load_balance_group_swap_factor: float = 0.4,
|
||||
enable_kernel: bool = False,
|
||||
enable_comm_overlap: bool = False,
|
||||
enable_hierarchical_alltoall: bool = False,
|
||||
enable_hierarchical_alltoall: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
MoE related arguments.
|
||||
|
@ -465,7 +465,7 @@ class OpenMoeDecoderLayer(nn.Module):
|
|||
load_balance_beam_width=config.load_balance_beam_width,
|
||||
load_balance_group_swap_factor=config.load_balance_group_swap_factor,
|
||||
enable_kernel=config.enable_kernel,
|
||||
enable_comm_overlap=config.enable_comm_overlap,
|
||||
enable_hierarchical_comm=config.enable_hierarchical_alltoall,
|
||||
)
|
||||
self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.extra_mlp = OpenMoeMLP(config)
|
||||
|
@ -903,7 +903,7 @@ class OpenMoeForCausalLM(OpenMoePreTrainedModel):
|
|||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
# reset moe loss
|
||||
MOE_MANAGER.reset_loss()
|
||||
MOE_MANAGER.reset_loss() # TODO: remove
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -1027,7 +1027,7 @@ class OpenMoeForCausalLM(OpenMoePreTrainedModel):
|
|||
|
||||
def _calculate_router_loss(self, aux_loss: list = None, z_loss: list = None):
|
||||
if aux_loss is None or z_loss is None:
|
||||
aux_loss, z_loss = MOE_MANAGER.get_loss()
|
||||
aux_loss, z_loss = MOE_MANAGER.get_loss() # TODO: remove
|
||||
assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval
|
||||
aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss)
|
||||
z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss)
|
||||
|
|
|
@ -172,6 +172,7 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
|
|||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
# TODO: recursively assign ep group foe all modules
|
||||
new_item = {
|
||||
OpenMoeForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
|
|
|
@ -1,37 +1,37 @@
|
|||
pip install -r requirements.txt
|
||||
# pip install -r requirements.txt
|
||||
|
||||
# inference
|
||||
python infer.py --model "test"
|
||||
# python infer.py --model "test"
|
||||
|
||||
# train
|
||||
torchrun --standalone --nproc_per_node 4 train.py \
|
||||
--num_epoch 1 \
|
||||
--model_name "test" \
|
||||
--plugin "ep" \
|
||||
--batch_size 1
|
||||
# torchrun --standalone --nproc_per_node 4 train.py \
|
||||
# --num_epoch 1 \
|
||||
# --model_name "test" \
|
||||
# --plugin "ep" \
|
||||
# --batch_size 1
|
||||
|
||||
torchrun --standalone --nproc_per_node 4 train.py \
|
||||
--num_epoch 1 \
|
||||
--model_name "test" \
|
||||
--plugin "ep_zero" \
|
||||
--batch_size 1 \
|
||||
--zero_stage 1 \
|
||||
--extra_dp_size 2 \
|
||||
# torchrun --standalone --nproc_per_node 4 train.py \
|
||||
# --num_epoch 1 \
|
||||
# --model_name "test" \
|
||||
# --plugin "ep_zero" \
|
||||
# --batch_size 1 \
|
||||
# --zero_stage 1 \
|
||||
# --extra_dp_size 2 \
|
||||
|
||||
torchrun --standalone --nproc_per_node 4 train.py \
|
||||
--num_epoch 1 \
|
||||
--model_name "test" \
|
||||
--plugin "ep_zero" \
|
||||
--batch_size 1 \
|
||||
--zero_stage 2 \
|
||||
--extra_dp_size 2 \
|
||||
# torchrun --standalone --nproc_per_node 4 train.py \
|
||||
# --num_epoch 1 \
|
||||
# --model_name "test" \
|
||||
# --plugin "ep_zero" \
|
||||
# --batch_size 1 \
|
||||
# --zero_stage 2 \
|
||||
# --extra_dp_size 2 \
|
||||
|
||||
torchrun --standalone --nproc_per_node 4 train.py \
|
||||
--model_name "test" \
|
||||
--plugin "hybrid" \
|
||||
--num_epoch 1 \
|
||||
--pp_size 2 \
|
||||
--dp_size 1 \
|
||||
--ep_size 2 \
|
||||
--zero_stage 1 \
|
||||
--batch_size 1
|
||||
# torchrun --standalone --nproc_per_node 4 train.py \
|
||||
# --model_name "test" \
|
||||
# --plugin "hybrid" \
|
||||
# --num_epoch 1 \
|
||||
# --pp_size 2 \
|
||||
# --dp_size 1 \
|
||||
# --ep_size 2 \
|
||||
# --zero_stage 1 \
|
||||
# --batch_size 1
|
||||
|
|
|
@ -19,10 +19,9 @@ from colossalai.accelerator import get_accelerator
|
|||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.moe.layers import apply_load_balance
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import skip_init
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.shardformer.layer.moe import apply_load_balance
|
||||
|
||||
|
||||
def move_to_cuda(batch, device):
|
||||
|
@ -221,48 +220,49 @@ def main():
|
|||
"precision": args.precision,
|
||||
"zero_stage": args.zero_stage,
|
||||
}
|
||||
mgr_dict = {}
|
||||
if args.plugin == "ep":
|
||||
dp_size = dist.get_world_size()
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
ep_size=args.ep_size,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=dp_size,
|
||||
**mgr_dict,
|
||||
)
|
||||
# MOE_MANAGER.setup(
|
||||
# parallel="EP",
|
||||
# max_ep_size=dp_size,
|
||||
# **mgr_dict,
|
||||
# )
|
||||
elif args.plugin == "ep_zero":
|
||||
dp_size = dist.get_world_size()
|
||||
use_ep_inside = False
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
extra_dp_size=args.extra_dp_size,
|
||||
ep_size=dp_size // args.ep_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=dp_size // args.extra_dp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
**mgr_dict,
|
||||
)
|
||||
# MOE_MANAGER.setup(
|
||||
# parallel="EP",
|
||||
# max_ep_size=dp_size // args.extra_dp_size,
|
||||
# use_ep_inside=use_ep_inside,
|
||||
# **mgr_dict,
|
||||
# )
|
||||
elif args.plugin == "hybrid":
|
||||
dp_size = dist.get_world_size() // args.pp_size
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=args.pp_size,
|
||||
ep_size=args.ep_size,
|
||||
microbatch_size=args.microbatch_size,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=args.dp_size,
|
||||
fixed_ep_size=args.ep_size,
|
||||
fixed_pp_size=args.pp_size,
|
||||
**mgr_dict,
|
||||
)
|
||||
# MOE_MANAGER.setup(
|
||||
# parallel="EP",
|
||||
# mode="fixed",
|
||||
# fixed_dp_size=args.dp_size,
|
||||
# fixed_ep_size=args.ep_size,
|
||||
# fixed_pp_size=args.pp_size,
|
||||
# **mgr_dict,
|
||||
# )
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||
|
|
|
@ -59,10 +59,10 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
|
|||
# check master weight
|
||||
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
|
||||
working_param_id_set = set(id(p) for p in new_model.parameters())
|
||||
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
|
||||
for p_id, master_param in new_optimizer.working_to_master_param.items():
|
||||
assert p_id in working_param_id_set
|
||||
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
|
||||
padding = new_optimizer._param_store.get_param_padding_size(working_param)
|
||||
working_param = new_optimizer.master_to_working_param[id(master_param)]
|
||||
padding = new_optimizer.get_param_padding_size(working_param)
|
||||
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
|
||||
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
|
||||
assert torch.equal(
|
||||
|
@ -115,10 +115,10 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
|
|||
# check master weight
|
||||
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
|
||||
working_param_id_set = set(id(p) for p in new_model.parameters())
|
||||
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
|
||||
for p_id, master_param in new_optimizer.working_to_master_param.items():
|
||||
assert p_id in working_param_id_set
|
||||
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
|
||||
padding = new_optimizer._param_store.get_param_padding_size(working_param)
|
||||
working_param = new_optimizer.master_to_working_param[id(master_param)]
|
||||
padding = new_optimizer.get_param_padding_size(working_param)
|
||||
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
|
||||
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
|
||||
assert torch.equal(
|
||||
|
|
|
@ -1,48 +1,37 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
|
||||
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
|
||||
from colossalai.legacy.registry import GRADIENT_HANDLER
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import get_moe_epsize_param_dict
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
|
||||
|
||||
# from colossalai.shardformer.layer.moe import SparseMLP
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group
|
||||
|
||||
|
||||
def delete_moe_info(model):
|
||||
for _, param in model.named_parameters():
|
||||
if hasattr(param, "moe_info"):
|
||||
delattr(param, "moe_info")
|
||||
if hasattr(param, "ep_group"):
|
||||
delattr(param, "ep_group")
|
||||
|
||||
|
||||
class MoeModel(nn.Module):
|
||||
def __init__(self, enable_load_balance: bool = False):
|
||||
class TestSubModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.moe = SparseMLP(
|
||||
num_experts=8, hidden_size=16, intermediate_size=32, enable_load_balance=enable_load_balance
|
||||
)
|
||||
self.proj = nn.Linear(16, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.moe(x)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
def __init__(self, ep_group: ProcessGroup = None):
|
||||
super().__init__()
|
||||
self.test_embed = nn.Linear(4, 16)
|
||||
self.test_transform = TestSubModule()
|
||||
self.test_embed = nn.Linear(4, 16, bias=False)
|
||||
self.w1 = torch.nn.Parameter(torch.randn(16, 8))
|
||||
if ep_group:
|
||||
set_moe_tensor_ep_group(self.w1, ep_group)
|
||||
|
||||
def forward(self, x):
|
||||
MOE_MANAGER.reset_loss()
|
||||
|
||||
x = self.test_embed(x)
|
||||
x = self.test_transform(x)
|
||||
x = torch.matmul(x, self.w1)
|
||||
|
||||
return x
|
||||
|
||||
|
@ -116,7 +105,7 @@ def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False)
|
|||
return y
|
||||
|
||||
|
||||
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
|
||||
def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
|
||||
"""Sync the parameters of tp model from ep model
|
||||
|
||||
Args:
|
||||
|
@ -126,7 +115,6 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
|
|||
for (local_name, local_param), (ep_name, ep_param) in zip(
|
||||
local_model.named_parameters(), ep_model.named_parameters()
|
||||
):
|
||||
assert local_name in ep_name, print(f"{local_name} != {ep_name}")
|
||||
if "experts" not in local_name:
|
||||
if assert_grad_flag:
|
||||
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
|
||||
|
|
|
@ -5,8 +5,9 @@ import torch.nn as nn
|
|||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
||||
# from colossalai.shardformer.layer.moe.layers import SparseMLP
|
||||
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
|
||||
from tests.test_moe.moe_utils import MoeGradientHandler
|
||||
|
||||
|
@ -69,6 +70,7 @@ def run_test(rank, world_size, port):
|
|||
# MoE grad handler test passed
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="moe need to be refactored")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_grad_handler():
|
||||
|
|
|
@ -1,98 +1,96 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
BATCH_SIZE = 4
|
||||
# from colossalai.moe import SparseMLP
|
||||
from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum
|
||||
|
||||
NUM_EXPERTS = 4
|
||||
BATCH_SIZE = 4
|
||||
SEQ_LEN = 4
|
||||
|
||||
MOE_TENSOR_PATH = os.getenv("MOE_TENSOR_PATH")
|
||||
|
||||
|
||||
def check_equal(tensor_a, tensor_b, atol=1e-06):
|
||||
assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True
|
||||
|
||||
|
||||
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, topk=1):
|
||||
# Here we do not need TF32, since it brings absolute error on results
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
def run_moe_cumsum():
|
||||
test_mask = torch.tensor(
|
||||
[
|
||||
[0, 1, 0, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 0],
|
||||
[1, 0, 0, 0],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
).to("cuda")
|
||||
out_no_kernel = moe_cumsum(test_mask, use_kernel=False)
|
||||
out_kernel = moe_cumsum(test_mask, use_kernel=True)
|
||||
print(out_no_kernel.dtype, out_kernel.dtype)
|
||||
check_equal(out_no_kernel.to(torch.int32), out_kernel)
|
||||
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
local_rank = dist.get_rank()
|
||||
|
||||
MOE_MANAGER.setup(parallel="EP") # MOE environment initialization
|
||||
MOE_MANAGER.reset_loss()
|
||||
torch.manual_seed(rs + local_rank) # set each process has different random seed
|
||||
|
||||
# get randomized data
|
||||
def run_moe_dispatch_combine_fwd_bwd(data_type=torch.float32, hidden_size=128, num_experts=4):
|
||||
tokens = torch.randn(
|
||||
BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True
|
||||
)
|
||||
|
||||
layer = SparseMLP(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_experts=NUM_EXPERTS,
|
||||
router_top_k=topk,
|
||||
router_capacity_factor_train=1.0,
|
||||
)
|
||||
layer = layer.to(get_accelerator().get_current_device())
|
||||
if data_type == torch.float16:
|
||||
layer = layer.half()
|
||||
# use kernel
|
||||
route_result_list_kernel = torch.load(f"{MOE_TENSOR_PATH}/True_4_{data_type}.pt")
|
||||
# dispatch
|
||||
dispatch_data_kernel = MoeDispatch.apply(tokens, *route_result_list_kernel[1:])
|
||||
dispatch_data_kernel = dispatch_data_kernel.reshape(num_experts, -1, hidden_size)
|
||||
# combine
|
||||
expert_output = dispatch_data_kernel.reshape(-1, hidden_size)
|
||||
ans_kernel = MoeCombine.apply(expert_output, *route_result_list_kernel)
|
||||
|
||||
# use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine
|
||||
layer.enable_kernel = False
|
||||
old_out = layer(tokens)
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_accelerator().get_current_device())
|
||||
old_out.backward(grad) # get gradient
|
||||
# no kernel
|
||||
route_result_list_no_kernel = torch.load(f"{MOE_TENSOR_PATH}/False_2_{data_type}.pt")
|
||||
# dispatch
|
||||
sec_mask_f = route_result_list_no_kernel[1].type_as(tokens)
|
||||
dispatch_data_no_kernel = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||
# combine
|
||||
combine_weights = route_result_list_no_kernel[0].type_as(tokens)
|
||||
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
||||
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
||||
ans_no_kernel = torch.matmul(combine_weights, expert_output)
|
||||
|
||||
# save all results
|
||||
o_tk_grad = tokens.grad.data.clone()
|
||||
o_gt_grad = layer.gate_weight.grad.data.clone()
|
||||
# check fwd
|
||||
if data_type == torch.float32:
|
||||
check_equal(dispatch_data_kernel.reshape(dispatch_data_no_kernel.shape), dispatch_data_no_kernel)
|
||||
else:
|
||||
check_equal(dispatch_data_kernel.reshape(dispatch_data_no_kernel.shape), dispatch_data_no_kernel, 1e-2)
|
||||
|
||||
# reset all gradients
|
||||
if data_type == torch.float32:
|
||||
check_equal(ans_kernel, ans_no_kernel)
|
||||
else:
|
||||
check_equal(ans_kernel, ans_no_kernel, 1e-2)
|
||||
|
||||
# check bwd
|
||||
out_shape = ans_kernel.shape
|
||||
grad = torch.randn(out_shape, device=get_accelerator().get_current_device())
|
||||
|
||||
ans_kernel.backward(grad, retain_graph=True)
|
||||
grad_kernel = tokens.grad.data.clone()
|
||||
tokens.grad.zero_()
|
||||
layer.gate_weight.grad.zero_()
|
||||
|
||||
layer.enable_kernel = True
|
||||
new_out = layer(tokens) # get outputs through colossal kernel
|
||||
ans_no_kernel.backward(grad) # get gradient
|
||||
grad_no_kernel = tokens.grad.data.clone()
|
||||
tokens.grad.zero_()
|
||||
|
||||
if data_type == torch.float32:
|
||||
check_equal(old_out, new_out)
|
||||
check_equal(grad_no_kernel, grad_kernel)
|
||||
else:
|
||||
check_equal(old_out, new_out, 1e-2)
|
||||
# forward function passed
|
||||
|
||||
new_out.backward(grad) # get new type gradient
|
||||
n_tk_grad = tokens.grad.data.clone()
|
||||
n_gt_grad = layer.gate_weight.grad.data.clone()
|
||||
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_tk_grad, n_tk_grad)
|
||||
else:
|
||||
check_equal(o_tk_grad, o_tk_grad, 1e-2)
|
||||
# tokens gradient is correct
|
||||
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_gt_grad, n_gt_grad, 5e-05)
|
||||
else:
|
||||
check_equal(o_gt_grad, n_gt_grad, 2e-01)
|
||||
# bias gradient is correct
|
||||
check_equal(grad_no_kernel, grad_kernel, 1e-2)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("rs", [131])
|
||||
@pytest.mark.parametrize("hidden_size", [32, 144])
|
||||
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
|
||||
@pytest.mark.parametrize("topk", [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_kernel(rs, hidden_size, data_type, topk):
|
||||
spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_kernel(2, 256, torch.float16, 2)
|
||||
def test_moe_kernel(data_type):
|
||||
torch.manual_seed(1024)
|
||||
run_moe_cumsum()
|
||||
run_moe_dispatch_combine_fwd_bwd(data_type=data_type)
|
||||
|
|
|
@ -3,13 +3,13 @@ from copy import deepcopy
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
|
||||
from torch.testing import assert_close
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import colossalai
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
|
@ -19,8 +19,11 @@ top_k = 2
|
|||
|
||||
def check_mixtral_moe_layer():
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=dist.get_world_size(),
|
||||
)
|
||||
config = MixtralConfig(
|
||||
hidden_size=hidden_size,
|
||||
|
@ -33,7 +36,7 @@ def check_mixtral_moe_layer():
|
|||
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
||||
orig_output, orig_logits = orig_model(x)
|
||||
model = deepcopy(orig_model)
|
||||
model = EPMixtralSparseMoeBlock.from_native_module(model)
|
||||
model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group)
|
||||
ep_output, ep_logits = model(x)
|
||||
assert_close(orig_logits, ep_logits)
|
||||
assert_close(orig_output, ep_output)
|
|
@ -1,201 +1,176 @@
|
|||
import importlib
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.models.llama import LlamaConfig
|
||||
from torch.optim import Adam
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
sys.path.append(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"examples/language/openmoe",
|
||||
)
|
||||
)
|
||||
|
||||
OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
|
||||
set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
|
||||
OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
top_k = 2
|
||||
|
||||
|
||||
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_accelerator().get_current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
def check_model_equal(model1, model2):
|
||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
||||
if not torch.equal(p1.half(), p2.half()):
|
||||
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
|
||||
raise AssertionError(f"Model parameter {name} is not equal")
|
||||
|
||||
|
||||
def get_optimizer_snapshot(optim):
|
||||
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
|
||||
param_groups = []
|
||||
for group in optim.param_groups:
|
||||
params = [id(p) for p in group["params"]]
|
||||
new_group = {"params": params}
|
||||
for k, v in group.items():
|
||||
if k != "params":
|
||||
new_group[k] = v
|
||||
param_groups.append(new_group)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": input_ids,
|
||||
"state": state,
|
||||
"param_groups": param_groups,
|
||||
}
|
||||
|
||||
|
||||
def run_fwd_bwd(
|
||||
model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
|
||||
):
|
||||
model.train()
|
||||
if pipeline:
|
||||
train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
|
||||
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
|
||||
y = booster.execute_pipeline(
|
||||
train_dataloader_iter,
|
||||
model,
|
||||
lambda x, y: x.loss,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
loss = y["loss"]
|
||||
else:
|
||||
if criterion:
|
||||
y = model(data).logits
|
||||
loss = criterion(y)
|
||||
def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_group=None):
|
||||
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
|
||||
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
|
||||
assert set(group1.keys()) == set(group2.keys())
|
||||
for k in group1.keys():
|
||||
assert group1[k] == group2[k]
|
||||
# check state
|
||||
assert set(snapshot1["state"].keys()) == set(
|
||||
snapshot2["state"].keys()
|
||||
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
|
||||
|
||||
passed = True
|
||||
count = 0
|
||||
for pid in snapshot1["state"].keys():
|
||||
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
|
||||
assert set(state1.keys()) == set(state2.keys())
|
||||
bug = False
|
||||
for k in state1.keys():
|
||||
if isinstance(state1[k], torch.Tensor):
|
||||
if not torch.equal(state1[k], state2[k]):
|
||||
bug = True
|
||||
count += 1
|
||||
else:
|
||||
assert state1[k] == state2[k]
|
||||
if bug:
|
||||
passed = False
|
||||
|
||||
if not passed:
|
||||
raise AssertionError(f"A total of {count} optim states are not equal")
|
||||
|
||||
|
||||
def check_mixtral_moe_layer():
|
||||
context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
|
||||
with context as f:
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
if dist.get_rank() == 0:
|
||||
broadcast_objects = [f] # any picklable object
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
broadcast_objects = [None]
|
||||
dist.broadcast_object_list(broadcast_objects, src=0)
|
||||
|
||||
if optimizer is not None:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
return y
|
||||
|
||||
|
||||
def get_config():
|
||||
config = LlamaConfig(
|
||||
vocab_size=300,
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
head_dim=4,
|
||||
dropout_rate=0.0,
|
||||
hidden_act="swiglu",
|
||||
)
|
||||
set_openmoe_args(config, num_experts=8, moe_layer_interval=1)
|
||||
return config
|
||||
|
||||
|
||||
def get_model(parallel):
|
||||
config = get_config()
|
||||
model = OpenMoeForCausalLM(config)
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
|
||||
if parallel == None:
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=1,
|
||||
zero_stage=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
config = MixtralConfig(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
elif parallel == "ep":
|
||||
torch.manual_seed(0)
|
||||
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
|
||||
orig_model = MixtralForCausalLM(config).cuda()
|
||||
model = deepcopy(orig_model)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=dist.get_world_size(),
|
||||
zero_stage=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
elif parallel == "ep_zero":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=2,
|
||||
zero_stage=2,
|
||||
extra_dp_size=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
elif parallel == "hybrid":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=2,
|
||||
ep_size=2,
|
||||
zero_stage=1,
|
||||
tp_size=1,
|
||||
checkpoint_io=MoECheckpointIO,
|
||||
microbatch_size=1,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
zero_stage=1,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
|
||||
return model, booster, optim
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
|
||||
# initialize grads
|
||||
data_iter = iter(
|
||||
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
|
||||
)
|
||||
booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
lambda outputs, inputs: outputs.loss,
|
||||
optimizer,
|
||||
)
|
||||
|
||||
tmpdirname = broadcast_objects[0]
|
||||
model_dir = os.path.join(tmpdirname, "mixtral_model")
|
||||
hf_model_dir = os.path.join(tmpdirname, "mixtral_hf_model")
|
||||
optim_dir = os.path.join(tmpdirname, "mixtral_optim")
|
||||
|
||||
booster.save_model(model, model_dir, shard=True)
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
saved_model = MixtralForCausalLM.from_pretrained(model_dir).cuda()
|
||||
check_model_equal(orig_model, saved_model)
|
||||
# check_model_equal(model, saved_model)
|
||||
saved_model.save_pretrained(hf_model_dir)
|
||||
dist.barrier()
|
||||
# check load model
|
||||
new_model = MixtralForCausalLM(config).cuda()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
||||
booster.load_model(new_model, hf_model_dir)
|
||||
check_model_equal(model, new_model)
|
||||
|
||||
# check save optimizer
|
||||
optimizer.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = 0.1
|
||||
snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
booster.save_optimizer(optimizer, optim_dir, shard=True)
|
||||
dist.barrier()
|
||||
|
||||
# reset optimizer state
|
||||
for state in optimizer.unwrap().state.values():
|
||||
for v in state.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v.zero_()
|
||||
booster.load_optimizer(optimizer, optim_dir)
|
||||
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
check_optimizer_snapshot_equal(snapshot, loaded_snapshot, None, model)
|
||||
# Ensure rank 0 waits for all other ranks to finish
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def _test_moe_checkpoint(rank, parallel):
|
||||
model1, booster1, optim1 = get_model(parallel)
|
||||
model2, booster2, optim2 = get_model(parallel)
|
||||
model3, booster3, optim3 = get_model(parallel)
|
||||
|
||||
# param ckpt
|
||||
# shard
|
||||
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
|
||||
booster2.load_model(model2, "./tmp_ckpt1")
|
||||
# unshard
|
||||
booster1.save_model(model1, "./tmp_ckpt1.pth")
|
||||
booster3.load_model(model3, "./tmp_ckpt1.pth")
|
||||
# check
|
||||
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
|
||||
check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)
|
||||
|
||||
# optim ckpt
|
||||
criterion = lambda x: x.mean()
|
||||
data = torch.randint(0, 4, (2, 4)).cuda()
|
||||
label = torch.randint(0, 4, (2,)).cuda()
|
||||
if parallel == "hybrid":
|
||||
kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
|
||||
else:
|
||||
kwargs = {}
|
||||
run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
|
||||
optim1.step()
|
||||
optim1.zero_grad()
|
||||
# shard
|
||||
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
|
||||
dist.barrier()
|
||||
booster2.load_optimizer(optim2, "./tmp_ckpt2")
|
||||
# unshard
|
||||
booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")
|
||||
booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")
|
||||
# check
|
||||
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
|
||||
check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
shutil.rmtree("./tmp_ckpt1")
|
||||
shutil.rmtree("./tmp_ckpt2")
|
||||
os.remove("./tmp_ckpt1.pth")
|
||||
os.remove("./tmp_ckpt2.pth")
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch(rank, world_size, "localhost", port)
|
||||
check_mixtral_moe_layer()
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port, parallel):
|
||||
colossalai.launch(
|
||||
config=dict(),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host="localhost",
|
||||
port=port,
|
||||
backend="nccl",
|
||||
)
|
||||
_test_moe_checkpoint(rank, parallel)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This is tested in ColossalMOE")
|
||||
@pytest.mark.dist
|
||||
# Test EP + ZeRO + PP
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_checkpoint(world_size, parallel):
|
||||
spawn(_run_dist, world_size, parallel=parallel)
|
||||
def test_mixtral_moe_layer(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_checkpoint(world_size=4, parallel="hybrid")
|
||||
test_mixtral_moe_layer(4)
|
||||
|
|
|
@ -8,15 +8,16 @@ import torch.distributed as dist
|
|||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import sync_moe_model_param
|
||||
|
||||
# from colossalai.shardformer.layer import SparseMLP
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
|
||||
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
|
||||
from tests.test_moe.moe_utils import MoeGradientHandler
|
||||
|
||||
|
||||
def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_flag: bool = False) -> None:
|
||||
def sync_tp_from_local(tp_model, local_model, assert_grad_flag: bool = False) -> None:
|
||||
"""Sync the parameters of tp model from local model
|
||||
|
||||
Args:
|
||||
|
@ -48,7 +49,7 @@ def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_
|
|||
tp_param.data.copy_(local_param[tuple(tp_slice)].data)
|
||||
|
||||
|
||||
def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
|
||||
def sync_tp_from_ep(tp_model, ep_model, assert_grad_flag: bool = False) -> None:
|
||||
"""Sync the parameters of tp model from ep model
|
||||
|
||||
Args:
|
||||
|
@ -90,7 +91,7 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag:
|
|||
tp_param.data.copy_(new_tp_param.data)
|
||||
|
||||
|
||||
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
|
||||
def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
|
||||
"""Sync the parameters of tp model from ep model
|
||||
|
||||
Args:
|
||||
|
@ -216,6 +217,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="moe need to be refactored")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("num_experts", [4, 64])
|
||||
@pytest.mark.parametrize("batch_size", [16])
|
||||
|
|
|
@ -4,9 +4,10 @@ import torch.nn as nn
|
|||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe.experts import MLPExperts
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import sync_moe_model_param
|
||||
|
||||
# from colossalai.shardformer.layer.moe import MLPExperts
|
||||
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
|
||||
|
||||
HIDDEN_SIZE = 4
|
||||
|
@ -69,6 +70,7 @@ def _run_test(rank, world_size, port, expert_parallel):
|
|||
run_moe_init(expert_parallel)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="moe need to be refactored")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("expert_parallel", ["EP", "TP"])
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
|
@ -86,6 +86,7 @@ def run_dist(rank, world_size, port):
|
|||
run_zero_optim_test(rank, world_size, stage=2)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="moe need to be refactored")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
|
@ -6,8 +6,9 @@ import colossalai
|
|||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.moe.layers import apply_load_balance
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
||||
# from colossalai.shardformer.layer.moe import apply_load_balance
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
|
||||
|
@ -176,6 +177,7 @@ def run_dist(rank, world_size, port):
|
|||
run_hybrid_zero_optim_test(rank, world_size, stage=2)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="moe need to be refactored")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["router", "num_groups"],
|
||||
[
|
||||
(Top1Router(), 1),
|
||||
(Top2Router(), 1),
|
||||
# (TopKRouter(num_selected_experts=3), 4),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
["batch_size", "seq_len", "num_experts"],
|
||||
[
|
||||
(4, 5, 8),
|
||||
(3, 4, 4),
|
||||
],
|
||||
)
|
||||
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
|
||||
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
|
||||
if num_groups > 1:
|
||||
x = x.expand(num_groups, -1, -1)
|
||||
|
||||
router.train()
|
||||
if isinstance(router, TopKRouter):
|
||||
combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
else:
|
||||
combine_array, dispatch_mask = router(x)[1:3]
|
||||
assert combine_array.shape[:-1] == x.shape
|
||||
assert dispatch_mask.shape[:-1] == x.shape
|
||||
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
|
||||
|
||||
router.eval()
|
||||
if isinstance(router, TopKRouter):
|
||||
combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
else:
|
||||
combine_array, dispatch_mask = router(x)[1:3]
|
||||
assert combine_array.shape[:-1] == x.shape
|
||||
assert dispatch_mask.shape[:-1] == x.shape
|
||||
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_router_forward(Top2Router(), 4, 4, 4, 1)
|
|
@ -1,78 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep
|
||||
|
||||
|
||||
def run_zero_test(local_rank, stage=1):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(parallel="EP")
|
||||
moe_model = MoeModel().bfloat16()
|
||||
moe_optimizer = torch.optim.Adam(moe_model.parameters())
|
||||
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
|
||||
moe_booster = Booster(plugin=moe_plugin)
|
||||
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
|
||||
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(parallel=None)
|
||||
zero_model = MoeModel().bfloat16()
|
||||
delete_moe_info(zero_model)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters())
|
||||
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
|
||||
zero_booster = Booster(plugin=zero_plugin)
|
||||
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
|
||||
sync_local_from_ep(zero_model, moe_model)
|
||||
|
||||
data = torch.randn(16, 4).bfloat16().cuda()
|
||||
label = torch.randint(0, 4, (16,)).cuda()
|
||||
|
||||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
|
||||
assert torch.allclose(zero_out, moe_out)
|
||||
|
||||
for (moe_name, moe_param), (zero_name, zero_param) in zip(
|
||||
moe_model.module.named_parameters(), zero_model.module.named_parameters()
|
||||
):
|
||||
assert moe_name == zero_name
|
||||
moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param))
|
||||
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
|
||||
if hasattr(moe_param, "moe_info"):
|
||||
assert len(moe_grad_list) == 0
|
||||
if stage == 1:
|
||||
zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape)
|
||||
else:
|
||||
zero_grad = zero_grad_list[0].view(moe_param.grad.shape)
|
||||
assert torch.allclose(
|
||||
moe_param.grad, zero_grad, atol=1e-5
|
||||
), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}"
|
||||
else:
|
||||
assert len(moe_grad_list) > 0
|
||||
assert len(moe_grad_list) == len(zero_grad_list)
|
||||
for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list):
|
||||
assert torch.allclose(moe_grad, zero_grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, stage):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
seed_all(42 + rank)
|
||||
run_zero_test(rank, stage=stage)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.parametrize("stage", [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_model(world_size, stage):
|
||||
spawn(run_dist, world_size, stage=stage)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_zero_model(world_size=2, stage=1)
|
|
@ -0,0 +1,132 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
from tests.test_moe.moe_utils import loose_close
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
top_k = 2
|
||||
|
||||
|
||||
def split_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
grad = grad.clone().detach().flatten()
|
||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
splited_grad = grad.split(grad.numel() // world_size)
|
||||
return splited_grad
|
||||
|
||||
|
||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||
@parameterize("master_weights", [True, False])
|
||||
@parameterize("stage", [1, 2])
|
||||
def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.dtype, stage: int):
|
||||
rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=dist.get_world_size() // 2,
|
||||
)
|
||||
|
||||
seed_all(10086)
|
||||
config = MixtralConfig(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
)
|
||||
|
||||
orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda()
|
||||
|
||||
ori_model = DDP(orig_model.cuda(), static_graph=True).cuda()
|
||||
|
||||
zero_model = deepcopy(orig_model).to(dtype)
|
||||
zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group)
|
||||
|
||||
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
|
||||
pg_param_list = {plugin.global_dp_group: [], plugin.moe_dp_group: []}
|
||||
for p in zero_model.parameters():
|
||||
if is_moe_tensor(p):
|
||||
pg_param_list[plugin.moe_dp_group].append(p)
|
||||
else:
|
||||
pg_param_list[plugin.global_dp_group].append(p)
|
||||
|
||||
zero_optimizer = LowLevelZeroOptimizer(
|
||||
zero_optimizer,
|
||||
pg_to_param_list=pg_param_list,
|
||||
master_weights=master_weights,
|
||||
initial_scale=1,
|
||||
overlap_communication=False,
|
||||
partition_grad=True,
|
||||
)
|
||||
|
||||
ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1)
|
||||
|
||||
# create
|
||||
seed_all(1453 + rank)
|
||||
|
||||
for _ in range(2):
|
||||
# zero-dp forward
|
||||
input_data = torch.rand(1, tokens, hidden_size).cuda()
|
||||
zero_output, zero_logits = zero_model(input_data.to(dtype))
|
||||
|
||||
# torch-ddp forward
|
||||
ori_output, ori_logits = ori_model(input_data.to(dtype))
|
||||
loose_close(zero_output, ori_output, dtype=dtype)
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.mean().float())
|
||||
|
||||
# torch-ddp backward
|
||||
ori_output.mean().backward()
|
||||
|
||||
# check grad
|
||||
name_to_p = {n: p for n, p in ori_model.module.named_parameters()}
|
||||
for n, p in zero_model.named_parameters():
|
||||
zero_grad = zero_optimizer.get_param_grad(p)
|
||||
if name_to_p[n].grad is None:
|
||||
assert zero_grad is None
|
||||
continue
|
||||
|
||||
loose_close(zero_grad, name_to_p[n].grad, dtype=dtype)
|
||||
|
||||
# zero-dp step
|
||||
zero_optimizer.step()
|
||||
|
||||
# original model step
|
||||
ori_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
for n, p in zero_model.named_parameters():
|
||||
loose_close(p.data, name_to_p[n].data, dtype=dtype)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_zero_with_original_model(world_size=world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_model(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_zero_model(world_size=4)
|
|
@ -1,83 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep
|
||||
|
||||
|
||||
def run_zero_test(local_rank, stage=1):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(parallel="EP")
|
||||
moe_model = MoeModel().bfloat16()
|
||||
moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0)
|
||||
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
|
||||
moe_booster = Booster(plugin=moe_plugin)
|
||||
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
|
||||
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(parallel=None)
|
||||
zero_model = MoeModel().bfloat16()
|
||||
delete_moe_info(zero_model)
|
||||
sync_local_from_ep(zero_model, moe_model)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0)
|
||||
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
|
||||
zero_booster = Booster(plugin=zero_plugin)
|
||||
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
|
||||
|
||||
for (moe_name, moe_param), (zero_name, zero_param) in zip(
|
||||
moe_model.named_parameters(), zero_model.named_parameters()
|
||||
):
|
||||
if ".experts." in moe_name:
|
||||
continue
|
||||
assert moe_name == zero_name
|
||||
assert torch.allclose(
|
||||
moe_param.data, zero_param.data
|
||||
), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}"
|
||||
|
||||
for _ in range(1):
|
||||
data = torch.randn(2, 4).bfloat16().cuda()
|
||||
label = torch.randint(0, 4, (2,)).cuda()
|
||||
|
||||
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
|
||||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
assert torch.allclose(zero_out, moe_out)
|
||||
moe_optimizer.step()
|
||||
zero_optimizer.step()
|
||||
|
||||
for (moe_name, moe_param), (zero_name, zero_param) in zip(
|
||||
moe_model.named_parameters(), zero_model.named_parameters()
|
||||
):
|
||||
assert moe_name == zero_name
|
||||
if is_moe_tensor(moe_param):
|
||||
param_size = moe_param.shape[0]
|
||||
zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size]
|
||||
loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype)
|
||||
|
||||
moe_optimizer.zero_grad()
|
||||
zero_optimizer.zero_grad()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, stage):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
seed_all(42 + rank)
|
||||
run_zero_test(rank, stage=stage)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.parametrize("stage", [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_optim(world_size, stage):
|
||||
spawn(run_dist, world_size, stage=stage)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_zero_optim(world_size=2, stage=1)
|
|
@ -234,7 +234,7 @@ def check_dist_grad(sharded_optimizer, org_model, sharded_model, weight_layer_fo
|
|||
if org_name in weight_layer_for_check:
|
||||
org_grad = org_param.grad
|
||||
group_id = dist.get_rank(sharded_optimizer.optim.dp_group)
|
||||
dist_grad = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(group_id, id(sharded_param))
|
||||
dist_grad = sharded_optimizer.get_partitioned_gradients_by_param_id(group_id, id(sharded_param))
|
||||
|
||||
# dist_grad concat then reshape to org_grad shape
|
||||
if dist_grad:
|
||||
|
|
|
@ -316,7 +316,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||
dp_process_group=dp_group,
|
||||
verbose=True,
|
||||
)
|
||||
shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened
|
||||
shard_to_param = dist_optim.master_to_working_param # {id(): param tensor} but flattened
|
||||
dist_optim.optim.setup_distributed(
|
||||
tp_group=tp_group,
|
||||
dp_group=dp_group,
|
||||
|
|
|
@ -200,7 +200,7 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||
dp_process_group=dp_group,
|
||||
verbose=True,
|
||||
)
|
||||
shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened
|
||||
shard_to_param = dist_optim.master_to_working_param # {id(): param tensor} but flattened
|
||||
dist_optim.optim.setup_distributed(
|
||||
tp_group=tp_group,
|
||||
dp_group=dp_group,
|
||||
|
|
|
@ -229,7 +229,7 @@ def run_dist_lamb_fwd_bwd(
|
|||
dp_process_group=dp_group,
|
||||
verbose=True,
|
||||
)
|
||||
shard_to_param = optim._param_store.master_to_working_param
|
||||
shard_to_param = optim.master_to_working_param
|
||||
optim.optim.setup_distributed(tp_group, dp_group, shard_to_param, is_zero=True)
|
||||
else:
|
||||
optim.setup_distributed(tp_group)
|
||||
|
|
|
@ -32,6 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
dp_group = booster.plugin.dp_group
|
||||
|
||||
bert = unwrap_model(org_model, "BertModel", "bert")
|
||||
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
|
||||
|
@ -53,8 +54,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
device = origin_norm.device
|
||||
norm_groups = []
|
||||
for group_id in range(sharded_optimizer.num_param_groups):
|
||||
working_grads = sharded_optimizer._grad_store.get_working_grads_by_group_id(group_id)
|
||||
norm_group = sharded_optimizer._compute_grad_norm(gradients=working_grads)
|
||||
working_grads = sharded_optimizer.get_working_grads_by_group_id(group_id)
|
||||
norm_group = sharded_optimizer._compute_grad_norm(dp_group, gradients=working_grads)
|
||||
norm_groups.append(norm_group)
|
||||
total_norm = 0.0
|
||||
for norm in norm_groups:
|
||||
|
|
|
@ -62,10 +62,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||
):
|
||||
for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
|
||||
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||
working_p = sharded_optimizer.master_to_working_param[id(p2)]
|
||||
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||
grad_index = (
|
||||
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
|
||||
0 if sharded_optimizer._partition_grads else sharded_optimizer.pid_to_bucket_store[id(p2)].local_rank
|
||||
)
|
||||
grad = grads[grad_index]
|
||||
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
||||
|
|
|
@ -62,10 +62,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||
):
|
||||
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
|
||||
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||
working_p = sharded_optimizer.master_to_working_param[id(p2)]
|
||||
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||
grad_index = (
|
||||
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
|
||||
0
|
||||
if sharded_optimizer._partition_grads
|
||||
else sharded_optimizer.pid_to_bucket_store[id(working_p)].local_rank
|
||||
)
|
||||
grad = grads[grad_index]
|
||||
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MlpModel, self).__init__()
|
||||
self.linear1 = nn.Linear(123, 253)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
return x
|
||||
|
||||
|
||||
DEL_CALLED = False
|
||||
|
||||
|
||||
class TestLowLevelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
def __del__(self):
|
||||
super().__del__()
|
||||
global DEL_CALLED
|
||||
DEL_CALLED = True
|
||||
|
||||
|
||||
def exam_mem_leak(world_size):
|
||||
"""
|
||||
In this test, we test whether del will be called after the optimizer
|
||||
is out of scope.
|
||||
"""
|
||||
# create models
|
||||
zero_model = MlpModel().cuda()
|
||||
|
||||
# we only test stage 1 here
|
||||
# in `check_sharded_param_consistency.py`, we will test whether
|
||||
# level 1 and 2 will produce exactly the same results
|
||||
zero_optimizer = TestLowLevelZeroOptimizer(torch.optim.SGD(zero_model.parameters(), lr=1))
|
||||
|
||||
del zero_optimizer
|
||||
|
||||
assert DEL_CALLED
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
exam_mem_leak(world_size=world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_1_2():
|
||||
spawn(run_dist, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_zero_1_2()
|
|
@ -91,10 +91,13 @@ def exam_zero_1_2():
|
|||
zero2_optimizer.backward(zero2_output.mean().float())
|
||||
|
||||
# check grad
|
||||
z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0)
|
||||
z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0)
|
||||
for z1g, z2g in zip(z1g_list, z2g_list):
|
||||
assert torch.equal(z1g, z2g)
|
||||
for p1, p2 in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||
g1 = zero1_optimizer.get_param_grad(p1)
|
||||
g2 = zero2_optimizer.get_param_grad(p2)
|
||||
if g1 is None or g2 is None:
|
||||
assert g1 is None and g2 is None
|
||||
continue
|
||||
assert torch.allclose(g1, g2)
|
||||
|
||||
# step
|
||||
zero1_optimizer.step()
|
||||
|
@ -102,7 +105,7 @@ def exam_zero_1_2():
|
|||
|
||||
# check updated param
|
||||
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||
assert torch.equal(z1p.data, z2p.data)
|
||||
assert torch.allclose(z1p, z2p)
|
||||
|
||||
|
||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||
|
@ -120,7 +123,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
|||
seed_all(1453)
|
||||
|
||||
# create models
|
||||
torch_model = MlpModel().cuda()
|
||||
torch_model = MlpModel().cuda().to(dtype)
|
||||
zero_model = copy.deepcopy(torch_model).to(dtype)
|
||||
|
||||
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
|
||||
|
@ -142,39 +145,41 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
|||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
|
||||
seed_all(1453 + local_rank)
|
||||
# create
|
||||
input_data = torch.rand(32, 123).cuda()
|
||||
|
||||
# zero-dp forward
|
||||
zero_output = zero_model(input_data.to(dtype))
|
||||
for _ in range(2):
|
||||
# create
|
||||
input_data = torch.rand(32, 123).cuda().to(dtype)
|
||||
|
||||
# torch-ddp forward
|
||||
torch_output = torch_model(input_data)
|
||||
loose_close(zero_output, torch_output, dtype=dtype)
|
||||
# zero-dp forward
|
||||
zero_output = zero_model(input_data)
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.mean().float())
|
||||
# torch-ddp forward
|
||||
torch_output = torch_model(input_data)
|
||||
loose_close(zero_output, torch_output, dtype=dtype)
|
||||
|
||||
# torch-ddp backward
|
||||
torch_output.mean().backward()
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.mean())
|
||||
|
||||
# check grad
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
if p.grad is not None:
|
||||
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p))
|
||||
torch_grad_list = split_ddp_grad(p.grad, world_size)
|
||||
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
|
||||
loose_close(zero_grad, torch_grad, dtype=dtype)
|
||||
# torch-ddp backward
|
||||
torch_output.mean().backward()
|
||||
|
||||
# zero-dp step
|
||||
zero_optimizer.step()
|
||||
# check grad
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_optimizer.get_param_grad(z1p)
|
||||
if p.grad is None:
|
||||
assert zero_grad is None
|
||||
continue
|
||||
loose_close(p.grad, zero_grad, dtype=dtype)
|
||||
|
||||
# torch ddp step
|
||||
torch_optimizer.step()
|
||||
# zero-dp step
|
||||
zero_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
loose_close(p.data, z1p.data, dtype=dtype)
|
||||
# torch ddp step
|
||||
torch_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
loose_close(p, z1p, dtype=dtype)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
Loading…
Reference in New Issue