pull/5190/head
Xuanlei Zhao 2023-12-18 10:37:07 +08:00
parent ebd8cc579a
commit 7c5b1a585f
4 changed files with 358 additions and 4 deletions

View File

@ -0,0 +1,126 @@
import logging
import os
from pathlib import Path
import torch.distributed as dist
import torch.nn as nn
from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
from colossalai.moe import MoECheckpintIO
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, is_moe_tensor
class MixtralMoECheckpointIO(MoECheckpintIO):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
"""
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
"""
model_param_dict = dict(model.named_parameters())
for name, param in list(state_dict.items()):
if ".experts." in name:
if ".experts.gate.weight" in name:
new_name = name.replace(".experts.gate.weight", ".experts.gate_weight")
state_dict[new_name] = state_dict.pop(name)
else:
str_idx = name.index(".experts.")
int(name.split(".")[-3])
if ".w1." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
elif ".w2." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
elif ".w3." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wo")
model_param = model_param_dict[model_param_name]
assert is_moe_tensor(model_param)
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = 8 // ep_size
range(ep_rank * expert_num, (ep_rank + 1) * expert_num)
state_dict[name] = param
for name, param in list(state_dict.items()):
new_name = "module." + name
state_dict[new_name] = state_dict.pop(name)
assert new_name in model_param_dict, f"{new_name} not in model"
dist.barrier()
return state_dict
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
state_dict = self.pre_load_model(model, state_dict)
missing_keys = []
load_state_dict_into_model(
model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True,
)
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
name = name.replace("module.", "")
name = name.replace(".gate_weight", ".gate.weight")
if ".experts.wi_gate" in name:
for i in range(8):
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
_load(new_name)
elif ".experts.wi_up" in name:
for i in range(8):
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
_load(new_name)
elif ".experts.wo" in name:
for i in range(8):
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
_load(new_name)
else:
_load(name)
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")

View File

@ -0,0 +1,219 @@
import importlib
import os
import shutil
import sys
import pytest
import torch
import torch.distributed as dist
from transformers.models.llama import LlamaConfig
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
sys.path.append(
os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"examples/language/openmoe",
)
)
OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
attention_mask = torch.ones_like(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
}
def run_fwd_bwd(
model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
):
model.train()
if pipeline:
train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
y = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = y["loss"]
else:
if criterion:
y = model(data).logits
loss = criterion(y)
else:
loss = model(data, label)
loss = loss.float()
if optimizer is not None:
optimizer.backward(loss)
else:
loss.backward()
return y
def get_config():
config = LlamaConfig(
vocab_size=300,
hidden_size=16,
intermediate_size=32,
num_hidden_layers=2,
num_attention_heads=2,
head_dim=4,
dropout_rate=0.0,
hidden_act="swiglu",
)
set_openmoe_args(config, num_experts=8, moe_layer_interval=1)
return config
def get_model(parallel):
config = get_config()
model = OpenMoeForCausalLM(config)
optim = torch.optim.Adam(model.parameters())
if parallel == None:
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "ep":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "ep_zero":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=2,
extra_dp_size=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
elif parallel == "hybrid":
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=2,
zero_stage=1,
microbatch_size=1,
custom_policy=OpenMoeForCausalLMPolicy(),
)
booster = Booster(plugin=plugin)
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
return model, booster, optim
def _test_moe_checkpoint(rank, parallel):
if parallel == None:
MOE_MANAGER.setup(
parallel=None,
)
elif parallel == "ep":
MOE_MANAGER.setup(
parallel="EP",
)
elif parallel == "ep_zero":
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=2,
)
elif parallel == "hybrid":
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=1,
fixed_ep_size=2,
fixed_pp_size=2,
)
model1, booster1, optim1 = get_model(parallel)
model2, booster2, optim2 = get_model(parallel)
model3, booster3, optim3 = get_model(parallel)
# param ckpt
# shard
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
booster2.load_model(model2, "./tmp_ckpt1")
# unshard
booster1.save_model(model1, "./tmp_ckpt1.pth")
booster3.load_model(model3, "./tmp_ckpt1.pth")
# check
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)
# optim ckpt
criterion = lambda x: x.mean()
data = torch.randint(0, 4, (2, 4)).cuda()
label = torch.randint(0, 4, (2,)).cuda()
if parallel == "hybrid":
kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
else:
kwargs = {}
run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
optim1.step()
optim1.zero_grad()
# shard
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
dist.barrier()
booster2.load_optimizer(optim2, "./tmp_ckpt2")
# unshard
booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")
booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")
# check
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)
if dist.get_rank() == 0:
shutil.rmtree("./tmp_ckpt1")
shutil.rmtree("./tmp_ckpt2")
os.remove("./tmp_ckpt1.pth")
os.remove("./tmp_ckpt2.pth")
def _run_dist(rank, world_size, port, parallel):
colossalai.launch(
config=dict(),
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
_test_moe_checkpoint(rank, parallel)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
@rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size, parallel):
spawn(_run_dist, world_size, parallel=parallel)
if __name__ == "__main__":
test_moe_checkpoint(world_size=4, parallel="hybrid")

View File

@ -2,6 +2,7 @@ import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from torch.utils.data import Dataset
from tqdm import tqdm
@ -182,6 +183,7 @@ def main():
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"zero_stage": args.zero_stage,
"checkpoint_io": MixtralMoECheckpointIO,
}
mgr_dict = {}
if args.plugin == "ep":
@ -240,10 +242,12 @@ def main():
# )
config = MixtralConfig.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
config.use_cache = False
# config.num_local_experts = 1
init_ctx = LazyInitContext(default_device=get_current_device())
with init_ctx:
model = MixtralForCausalLM(config).bfloat16()
model = MixtralForCausalLM.from_pretrained(
"/home/lczxl/.cache/huggingface/hub/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/f1ca00645f0b1565c7f9a1c863d2be6ebf896b04",
config=config,
).bfloat16()
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing
@ -320,7 +324,7 @@ def main():
booster.save_model(model, args.output_path, shard=True)
# save checkpoint at the end of each epochs
booster.save_model(model, args.output_path, shard=True)
booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
# Finish training

View File

@ -181,6 +181,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_communication: bool = True,
use_ep_inside: bool = True,
custom_policy: Policy = None,
checkpoint_io: Optional[MoECheckpintIO] = None,
) -> None:
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
@ -200,6 +201,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.checkpoint_io = checkpoint_io
# we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
@ -323,7 +325,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
def get_checkpoint_io(self) -> MoECheckpintIO:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
else:
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def configure(