mirror of https://github.com/hpcaitech/ColossalAI
update
parent
ebd8cc579a
commit
7c5b1a585f
|
@ -0,0 +1,126 @@
|
|||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
|
||||
from colossalai.moe import MoECheckpintIO
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, is_moe_tensor
|
||||
|
||||
|
||||
class MixtralMoECheckpointIO(MoECheckpintIO):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
|
||||
"""
|
||||
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
|
||||
"""
|
||||
model_param_dict = dict(model.named_parameters())
|
||||
for name, param in list(state_dict.items()):
|
||||
if ".experts." in name:
|
||||
if ".experts.gate.weight" in name:
|
||||
new_name = name.replace(".experts.gate.weight", ".experts.gate_weight")
|
||||
state_dict[new_name] = state_dict.pop(name)
|
||||
else:
|
||||
str_idx = name.index(".experts.")
|
||||
int(name.split(".")[-3])
|
||||
if ".w1." in name:
|
||||
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
|
||||
elif ".w2." in name:
|
||||
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
|
||||
elif ".w3." in name:
|
||||
model_param_name = name.replace(name[str_idx:], ".experts.wo")
|
||||
model_param = model_param_dict[model_param_name]
|
||||
assert is_moe_tensor(model_param)
|
||||
|
||||
ep_rank = get_ep_rank(model_param)
|
||||
ep_size = get_ep_size(model_param)
|
||||
expert_num = 8 // ep_size
|
||||
range(ep_rank * expert_num, (ep_rank + 1) * expert_num)
|
||||
|
||||
state_dict[name] = param
|
||||
|
||||
for name, param in list(state_dict.items()):
|
||||
new_name = "module." + name
|
||||
state_dict[new_name] = state_dict.pop(name)
|
||||
assert new_name in model_param_dict, f"{new_name} not in model"
|
||||
dist.barrier()
|
||||
return state_dict
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
|
||||
# Check whether the checkpoint uses safetensors.
|
||||
use_safetensors = False
|
||||
if "safetensors" in checkpoint_index_file.name:
|
||||
use_safetensors = True
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
strict = False
|
||||
|
||||
# Load params & buffers to model.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
|
||||
def _load(name: str):
|
||||
if name not in weight_map:
|
||||
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
|
||||
filename = weight_map[name]
|
||||
|
||||
# If this param/buffer has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
return
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
state_dict = self.pre_load_model(model, state_dict)
|
||||
missing_keys = []
|
||||
|
||||
load_state_dict_into_model(
|
||||
model,
|
||||
state_dict,
|
||||
missing_keys=missing_keys,
|
||||
strict=strict,
|
||||
load_sub_module=True,
|
||||
)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Load parameters.
|
||||
for name, _ in model.named_parameters():
|
||||
name = name.replace("module.", "")
|
||||
name = name.replace(".gate_weight", ".gate.weight")
|
||||
if ".experts.wi_gate" in name:
|
||||
for i in range(8):
|
||||
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
|
||||
_load(new_name)
|
||||
elif ".experts.wi_up" in name:
|
||||
for i in range(8):
|
||||
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
|
||||
_load(new_name)
|
||||
elif ".experts.wo" in name:
|
||||
for i in range(8):
|
||||
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
|
||||
_load(new_name)
|
||||
else:
|
||||
_load(name)
|
||||
|
||||
if self.verbose:
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
|
@ -0,0 +1,219 @@
|
|||
import importlib
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
sys.path.append(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"examples/language/openmoe",
|
||||
)
|
||||
)
|
||||
|
||||
OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
|
||||
set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
|
||||
OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
|
||||
|
||||
|
||||
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": input_ids,
|
||||
}
|
||||
|
||||
|
||||
def run_fwd_bwd(
|
||||
model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
|
||||
):
|
||||
model.train()
|
||||
if pipeline:
|
||||
train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
|
||||
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
|
||||
y = booster.execute_pipeline(
|
||||
train_dataloader_iter,
|
||||
model,
|
||||
lambda x, y: x.loss,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
loss = y["loss"]
|
||||
else:
|
||||
if criterion:
|
||||
y = model(data).logits
|
||||
loss = criterion(y)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
|
||||
if optimizer is not None:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
return y
|
||||
|
||||
|
||||
def get_config():
|
||||
config = LlamaConfig(
|
||||
vocab_size=300,
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
head_dim=4,
|
||||
dropout_rate=0.0,
|
||||
hidden_act="swiglu",
|
||||
)
|
||||
set_openmoe_args(config, num_experts=8, moe_layer_interval=1)
|
||||
return config
|
||||
|
||||
|
||||
def get_model(parallel):
|
||||
config = get_config()
|
||||
model = OpenMoeForCausalLM(config)
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
|
||||
if parallel == None:
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
zero_stage=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
elif parallel == "ep":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
zero_stage=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
elif parallel == "ep_zero":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
zero_stage=2,
|
||||
extra_dp_size=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
elif parallel == "hybrid":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=2,
|
||||
zero_stage=1,
|
||||
microbatch_size=1,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
|
||||
return model, booster, optim
|
||||
|
||||
|
||||
def _test_moe_checkpoint(rank, parallel):
|
||||
if parallel == None:
|
||||
MOE_MANAGER.setup(
|
||||
parallel=None,
|
||||
)
|
||||
elif parallel == "ep":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
)
|
||||
elif parallel == "ep_zero":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=2,
|
||||
)
|
||||
elif parallel == "hybrid":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=1,
|
||||
fixed_ep_size=2,
|
||||
fixed_pp_size=2,
|
||||
)
|
||||
model1, booster1, optim1 = get_model(parallel)
|
||||
model2, booster2, optim2 = get_model(parallel)
|
||||
model3, booster3, optim3 = get_model(parallel)
|
||||
|
||||
# param ckpt
|
||||
# shard
|
||||
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
|
||||
booster2.load_model(model2, "./tmp_ckpt1")
|
||||
# unshard
|
||||
booster1.save_model(model1, "./tmp_ckpt1.pth")
|
||||
booster3.load_model(model3, "./tmp_ckpt1.pth")
|
||||
# check
|
||||
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
|
||||
check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)
|
||||
|
||||
# optim ckpt
|
||||
criterion = lambda x: x.mean()
|
||||
data = torch.randint(0, 4, (2, 4)).cuda()
|
||||
label = torch.randint(0, 4, (2,)).cuda()
|
||||
if parallel == "hybrid":
|
||||
kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
|
||||
else:
|
||||
kwargs = {}
|
||||
run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
|
||||
optim1.step()
|
||||
optim1.zero_grad()
|
||||
# shard
|
||||
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
|
||||
dist.barrier()
|
||||
booster2.load_optimizer(optim2, "./tmp_ckpt2")
|
||||
# unshard
|
||||
booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")
|
||||
booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")
|
||||
# check
|
||||
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
|
||||
check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
shutil.rmtree("./tmp_ckpt1")
|
||||
shutil.rmtree("./tmp_ckpt2")
|
||||
os.remove("./tmp_ckpt1.pth")
|
||||
os.remove("./tmp_ckpt2.pth")
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port, parallel):
|
||||
colossalai.launch(
|
||||
config=dict(),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host="localhost",
|
||||
port=port,
|
||||
backend="nccl",
|
||||
)
|
||||
_test_moe_checkpoint(rank, parallel)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_checkpoint(world_size, parallel):
|
||||
spawn(_run_dist, world_size, parallel=parallel)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_checkpoint(world_size=4, parallel="hybrid")
|
|
@ -2,6 +2,7 @@ import argparse
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
@ -182,6 +183,7 @@ def main():
|
|||
"enable_jit_fused": args.use_kernel,
|
||||
"precision": args.precision,
|
||||
"zero_stage": args.zero_stage,
|
||||
"checkpoint_io": MixtralMoECheckpointIO,
|
||||
}
|
||||
mgr_dict = {}
|
||||
if args.plugin == "ep":
|
||||
|
@ -240,10 +242,12 @@ def main():
|
|||
# )
|
||||
config = MixtralConfig.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
config.use_cache = False
|
||||
# config.num_local_experts = 1
|
||||
init_ctx = LazyInitContext(default_device=get_current_device())
|
||||
with init_ctx:
|
||||
model = MixtralForCausalLM(config).bfloat16()
|
||||
model = MixtralForCausalLM.from_pretrained(
|
||||
"/home/lczxl/.cache/huggingface/hub/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/f1ca00645f0b1565c7f9a1c863d2be6ebf896b04",
|
||||
config=config,
|
||||
).bfloat16()
|
||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||
|
||||
# Enable gradient checkpointing
|
||||
|
@ -320,7 +324,7 @@ def main():
|
|||
booster.save_model(model, args.output_path, shard=True)
|
||||
|
||||
# save checkpoint at the end of each epochs
|
||||
booster.save_model(model, args.output_path, shard=True)
|
||||
booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
|
||||
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||
|
||||
# Finish training
|
||||
|
|
|
@ -181,6 +181,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
overlap_communication: bool = True,
|
||||
use_ep_inside: bool = True,
|
||||
custom_policy: Policy = None,
|
||||
checkpoint_io: Optional[MoECheckpintIO] = None,
|
||||
) -> None:
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
|
@ -200,6 +201,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.checkpoint_io = checkpoint_io
|
||||
# we change pg mesh to (pp, dp, tp) for better moe performance
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
|
||||
|
||||
|
@ -323,7 +325,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
)
|
||||
|
||||
def get_checkpoint_io(self) -> MoECheckpintIO:
|
||||
if self.checkpoint_io is None:
|
||||
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
else:
|
||||
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
return self.checkpoint_io
|
||||
|
||||
def configure(
|
||||
|
|
Loading…
Reference in New Issue