[MoE/ZeRO] Moe refactor with zero refactor (#5821)

* [moe] removed openmoe-coupled code and rectify mixstral code (#5471)

* [Feauture] MoE refractor; Intergration with Mixtral  (#5682)

* cherry pick from refractor-moe branch

* tests passed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support ep + zero

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* add mixtral auto policy & move pipeline forward code to modeling folder

* [moe refactor] modify kernel test without Route Class

* [moe refactor] add moe tensor test path environment variable to github workflow

* fix typos

* fix moe test bug due to the code rebase

* [moe refactor] fix moe zero test, and little bug in low level zero

* fix typo

* add moe tensor path to github workflow

* remove some useless code

* fix typo & unify global variable XX_AXIS logic without using -1

* fix typo & prettifier the code

* remove print code & support zero 2 test

* remove useless code

* reanme function

* fix typo

* fix typo

* Further improve the test code

* remove print code

* [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test

* [moe refactor] skip some unit test which will be refactored later

* [moe refactor] fix unit import error

* [moe refactor] fix circular import issues

* [moe refactor] remove debug code

* [moe refactor] update github workflow

* [moe/zero] refactor low level optimizer (#5767)

* [zero] refactor low level optimizer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] MoE refactor with newest version of ZeRO (#5801)

* [zero] remove redundant members in BucketStore (#5802)

* [zero] align api with previous version

* [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819)

* [moe refactor] update unit test with the refactored ZeRO and remove useless test

* move moe checkpoint to checkpoint folder and exchange global axis to class member

* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug

* fix zero unit test

* Add an assertion to prevent users from using it incorrectly

* [hotfix]Solve the compatibility issue of zero refactor (#5823)

* [moe refactor] update unit test with the refactored ZeRO and remove useless test

* move moe checkpoint to checkpoint folder and exchange global axis to class member

* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug

* fix zero unit test

* Add an assertion to prevent users from using it incorrectly

* Modify function parameter names to resolve compatibility issues

* [zero] fix missing hook removal (#5824)

* [MoE] Resolve .github conflict (#5829)

* [Fix/Example] Fix Llama Inference Loading Data Type (#5763)

* [fix/example] fix llama inference loading dtype

* revise loading dtype of benchmark llama3

* [release] update version (#5752)

* [release] update version

* [devops] update compatibility test

* [devops] update compatibility test

* [devops] update compatibility test

* [devops] update compatibility test

* [test] fix ddp plugin test

* [test] fix gptj and rpc test

* [devops] fix cuda ext compatibility

* [inference] fix flash decoding test

* [inference] fix flash decoding test

* fix (#5765)

* [test] Fix/fix testcase (#5770)

* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;

* [Hotfix] Add missing init file in inference.executor (#5774)

* [CI/tests] simplify some test case to reduce testing time (#5755)

* [ci/tests] simplify some test case to reduce testing time

* [ci/tests] continue to remove test case to reduce ci time cost

* restore some test config

* [ci/tests] continue to reduce ci time cost

* [misc] update dockerfile (#5776)

* [misc] update dockerfile

* [misc] update dockerfile

* [devops] fix docker ci (#5780)

* [Inference]Add Streaming LLM (#5745)

* Add Streaming LLM

* add some parameters to llama_generation.py

* verify streamingllm config

* add test_streamingllm.py

* modified according to the opinions of review

* add Citation

* change _block_tables tolist

* [hotfix] fix llama flash attention forward (#5777)

* [misc] Accelerate CI for zero and dist optim (#5758)

* remove fp16 from lamb

* remove d2h copy in checking states

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Test/CI] remove test cases to reduce CI duration (#5753)

* [test] smaller gpt2 test case

* [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py

* [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py

* [test] reduce test cases tests/test_zero/test_gemini/test_optim.py

* Revert "[test] smaller gpt2 test case"

Some tests might depend on the size of model (num of chunks)

This reverts commit df705a5210.

* [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py

* [CI] smaller test model for two mwo the two modifid cases

* [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there

* [hotfix] fix testcase in test_fx/test_tracer (#5779)

* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;

* [fix] fix test_deepfm_model & test_dlrf_model;

* [fix] fix test_hf_albert & test_hf_gpt;

* [gemini] optimize reduce scatter d2h copy (#5760)

* [gemini] optimize reduce scatter d2h copy

* [fix] fix missing reduce variable

* [refactor] remove legacy async reduce scatter code

* [gemini] missing sync

* Revert "[refactor] remove legacy async reduce scatter code"

This reverts commit 58ad76d466.

* [gemini] further optimize with async all reduce

* [fix] pass flag from manager to chunk

* Allow building cuda extension without a device. (#5535)

Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are.

* [misc] fix dist logger (#5782)

* [install]fix setup (#5786)

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update requirements (#5787)

* [shardformer] fix import (#5788)

* upgrade colossal-chat support tp_group>1, add sp for sft

* upgrade ppo dpo rm script

* run pre-commit

* moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy

* fix training script

* fix ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix transformers version

* remove duplicated test

* fix datasets version

* remove models that require huggingface auth from ci

* remove local data path

* update ci

* remove baichuan from template test due to transformer version conflict

* merge

* Refactor modeling by adding attention backend

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix tests and naming

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Pass inference model shard configs for module init

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Clean up

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* replace the customized dataloader setup with the build-in one

* replace the customized dataloader setup with the build-in one

* Remove flash attention backend

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* fix readme

* Fix test import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* update sft trainning script

* [Inference]refactor baichuan (#5791)

* refactor baichuan

* remove unused code and add TODO for lazyinit

* [test] fix chatglm test kit (#5793)

* [shardformer] fix modeling of bloom and falcon (#5796)

* [test] fix qwen2 pytest distLarge (#5797)

* [Inference] Fix flash-attn import and add model test (#5794)

* Fix torch int32 dtype

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix flash-attn import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add generalized model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Remove exposed path to model

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add default value for use_flash_attn

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Rename model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* [Gemini] Use async stream to prefetch and h2d data moving (#5781)

* use async stream to prefetch and h2d data moving

* Remove redundant code

* [gemini] quick fix on possible async operation (#5803)

* [gemini] quick fix on possible async operation

* [gemini] quick fix on possible async operation

* [shardformer] upgrade transformers to 4.39.3 (#5815)

* [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807)

* [shardformer] fix modeling of gpt2 and gptj

* [shardformer] fix whisper modeling

* [misc] update requirements

---------

Co-authored-by: ver217 <lhx0217@gmail.com>

* [shardformer]upgrade transformers for mistral (#5808)

* upgrade transformers for mistral

* fix

* fix

* [shardformer]upgrade transformers for llama (#5809)

* update transformers

fix

* fix

* fix

* [inference] upgrade transformers (#5810)

* update transformers

fix

* fix

* fix

* fix

* fix

* [gemini] update transformers for gemini (#5814)

---------

Co-authored-by: ver217 <lhx0217@gmail.com>

* Support 4d parallel + flash attention (#5789)

* support tp + sp + pp

* remove comments

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>

* [zero] fix hook bug

* [zero] add low level optimizer back (#5839)

* [zero] fix param & refactor

* [zero] add back original low level opt

* [zero] remove moe related

* [zero] pass zero tests

* [zero] refactor

* [chore] add del func back

* [zero] comments and naming (#5840)

* [zero] modify api (#5843)

* [zero] modify api

* [test] remove _grad_store access in tests

* [test] fix (#5857)

* [CI] skip openmoe CI check

* [CI] fox pre-commit

* [zero] remove redundant memebr init (#5862)

* [misc] remove useless code, modify the pg mesh implementation

* [misc] remove useless code, modify the pg mesh implementation

* [misc] use tempfile

* resolve conflict with main branch

* [misc] use tempfile in test_moe_checkpoint.py

* [misc] remove useless code, add assertion about sequence parallel, move logger into function

* [misc] remove useless code

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
pull/5874/head
Haze188 2024-06-28 14:00:08 +08:00 committed by GitHub
parent 773d9f964a
commit 416580b314
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
69 changed files with 1780 additions and 3076 deletions

View File

@ -90,7 +90,7 @@ jobs:
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch:/data/scratch
timeout-minutes: 90
defaults:
run:
@ -165,6 +165,7 @@ jobs:
env:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
- name: Collate artifact
env:

View File

@ -13,7 +13,7 @@ jobs:
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 90
steps:
- name: Check GPU Availability # ensure all GPUs have enough memory
@ -69,6 +69,7 @@ jobs:
env:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
- name: Notify Lark
id: message-preparation

View File

@ -50,7 +50,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 200
steps:
- name: Install dependencies
@ -92,3 +92,4 @@ jobs:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors

View File

@ -41,7 +41,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 200
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
@ -87,3 +87,4 @@ jobs:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors

View File

@ -38,7 +38,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 200
steps:
- name: Install dependencies
@ -85,6 +85,7 @@ jobs:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
- name: Notify Lark
id: message-preparation

View File

@ -1,92 +0,0 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from colossalai.lazy import LazyInitContext
from colossalai.moe import MOE_MANAGER
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config):
super().__init__(config)
self.setup_ep()
def setup_ep(self):
_, moe_info = MOE_MANAGER.get_info(self.num_experts)
ep_group = moe_info.ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
set_moe_tensor_info(p, moe_info)
@staticmethod
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
module.setup_ep()
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
selected_experts = selected_experts.t().reshape(-1)
selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
# no need to split
expert = self.experts[self.expert_start_idx]
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states)
else:
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device
)
dispatch_states = dispatch_states[recover_experts_idx]
k_hidden_states = dispatch_states.chunk(self.top_k)
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
for i in range(1, self.top_k):
output_states += k_hidden_states[i] * routing_weights[:, i, None]
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
return output_states, router_logits

View File

@ -2,8 +2,6 @@ import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
@ -70,8 +68,6 @@ def main():
ep_size=ep_size,
zero_stage=1,
precision=args.precision,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
)

View File

@ -1,5 +1,6 @@
NUM_GPU=2
MODEL="mistralai/Mixtral-8x7B-v0.1"
# MODEL="mistralai/Mixtral-8x7B-v0.1"
MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
# ep
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \

View File

@ -1,146 +0,0 @@
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from torch.optim import Adam
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing.utils import spawn
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert torch.equal(p1.half(), p2.half())
def get_optimizer_snapshot(optim):
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
param_groups = []
for group in optim.param_groups:
params = [id(p) for p in group["params"]]
new_group = {"params": params}
for k, v in group.items():
if k != "params":
new_group[k] = v
param_groups.append(new_group)
return {
"state": state,
"param_groups": param_groups,
}
def check_optimizer_snapshot_equal(snapshot1, snapshot2):
# check param_groups
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
assert set(group1.keys()) == set(group2.keys())
for k in group1.keys():
assert group1[k] == group2[k]
# check state
assert set(snapshot1["state"].keys()) == set(
snapshot2["state"].keys()
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
for pid in snapshot1["state"].keys():
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
assert set(state1.keys()) == set(state2.keys())
for k in state1.keys():
if isinstance(state1[k], torch.Tensor):
assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
else:
assert state1[k] == state2[k]
def check_mixtral_moe_layer():
torch.cuda.set_device(dist.get_rank())
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
num_attention_heads=2,
num_key_value_heads=2,
)
torch.manual_seed(0)
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
orig_model = MixtralForCausalLM(config).cuda()
model = deepcopy(orig_model)
optimizer = Adam(model.parameters(), lr=1e-3)
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=2,
ep_size=2,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
microbatch_size=1,
zero_stage=1,
)
booster = Booster(plugin=plugin)
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
# initialize grads
data_iter = iter(
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
)
booster.execute_pipeline(
data_iter,
model,
lambda outputs, inputs: outputs.loss,
optimizer,
)
# check save model
booster.save_model(model, "mixtral_model", shard=True)
dist.barrier()
if dist.get_rank() == 0:
saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
check_model_equal(orig_model, saved_model)
saved_model.save_pretrained("mixtral_hf_model")
dist.barrier()
# check load model
new_model = MixtralForCausalLM(config).cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
booster.load_model(new_model, "mixtral_hf_model")
check_model_equal(model, new_model)
# check save optimizer
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
snapshot = get_optimizer_snapshot(optimizer.unwrap())
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
dist.barrier()
# reset optimizer state
for state in optimizer.unwrap().state.values():
for v in state.values():
if isinstance(v, torch.Tensor):
v.zero_()
booster.load_optimizer(optimizer, "mixtral_optim")
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch(rank, world_size, "localhost", port)
check_mixtral_moe_layer()
@pytest.mark.parametrize("world_size", [4])
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mixtral_moe_layer(4)

View File

@ -2,13 +2,11 @@ import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralForCausalLM
from utils import load_checkpoint, move_to_cuda, save_checkpoint
import colossalai
from colossalai.booster import Booster
@ -155,12 +153,10 @@ def main():
pp_size=args.pp_size,
ep_size=args.ep_size,
microbatch_size=args.microbatch_size,
custom_policy=MixtralForCausalLMPolicy(),
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
precision=args.precision,
zero_stage=args.zero_stage,
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
)
else:

View File

@ -20,6 +20,7 @@ resp = llm(prompt='What do you call a three-ton kangaroo?', **gen_config)
print(resp) # super-heavyweight awesome-natured yawning Australian creature!
"""
import json
from typing import Any, Mapping

View File

@ -655,7 +655,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
@ -718,7 +717,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
"""Retrieve all working gradients from different parameter groups."""
all_working_grads = []
for group_id in range(self.num_param_groups):
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
working_grads = self.get_working_grads_by_group_id(group_id)
all_working_grads.extend(working_grads)
return all_working_grads
@ -726,7 +725,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
"""Identify gradients to be synchronized in the sequence parallelism."""
grads_to_sync = []
for grad in all_working_grads:
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
param_id_for_grad = self.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
grads_to_sync.append(grad)
@ -739,7 +738,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads)
if self._grad_store.require_grad_sync and grads_to_sync is not None:
if self.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required.
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
else:
@ -763,7 +762,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph)
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
@ -788,14 +787,14 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad)
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
@ -811,7 +810,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
if len(gradients) == 0:
return 0.0
dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1
dp_size = get_world_size(dp_pg) if dp_pg is not None else 1
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
@ -842,7 +841,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
param_id_for_grad = self.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if not is_distributed_tensor(param_for_grad):
@ -856,7 +855,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage]
working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param))
working_grad = self.get_working_grad_by_param_id(id(stage_shared_param))
if grad is working_grad:
grad_norm_exponentiated /= len(shared_param)
@ -867,7 +866,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
)
if dp_size > 1:
# compute norm in dp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg)
if tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
@ -1309,7 +1308,7 @@ class HybridParallelPlugin(PipelinePluginBase):
# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
):
return outputs

View File

@ -1,4 +1,5 @@
import random
import warnings
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
@ -20,19 +21,19 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
get_param_info,
init_pipeline_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MOE_MANAGER, MoECheckpointIO
from colossalai.logging import get_dist_logger
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
@ -67,8 +68,20 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
self.pp_pg = pp_process_group
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
pg_param_list = {
dp_process_group: [],
moe_extra_dp_process_group: [],
}
for param in model.parameters():
if is_moe_tensor(param):
pg_param_list[moe_extra_dp_process_group].append(param)
else:
pg_param_list[dp_process_group].append(param)
super().__init__(
optimizer=optimizer,
pg_to_param_list=pg_param_list,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
@ -83,9 +96,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
moe_extra_dp_process_group=moe_extra_dp_process_group,
)
@ -107,8 +118,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
@ -144,14 +155,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
use_ep_inside (bool, Optional): Whether to use ep inside dp (intra-node) for moe params.
"""
def __init__(
self,
tp_size: int,
pp_size: int,
ep_size: int,
extra_dp_size: int = 1,
tp_size: int = 1,
sp_size: int = 1,
precision: str = "fp16",
zero_stage: int = 0,
enable_all_optimization: bool = False,
@ -184,32 +196,22 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
custom_policy: Policy = None,
checkpoint_io: Optional[MoECheckpointIO] = None,
) -> None:
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
world_size = dist.get_world_size()
assert tp_size == 1, "Tensor parallel is not supported in MoE yet"
assert sp_size == 1 and enable_sequence_parallelism is False, "Sequence parallelism it not supported in MoE yet"
if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
world_size % (tp_size * pp_size) == 0
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=self.real_dp_size,
fixed_ep_size=ep_size,
fixed_pp_size=pp_size,
use_ep_inside=use_ep_inside,
)
world_size % (tp_size * pp_size * ep_size) == 0
), f"world size {world_size} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
self.dp_size = world_size // (tp_size * pp_size)
self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.ep_size = ep_size
self.moe_info = MOE_MANAGER.get_info(0)[1]
self.sp_size = sp_size
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
@ -219,43 +221,57 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.checkpoint_io = checkpoint_io
logger = get_dist_logger()
# NOTE: Two process meshes: global dp for non-moe param; dp + ep for moe param
# See https://hpc-ai.com/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient
# we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
assert (
self.ep_size <= self.dp_size
), f"Not enough devices({self.dp_size}) for expert parallelism size({self.ep_size})."
# sync moe in outer dp group, and sync other param in global dp group
if extra_dp_size > 1:
ep_size = self.dp_size // extra_dp_size
if use_ep_inside:
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size)
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1)
if dist.get_rank() == 0:
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}")
self.moe_dp_size = self.dp_size // self.ep_size
self.use_ep_inside = use_ep_inside
if self.use_ep_inside:
logger.info(f"MoE Parallel use ep inside dp.", ranks=[0])
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size)
else:
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size)
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2)
if dist.get_rank() == 0:
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}")
else:
self.moe_extra_dp_group = None
logger.info(f"MoE Parallel use ep outside dp.", ranks=[0])
warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.")
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size)
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
logger.info(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}", ranks=[0])
logger.info(
f"MoE Parallel: pp {self.pp_size}, ep {ep_size}, moe dp {self.moe_dp_size}, tp {tp_size}", ranks=[0]
)
self.tp_group = self.pg_mesh.get_group_along_axis(
self.tp_axis
) # TODO: support custom tp size for mixtral lm head
self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis))
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
# TODO: Currently moe only support partially sequence parallel
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.custom_policy = custom_policy
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
# TODO: Currently moe only support partially sequence parallel
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
@ -267,6 +283,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
ep_group=self.ep_group,
)
self.amp_config = dict(
initial_scale=initial_scale,
@ -323,7 +340,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
dataset,
num_replicas=self.dp_size,
rank=dist.get_rank(self.global_dp_group),
shuffle=shuffle,
)
# Deterministic dataloader
@ -346,9 +366,20 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
def get_checkpoint_io(self) -> MoECheckpointIO:
if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
self.checkpoint_io = MoECheckpointIO(
self.global_dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
)
else:
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
self.checkpoint_io = self.checkpoint_io(
self.global_dp_group,
self.pp_group,
self.tp_group,
ep_group=self.ep_group,
moe_dp_group=self.moe_dp_group,
zero_stage=self.zero_stage,
)
if hasattr(self.checkpoint_io, "moe_info"):
self.checkpoint_io.moe_info = self.moe_info
return self.checkpoint_io
def configure(
@ -366,7 +397,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
dp_group=self.global_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
@ -392,15 +423,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.dp_group,
dp_process_group=self.global_dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_extra_dp_process_group=self.moe_extra_dp_group,
moe_extra_dp_process_group=self.moe_dp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,

View File

@ -2,5 +2,12 @@ from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile
from .moe_checkpoint import MoECheckpointIO
__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
__all__ = [
"CheckpointIO",
"CheckpointIndexFile",
"GeneralCheckpointIO",
"HybridParallelCheckpointIO",
"MoECheckpointIO",
]

View File

@ -70,13 +70,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
verbose: bool = True,
) -> None:
super().__init__()
self.dp_group = dp_group
self.global_dp_group = dp_group
self.pp_group = pp_group
self.tp_group = tp_group
self.dp_rank = dist.get_rank(self.dp_group)
self.dp_rank = dist.get_rank(self.global_dp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.pp_rank = dist.get_rank(self.pp_group)
self.dp_size = dist.get_world_size(dp_group)
self.global_dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
self.use_zero = zero_stage > 0
@ -433,7 +433,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.dp_group,
dp_group=self.global_dp_group,
tp_group=self.tp_group,
size_per_shard=size_per_shard,
)
@ -727,7 +727,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state,
working_param,
original_shape=original_shape,
dp_group=self.dp_group,
dp_group=self.global_dp_group,
tp_group=self.tp_group,
use_zero=self.use_zero,
inplace=False,
@ -932,12 +932,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Shard state along data parallel group when using Zero.
if self.use_zero:
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.dp_size
slice_size = v.numel() // self.global_dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device)

View File

@ -9,6 +9,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import get_global_rank
from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
@ -19,15 +20,16 @@ from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
load_state_dict,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
try:
@ -36,21 +38,30 @@ except ImportError:
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
class MoECheckpointIO(HybridParallelCheckpointIO):
def __init__(
self,
dp_group: ProcessGroup,
global_dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_dp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
self.ep_group = moe_info.ep_group
self.ep_size = moe_info.ep_size
self.ep_rank = moe_info.ep_rank
self.real_dp_rank = moe_info.dp_rank
super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose)
self.global_dp_group = global_dp_group
self.global_dp_rank = dist.get_rank(global_dp_group)
self.global_dp_size = dist.get_world_size(global_dp_group)
self.pp_group = pp_group
self.tp_group = tp_group
self.moe_dp_group = moe_dp_group
self.moe_dp_size = dist.get_world_size(moe_dp_group)
self.moe_dp_rank = dist.get_rank(moe_dp_group)
self.ep_group = ep_group
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
@staticmethod
def _model_sharder(
@ -134,7 +145,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
Path(checkpoint).mkdir(parents=True, exist_ok=True)
if self.real_dp_rank != 0:
if self.moe_dp_rank != 0:
dist.barrier()
return
@ -144,7 +155,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
state_dict_shard = MoECheckpointIO._model_sharder(
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
@ -234,11 +245,12 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
state: OrderedDict,
param: torch.Tensor,
original_shape: torch.Size,
dp_group: ProcessGroup,
global_dp_group: ProcessGroup,
tp_group: ProcessGroup,
use_zero: bool,
inplace: bool,
is_moe_param: bool,
moe_dp_group: ProcessGroup = None,
device: torch.device = torch.device("cpu"),
) -> OrderedDict:
"""
@ -248,7 +260,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
original_shape (torch.Size): The size of parameter before sharding.
dp_group (ProcessGroup): The process group of data parallel.
global_dp_group (ProcessGroup): The process group of data parallel.
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
@ -257,27 +269,47 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
Returns:
OrderedDict: The complete optimizer state of given parameter.
"""
dp_size = dist.get_world_size(dp_group)
global_dp_size = dist.get_world_size(global_dp_group)
tp_size = dist.get_world_size(tp_group)
moe_dp_size = dist.get_world_size(moe_dp_group) if moe_dp_group is not None else 1
current_shape = param.shape
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero and not is_moe_param:
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
# First gather Zero shards.
if use_zero and is_moe_param and moe_dp_size > 1:
moe_dp_rank = dist.get_rank(moe_dp_group)
dst = get_global_rank(moe_dp_group, 0)
if moe_dp_rank == 0:
gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)]
dist.gather(v, gather_tensor, group=moe_dp_group, dst=dst)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
else:
dist.gather(v, group=moe_dp_group, dst=dst)
elif use_zero and not is_moe_param and global_dp_size > 1:
dp_rank = dist.get_rank(global_dp_group)
dst = get_global_rank(global_dp_group, 0)
if dp_rank == 0:
gather_tensor = [torch.zeros_like(v) for _ in range(global_dp_size)]
dist.gather(v, gather_tensor, group=global_dp_group, dst=dst)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
else:
dist.gather(v, group=global_dp_group, dst=dst)
# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
if partition_dim is not None:
tp_rank = dist.get_rank(tp_group)
dst = get_global_rank(tp_group, 0)
if tp_rank == 0:
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
dist.all_gather(gather_tensor, v, group=tp_group)
dist.gather(v, gather_tensor, group=tp_group, dst=dst)
v = torch.cat(gather_tensor, dim=partition_dim)
else:
dist.gather(v, group=tp_group, dst=dst)
state_[k] = v.detach().clone().to(device)
return state_
@ -286,8 +318,9 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
def _optimizer_sharder(
optimizer: OptimizerWrapper,
use_zero: bool,
dp_group: ProcessGroup,
global_dp_group: ProcessGroup,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
size_per_shard: int = 1024,
only_moe_param: bool = False,
):
@ -296,7 +329,6 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
master_to_working_map = optimizer.get_master_to_working_map()
for param, state in optimizer.optim.state.items():
if param is None:
continue
@ -305,22 +337,23 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
working_param = master_to_working_map[id(param)]
else:
working_param = param
param_id = param_info["param2id"][id(working_param)]
original_shape = param_info["param2shape"][id(working_param)]
state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state_ = MoECheckpointIO.gather_from_sharded_optimizer_state(
state,
working_param,
original_shape=original_shape,
dp_group=dp_group,
global_dp_group=global_dp_group,
moe_dp_group=moe_dp_group,
tp_group=tp_group,
use_zero=use_zero,
inplace=False,
is_moe_param=is_moe_tensor(working_param),
is_moe_param=is_moe_tensor(working_param), # TODO: Check correctness here
)
if only_moe_param and not is_moe_tensor(working_param):
continue
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
yield block, block_size
@ -359,25 +392,28 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
if not self.use_zero and self.real_dp_rank != 0:
# If optim states are not sharded, other ranks don't need to participate in gather.
if not self.use_zero and self.moe_dp_rank != 0:
dist.barrier()
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
state_dict_shard = MoECheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.dp_group,
global_dp_group=self.global_dp_group,
tp_group=self.tp_group,
moe_dp_group=self.moe_dp_group,
size_per_shard=size_per_shard,
only_moe_param=self.ep_rank != 0,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
# e.g. dp_size = 4, moe_dp_size = 2, ep_size = 2 and use gather
# rank 0 saves moe & non-moe params; rank 1 only saves moe params
# rank 3 & 4 save nothing
control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
@ -596,7 +632,6 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
OrderedDict: The sharded optimizer state of the given parameter.
"""
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
@ -606,24 +641,218 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
# Shard state along data parallel group when using Zero.
if self.use_zero and not is_moe_param:
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
if self.use_zero and not is_moe_param and self.global_dp_size > 1:
padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
slice_size = v.numel() // self.global_dp_size
v = v.split(slice_size, dim=0)[self.global_dp_rank]
elif self.use_zero and is_moe_param and self.moe_dp_size > 1:
# LowLevelZeRO pads by global dp size for now.
# TODO: update both to use moe dp size
padding_size = (self.global_dp_size - v.numel() % self.global_dp_size) % self.global_dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.moe_dp_size
v = v.split(slice_size, dim=0)[self.moe_dp_rank]
state_[k] = v.detach().clone().to(device)
return state_
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
raise NotImplementedError
"""Migration from MoEHybridParallelCheckpointIO. These functions mostly deals with unsharded saving,
and can be savely deleted since large MoE models are often saved in shards.
"""
# Copied from colossalai.moe
def pre_save_model(self, model: nn.Module) -> dict:
state_dict = model.state_dict()
for name, param in model.named_parameters():
if ".experts." in name and is_moe_tensor(param):
ep_group = param.ep_group
ep_rank = dist.get_rank(ep_group)
ep_size = dist.get_world_size(ep_group)
# TODO: check correctness here
# dp_rank = get_dp_rank(param)
dp_rank = dist.get_rank(self.global_dp_group)
if dp_rank == 0:
param = param.data.cuda()
if ep_rank == 0:
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
else:
all_param = None
# gather param from every ep rank
# dist.all_gather(all_param, param, group=ep_group)
dist.gather(param, all_param, group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
state_dict[name] = all_param.cpu()
if self.pp_size > 1:
if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)]
dist.gather_object(state_dict, out, group=self.pp_group)
if self.pp_rank == 0:
new_state_dict = {}
for o in out:
new_state_dict.update(o)
state_dict = new_state_dict
dist.barrier()
return state_dict
def save_unsharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
):
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
torch.save(state_dict, checkpoint)
dist.barrier()
# Copied from colossalai.moe
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
raise NotImplementedError
"""
Save optimizer state dict to a file with given path.
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
checkpoint (str): Path to save optimizer state_dict.
gather_dtensor (bool): Whether to gather_dtensor, not used.
"""
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
# optimizer states of parameters kept by local device('s pipeline stage)
local_states = dict()
for param, state in optimizer.optim.state.items():
if param is None:
continue
# working param is needed for obtaining correct param_id
master_to_working_map = optimizer.get_master_to_working_map()
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
# gather complete state from tp shards & dp shards
param_id = optimizer.param_info["param2id"][id(working_param)]
local_states[param_id] = self.pre_save_optim(
state,
working_param,
inplace=False,
device=torch.device("cuda"),
)
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
states_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group)
# dist.all_gather_object(states_list, local_states, self.pp_group)
dist.gather_object(local_states, states_list, self.pp_group)
# Only the master rank do the saving.
if self.coordinator.is_master():
state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
dist.barrier()
# Copied from colossalai.moe
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
raise NotImplementedError
"""
Load optimizer from a file with given path.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the checkpoint file.
"""
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
if id(working_param) in optimizer.param_info["param2id"]:
return optimizer.param_info["param2id"][id(working_param)]
else:
None
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
state_dict = load_state_dict(checkpoint)
# Load param_groups.
updated_groups = []
saved_groups = state_dict["param_groups"]
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
updated_groups.append(new_pg)
# ep extra group
# if MOE_MANAGER.parallel == "EP":
if self.ep_size > 1:
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1][
"params"
] # Only keep the parameters kept by current pipeline stage.
for param in new_pg["params"]:
param.data = param.data.to(torch.float32)
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
master_to_working_map = optimizer.get_master_to_working_map()
id_map = {}
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id is not None:
id_map[param_id] = param
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
if param is None:
continue
device = param.device
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
param,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True,
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
dist.barrier()

View File

@ -242,6 +242,7 @@ def save_state_dict_shards(
shard_filenames = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
# Just loop over the sharder and gather to other ranks if not master
if not is_master:
del shard
continue

View File

@ -244,19 +244,25 @@ class ProcessGroupMesh:
return target_group
def get_group_along_axis(
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
axis (int): Axis along which the process groups are created.
axis (int or list of int): Axes along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
indices_at_axis = indices_at_axis
if indices_at_axis is None:
if isinstance(axis, (list, tuple)):
indices_at_axis = list(list(range(self._shape[ax])) for ax in axis)
else:
indices_at_axis = list(range(self._shape[axis]))
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
if ranks_in_group not in self._ranks_to_group:

View File

@ -1,20 +1,5 @@
from .checkpoint import MoECheckpointIO
from .experts import MLPExperts
from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
__all__ = [
"MLPExperts",
"MoeRouter",
"Top1Router",
"Top2Router",
"TopKRouter",
"NormalNoiseGenerator",
"UniformNoiseGenerator",
"SparseMLP",
"MoECheckpointIO",
"MOE_MANAGER",
"apply_load_balance",
]

View File

@ -1,792 +0,0 @@
import copy
import logging
import os
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
from colossalai.checkpoint_io.utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import OptimizerWrapper
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import (
get_dp_group,
get_dp_rank,
get_dp_size,
get_ep_group,
get_ep_rank,
get_ep_size,
is_moe_tensor,
)
class MoECheckpointIO(HybridParallelCheckpointIO):
def __init__(
self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
) -> None:
assert zero_stage in [
0,
1,
2,
], f"zero_stage should be 0 or 1 or 2, got {zero_stage}"
super().__init__(dp_group, pp_group, tp_group, zero_stage)
self.parallel = MOE_MANAGER.parallel
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
"""
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
"""
for name, param in state_dict.items():
if ".experts." in name:
if name in dict(model.named_parameters()):
model_param = dict(model.named_parameters())[name]
if is_moe_tensor(model_param):
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = param.shape[0] // ep_size
assert param.shape[0] % ep_size == 0
param = param[ep_rank * expert_num : (ep_rank + 1) * expert_num]
state_dict[name] = param
dist.barrier()
return state_dict
def _model_sharder(
self,
state_dict: nn.Module,
prefix: str = "",
keep_vars: bool = False,
size_per_shard: int = 1024,
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
for name, param in state_dict.items():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None:
state_dict = torch.load(checkpoint)
state_dict = self.pre_load_model(model, state_dict)
model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
state_dict = self.pre_load_model(model, state_dict)
missing_keys = []
load_state_dict_into_model(
model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True,
)
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
_load(name)
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def pre_save_model(self, model: nn.Module) -> dict:
state_dict = model.state_dict()
for name, param in model.named_parameters():
if ".experts." in name and is_moe_tensor(param):
ep_group = get_ep_group(param)
ep_rank = get_ep_rank(param)
ep_size = get_ep_size(param)
dp_rank = get_dp_rank(param)
if dp_rank == 0:
param = param.data.cuda()
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
# gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
state_dict[name] = all_param.cpu()
if self.pp_size > 1:
if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)]
dist.all_gather_object(out, state_dict, group=self.pp_group)
if self.pp_rank == 0:
new_state_dict = {}
for o in out:
new_state_dict.update(o)
state_dict = new_state_dict
dist.barrier()
return state_dict
def save_unsharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
):
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
torch.save(state_dict, checkpoint)
dist.barrier()
def save_sharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
The filenames are in the form of "pytorch_model.<prefix>-000XX.bin"
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
torch.cuda.empty_cache()
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard)
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0:
return
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
dist.barrier()
torch.cuda.empty_cache()
# ========================================================
# Abstract methods for optimizer loading/saving implementation
# ========================================================
def pre_load_optim(
self,
state: OrderedDict,
working_param,
current_shape: torch.Size,
original_shape: torch.Size,
device: torch.device,
inplace: bool,
) -> OrderedDict:
"""
With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device.
Args:
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
current_shape (torch.Size): The size of parameter after sharding.
original_shape (torch.Size): The size of parameter before sharding.
device (torch.device): The destination device of loaded optimizer states.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
Returns:
OrderedDict: The sharded optimizer state of the given parameter.
"""
state_ = state if inplace else copy.deepcopy(state)
is_moe_tensor_flag = is_moe_tensor(working_param)
if is_moe_tensor_flag:
ep_rank = get_ep_rank(working_param)
ep_size = get_ep_size(working_param)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
if is_moe_tensor_flag:
with torch.no_grad():
expert_num = v.shape[0] // ep_size
assert v.shape[0] % ep_size == 0
v = v[ep_rank * expert_num : (ep_rank + 1) * expert_num]
else:
# Shard state along data parallel group when using Zero.
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device)
return state_
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
id_map[param_id] = param
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_groups = torch.load(param_group_path)
updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
updated_groups.append(new_pg)
# ep param group
if len(optimizer.optim.param_groups) > len(saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
# If this param's states has been loaded before, directly return.
if filename in loaded_file:
continue
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
# Then shard the loaded optimizer states if using tp/zero.
for pid, state in list(state_dict.items()):
if pid in id_map:
param = id_map[pid]
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif (
hasattr(optimizer, "moe_master_to_working_map")
and id(param) in optimizer.moe_master_to_working_map
):
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
working_param,
current_shape=working_param.shape,
original_shape=original_shape,
device="cpu",
inplace=True,
)
state_dict[pid] = sharded_state
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
dist.barrier()
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Load optimizer from a file with given path.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the checkpoint file.
"""
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
if id(working_param) in optimizer.param_info["param2id"]:
return optimizer.param_info["param2id"][id(working_param)]
else:
None
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
state_dict = load_state_dict(checkpoint)
# Load param_groups.
updated_groups = []
saved_groups = state_dict["param_groups"]
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
updated_groups.append(new_pg)
# ep extra group
if MOE_MANAGER.parallel == "EP":
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1][
"params"
] # Only keep the parameters kept by current pipeline stage.
for param in new_pg["params"]:
param.data = param.data.to(torch.float32)
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
master_to_working_map = optimizer.get_master_to_working_map()
id_map = {}
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id is not None:
id_map[param_id] = param
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
if param is None:
continue
device = param.device
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
param,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True,
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
dist.barrier()
def pre_save_optim(
self,
state: OrderedDict,
param: torch.Tensor,
inplace: bool,
device: torch.device = torch.device("cpu"),
) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
Args:
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
original_shape (torch.Size): The size of parameter before sharding.
dp_group (ProcessGroup): The process group of data parallel.
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
Returns:
OrderedDict: The complete optimizer state of given parameter.
"""
if is_moe_tensor(param):
moe_dp_group = get_dp_group(param)
moe_dp_size = get_dp_size(param)
moe_ep_group = get_ep_group(param)
moe_ep_size = get_ep_size(param)
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != "step":
# moe param
if is_moe_tensor(param):
# dp gather
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)]
dist.all_gather(gather_tensor, v, group=moe_dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
# ep gather
gather_tensor = [torch.zeros_like(v) for _ in range(moe_ep_size)]
dist.all_gather(gather_tensor, v, group=moe_ep_group)
v = torch.cat(gather_tensor, dim=0)
else:
# global dp
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dist.get_world_size(self.dp_group))]
dist.all_gather(gather_tensor, v, group=self.dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
state_[k] = v.detach().clone().to(device)
return state_
def _optimizer_sharder(
self,
optimizer: OptimizerWrapper,
size_per_shard: int = 1024,
):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
master_to_working_map = optimizer.get_master_to_working_map()
for param, state in optimizer.optim.state.items():
if param is None:
continue
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
param_id = param_info["param2id"][id(working_param)]
state_ = self.pre_save_optim(
state,
working_param,
inplace=False,
device=torch.device("cuda"),
)
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files that store state tensors of optimizers.
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
torch.cuda.empty_cache()
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
if not self.use_zero and self.dp_rank != 0:
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = self._optimizer_sharder(
optimizer,
size_per_shard=size_per_shard,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if control_saving:
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if control_saving:
assert (
self.dp_rank == 0 and self.tp_rank == 0
), "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
return
dist.barrier(self.pp_group)
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for param_id, state_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(param_id, state_filename)
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
torch.cuda.empty_cache()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer state dict to a file with given path.
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
checkpoint (str): Path to save optimizer state_dict.
gather_dtensor (bool): Whether to gather_dtensor, not used.
"""
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
# optimizer states of parameters kept by local device('s pipeline stage)
local_states = dict()
for param, state in optimizer.optim.state.items():
if param is None:
continue
# working param is needed for obtaining correct param_id
master_to_working_map = optimizer.get_master_to_working_map()
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
# gather complete state from tp shards & dp shards
param_id = optimizer.param_info["param2id"][id(working_param)]
local_states[param_id] = self.pre_save_optim(
state,
working_param,
inplace=False,
device=torch.device("cuda"),
)
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
states_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group)
dist.all_gather_object(states_list, local_states, self.pp_group)
# Only the master rank do the saving.
if self.coordinator.is_master():
state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
dist.barrier()

View File

@ -7,8 +7,8 @@ from torch import Tensor, nn
from torch.distributed import ProcessGroup
from colossalai.cluster import ProcessGroupMesh
from colossalai.moe.experts import MLPExperts
from colossalai.moe.manager import MOE_MANAGER
from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.zero.low_level import LowLevelZeroOptimizer
@ -292,7 +292,7 @@ class LoadBalancer:
exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"]
exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"]
else:
master_weight_ptr = optim._param_store.working_to_master_param[id(weight)]
master_weight_ptr = optim.working_to_master_param[id(weight)]
working_weight_ptr = weight
exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"]
exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"]
@ -344,7 +344,7 @@ class LoadBalancer:
# gate optim should be obtained first
gate_shape = self.gate.shape
# get master weight and optim
master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)]
master_gate_weight = optim.working_to_master_param[id(self.gate)]
gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"]
gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"]
# gather

View File

@ -1,78 +0,0 @@
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from colossalai.moe.manager import MOE_MANAGER
class MoeCrossEntropyLoss(_Loss):
r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
Args:
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
The ``args`` and ``kwargs`` should include parameters below:
::
weight (Tensor, optional)
size_average (bool, optional)
ignore_index (int, optional)
reduce (bool, optional)
reduction (str, optional)
label_smoothing (float, optional)
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
"""
def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
super().__init__()
self.loss = nn.CrossEntropyLoss(*args, **kwargs)
self.aux_weight = aux_weight
def forward(self, *args):
"""
The ``args`` should at least include parameters below:
::
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
"""
main_loss = self.loss(*args)
aux_loss = MOE_MANAGER.get_loss()
return main_loss + self.aux_weight * aux_loss
class MoeLoss(_Loss):
"""A wrapper class for any loss module to add with auxiliary loss.
Args:
aux_weight (float): Weight of auxiliary loss in total loss.
loss_fn (``Callable``): Loss function.
args (list): Args in loss function.
kwargs (dict): Kwargs in loss function
"""
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
super().__init__()
self.loss_fn = loss_fn(*args, **kwargs)
self.aux_weight = aux_weight
def forward(self, *args, **kwargs):
"""
The ``args`` and ``kwargs`` should at least include parameters below:
::
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
Note:
The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
"""
main_loss = self.loss_fn(*args, **kwargs)
aux_loss = MOE_MANAGER.get_loss()
return main_loss + self.aux_weight * aux_loss

View File

@ -1,466 +0,0 @@
import math
from abc import ABC
from typing import Callable, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
from colossalai.moe._operation import moe_cumsum
from colossalai.moe.manager import MOE_MANAGER
class MoeRouter(nn.Module, ABC):
"""Base class for all MoE routers.
Args:
k_value (int): The value of top_k.
capacity_factor_train (float): Capacity factor in routing of training.
capacity_factor_eval (float): Capacity factor in routing of evaluation.
min_capacity (int): The minimum number of the capacity of each expert.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(
self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
use_kernel: bool = False,
):
super().__init__()
self.k_value = k_value
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.noisy_func = noisy_func
self.drop_tks = drop_tks
self._aux_loss = None
self._z_loss = None
self.use_kernel = use_kernel
def get_capacity(self, num_tokens, num_experts, ep_group=None):
if ep_group is not None:
num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
dist.all_reduce(num_tokens_tensor, group=ep_group)
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return int(capacity)
def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None:
"""Computes auxiliary load balancing loss as in Switch Transformer.
See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
implements the loss function presented in equations (4) - (6). It aims to
penalize those cases where the routing between experts is unbalanced.
Args:
router_probs: Probability assigned to each expert per token. Shape:
<float32>[num_groups, tokens_per_group, num_experts].
expert_indices: <int>[num_groups, tokens_per_group, num_selected_experts]
indices identifying the top num_selected_experts for a given token.
"""
assert self._aux_loss is None
if router_probs.dim() == expert_indices.dim() == 2:
router_probs = router_probs.unsqueeze(0)
expert_indices = expert_indices.unsqueeze(0)
assert (
router_probs.dim() == expert_indices.dim() == 3
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_indices, num_experts)
# For a given token, determine if it was routed to a given expert.
# Shape: [num_groups, tokens_per_group, num_experts]
expert_mask = expert_mask.max(dim=-2)[0]
tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2)
router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2)
aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
self._aux_loss = aux_loss
def set_z_loss(self, router_logits: torch.Tensor):
"""Compute router z-loss.
The router z-loss was introduced in Designing Effective Sparse Expert Models
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
small in an effort to improve stability.
Args:
router_logits: <float>[num_groups, tokens_per_group, num_experts] router logits.
"""
assert self._z_loss is None
if router_logits.dim() == 2:
router_logits = router_logits.unsqueeze(0)
assert router_logits.dim() == 3, "router_logits must be 3D tensor"
num_groups, tokens_per_group, _ = router_logits.shape
log_z = torch.logsumexp(router_logits, dim=-1)
z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group)
self._z_loss = z_loss
def pop_router_loss(self) -> torch.Tensor:
assert self._aux_loss is not None
MOE_MANAGER.add_loss(self._aux_loss, self._z_loss)
self._aux_loss = None
self._z_loss = None
class Top1Router(MoeRouter):
"""Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
function can be found in the paper about Switch Transformer of Google.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert.
select_policy (str, optional): The policy about tokens selection.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
self.select_policy = select_policy
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample
def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_loss: bool = False,
use_norm: bool = False,
) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
Returns:
1. use_kernel is False:
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
2. use_kernel is True:
...
"""
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
num_tokens = inputs.size(0)
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
# calculate router loss
self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
if self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
elif self.select_policy == "first":
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
mask = mask * torch.lt(ranks, capacity)
else:
raise NotImplementedError("Not support such select policy yet.")
ranks = torch.sum(mask * ranks, dim=-1)
used_capacity = mask.sum(dim=0)
if use_kernel:
mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
return used_capacity, probs, mask, dest_idx, num_experts * capacity
else:
ranks = F.one_hot(ranks, num_classes=capacity)
weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return used_capacity, combine_weights, sec_mask, probs
class Top2Router(MoeRouter):
"""Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
function can be found in the paper about ViT-MoE.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_norm: bool = False,
use_loss: bool = True,
) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
Returns:
1. use_kernel is False:
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
2. use_kernel is True:
...
"""
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
if use_norm:
routing_weights, _ = torch.topk(probs, 2, dim=-1)
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
num_experts = probs.size(-1)
num_tokens = inputs.size(0)
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
top1_idx = torch.argmax(probs, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = mask1 + mask2 # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss
if use_loss:
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
mask1 *= torch.lt(rank1, capacity)
mask2 *= torch.lt(rank2, capacity)
used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0)
rank1 = torch.sum(mask1 * rank1, dim=-1)
rank2 = torch.sum(mask2 * rank2, dim=-1)
if use_kernel:
mask1 = torch.sum(mask1, dim=-1)
mask2 = torch.sum(mask2, dim=-1)
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
return used_capacity, probs, mask, dest_idx, num_experts * capacity
else:
"""
The following code is equivalent to:
```
weight1 = mask1 * probs.type_as(inputs)
weight2 = mask2 * probs.type_as(inputs)
rank1_sc = F.one_hot(rank1, num_classes=capacity)
rank2_sc = F.one_hot(rank2, num_classes=capacity)
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
cb_weight = cb_weight1 + cb_weight2
sec_mask = cb_weight.bool()
```
"""
weight1 = mask1 * probs.type_as(inputs)
weight2 = mask2 * probs.type_as(inputs)
cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device)
sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
indices = torch.arange(0, inputs.shape[0], device=inputs.device)
cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]
return used_capacity, cb_weight, sec_mask
class TopKRouter(MoeRouter):
"""Masked matmul router using tokens choose top-k experts assignment.
NOTE: this is modified from flaxformer.
This router uses the same mechanism as in Switch Transformer
(https://arxiv.org/abs/2101.03961) and V-MoE
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are
sorted by router_probs and then routed to their choice of expert until the
expert's expert_capacity is reached. There is no guarantee that each token is
processed by an expert, or that each expert receives at least one token.
Attributes:
num_selected_experts: Maximum number of experts to which each token is
routed. Tokens may be routed to fewer experts if particular experts are
oversubscribed / reach capacity.
"""
def __init__(
self,
num_selected_experts: int,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
)
def forward(
self,
router_probs: torch.Tensor,
expert_capacity: int,
) -> Tuple:
"""Computes masks for the top-k experts per token.
Args:
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
probabilities used to determine the routing of tokens to the experts.
Returns:
Dispatch and combine arrays for routing with masked matmuls.
"""
# TODO: FIXME: add parallel group
num_groups, _, num_experts = router_probs.shape
# Top-k router probability and corresponding expert indices for each token.
# Shape: [num_groups, tokens_per_group, num_selected_experts].
expert_gate, expert_index = torch.topk(router_probs, self.k_value)
self.set_aux_loss(router_probs, expert_index, num_experts)
self.pop_router_loss()
# Make num_selected_experts the leading axis to ensure that top-1 choices
# have priority over top-2 choices, which have priority over top-3 choices,
# etc.
expert_index = torch.transpose(expert_index, 1, 2)
# Shape: [num_groups, num_selected_experts * tokens_per_group]
expert_index = expert_index.reshape(num_groups, -1)
# Create mask out of indices.
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
# Experts have a fixed capacity that we cannot exceed. A token's priority
# within the expert's buffer is given by the masked, cumulative capacity of
# its target expert.
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1
# Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts].
token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts))
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
token_priority = torch.transpose(token_priority, 1, 2)
# For each token, across all selected experts, select the only non-negative
# (unmasked) priority. Now, for group G routing to expert E, token T has
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
# is its targeted expert.
# Shape: [num_groups, tokens_per_group, num_experts].
token_priority = torch.max(token_priority, dim=2)[0]
# Token T can only be routed to expert E if its priority is positive and
# less than the expert capacity. One-hot matrix will ignore indices outside
# the range [0, expert_capacity).
# Shape: [num_groups, tokens_per_group, num_experts, expert_capacity].
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity)
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
# The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity].
combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
return combine_array, dispatch_mask
def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter:
if not grouped:
if top_k == 1:
return Top1Router
elif top_k == 2:
return Top2Router
else:
raise NotImplementedError("top_k > 2 is not supported yet")
else:
return TopKRouter

View File

@ -6,10 +6,11 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.distributed_c10d import get_process_group_ranks
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
from colossalai.tensor.moe_tensor.api import is_moe_tensor
class ForceFP32Parameter(torch.nn.Parameter):
@ -145,7 +146,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
if not is_moe_tensor(param):
ep_size = 1 # set ep_size to 1 for dp parameters
else:
ep_size = get_ep_size(param)
ep_size = dist.get_world_size(param.ep_group)
if ep_size not in epsize_param_dict:
epsize_param_dict[ep_size] = []
epsize_param_dict[ep_size].append(param)
@ -170,8 +171,8 @@ def sync_moe_model_param(model: nn.Module):
# When ep_size = world_size, communication is not needed
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
for param in param_dict[ep_size]:
src_rank = get_dp_group_ranks(param)[0]
dist.broadcast(param, src=src_rank, group=get_dp_group(param))
src_rank = get_process_group_ranks(param.dp_group)[0]
dist.broadcast(param, src=src_rank, group=param.dp_group)
def set_moe_args(config: Any, args: dict):

View File

@ -0,0 +1,3 @@
from .experts import *
from .layers import *
from .routers import *

View File

@ -9,7 +9,7 @@ from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
@ -35,7 +35,7 @@ class MLPExperts(nn.Module):
num_experts: int,
hidden_size: int,
intermediate_size: int,
expert_parallel: Optional[str] = None,
expert_parallel: Optional[str] = "EP",
activation: Optional[Callable] = None,
drop_rate: Optional[float] = 0,
gated: Optional[bool] = False,

View File

@ -8,11 +8,9 @@ import torch.nn as nn
import torch.nn.functional as F
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe.experts import MLPExperts
from colossalai.moe.load_balance import LoadBalancer
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
@ -23,6 +21,7 @@ class SparseMLP(nn.Module):
dim_model (int): Hidden dimension of training model
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
parallel (str): parallel mode. Should be "EP", "TP" or None
capacity_factor_train (float, optional): Capacity factor in routing during training
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
min_capacity (int, optional): The minimum number of the capacity of each expert
@ -51,6 +50,7 @@ class SparseMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
router_top_k: int = 1,
parallel: str = "EP",
router_loss: bool = True,
router_norm: bool = False,
router_capacity_factor_train: float = 1.25,
@ -66,7 +66,7 @@ class SparseMLP(nn.Module):
load_balance_group_swap_factor: float = 0.4,
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False,
enable_hierarchical_comm: bool = True,
return_gate_logits: bool = False,
):
super().__init__()
@ -77,7 +77,9 @@ class SparseMLP(nn.Module):
self.return_gate_logits = return_gate_logits
self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel()
# self.expert_parallel = MOE_MANAGER.get_parallel()
assert parallel in ["EP", "TP", None], "parallel mode must be EP, TP or None"
self.parallel = parallel
self.router_loss = router_loss
self.router_norm = router_norm
@ -99,7 +101,7 @@ class SparseMLP(nn.Module):
# moe experts
self.experts = MLPExperts(
num_experts=self.num_experts,
expert_parallel=self.expert_parallel,
expert_parallel=self.parallel,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
activation=mlp_activation,
@ -108,11 +110,12 @@ class SparseMLP(nn.Module):
)
# get parallel settings
if self.expert_parallel is not None:
if self.parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
self.ep_hierarchical_group = None
if enable_hierarchical_comm:
# TODO: move to plugin
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
get_ep_group_ranks(self.experts)
)
@ -186,11 +189,11 @@ class SparseMLP(nn.Module):
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP":
if self.parallel == "EP":
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel == "TP":
elif self.parallel == "TP":
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel is None:
elif self.parallel is None:
expert_output = self._local_process(dispatch_data)
else:
raise NotImplementedError(

View File

@ -0,0 +1,161 @@
import math
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size
if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
class MLPExperts(nn.Module):
"""
SparseMLP is a multi-layer perceptron with sparse expert parallel layers.
Args:
num_experts (int): The number of experts
hidden_size (int): The hidden size of MLP
intermediate_size (int): The intermediate size of MLP
expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP.
activation (optional): The activation function of MLP
drop_rate (float, optional): The drop rate of MLP
gated (bool, optional): Whether to use gated MLP
use_kernel (bool, optional): Whether to use kernel optimization
"""
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
expert_parallel: Optional[str] = "EP",
activation: Optional[Callable] = None,
drop_rate: Optional[float] = 0,
gated: Optional[bool] = False,
use_kernel: Optional[bool] = False,
):
super().__init__()
assert expert_parallel in ["EP", "TP", None]
self.expert_parallel = expert_parallel
self.num_total_experts = num_experts
self.gated = gated
self.use_kernel = use_kernel
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
# get expert parallel info
if expert_parallel is not None:
self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
num_experts, use_tp=True if expert_parallel == "TP" else False
)
# get settings for different parallel
self.ep_size = get_ep_size(self)
if expert_parallel == "TP":
intermediate_size = intermediate_size // self.ep_size
num_experts = self.num_total_experts
else:
num_experts = self.num_local_experts
else:
self.num_local_experts = self.num_total_experts
self.ep_size = 1
if gated:
self.wi_gate = nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
)
)
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size))
self.act_name = activation
self.act = get_activation(activation)
self.drop = nn.Dropout(p=drop_rate)
if expert_parallel is not None:
for param in self.parameters():
set_moe_tensor_info(param, self.moe_info)
# init param
self.reset_parameters()
@torch.no_grad()
def reset_parameters(self):
# expert param should be different
if self.expert_parallel is not None:
seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True)
else:
seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
with seed_ctx:
if self.gated:
torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size))
torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size))
else:
torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size))
torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))
def forward(
self,
x: torch.Tensor,
param_slice: Tuple[slice] = (slice(None),),
use_sparse: bool = True,
) -> torch.Tensor:
"""
forward: hidden_size --> intermediate_size --> hidden_size
Args:
x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size)
Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
"""
x = MoeInGradScaler.apply(x, self.ep_size)
e = x.size(1)
h = x.size(-1)
x = x.transpose(0, 1)
inshape = x.shape
x = x.reshape(e, -1, h)
if self.use_kernel and use_sparse:
seq_len = x.shape[1]
with torch.no_grad():
mask = x[:, :, 0] != 0.0
mask = torch.sum(mask, dim=-1)
x_list = []
for i in range(e):
x_list.append(x[i, : mask[i]])
x = x_list
if self.gated:
x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)]
x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)]
if self.use_kernel and HAS_TRITON and self.act_name == "swiglu":
x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)]
else:
x = [self.act(x_gate[i]) * x_up[i] for i in range(e)]
else:
x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)]
x = [self.act(x[i]) for i in range(e)]
x = [self.drop(x[i]) for i in range(e)]
x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)]
if self.use_kernel and use_sparse:
for i in range(e):
x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0)
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
x = MoeOutGradScaler.apply(x, self.ep_size)
return x

View File

@ -1,222 +1,108 @@
from functools import partial
from typing import Callable, Dict, List, Optional, Union
from typing import List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import CrossEntropyLoss, Module
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import ProcessGroup
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
from torch.nn import CrossEntropyLoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralModel,
MixtralSparseMoeBlock,
MoeCausalLMOutputWithPast,
_prepare_4d_causal_attention_mask,
load_balancing_loss_func,
)
from transformers.utils import logging
from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.shard import ShardConfig
from .mixtral_layer import EPMixtralSparseMoeBlock
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
from colossalai.shardformer.shard.utils import set_tensors_to_none
class MixtralPolicy(Policy):
def config_sanity_check(self):
pass
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config):
self.moe_info = None
super().__init__(config)
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
def setup_ep(self, ep_group: ProcessGroup):
ep_group = ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
assert self.num_experts % self.ep_size == 0
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters():
p.ep_group = ep_group
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
@staticmethod
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
# if "ep_group" in kwargs:
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
module.setup_ep(kwargs["ep_group"])
return module
return self.model
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
selected_experts = selected_experts.t().reshape(-1)
selected_experts_idx = selected_experts.argsort()
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key=MixtralDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=MixtralModel,
)
if self.shard_config.enable_flash_attention:
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
return policy
def postprocess(self):
return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
if self.num_experts_per_ep == 1:
# no need to split
expert = self.experts[self.expert_start_idx]
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states)
else:
module = self.model.model
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = []
for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0:
continue
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states)
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
selected_experts_idx.size(0), device=selected_experts_idx.device
)
return
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
class MixtralModelPolicy(MixtralPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MixtralModel,
new_forward=MixtralPipelineForwards.mixtral_model_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
return []
class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
MixtralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
]
)
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MixtralForCausalLM,
new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
dispatch_states = dispatch_states[recover_experts_idx]
k_hidden_states = dispatch_states.chunk(self.top_k)
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
for i in range(1, self.top_k):
output_states += k_hidden_states[i] * routing_weights[:, i, None]
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
return output_states, router_logits
class MixtralPipelineForwards:
@ -332,7 +218,7 @@ class MixtralPipelineForwards:
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if self._use_flash_attention_2:
if is_flash_attn_2_available():
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:

View File

@ -176,6 +176,7 @@ _POLICY_LIST = {
"transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation(
file_name="falcon", class_name="FalconForQuestionAnsweringPolicy"
),
# mistral
"transformers.models.mistral.modeling_mistral.MistralModel": PolicyLocation(
file_name="mistral", class_name="MistralModelPolicy"
),
@ -185,6 +186,13 @@ _POLICY_LIST = {
"transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation(
file_name="mistral", class_name="MistralForSequenceClassificationPolicy"
),
# mixtral
"transformers.models.mixtral.modeling_mixtral.MixtralModel": PolicyLocation(
file_name="mixtral", class_name="MixtralModelPolicy"
),
"transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation(
file_name="mixtral", class_name="MixtralForCausalLMPolicy"
),
# Qwen2
"transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
file_name="qwen2", class_name="Qwen2ModelPolicy"
@ -195,7 +203,7 @@ _POLICY_LIST = {
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
),
# Command-R
# command
"transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation(
file_name="command", class_name="CommandModelPolicy"
),

View File

@ -0,0 +1,210 @@
from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
class MixtralPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
if getattr(self.shard_config, "ep_group", None) is None:
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
# expert parallel
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group},
)
],
policy=policy,
target_key=MixtralDecoderLayer,
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key=MixtralDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=MixtralModel,
)
if self.shard_config.enable_flash_attention:
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
return policy
def postprocess(self):
return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
else:
module = self.model.model
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
class MixtralModelPolicy(MixtralPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MixtralModel,
new_forward=MixtralPipelineForwards.mixtral_model_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
return []
class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self):
policy = super().module_policy()
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
MixtralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
]
)
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MixtralForCausalLM,
new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []

View File

@ -46,6 +46,7 @@ class ShardConfig:
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
ep_group: Optional[ProcessGroup] = None
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']

View File

@ -17,10 +17,10 @@ def is_moe_tensor(tensor: torch.Tensor) -> bool:
Returns:
bool: Whether the given tensor is a moe tensor.
"""
return hasattr(tensor, "moe_info")
return hasattr(tensor, "ep_group")
def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None:
def set_moe_tensor_ep_group(tensor: torch.Tensor, ep_group: ProcessGroup) -> None:
"""
Set moe info for the given tensor.
@ -29,7 +29,7 @@ def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None
moe_info (dict): The moe info to be set.
"""
tensor.__setattr__("moe_info", moe_info)
tensor.__setattr__("ep_group", ep_group)
def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo:
@ -58,7 +58,7 @@ def get_ep_group(tensor: torch.Tensor) -> ProcessGroup:
Returns:
torch.distributed.ProcessGroup: The expert parallel group of the given tensor.
"""
return tensor.moe_info.ep_group
return tensor.ep_group
def get_ep_size(tensor: torch.Tensor) -> int:
@ -71,7 +71,8 @@ def get_ep_size(tensor: torch.Tensor) -> int:
Returns:
int: The expert parallel size of the given tensor.
"""
return tensor.moe_info.ep_size
assert getattr(tensor, "ep_group") is not None, "The tensor does not have expert parallel group."
return dist.get_world_size(tensor.ep_group)
def get_dp_size(tensor: torch.Tensor) -> int:

View File

@ -1,6 +1,5 @@
from .bucket_store import BucketStore
from .gradient_store import GradientStore
from .parameter_store import ParameterStore
from .tensor_bucket import TensorBucket
__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"]
__all__ = ["GradientStore", "BucketStore", "TensorBucket"]

View File

@ -1,12 +1,11 @@
from typing import Dict, Optional
from typing import Dict
import torch
import torch.distributed as dist
from torch import Tensor
from torch._utils import _flatten_dense_tensors
from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
from colossalai.accelerator.api import get_accelerator
from .base_store import BaseStore
@ -16,29 +15,11 @@ class BucketStore(BaseStore):
self,
torch_pg: ProcessGroup,
reduce_bucket_size: int,
overlap_communication: bool,
communication_dtype: Optional[torch.dtype] = None,
moe_extra_dp_process_group: ProcessGroup = None,
):
super().__init__(torch_pg)
self.reduce_bucket_size = reduce_bucket_size
# communication params
self._overlap_communication = overlap_communication
self._communication_dtype = communication_dtype
if self._overlap_communication:
self.comm_stream = get_accelerator().Stream()
self.zero_local_rank = dist.get_rank(group=self.torch_pg)
self.zero_world_size = dist.get_world_size(group=self.torch_pg)
# extra dp
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
# And moe working and master param are split by extra dp pg.
self.moe_extra_dp_pg = moe_extra_dp_process_group
if self.moe_extra_dp_pg is not None:
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
self.reset_all()
self.comm_stream = get_accelerator().Stream()
def reset_all(self) -> None:
# init

View File

@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
from torch import Tensor
@ -6,7 +6,7 @@ from .base_store import BaseStore
class GradientStore(BaseStore):
def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True):
def __init__(self, *args, partition_grad: bool = False):
super().__init__(*args)
"""
self._grads_of_params mapping the parameter and its gradient slices
@ -20,8 +20,6 @@ class GradientStore(BaseStore):
self._grads_of_params = dict()
# stage 2
self._partition_grads = partition_grad
# grad accumulation
self.require_grad_sync = require_grad_sync
self._working_index = 0 if partition_grad else self._local_rank
# for zero2, it's `param_id: [grad_local_rank]`
self.grad_to_param_mapping = dict()
@ -107,8 +105,7 @@ class GradientStore(BaseStore):
for group in self._grads_of_params.values():
if param_id in group.keys():
return group[param_id][self._working_index]
raise KeyError(f"Working gradient for param_id {param_id} not found.")
return None
def reset_grads_by_group_id(self, group_id: int):
self._grads_of_params[group_id] = dict()
@ -116,7 +113,7 @@ class GradientStore(BaseStore):
def reset_all_gradients(self):
self._grads_of_params = dict()
def get_param_id_for_grad(self, grad: Tensor) -> int:
def get_param_id_for_grad(self, grad: Tensor) -> Optional[int]:
"""Return the id of a parameter which the gradient slice belongs to
Args:
@ -126,4 +123,4 @@ class GradientStore(BaseStore):
int: the id of a parameter which the gradient slice belongs to
"""
return self.grad_to_param_mapping[id(grad)]
return self.grad_to_param_mapping.get(id(grad), None)

View File

@ -1,60 +0,0 @@
from typing import Dict
from torch import Tensor
from torch.distributed import ProcessGroup
from .base_store import BaseStore
class ParameterStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
# record the padding size of each param
self._padding_map = dict()
# mapping working param and master param
self.master_to_working_param = dict()
self.working_to_master_param = dict()
def record_param_padding_size(self, param: Tensor, padding_size: int):
"""Record the padding size of a param
Args:
param (Tensor): The parameter
padding_size (int): The padding size of the parameter
"""
self._padding_map[id(param)] = padding_size
def get_param_padding_size(self, param: Tensor) -> int:
"""Return the padding size of the parameter
Args:
param (Tensor): The parameter
Returns:
int: the padding size of the parameter
"""
return self._padding_map[id(param)]
def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor):
"""Mapping master parameter and working parameter
Args:
master_param (Tensor): The parameter copy in optimizer
working_param (Tensor): The parameter of the model
"""
self.master_to_working_param[id(master_param)] = working_param
self.working_to_master_param[id(working_param)] = master_param
def get_padding_map(self) -> Dict[int, Tensor]:
"""Return the padding map
Returns:
Dict[int, Tensor]: The padding map
"""
return self._padding_map

File diff suppressed because it is too large Load Diff

View File

@ -176,7 +176,7 @@ def main():
use_ep_inside = False
plugin = MoeHybridParallelPlugin(
pp_size=1,
extra_dp_size=args.extra_dp_size,
ep_size=args.ep_size,
use_ep_inside=use_ep_inside,
**hybrid_dict,
)

View File

@ -50,9 +50,9 @@ try:
except:
HAS_FLASH_ATTN = False
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe.layers import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation, set_moe_args
from colossalai.shardformer.layer.moe import SparseMLP
if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
@ -83,7 +83,7 @@ def set_openmoe_args(
load_balance_group_swap_factor: float = 0.4,
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
enable_hierarchical_alltoall: bool = False,
enable_hierarchical_alltoall: bool = True,
) -> None:
"""
MoE related arguments.
@ -465,7 +465,7 @@ class OpenMoeDecoderLayer(nn.Module):
load_balance_beam_width=config.load_balance_beam_width,
load_balance_group_swap_factor=config.load_balance_group_swap_factor,
enable_kernel=config.enable_kernel,
enable_comm_overlap=config.enable_comm_overlap,
enable_hierarchical_comm=config.enable_hierarchical_alltoall,
)
self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.extra_mlp = OpenMoeMLP(config)
@ -903,7 +903,7 @@ class OpenMoeForCausalLM(OpenMoePreTrainedModel):
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
# reset moe loss
MOE_MANAGER.reset_loss()
MOE_MANAGER.reset_loss() # TODO: remove
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1027,7 +1027,7 @@ class OpenMoeForCausalLM(OpenMoePreTrainedModel):
def _calculate_router_loss(self, aux_loss: list = None, z_loss: list = None):
if aux_loss is None or z_loss is None:
aux_loss, z_loss = MOE_MANAGER.get_loss()
aux_loss, z_loss = MOE_MANAGER.get_loss() # TODO: remove
assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval
aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss)
z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss)

View File

@ -172,6 +172,7 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
# TODO: recursively assign ep group foe all modules
new_item = {
OpenMoeForCausalLM: ModulePolicyDescription(
sub_module_replacement=[

View File

@ -1,37 +1,37 @@
pip install -r requirements.txt
# pip install -r requirements.txt
# inference
python infer.py --model "test"
# python infer.py --model "test"
# train
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep" \
--batch_size 1
# torchrun --standalone --nproc_per_node 4 train.py \
# --num_epoch 1 \
# --model_name "test" \
# --plugin "ep" \
# --batch_size 1
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep_zero" \
--batch_size 1 \
--zero_stage 1 \
--extra_dp_size 2 \
# torchrun --standalone --nproc_per_node 4 train.py \
# --num_epoch 1 \
# --model_name "test" \
# --plugin "ep_zero" \
# --batch_size 1 \
# --zero_stage 1 \
# --extra_dp_size 2 \
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep_zero" \
--batch_size 1 \
--zero_stage 2 \
--extra_dp_size 2 \
# torchrun --standalone --nproc_per_node 4 train.py \
# --num_epoch 1 \
# --model_name "test" \
# --plugin "ep_zero" \
# --batch_size 1 \
# --zero_stage 2 \
# --extra_dp_size 2 \
torchrun --standalone --nproc_per_node 4 train.py \
--model_name "test" \
--plugin "hybrid" \
--num_epoch 1 \
--pp_size 2 \
--dp_size 1 \
--ep_size 2 \
--zero_stage 1 \
--batch_size 1
# torchrun --standalone --nproc_per_node 4 train.py \
# --model_name "test" \
# --plugin "hybrid" \
# --num_epoch 1 \
# --pp_size 2 \
# --dp_size 1 \
# --ep_size 2 \
# --zero_stage 1 \
# --batch_size 1

View File

@ -19,10 +19,9 @@ from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.layer.moe import apply_load_balance
def move_to_cuda(batch, device):
@ -221,48 +220,49 @@ def main():
"precision": args.precision,
"zero_stage": args.zero_stage,
}
mgr_dict = {}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(
pp_size=1,
ep_size=args.ep_size,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size,
**mgr_dict,
)
# MOE_MANAGER.setup(
# parallel="EP",
# max_ep_size=dp_size,
# **mgr_dict,
# )
elif args.plugin == "ep_zero":
dp_size = dist.get_world_size()
use_ep_inside = False
plugin = MoeHybridParallelPlugin(
pp_size=1,
extra_dp_size=args.extra_dp_size,
ep_size=dp_size // args.ep_size,
use_ep_inside=use_ep_inside,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size // args.extra_dp_size,
use_ep_inside=use_ep_inside,
**mgr_dict,
)
# MOE_MANAGER.setup(
# parallel="EP",
# max_ep_size=dp_size // args.extra_dp_size,
# use_ep_inside=use_ep_inside,
# **mgr_dict,
# )
elif args.plugin == "hybrid":
dp_size = dist.get_world_size() // args.pp_size
plugin = MoeHybridParallelPlugin(
pp_size=args.pp_size,
ep_size=args.ep_size,
microbatch_size=args.microbatch_size,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=args.dp_size,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
)
# MOE_MANAGER.setup(
# parallel="EP",
# mode="fixed",
# fixed_dp_size=args.dp_size,
# fixed_ep_size=args.ep_size,
# fixed_pp_size=args.pp_size,
# **mgr_dict,
# )
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")

View File

@ -59,10 +59,10 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
working_param_id_set = set(id(p) for p in new_model.parameters())
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
for p_id, master_param in new_optimizer.working_to_master_param.items():
assert p_id in working_param_id_set
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
padding = new_optimizer._param_store.get_param_padding_size(working_param)
working_param = new_optimizer.master_to_working_param[id(master_param)]
padding = new_optimizer.get_param_padding_size(working_param)
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
assert torch.equal(
@ -115,10 +115,10 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
working_param_id_set = set(id(p) for p in new_model.parameters())
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
for p_id, master_param in new_optimizer.working_to_master_param.items():
assert p_id in working_param_id_set
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
padding = new_optimizer._param_store.get_param_padding_size(working_param)
working_param = new_optimizer.master_to_working_param[id(master_param)]
padding = new_optimizer.get_param_padding_size(working_param)
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
assert torch.equal(

View File

@ -1,48 +1,37 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.testing import assert_close
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_moe_epsize_param_dict
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
# from colossalai.shardformer.layer.moe import SparseMLP
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group
def delete_moe_info(model):
for _, param in model.named_parameters():
if hasattr(param, "moe_info"):
delattr(param, "moe_info")
if hasattr(param, "ep_group"):
delattr(param, "ep_group")
class MoeModel(nn.Module):
def __init__(self, enable_load_balance: bool = False):
class TestSubModule(nn.Module):
def __init__(self):
def __init__(self, ep_group: ProcessGroup = None):
super().__init__()
self.moe = SparseMLP(
num_experts=8, hidden_size=16, intermediate_size=32, enable_load_balance=enable_load_balance
)
self.proj = nn.Linear(16, 4)
self.test_embed = nn.Linear(4, 16, bias=False)
self.w1 = torch.nn.Parameter(torch.randn(16, 8))
if ep_group:
set_moe_tensor_ep_group(self.w1, ep_group)
def forward(self, x):
x = self.moe(x)
x = self.proj(x)
return x
super().__init__()
self.test_embed = nn.Linear(4, 16)
self.test_transform = TestSubModule()
def forward(self, x):
MOE_MANAGER.reset_loss()
x = self.test_embed(x)
x = self.test_transform(x)
x = torch.matmul(x, self.w1)
return x
@ -116,7 +105,7 @@ def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False)
return y
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
@ -126,7 +115,6 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name in ep_name, print(f"{local_name} != {ep_name}")
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"

View File

@ -5,8 +5,9 @@ import torch.nn as nn
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
# from colossalai.shardformer.layer.moe.layers import SparseMLP
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler
@ -69,6 +70,7 @@ def run_test(rank, world_size, port):
# MoE grad handler test passed
@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_grad_handler():

View File

@ -1,98 +1,96 @@
import os
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn
BATCH_SIZE = 4
# from colossalai.moe import SparseMLP
from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum
NUM_EXPERTS = 4
BATCH_SIZE = 4
SEQ_LEN = 4
MOE_TENSOR_PATH = os.getenv("MOE_TENSOR_PATH")
def check_equal(tensor_a, tensor_b, atol=1e-06):
assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, topk=1):
# Here we do not need TF32, since it brings absolute error on results
torch.backends.cuda.matmul.allow_tf32 = False
def run_moe_cumsum():
test_mask = torch.tensor(
[
[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[1, 0, 0, 0],
],
dtype=torch.int32,
).to("cuda")
out_no_kernel = moe_cumsum(test_mask, use_kernel=False)
out_kernel = moe_cumsum(test_mask, use_kernel=True)
print(out_no_kernel.dtype, out_kernel.dtype)
check_equal(out_no_kernel.to(torch.int32), out_kernel)
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
local_rank = dist.get_rank()
MOE_MANAGER.setup(parallel="EP") # MOE environment initialization
MOE_MANAGER.reset_loss()
torch.manual_seed(rs + local_rank) # set each process has different random seed
# get randomized data
def run_moe_dispatch_combine_fwd_bwd(data_type=torch.float32, hidden_size=128, num_experts=4):
tokens = torch.randn(
BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True
)
layer = SparseMLP(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_experts=NUM_EXPERTS,
router_top_k=topk,
router_capacity_factor_train=1.0,
)
layer = layer.to(get_accelerator().get_current_device())
if data_type == torch.float16:
layer = layer.half()
# use kernel
route_result_list_kernel = torch.load(f"{MOE_TENSOR_PATH}/True_4_{data_type}.pt")
# dispatch
dispatch_data_kernel = MoeDispatch.apply(tokens, *route_result_list_kernel[1:])
dispatch_data_kernel = dispatch_data_kernel.reshape(num_experts, -1, hidden_size)
# combine
expert_output = dispatch_data_kernel.reshape(-1, hidden_size)
ans_kernel = MoeCombine.apply(expert_output, *route_result_list_kernel)
# use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine
layer.enable_kernel = False
old_out = layer(tokens)
ech = old_out.shape
grad = torch.randn(ech, device=get_accelerator().get_current_device())
old_out.backward(grad) # get gradient
# no kernel
route_result_list_no_kernel = torch.load(f"{MOE_TENSOR_PATH}/False_2_{data_type}.pt")
# dispatch
sec_mask_f = route_result_list_no_kernel[1].type_as(tokens)
dispatch_data_no_kernel = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# combine
combine_weights = route_result_list_no_kernel[0].type_as(tokens)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ans_no_kernel = torch.matmul(combine_weights, expert_output)
# save all results
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate_weight.grad.data.clone()
# check fwd
if data_type == torch.float32:
check_equal(dispatch_data_kernel.reshape(dispatch_data_no_kernel.shape), dispatch_data_no_kernel)
else:
check_equal(dispatch_data_kernel.reshape(dispatch_data_no_kernel.shape), dispatch_data_no_kernel, 1e-2)
# reset all gradients
if data_type == torch.float32:
check_equal(ans_kernel, ans_no_kernel)
else:
check_equal(ans_kernel, ans_no_kernel, 1e-2)
# check bwd
out_shape = ans_kernel.shape
grad = torch.randn(out_shape, device=get_accelerator().get_current_device())
ans_kernel.backward(grad, retain_graph=True)
grad_kernel = tokens.grad.data.clone()
tokens.grad.zero_()
layer.gate_weight.grad.zero_()
layer.enable_kernel = True
new_out = layer(tokens) # get outputs through colossal kernel
ans_no_kernel.backward(grad) # get gradient
grad_no_kernel = tokens.grad.data.clone()
tokens.grad.zero_()
if data_type == torch.float32:
check_equal(old_out, new_out)
check_equal(grad_no_kernel, grad_kernel)
else:
check_equal(old_out, new_out, 1e-2)
# forward function passed
new_out.backward(grad) # get new type gradient
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate_weight.grad.data.clone()
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# tokens gradient is correct
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# bias gradient is correct
check_equal(grad_no_kernel, grad_kernel, 1e-2)
@pytest.mark.dist
@pytest.mark.parametrize("rs", [131])
@pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
@pytest.mark.parametrize("topk", [1, 2])
@rerun_if_address_is_in_use()
def test_moe_kernel(rs, hidden_size, data_type, topk):
spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk)
if __name__ == "__main__":
test_moe_kernel(2, 256, torch.float16, 2)
def test_moe_kernel(data_type):
torch.manual_seed(1024)
run_moe_cumsum()
run_moe_dispatch_combine_fwd_bwd(data_type=data_type)

View File

@ -3,13 +3,13 @@ from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
from torch.testing import assert_close
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import colossalai
from colossalai.moe import MOE_MANAGER
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock
from colossalai.testing.utils import spawn
tokens, n_experts = 7, 4
@ -19,8 +19,11 @@ top_k = 2
def check_mixtral_moe_layer():
torch.cuda.set_device(dist.get_rank())
MOE_MANAGER.setup(
parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=dist.get_world_size(),
)
config = MixtralConfig(
hidden_size=hidden_size,
@ -33,7 +36,7 @@ def check_mixtral_moe_layer():
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output, orig_logits = orig_model(x)
model = deepcopy(orig_model)
model = EPMixtralSparseMoeBlock.from_native_module(model)
model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group)
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)
assert_close(orig_output, ep_output)

View File

@ -1,201 +1,176 @@
import importlib
import os
import shutil
import sys
import tempfile
from contextlib import nullcontext
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from transformers.models.llama import LlamaConfig
from torch.optim import Adam
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing.utils import spawn
sys.path.append(
os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"examples/language/openmoe",
)
)
OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_accelerator().get_current_device())
attention_mask = torch.ones_like(input_ids)
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
if not torch.equal(p1.half(), p2.half()):
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
raise AssertionError(f"Model parameter {name} is not equal")
def get_optimizer_snapshot(optim):
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
param_groups = []
for group in optim.param_groups:
params = [id(p) for p in group["params"]]
new_group = {"params": params}
for k, v in group.items():
if k != "params":
new_group[k] = v
param_groups.append(new_group)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
"state": state,
"param_groups": param_groups,
}
def run_fwd_bwd(
model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
):
model.train()
if pipeline:
train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
y = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = y["loss"]
else:
if criterion:
y = model(data).logits
loss = criterion(y)
else:
loss = model(data, label)
loss = loss.float()
def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_group=None):
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
assert set(group1.keys()) == set(group2.keys())
for k in group1.keys():
assert group1[k] == group2[k]
# check state
assert set(snapshot1["state"].keys()) == set(
snapshot2["state"].keys()
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
if optimizer is not None:
optimizer.backward(loss)
passed = True
count = 0
for pid in snapshot1["state"].keys():
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
assert set(state1.keys()) == set(state2.keys())
bug = False
for k in state1.keys():
if isinstance(state1[k], torch.Tensor):
if not torch.equal(state1[k], state2[k]):
bug = True
count += 1
else:
loss.backward()
return y
assert state1[k] == state2[k]
if bug:
passed = False
if not passed:
raise AssertionError(f"A total of {count} optim states are not equal")
def get_config():
config = LlamaConfig(
vocab_size=300,
hidden_size=16,
intermediate_size=32,
num_hidden_layers=2,
def check_mixtral_moe_layer():
context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
with context as f:
torch.cuda.set_device(dist.get_rank())
if dist.get_rank() == 0:
broadcast_objects = [f] # any picklable object
else:
broadcast_objects = [None]
dist.broadcast_object_list(broadcast_objects, src=0)
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
num_attention_heads=2,
head_dim=4,
dropout_rate=0.0,
hidden_act="swiglu",
num_key_value_heads=2,
)
set_openmoe_args(config, num_experts=8, moe_layer_interval=1)
return config
def get_model(parallel):
config = get_config()
model = OpenMoeForCausalLM(config)
optim = torch.optim.Adam(model.parameters())
if parallel == None:
torch.manual_seed(0)
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
orig_model = MixtralForCausalLM(config).cuda()
model = deepcopy(orig_model)
optimizer = Adam(model.parameters(), lr=1e-3)
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=1,
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "ep":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=dist.get_world_size(),
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "ep_zero":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=2,
zero_stage=2,
extra_dp_size=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "hybrid":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=2,
ep_size=2,
zero_stage=1,
tp_size=1,
checkpoint_io=MoECheckpointIO,
microbatch_size=1,
custom_policy=OpenMoeForCausalLMPolicy(),
zero_stage=1,
)
booster = Booster(plugin=plugin)
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
return model, booster, optim
def _test_moe_checkpoint(rank, parallel):
model1, booster1, optim1 = get_model(parallel)
model2, booster2, optim2 = get_model(parallel)
model3, booster3, optim3 = get_model(parallel)
# param ckpt
# shard
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
booster2.load_model(model2, "./tmp_ckpt1")
# unshard
booster1.save_model(model1, "./tmp_ckpt1.pth")
booster3.load_model(model3, "./tmp_ckpt1.pth")
# check
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)
# optim ckpt
criterion = lambda x: x.mean()
data = torch.randint(0, 4, (2, 4)).cuda()
label = torch.randint(0, 4, (2,)).cuda()
if parallel == "hybrid":
kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
else:
kwargs = {}
run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
optim1.step()
optim1.zero_grad()
# shard
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
dist.barrier()
booster2.load_optimizer(optim2, "./tmp_ckpt2")
# unshard
booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")
booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")
# check
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)
if dist.get_rank() == 0:
shutil.rmtree("./tmp_ckpt1")
shutil.rmtree("./tmp_ckpt2")
os.remove("./tmp_ckpt1.pth")
os.remove("./tmp_ckpt2.pth")
def _run_dist(rank, world_size, port, parallel):
colossalai.launch(
config=dict(),
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
# initialize grads
data_iter = iter(
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
)
_test_moe_checkpoint(rank, parallel)
booster.execute_pipeline(
data_iter,
model,
lambda outputs, inputs: outputs.loss,
optimizer,
)
tmpdirname = broadcast_objects[0]
model_dir = os.path.join(tmpdirname, "mixtral_model")
hf_model_dir = os.path.join(tmpdirname, "mixtral_hf_model")
optim_dir = os.path.join(tmpdirname, "mixtral_optim")
booster.save_model(model, model_dir, shard=True)
dist.barrier()
if dist.get_rank() == 0:
saved_model = MixtralForCausalLM.from_pretrained(model_dir).cuda()
check_model_equal(orig_model, saved_model)
# check_model_equal(model, saved_model)
saved_model.save_pretrained(hf_model_dir)
dist.barrier()
# check load model
new_model = MixtralForCausalLM(config).cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
booster.load_model(new_model, hf_model_dir)
check_model_equal(model, new_model)
# check save optimizer
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
snapshot = get_optimizer_snapshot(optimizer.unwrap())
booster.save_optimizer(optimizer, optim_dir, shard=True)
dist.barrier()
# reset optimizer state
for state in optimizer.unwrap().state.values():
for v in state.values():
if isinstance(v, torch.Tensor):
v.zero_()
booster.load_optimizer(optimizer, optim_dir)
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
check_optimizer_snapshot_equal(snapshot, loaded_snapshot, None, model)
# Ensure rank 0 waits for all other ranks to finish
dist.barrier()
@pytest.mark.skip(reason="This is tested in ColossalMOE")
@pytest.mark.dist
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch(rank, world_size, "localhost", port)
check_mixtral_moe_layer()
# Test EP + ZeRO + PP
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
@rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size, parallel):
spawn(_run_dist, world_size, parallel=parallel)
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_checkpoint(world_size=4, parallel="hybrid")
test_mixtral_moe_layer(4)

View File

@ -8,15 +8,16 @@ import torch.distributed as dist
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
# from colossalai.shardformer.layer import SparseMLP
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler
def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_flag: bool = False) -> None:
def sync_tp_from_local(tp_model, local_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from local model
Args:
@ -48,7 +49,7 @@ def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_
tp_param.data.copy_(local_param[tuple(tp_slice)].data)
def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
def sync_tp_from_ep(tp_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
@ -90,7 +91,7 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag:
tp_param.data.copy_(new_tp_param.data)
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
@ -216,6 +217,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
)
@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
@pytest.mark.parametrize("num_experts", [4, 64])
@pytest.mark.parametrize("batch_size", [16])

View File

@ -4,9 +4,10 @@ import torch.nn as nn
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.moe.experts import MLPExperts
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
# from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
HIDDEN_SIZE = 4
@ -69,6 +70,7 @@ def _run_test(rank, world_size, port, expert_parallel):
run_moe_init(expert_parallel)
@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
@pytest.mark.parametrize("expert_parallel", ["EP", "TP"])
@rerun_if_address_is_in_use()

View File

@ -86,6 +86,7 @@ def run_dist(rank, world_size, port):
run_zero_optim_test(rank, world_size, stage=2)
@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()

View File

@ -6,8 +6,9 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
# from colossalai.shardformer.layer.moe import apply_load_balance
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
@ -176,6 +177,7 @@ def run_dist(rank, world_size, port):
run_hybrid_zero_optim_test(rank, world_size, stage=2)
@pytest.mark.skip(reason="moe need to be refactored")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()

View File

@ -1,47 +0,0 @@
import pytest
import torch
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
@pytest.mark.parametrize(
["router", "num_groups"],
[
(Top1Router(), 1),
(Top2Router(), 1),
# (TopKRouter(num_selected_experts=3), 4),
],
)
@pytest.mark.parametrize(
["batch_size", "seq_len", "num_experts"],
[
(4, 5, 8),
(3, 4, 4),
],
)
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
if num_groups > 1:
x = x.expand(num_groups, -1, -1)
router.train()
if isinstance(router, TopKRouter):
combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
router.eval()
if isinstance(router, TopKRouter):
combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
if __name__ == "__main__":
test_router_forward(Top2Router(), 4, 4, 4, 1)

View File

@ -1,78 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep
def run_zero_test(local_rank, stage=1):
criterion = torch.nn.CrossEntropyLoss()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
moe_model = MoeModel().bfloat16()
moe_optimizer = torch.optim.Adam(moe_model.parameters())
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
moe_booster = Booster(plugin=moe_plugin)
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
zero_model = MoeModel().bfloat16()
delete_moe_info(zero_model)
zero_optimizer = torch.optim.Adam(zero_model.parameters())
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
zero_booster = Booster(plugin=zero_plugin)
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
sync_local_from_ep(zero_model, moe_model)
data = torch.randn(16, 4).bfloat16().cuda()
label = torch.randint(0, 4, (16,)).cuda()
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
assert torch.allclose(zero_out, moe_out)
for (moe_name, moe_param), (zero_name, zero_param) in zip(
moe_model.module.named_parameters(), zero_model.module.named_parameters()
):
assert moe_name == zero_name
moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param))
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
if hasattr(moe_param, "moe_info"):
assert len(moe_grad_list) == 0
if stage == 1:
zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape)
else:
zero_grad = zero_grad_list[0].view(moe_param.grad.shape)
assert torch.allclose(
moe_param.grad, zero_grad, atol=1e-5
), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}"
else:
assert len(moe_grad_list) > 0
assert len(moe_grad_list) == len(zero_grad_list)
for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list):
assert torch.allclose(moe_grad, zero_grad)
def run_dist(rank, world_size, port, stage):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
seed_all(42 + rank)
run_zero_test(rank, stage=stage)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("stage", [1, 2])
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size, stage):
spawn(run_dist, world_size, stage=stage)
if __name__ == "__main__":
test_moe_zero_model(world_size=2, stage=1)

View File

@ -0,0 +1,132 @@
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import colossalai
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
from tests.test_moe.moe_utils import loose_close
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def split_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
@parameterize("dtype", [torch.float16, torch.bfloat16])
@parameterize("master_weights", [True, False])
@parameterize("stage", [1, 2])
def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.dtype, stage: int):
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=1,
ep_size=dist.get_world_size() // 2,
)
seed_all(10086)
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
)
orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda()
ori_model = DDP(orig_model.cuda(), static_graph=True).cuda()
zero_model = deepcopy(orig_model).to(dtype)
zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
pg_param_list = {plugin.global_dp_group: [], plugin.moe_dp_group: []}
for p in zero_model.parameters():
if is_moe_tensor(p):
pg_param_list[plugin.moe_dp_group].append(p)
else:
pg_param_list[plugin.global_dp_group].append(p)
zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer,
pg_to_param_list=pg_param_list,
master_weights=master_weights,
initial_scale=1,
overlap_communication=False,
partition_grad=True,
)
ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1)
# create
seed_all(1453 + rank)
for _ in range(2):
# zero-dp forward
input_data = torch.rand(1, tokens, hidden_size).cuda()
zero_output, zero_logits = zero_model(input_data.to(dtype))
# torch-ddp forward
ori_output, ori_logits = ori_model(input_data.to(dtype))
loose_close(zero_output, ori_output, dtype=dtype)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float())
# torch-ddp backward
ori_output.mean().backward()
# check grad
name_to_p = {n: p for n, p in ori_model.module.named_parameters()}
for n, p in zero_model.named_parameters():
zero_grad = zero_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
assert zero_grad is None
continue
loose_close(zero_grad, name_to_p[n].grad, dtype=dtype)
# zero-dp step
zero_optimizer.step()
# original model step
ori_optimizer.step()
# check updated param
for n, p in zero_model.named_parameters():
loose_close(p.data, name_to_p[n].data, dtype=dtype)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model(world_size=world_size)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_zero_model(world_size=4)

View File

@ -1,83 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep
def run_zero_test(local_rank, stage=1):
criterion = torch.nn.CrossEntropyLoss()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
moe_model = MoeModel().bfloat16()
moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0)
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
moe_booster = Booster(plugin=moe_plugin)
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
zero_model = MoeModel().bfloat16()
delete_moe_info(zero_model)
sync_local_from_ep(zero_model, moe_model)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0)
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
zero_booster = Booster(plugin=zero_plugin)
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
for (moe_name, moe_param), (zero_name, zero_param) in zip(
moe_model.named_parameters(), zero_model.named_parameters()
):
if ".experts." in moe_name:
continue
assert moe_name == zero_name
assert torch.allclose(
moe_param.data, zero_param.data
), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}"
for _ in range(1):
data = torch.randn(2, 4).bfloat16().cuda()
label = torch.randint(0, 4, (2,)).cuda()
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
assert torch.allclose(zero_out, moe_out)
moe_optimizer.step()
zero_optimizer.step()
for (moe_name, moe_param), (zero_name, zero_param) in zip(
moe_model.named_parameters(), zero_model.named_parameters()
):
assert moe_name == zero_name
if is_moe_tensor(moe_param):
param_size = moe_param.shape[0]
zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size]
loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype)
moe_optimizer.zero_grad()
zero_optimizer.zero_grad()
def run_dist(rank, world_size, port, stage):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
seed_all(42 + rank)
run_zero_test(rank, stage=stage)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("stage", [1, 2])
@rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size, stage):
spawn(run_dist, world_size, stage=stage)
if __name__ == "__main__":
test_moe_zero_optim(world_size=2, stage=1)

View File

@ -234,7 +234,7 @@ def check_dist_grad(sharded_optimizer, org_model, sharded_model, weight_layer_fo
if org_name in weight_layer_for_check:
org_grad = org_param.grad
group_id = dist.get_rank(sharded_optimizer.optim.dp_group)
dist_grad = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(group_id, id(sharded_param))
dist_grad = sharded_optimizer.get_partitioned_gradients_by_param_id(group_id, id(sharded_param))
# dist_grad concat then reshape to org_grad shape
if dist_grad:

View File

@ -316,7 +316,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
dp_process_group=dp_group,
verbose=True,
)
shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened
shard_to_param = dist_optim.master_to_working_param # {id(): param tensor} but flattened
dist_optim.optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,

View File

@ -200,7 +200,7 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
dp_process_group=dp_group,
verbose=True,
)
shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened
shard_to_param = dist_optim.master_to_working_param # {id(): param tensor} but flattened
dist_optim.optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,

View File

@ -229,7 +229,7 @@ def run_dist_lamb_fwd_bwd(
dp_process_group=dp_group,
verbose=True,
)
shard_to_param = optim._param_store.master_to_working_param
shard_to_param = optim.master_to_working_param
optim.optim.setup_distributed(tp_group, dp_group, shard_to_param, is_zero=True)
else:
optim.setup_distributed(tp_group)

View File

@ -32,6 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
dp_group = booster.plugin.dp_group
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
@ -53,8 +54,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
device = origin_norm.device
norm_groups = []
for group_id in range(sharded_optimizer.num_param_groups):
working_grads = sharded_optimizer._grad_store.get_working_grads_by_group_id(group_id)
norm_group = sharded_optimizer._compute_grad_norm(gradients=working_grads)
working_grads = sharded_optimizer.get_working_grads_by_group_id(group_id)
norm_group = sharded_optimizer._compute_grad_norm(dp_group, gradients=working_grads)
norm_groups.append(norm_group)
total_norm = 0.0
for norm in norm_groups:

View File

@ -62,10 +62,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
working_p = sharded_optimizer.master_to_working_param[id(p2)]
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = (
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
0 if sharded_optimizer._partition_grads else sharded_optimizer.pid_to_bucket_store[id(p2)].local_rank
)
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]

View File

@ -62,10 +62,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
working_p = sharded_optimizer.master_to_working_param[id(p2)]
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = (
0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank
0
if sharded_optimizer._partition_grads
else sharded_optimizer.pid_to_bucket_store[id(working_p)].local_rank
)
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]

View File

@ -0,0 +1,61 @@
import pytest
import torch
import torch.nn as nn
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.zero import LowLevelZeroOptimizer
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(123, 253)
def forward(self, x):
x = self.linear1(x)
return x
DEL_CALLED = False
class TestLowLevelZeroOptimizer(LowLevelZeroOptimizer):
def __del__(self):
super().__del__()
global DEL_CALLED
DEL_CALLED = True
def exam_mem_leak(world_size):
"""
In this test, we test whether del will be called after the optimizer
is out of scope.
"""
# create models
zero_model = MlpModel().cuda()
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer = TestLowLevelZeroOptimizer(torch.optim.SGD(zero_model.parameters(), lr=1))
del zero_optimizer
assert DEL_CALLED
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
exam_mem_leak(world_size=world_size)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_1_2():
spawn(run_dist, 2)
if __name__ == "__main__":
test_zero_1_2()

View File

@ -91,10 +91,13 @@ def exam_zero_1_2():
zero2_optimizer.backward(zero2_output.mean().float())
# check grad
z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0)
z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0)
for z1g, z2g in zip(z1g_list, z2g_list):
assert torch.equal(z1g, z2g)
for p1, p2 in zip(zero1_model.parameters(), zero2_model.parameters()):
g1 = zero1_optimizer.get_param_grad(p1)
g2 = zero2_optimizer.get_param_grad(p2)
if g1 is None or g2 is None:
assert g1 is None and g2 is None
continue
assert torch.allclose(g1, g2)
# step
zero1_optimizer.step()
@ -102,7 +105,7 @@ def exam_zero_1_2():
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert torch.equal(z1p.data, z2p.data)
assert torch.allclose(z1p, z2p)
@parameterize("dtype", [torch.float16, torch.bfloat16])
@ -120,7 +123,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
seed_all(1453)
# create models
torch_model = MlpModel().cuda()
torch_model = MlpModel().cuda().to(dtype)
zero_model = copy.deepcopy(torch_model).to(dtype)
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
@ -142,29 +145,31 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
seed_all(1453 + local_rank)
for _ in range(2):
# create
input_data = torch.rand(32, 123).cuda()
input_data = torch.rand(32, 123).cuda().to(dtype)
# zero-dp forward
zero_output = zero_model(input_data.to(dtype))
zero_output = zero_model(input_data)
# torch-ddp forward
torch_output = torch_model(input_data)
loose_close(zero_output, torch_output, dtype=dtype)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float())
zero_optimizer.backward(zero_output.mean())
# torch-ddp backward
torch_output.mean().backward()
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
if p.grad is not None:
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p))
torch_grad_list = split_ddp_grad(p.grad, world_size)
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
loose_close(zero_grad, torch_grad, dtype=dtype)
zero_grad = zero_optimizer.get_param_grad(z1p)
if p.grad is None:
assert zero_grad is None
continue
loose_close(p.grad, zero_grad, dtype=dtype)
# zero-dp step
zero_optimizer.step()
@ -174,7 +179,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
loose_close(p.data, z1p.data, dtype=dtype)
loose_close(p, z1p, dtype=dtype)
def run_dist(rank, world_size, port):