mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Lazy Init Support (#5785)
* lazy init support * lazy init llama support * :lazy init support for baichuan * aligh rpc * add note for baichuan --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5874/head
parent
d9d5e7ea1f
commit
3c7cda0c9a
|
@ -24,8 +24,9 @@ from colossalai.inference.modeling.policy import model_policy_map
|
||||||
from colossalai.inference.sampler import search_tokens
|
from colossalai.inference.sampler import search_tokens
|
||||||
from colossalai.inference.spec import Drafter, GlideInput
|
from colossalai.inference.spec import Drafter, GlideInput
|
||||||
from colossalai.inference.struct import Sequence
|
from colossalai.inference.struct import Sequence
|
||||||
from colossalai.inference.utils import get_model_size
|
from colossalai.inference.utils import get_model_size, has_index_file
|
||||||
from colossalai.interface import ModelWrapper
|
from colossalai.interface import ModelWrapper
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
@ -122,16 +123,24 @@ class InferenceEngine:
|
||||||
model_inference_config: the configuration for modeling initialization when inference.
|
model_inference_config: the configuration for modeling initialization when inference.
|
||||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||||
"""
|
"""
|
||||||
|
pretrained_path = None
|
||||||
if isinstance(model_or_path, str):
|
if isinstance(model_or_path, str):
|
||||||
|
import colossalai.interface.pretrained as pretrained_utils
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
||||||
arch = getattr(hf_config, "architectures")[0]
|
arch = getattr(hf_config, "architectures")[0]
|
||||||
if arch in _supported_models.keys():
|
if arch in _supported_models.keys():
|
||||||
# NOTE(lry89757) Currently we load the model using transformers-api,
|
if arch is "BaichuanForCausalLM":
|
||||||
# but we will use lazy tensor and checkpoint io to accelerate
|
self.logger.warning(
|
||||||
# the model load process in the future.
|
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
||||||
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
)
|
||||||
|
ctx = LazyInitContext(default_device="cuda")
|
||||||
|
with ctx:
|
||||||
|
model = _supported_models[arch].from_pretrained(
|
||||||
|
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||||
|
)
|
||||||
|
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||||
else:
|
else:
|
||||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||||
raise ValueError(f"Model {arch} is not supported.")
|
raise ValueError(f"Model {arch} is not supported.")
|
||||||
|
@ -189,14 +198,13 @@ class InferenceEngine:
|
||||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
|
if pretrained_path:
|
||||||
# if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
|
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||||
# from colossalai.inference.core.plugin import InferCheckpoint_io
|
|
||||||
|
|
||||||
# cpt_io = InferCheckpoint_io()
|
cpt_io = InferCheckpoint_io()
|
||||||
# if_has_index_file, model_index_file = has_index_file(model_or_path)
|
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
||||||
# assert if_has_index_file, "the model path is invalid"
|
assert if_has_index_file, "the model path is invalid"
|
||||||
# cpt_io.load_model(self.model, model_index_file)
|
cpt_io.load_model(self.model, model_index_file)
|
||||||
|
|
||||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||||
peak_memory = init_gpu_memory - free_gpu_memory
|
peak_memory = init_gpu_memory - free_gpu_memory
|
||||||
|
|
|
@ -73,7 +73,9 @@ class RPCInferenceEngine(InferenceEngine):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(model_or_path, str):
|
if isinstance(model_or_path, str):
|
||||||
self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
self.model_config = AutoConfig.from_pretrained(
|
||||||
|
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||||
|
)
|
||||||
elif isinstance(model_or_path, nn.Module):
|
elif isinstance(model_or_path, nn.Module):
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"
|
f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n"
|
||||||
|
|
|
@ -18,8 +18,9 @@ from colossalai.inference.modeling.policy import (
|
||||||
model_policy_map,
|
model_policy_map,
|
||||||
)
|
)
|
||||||
from colossalai.inference.sampler import search_tokens
|
from colossalai.inference.sampler import search_tokens
|
||||||
from colossalai.inference.utils import get_model_size
|
from colossalai.inference.utils import get_model_size, has_index_file
|
||||||
from colossalai.interface import ModelWrapper
|
from colossalai.interface import ModelWrapper
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
@ -178,20 +179,23 @@ class rpcWorkerService(rpyc.Service):
|
||||||
model_policy (Policy): the policy to replace the model
|
model_policy (Policy): the policy to replace the model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
pretrained_path = None
|
||||||
if isinstance(model_or_path, str):
|
if isinstance(model_or_path, str):
|
||||||
# is_local = os.path.isdir(model_or_path)
|
import colossalai.interface.pretrained as pretrained_utils
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
||||||
arch = getattr(hf_config, "architectures")[0]
|
arch = getattr(hf_config, "architectures")[0]
|
||||||
# NOTE(lry89757) Currently we load the model using transformers-api,
|
if arch is "BaichuanForCausalLM":
|
||||||
# but we will use lazy tensor and checkpoint io to accelerate
|
self.logger.warning(
|
||||||
# the model load process in the future.
|
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
||||||
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
)
|
||||||
# if is_local:
|
ctx = LazyInitContext(default_device="cuda")
|
||||||
# model = _SUPPORTED_MODELS[arch](hf_config)
|
with ctx:
|
||||||
# else:
|
model = _SUPPORTED_MODELS[arch].from_pretrained(
|
||||||
# # load the real checkpoint
|
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||||
# model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
)
|
||||||
|
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||||
|
@ -240,14 +244,13 @@ class rpcWorkerService(rpyc.Service):
|
||||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
|
if pretrained_path:
|
||||||
# if isinstance(model_or_path, str) and is_local:
|
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||||
# from colossalai.inference.core.plugin import InferCheckpoint_io
|
|
||||||
|
|
||||||
# cpt_io = InferCheckpoint_io()
|
cpt_io = InferCheckpoint_io()
|
||||||
# if_has_index_file, model_index_file = has_index_file(model_or_path)
|
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
||||||
# assert if_has_index_file, "the model path is invalid"
|
assert if_has_index_file, "the model path is invalid"
|
||||||
# cpt_io.load_model(self.model, model_index_file)
|
cpt_io.load_model(self.model, model_index_file)
|
||||||
|
|
||||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||||
peak_memory = init_gpu_memory - free_gpu_memory
|
peak_memory = init_gpu_memory - free_gpu_memory
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.shardformer.layer import Linear1D_Col
|
from colossalai.shardformer.layer import Linear1D_Col
|
||||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||||
|
|
||||||
|
@ -12,17 +14,51 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
|
||||||
def from_native_module(
|
def from_native_module(
|
||||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||||
) -> ParallelModule:
|
) -> ParallelModule:
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
module.in_features = module.weight.size(1)
|
module.in_features = module.weight.size(1)
|
||||||
module.out_features = module.weight.size(0)
|
module.out_features = module.weight.size(0)
|
||||||
module.bias = None
|
module.bias = None
|
||||||
module.weight.data = nn.functional.normalize(
|
module.weight.data = nn.functional.normalize(
|
||||||
module.weight
|
module.weight
|
||||||
) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
|
) # NOTE(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
|
||||||
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
|
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
|
||||||
|
|
||||||
return Linear1D_Col.from_native_module(
|
# get the attributes
|
||||||
module,
|
in_features = module.in_features
|
||||||
process_group,
|
out_features = module.out_features
|
||||||
*args,
|
bias = module.bias is not None
|
||||||
|
device = module.weight.device
|
||||||
|
# ensure only one process group is passed
|
||||||
|
if isinstance(process_group, (list, tuple)):
|
||||||
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||||
|
process_group = process_group[0]
|
||||||
|
|
||||||
|
tp_size = dist.get_world_size(process_group)
|
||||||
|
if out_features < tp_size:
|
||||||
|
return module
|
||||||
|
|
||||||
|
if out_features % tp_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||||
|
)
|
||||||
|
|
||||||
|
lmhead_1d = BaichuanLMHeadLinear1D_Col(
|
||||||
|
in_features=in_features,
|
||||||
|
out_features=out_features,
|
||||||
|
bias=bias,
|
||||||
|
device=device,
|
||||||
|
process_group=process_group,
|
||||||
|
weight=module.weight,
|
||||||
|
bias_=module.bias,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return lmhead_1d
|
||||||
|
|
||||||
|
def _load_from_state_dict(
|
||||||
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
|
):
|
||||||
|
state_dict[prefix + "weight"] = nn.functional.normalize(state_dict[prefix + "weight"])
|
||||||
|
super()._load_from_state_dict(
|
||||||
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
|
)
|
||||||
|
|
|
@ -70,7 +70,6 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
|
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj. Defaults to None.
|
||||||
"""
|
"""
|
||||||
ParallelModule.__init__(self)
|
ParallelModule.__init__(self)
|
||||||
self.o_proj = attn_oproj
|
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
@ -78,6 +77,7 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.W_pack = W_pack
|
self.W_pack = W_pack
|
||||||
|
self.o_proj = attn_oproj
|
||||||
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
|
self.use_cuda_kernel = model_shard_infer_config.use_cuda_kernel
|
||||||
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
self.attention_backend = get_attention_backend(model_shard_infer_config)
|
||||||
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
self.pre_attention_backend = get_pre_attention_backend(model_shard_infer_config)
|
||||||
|
|
|
@ -284,6 +284,10 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
||||||
self.gate_up_weight = nn.Parameter(
|
self.gate_up_weight = nn.Parameter(
|
||||||
torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0)
|
torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0)
|
||||||
)
|
)
|
||||||
|
self.gate_up_dict = {
|
||||||
|
"gate_proj.weight": None,
|
||||||
|
"up_proj.weight": None,
|
||||||
|
} # used and delattr in load/shard of gate/up weight
|
||||||
self.down_proj = mlp_dproj
|
self.down_proj = mlp_dproj
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
||||||
|
@ -321,44 +325,47 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
||||||
):
|
):
|
||||||
# NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight)
|
# NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight)
|
||||||
|
|
||||||
for hook in self._load_state_dict_pre_hooks.values():
|
if hasattr(self, "gate_up_dict"):
|
||||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
for hook in self._load_state_dict_pre_hooks.values():
|
||||||
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||||
|
|
||||||
key = "gate_up_weight"
|
device_mesh = self.helper_layout.device_mesh
|
||||||
k1 = "gate_proj.weight"
|
sharding_spec = self.helper_layout.sharding_spec
|
||||||
k2 = "up_proj.weight"
|
for weight_name in self.gate_up_dict:
|
||||||
|
prefix_weight_name = prefix + weight_name
|
||||||
|
if prefix_weight_name in state_dict.keys():
|
||||||
|
w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
|
||||||
|
self.gate_up_dict[weight_name] = w.T
|
||||||
|
|
||||||
gate_w = state_dict[prefix + k1]
|
if None not in self.gate_up_dict.values():
|
||||||
up_w = state_dict[prefix + k2]
|
# we've got all the weights of gate/up
|
||||||
|
gate_up_w = torch.stack(list(self.gate_up_dict.values()), dim=0)
|
||||||
|
|
||||||
device_mesh = self.helper_layout.device_mesh
|
input_param = nn.Parameter(
|
||||||
sharding_spec = self.helper_layout.sharding_spec
|
gate_up_w
|
||||||
gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec)
|
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||||
up_w = distribute_tensor(up_w, device_mesh, sharding_spec)
|
|
||||||
|
|
||||||
gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0)
|
key = "gate_up_weight"
|
||||||
|
param = local_state.get(key, None)
|
||||||
|
|
||||||
input_param = nn.Parameter(
|
try:
|
||||||
gate_up_w
|
with torch.no_grad():
|
||||||
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
param.copy_(input_param)
|
||||||
param = local_state[key]
|
except Exception as ex:
|
||||||
|
error_msgs.append(
|
||||||
|
'While copying the parameter named "{}", '
|
||||||
|
"whose dimensions in the model are {} and "
|
||||||
|
"whose dimensions in the checkpoint are {}, "
|
||||||
|
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
del self.gate_up_dict
|
||||||
with torch.no_grad():
|
|
||||||
param.copy_(input_param)
|
|
||||||
except Exception as ex:
|
|
||||||
error_msgs.append(
|
|
||||||
'While copying the parameter named "{}", '
|
|
||||||
"whose dimensions in the model are {} and "
|
|
||||||
"whose dimensions in the checkpoint are {}, "
|
|
||||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
|
||||||
)
|
|
||||||
|
|
||||||
strict = False # to avoid unexpected_keys
|
strict = False # to avoid unexpected_keys
|
||||||
super()._load_from_state_dict(
|
super()._load_from_state_dict(
|
||||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
)
|
)
|
||||||
|
@ -429,7 +436,15 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||||
self.helper_layout = (
|
self.helper_layout = (
|
||||||
attn_qproj_w.dist_layout
|
attn_qproj_w.dist_layout
|
||||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
||||||
|
self.qkv_dict = {
|
||||||
|
"q_proj.weight": None,
|
||||||
|
"k_proj.weight": None,
|
||||||
|
"v_proj.weight": None,
|
||||||
|
} # used and delattr in load/shard of qkv weight
|
||||||
else:
|
else:
|
||||||
|
self.helper_layout = (
|
||||||
|
attn_qproj_w.dist_layout
|
||||||
|
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
||||||
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
|
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
|
||||||
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
|
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
|
||||||
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
|
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
|
||||||
|
@ -577,49 +592,83 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||||
def _load_from_state_dict(
|
def _load_from_state_dict(
|
||||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
):
|
):
|
||||||
if self.num_heads == self.num_key_value_heads:
|
for hook in self._load_state_dict_pre_hooks.values():
|
||||||
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||||
|
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||||
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||||
|
|
||||||
|
device_mesh = self.helper_layout.device_mesh
|
||||||
|
sharding_spec = self.helper_layout.sharding_spec
|
||||||
|
|
||||||
|
if self.num_heads == self.num_key_value_heads and hasattr(self, "qkv_dict"):
|
||||||
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
|
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
|
||||||
for hook in self._load_state_dict_pre_hooks.values():
|
|
||||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
||||||
|
|
||||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
|
||||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
|
||||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
||||||
|
|
||||||
key = "qkv_weight"
|
key = "qkv_weight"
|
||||||
k1 = "q_proj.weight"
|
|
||||||
k2 = "k_proj.weight"
|
|
||||||
k3 = "v_proj.weight"
|
|
||||||
q_w = state_dict[prefix + k1]
|
|
||||||
k_w = state_dict[prefix + k2]
|
|
||||||
v_w = state_dict[prefix + k3]
|
|
||||||
|
|
||||||
device_mesh = self.helper_layout.device_mesh
|
# NOTE(@lry89757) We will load the sharded checkpoint file according to the weight map from *.index.json
|
||||||
sharding_spec = self.helper_layout.sharding_spec
|
# Here we need the weight of q,k,v to stack the weights of q,k,v into one qkv weight.
|
||||||
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
|
# Unfortunately, it is highly like that all weights of q,k,v are not in the same sharded checkpoint file(like meta-llama/llama3-70B)
|
||||||
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
|
# so here we will stack them when we really collect all the three weights.
|
||||||
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
|
for weight_name in self.qkv_dict:
|
||||||
|
prefix_weight_name = prefix + weight_name
|
||||||
|
if prefix_weight_name in state_dict.keys():
|
||||||
|
w = distribute_tensor(state_dict[prefix_weight_name], device_mesh, sharding_spec)
|
||||||
|
self.qkv_dict[weight_name] = w.T
|
||||||
|
|
||||||
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
|
if None not in self.qkv_dict.values():
|
||||||
|
# we've got all the weights of q, k, v
|
||||||
|
qkv_w = torch.stack(list(self.qkv_dict.values()), dim=0)
|
||||||
|
|
||||||
input_param = nn.Parameter(
|
input_param = nn.Parameter(
|
||||||
qkv_w
|
qkv_w
|
||||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||||
|
|
||||||
param = local_state[key]
|
param = local_state[key]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
param.copy_(input_param)
|
param.copy_(input_param)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
error_msgs.append(
|
error_msgs.append(
|
||||||
'While copying the parameter named "{}", '
|
'While copying the parameter named "{}", '
|
||||||
"whose dimensions in the model are {} and "
|
"whose dimensions in the model are {} and "
|
||||||
"whose dimensions in the checkpoint are {}, "
|
"whose dimensions in the checkpoint are {}, "
|
||||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||||
)
|
)
|
||||||
|
|
||||||
strict = False # to avoid unexpected_keys
|
del self.qkv_dict
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def _load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight"):
|
||||||
|
if prefix + origin_weight_name in state_dict.keys():
|
||||||
|
attn_qproj_w = state_dict[prefix + origin_weight_name]
|
||||||
|
w = distribute_tensor(attn_qproj_w, device_mesh, sharding_spec)
|
||||||
|
input_param = nn.Parameter(w.T)
|
||||||
|
param = local_state[local_weight_name]
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
param.copy_(input_param)
|
||||||
|
except Exception as ex:
|
||||||
|
key = local_weight_name
|
||||||
|
error_msgs.append(
|
||||||
|
'While copying the parameter named "{}", '
|
||||||
|
"whose dimensions in the model are {} and "
|
||||||
|
"whose dimensions in the checkpoint are {}, "
|
||||||
|
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||||
|
)
|
||||||
|
|
||||||
|
if prefix + "q_proj.weight" in state_dict.keys():
|
||||||
|
_load(origin_weight_name="q_proj.weight", local_weight_name="q_proj_weight")
|
||||||
|
|
||||||
|
if prefix + "k_proj.weight" in state_dict.keys():
|
||||||
|
_load(origin_weight_name="k_proj.weight", local_weight_name="k_proj_weight")
|
||||||
|
|
||||||
|
if prefix + "v_proj.weight" in state_dict.keys():
|
||||||
|
_load(origin_weight_name="v_proj.weight", local_weight_name="v_proj_weight")
|
||||||
|
|
||||||
|
strict = False # to avoid unexpected_keys
|
||||||
super()._load_from_state_dict(
|
super()._load_from_state_dict(
|
||||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
)
|
)
|
||||||
|
|
|
@ -674,6 +674,8 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
||||||
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
|
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
|
||||||
"""
|
"""
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
|
||||||
# get the attributes
|
# get the attributes
|
||||||
in_features = module.in_features
|
in_features = module.in_features
|
||||||
out_features = module.out_features
|
out_features = module.out_features
|
||||||
|
|
Loading…
Reference in New Issue