ColossalAI/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py

228 lines
9.5 KiB
Python

import logging
import os
from pathlib import Path
import torch.distributed as dist
import torch.nn as nn
import copy
import logging
import os
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
from colossalai.checkpoint_io.utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import OptimizerWrapper
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import (
get_dp_group,
get_dp_rank,
get_dp_size,
get_ep_group,
get_ep_rank,
get_ep_size,
is_moe_tensor,
)
from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
from colossalai.moe import MoECheckpintIO
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, is_moe_tensor
class MixtralMoECheckpointIO(MoECheckpintIO):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.no_grad()
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
"""
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
"""
model_param_dict = dict(model.named_parameters())
for name, param in list(state_dict.items()):
if ".gate.weight" in name:
new_name = "module." + name.replace(".gate.weight", ".gate_weight")
state_dict[new_name] = state_dict.pop(name)
elif ".experts." in name:
# if is moe tensor
# in our moe module, expert is cat as one tensor
# but mixtral's experts is not cat
# we will insert the loaded expert into the position of cat tensor
# get model param
str_idx = name.index(".experts.")
expert_idx = int(name.split(".")[-3])
if ".w1." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
elif ".w2." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wo")
elif ".w3." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
model_param_name = "module." + model_param_name
model_param = model_param_dict[model_param_name]
assert is_moe_tensor(model_param)
# get expert range
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = 8 // ep_size
expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num))
# insert new param
if expert_idx in expert_range:
new_param = model_param
new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1)
state_dict[model_param_name] = new_param
state_dict.pop(name)
else:
new_name = "module." + name
state_dict[new_name] = state_dict.pop(name)
for name, param in list(state_dict.items()):
assert name in model_param_dict, f"{name} not in model. model param dict: {model_param_dict.keys()}"
dist.barrier()
return state_dict
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
state_dict = self.pre_load_model(model, state_dict)
missing_keys = []
load_state_dict_into_model(
model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True,
)
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
name = name.replace("module.", "")
name = name.replace(".gate_weight", ".gate.weight")
if ".experts.wi_gate" in name:
for i in range(8):
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
_load(new_name)
elif ".experts.wi_up" in name:
for i in range(8):
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
_load(new_name)
elif ".experts.wo" in name:
for i in range(8):
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
_load(new_name)
else:
_load(name)
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
@torch.no_grad()
def pre_save_model(self, model: nn.Module) -> dict:
state_dict = model.state_dict()
for name, param in model.named_parameters():
if ".experts." in name:
if ".experts.gate_weight" in name:
new_name = name.replace(".experts.gate_weight", ".experts.gate.weight")
state_dict[new_name] = state_dict.pop(name)
elif ".experts." in name and is_moe_tensor(param):
ep_group = get_ep_group(param)
ep_rank = get_ep_rank(param)
ep_size = get_ep_size(param)
dp_rank = get_dp_rank(param)
if dp_rank == 0:
param = param.data.cuda()
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
# gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
assert all_param.shape[0] == 8
for i in range(8):
if ".wi_gate" in name:
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
elif ".wi_up" in name:
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
elif ".wo" in name:
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
new_name = new_name.replace("module.", "")
new_param = all_param[i].transpose(-1, -2)
state_dict[new_name] = new_param.cpu()
state_dict.pop(name)
for name, param in list(state_dict.items()):
new_name = name.replace("module.", "")
state_dict[new_name] = state_dict.pop(name)
if self.pp_size > 1:
if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)]
dist.all_gather_object(out, state_dict, group=self.pp_group)
if self.pp_rank == 0:
new_state_dict = {}
for o in out:
new_state_dict.update(o)
state_dict = new_state_dict
dist.barrier()
return state_dict