2023-12-18 02:37:07 +00:00
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.nn as nn
|
2023-12-25 08:05:42 +00:00
|
|
|
import copy
|
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
from pathlib import Path
|
|
|
|
from shutil import rmtree
|
|
|
|
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
2023-12-18 02:37:07 +00:00
|
|
|
|
2023-12-25 08:05:42 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
|
|
|
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
|
|
|
|
from colossalai.checkpoint_io.utils import (
|
|
|
|
StateDictSharder,
|
|
|
|
gather_distributed_param,
|
|
|
|
get_model_base_filenames,
|
|
|
|
get_optimizer_base_filenames,
|
|
|
|
is_safetensors_available,
|
|
|
|
load_shard_state_dict,
|
|
|
|
load_state_dict,
|
|
|
|
load_state_dict_into_model,
|
|
|
|
load_states_into_optimizer,
|
|
|
|
save_config_file,
|
|
|
|
save_param_groups,
|
|
|
|
save_state_dict,
|
|
|
|
save_state_dict_shards,
|
|
|
|
sharded_optimizer_loading_epilogue,
|
|
|
|
)
|
|
|
|
from colossalai.interface import OptimizerWrapper
|
|
|
|
from colossalai.moe.manager import MOE_MANAGER
|
|
|
|
from colossalai.tensor.moe_tensor.api import (
|
|
|
|
get_dp_group,
|
|
|
|
get_dp_rank,
|
|
|
|
get_dp_size,
|
|
|
|
get_ep_group,
|
|
|
|
get_ep_rank,
|
|
|
|
get_ep_size,
|
|
|
|
is_moe_tensor,
|
|
|
|
)
|
2023-12-18 02:37:07 +00:00
|
|
|
from colossalai.checkpoint_io import CheckpointIndexFile
|
|
|
|
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
|
|
|
|
from colossalai.moe import MoECheckpintIO
|
|
|
|
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, is_moe_tensor
|
|
|
|
|
|
|
|
|
|
|
|
class MixtralMoECheckpointIO(MoECheckpintIO):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
2023-12-25 08:05:42 +00:00
|
|
|
@torch.no_grad()
|
2023-12-18 02:37:07 +00:00
|
|
|
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
|
|
|
|
"""
|
|
|
|
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
|
|
|
|
"""
|
|
|
|
model_param_dict = dict(model.named_parameters())
|
|
|
|
for name, param in list(state_dict.items()):
|
2023-12-25 08:05:42 +00:00
|
|
|
if ".gate.weight" in name:
|
|
|
|
new_name = "module." + name.replace(".gate.weight", ".gate_weight")
|
2023-12-18 02:37:07 +00:00
|
|
|
state_dict[new_name] = state_dict.pop(name)
|
2023-12-25 08:05:42 +00:00
|
|
|
elif ".experts." in name:
|
|
|
|
# if is moe tensor
|
|
|
|
# in our moe module, expert is cat as one tensor
|
|
|
|
# but mixtral's experts is not cat
|
|
|
|
# we will insert the loaded expert into the position of cat tensor
|
|
|
|
|
|
|
|
# get model param
|
|
|
|
str_idx = name.index(".experts.")
|
|
|
|
expert_idx = int(name.split(".")[-3])
|
|
|
|
if ".w1." in name:
|
|
|
|
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
|
|
|
|
elif ".w2." in name:
|
|
|
|
model_param_name = name.replace(name[str_idx:], ".experts.wo")
|
|
|
|
elif ".w3." in name:
|
|
|
|
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
|
|
|
|
model_param_name = "module." + model_param_name
|
|
|
|
model_param = model_param_dict[model_param_name]
|
|
|
|
assert is_moe_tensor(model_param)
|
|
|
|
# get expert range
|
|
|
|
ep_rank = get_ep_rank(model_param)
|
|
|
|
ep_size = get_ep_size(model_param)
|
|
|
|
expert_num = 8 // ep_size
|
|
|
|
expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num))
|
|
|
|
# insert new param
|
|
|
|
if expert_idx in expert_range:
|
|
|
|
new_param = model_param
|
|
|
|
new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1)
|
|
|
|
state_dict[model_param_name] = new_param
|
|
|
|
state_dict.pop(name)
|
|
|
|
else:
|
|
|
|
new_name = "module." + name
|
|
|
|
state_dict[new_name] = state_dict.pop(name)
|
2023-12-18 02:37:07 +00:00
|
|
|
|
|
|
|
for name, param in list(state_dict.items()):
|
2023-12-25 08:05:42 +00:00
|
|
|
assert name in model_param_dict, f"{name} not in model. model param dict: {model_param_dict.keys()}"
|
2023-12-18 02:37:07 +00:00
|
|
|
dist.barrier()
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
|
|
|
"""
|
|
|
|
Load sharded model with the given path to index file of checkpoint folder.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): The model to be loaded.
|
|
|
|
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
|
|
|
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
|
|
|
This argument should be manually set to False since params on same device might be stored in different files.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Check whether the checkpoint uses safetensors.
|
|
|
|
use_safetensors = False
|
|
|
|
if "safetensors" in checkpoint_index_file.name:
|
|
|
|
use_safetensors = True
|
|
|
|
|
|
|
|
if use_safetensors and not is_safetensors_available():
|
|
|
|
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
|
|
|
|
|
|
|
# Read checkpoint index file.
|
|
|
|
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
|
|
|
ckpt_root_path = ckpt_index_file.root_path
|
|
|
|
weight_map = ckpt_index_file.weight_map
|
|
|
|
strict = False
|
|
|
|
|
|
|
|
# Load params & buffers to model.
|
|
|
|
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
|
|
|
loaded_file = set()
|
|
|
|
|
|
|
|
def _load(name: str):
|
|
|
|
if name not in weight_map:
|
|
|
|
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
|
|
|
|
filename = weight_map[name]
|
|
|
|
|
|
|
|
# If this param/buffer has been loaded before, directly return.
|
|
|
|
if filename in loaded_file:
|
|
|
|
return
|
|
|
|
|
|
|
|
file_path = os.path.join(ckpt_root_path, filename)
|
|
|
|
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
|
|
|
state_dict = self.pre_load_model(model, state_dict)
|
|
|
|
missing_keys = []
|
|
|
|
|
|
|
|
load_state_dict_into_model(
|
|
|
|
model,
|
|
|
|
state_dict,
|
|
|
|
missing_keys=missing_keys,
|
|
|
|
strict=strict,
|
|
|
|
load_sub_module=True,
|
|
|
|
)
|
|
|
|
loaded_file.add(filename)
|
|
|
|
|
|
|
|
# Load parameters.
|
|
|
|
for name, _ in model.named_parameters():
|
|
|
|
name = name.replace("module.", "")
|
|
|
|
name = name.replace(".gate_weight", ".gate.weight")
|
|
|
|
if ".experts.wi_gate" in name:
|
|
|
|
for i in range(8):
|
|
|
|
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
|
|
|
|
_load(new_name)
|
|
|
|
elif ".experts.wi_up" in name:
|
|
|
|
for i in range(8):
|
|
|
|
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
|
|
|
|
_load(new_name)
|
|
|
|
elif ".experts.wo" in name:
|
|
|
|
for i in range(8):
|
|
|
|
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
|
|
|
|
_load(new_name)
|
|
|
|
else:
|
|
|
|
_load(name)
|
|
|
|
|
|
|
|
if self.verbose:
|
|
|
|
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
2023-12-25 08:05:42 +00:00
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def pre_save_model(self, model: nn.Module) -> dict:
|
|
|
|
state_dict = model.state_dict()
|
|
|
|
for name, param in model.named_parameters():
|
|
|
|
if ".experts." in name:
|
|
|
|
if ".experts.gate_weight" in name:
|
|
|
|
new_name = name.replace(".experts.gate_weight", ".experts.gate.weight")
|
|
|
|
state_dict[new_name] = state_dict.pop(name)
|
|
|
|
elif ".experts." in name and is_moe_tensor(param):
|
|
|
|
ep_group = get_ep_group(param)
|
|
|
|
ep_rank = get_ep_rank(param)
|
|
|
|
ep_size = get_ep_size(param)
|
|
|
|
dp_rank = get_dp_rank(param)
|
|
|
|
|
|
|
|
if dp_rank == 0:
|
|
|
|
param = param.data.cuda()
|
|
|
|
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
|
|
|
|
# gather param from every ep rank
|
|
|
|
dist.all_gather(all_param, param, group=ep_group)
|
|
|
|
if ep_rank == 0:
|
|
|
|
all_param = torch.cat(all_param, dim=0)
|
|
|
|
assert all_param.shape[0] == 8
|
|
|
|
for i in range(8):
|
|
|
|
if ".wi_gate" in name:
|
|
|
|
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
|
|
|
|
elif ".wi_up" in name:
|
|
|
|
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
|
|
|
|
elif ".wo" in name:
|
|
|
|
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
|
|
|
|
new_name = new_name.replace("module.", "")
|
|
|
|
new_param = all_param[i].transpose(-1, -2)
|
|
|
|
state_dict[new_name] = new_param.cpu()
|
|
|
|
state_dict.pop(name)
|
|
|
|
|
|
|
|
for name, param in list(state_dict.items()):
|
|
|
|
new_name = name.replace("module.", "")
|
|
|
|
state_dict[new_name] = state_dict.pop(name)
|
|
|
|
|
|
|
|
if self.pp_size > 1:
|
|
|
|
if self.dp_rank == 0:
|
|
|
|
out = [None for _ in range(self.pp_size)]
|
|
|
|
dist.all_gather_object(out, state_dict, group=self.pp_group)
|
|
|
|
if self.pp_rank == 0:
|
|
|
|
new_state_dict = {}
|
|
|
|
for o in out:
|
|
|
|
new_state_dict.update(o)
|
|
|
|
state_dict = new_state_dict
|
|
|
|
dist.barrier()
|
|
|
|
return state_dict
|