mirror of https://github.com/hpcaitech/ColossalAI
[Fix] Llama3 Load/Omit CheckpointIO Temporarily (#5717)
* Fix Llama3 Load error * Omit Checkpoint IO Temporarilypull/5723/head
parent
5bbab1533a
commit
74c47921fa
|
@ -24,7 +24,7 @@ from colossalai.inference.modeling.policy import model_policy_map
|
||||||
from colossalai.inference.sampler import search_tokens
|
from colossalai.inference.sampler import search_tokens
|
||||||
from colossalai.inference.spec import Drafter, GlideInput
|
from colossalai.inference.spec import Drafter, GlideInput
|
||||||
from colossalai.inference.struct import Sequence
|
from colossalai.inference.struct import Sequence
|
||||||
from colossalai.inference.utils import get_model_size, has_index_file
|
from colossalai.inference.utils import get_model_size
|
||||||
from colossalai.interface import ModelWrapper
|
from colossalai.interface import ModelWrapper
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
@ -113,18 +113,15 @@ class InferenceEngine:
|
||||||
model_policy (Policy): the policy to replace the model
|
model_policy (Policy): the policy to replace the model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
casuallm = None
|
|
||||||
if isinstance(model_or_path, str):
|
if isinstance(model_or_path, str):
|
||||||
try:
|
try:
|
||||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||||
arch = getattr(hf_config, "architectures")[0]
|
arch = getattr(hf_config, "architectures")[0]
|
||||||
if arch in _supported_models.keys():
|
if arch in _supported_models.keys():
|
||||||
casuallm = _supported_models[arch](hf_config)
|
# NOTE(lry89757) Currently we load the model using transformers-api,
|
||||||
if isinstance(casuallm, AutoModelForCausalLM):
|
# but we will use lazy tensor and checkpoint io to accelerate
|
||||||
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
|
# the model load process in the future.
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half()
|
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||||
else:
|
|
||||||
model = _supported_models[arch](hf_config)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model {arch} is not supported.")
|
raise ValueError(f"Model {arch} is not supported.")
|
||||||
|
|
||||||
|
@ -175,13 +172,14 @@ class InferenceEngine:
|
||||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
|
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
|
||||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
# if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
|
||||||
|
# from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||||
|
|
||||||
cpt_io = InferCheckpoint_io()
|
# cpt_io = InferCheckpoint_io()
|
||||||
if_has_index_file, model_index_file = has_index_file(model_or_path)
|
# if_has_index_file, model_index_file = has_index_file(model_or_path)
|
||||||
assert if_has_index_file, "the model path is invalid"
|
# assert if_has_index_file, "the model path is invalid"
|
||||||
cpt_io.load_model(self.model, model_index_file)
|
# cpt_io.load_model(self.model, model_index_file)
|
||||||
|
|
||||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||||
peak_memory = init_gpu_memory - free_gpu_memory
|
peak_memory = init_gpu_memory - free_gpu_memory
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import rpyc
|
import rpyc
|
||||||
|
@ -19,7 +18,7 @@ from colossalai.inference.modeling.policy import (
|
||||||
model_policy_map,
|
model_policy_map,
|
||||||
)
|
)
|
||||||
from colossalai.inference.sampler import search_tokens
|
from colossalai.inference.sampler import search_tokens
|
||||||
from colossalai.inference.utils import get_model_size, has_index_file
|
from colossalai.inference.utils import get_model_size
|
||||||
from colossalai.interface import ModelWrapper
|
from colossalai.interface import ModelWrapper
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
@ -178,15 +177,19 @@ class rpcWorkerService(rpyc.Service):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(model_or_path, str):
|
if isinstance(model_or_path, str):
|
||||||
is_local = os.path.isdir(model_or_path)
|
# is_local = os.path.isdir(model_or_path)
|
||||||
try:
|
try:
|
||||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||||
arch = getattr(hf_config, "architectures")[0]
|
arch = getattr(hf_config, "architectures")[0]
|
||||||
if is_local:
|
# NOTE(lry89757) Currently we load the model using transformers-api,
|
||||||
model = _SUPPORTED_MODELS[arch](hf_config)
|
# but we will use lazy tensor and checkpoint io to accelerate
|
||||||
else:
|
# the model load process in the future.
|
||||||
# load the real checkpoint
|
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||||
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
# if is_local:
|
||||||
|
# model = _SUPPORTED_MODELS[arch](hf_config)
|
||||||
|
# else:
|
||||||
|
# # load the real checkpoint
|
||||||
|
# model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||||
|
@ -235,13 +238,14 @@ class rpcWorkerService(rpyc.Service):
|
||||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(model_or_path, str) and is_local:
|
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
|
||||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
# if isinstance(model_or_path, str) and is_local:
|
||||||
|
# from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||||
|
|
||||||
cpt_io = InferCheckpoint_io()
|
# cpt_io = InferCheckpoint_io()
|
||||||
if_has_index_file, model_index_file = has_index_file(model_or_path)
|
# if_has_index_file, model_index_file = has_index_file(model_or_path)
|
||||||
assert if_has_index_file, "the model path is invalid"
|
# assert if_has_index_file, "the model path is invalid"
|
||||||
cpt_io.load_model(self.model, model_index_file)
|
# cpt_io.load_model(self.model, model_index_file)
|
||||||
|
|
||||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||||
peak_memory = init_gpu_memory - free_gpu_memory
|
peak_memory = init_gpu_memory - free_gpu_memory
|
||||||
|
|
|
@ -646,48 +646,49 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||||
def _load_from_state_dict(
|
def _load_from_state_dict(
|
||||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
):
|
):
|
||||||
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
|
if self.num_heads == self.num_key_value_heads:
|
||||||
for hook in self._load_state_dict_pre_hooks.values():
|
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
|
||||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
for hook in self._load_state_dict_pre_hooks.values():
|
||||||
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||||
|
|
||||||
key = "qkv_weight"
|
key = "qkv_weight"
|
||||||
k1 = "q_proj.weight"
|
k1 = "q_proj.weight"
|
||||||
k2 = "k_proj.weight"
|
k2 = "k_proj.weight"
|
||||||
k3 = "v_proj.weight"
|
k3 = "v_proj.weight"
|
||||||
q_w = state_dict[prefix + k1]
|
q_w = state_dict[prefix + k1]
|
||||||
k_w = state_dict[prefix + k2]
|
k_w = state_dict[prefix + k2]
|
||||||
v_w = state_dict[prefix + k3]
|
v_w = state_dict[prefix + k3]
|
||||||
|
|
||||||
device_mesh = self.helper_layout.device_mesh
|
device_mesh = self.helper_layout.device_mesh
|
||||||
sharding_spec = self.helper_layout.sharding_spec
|
sharding_spec = self.helper_layout.sharding_spec
|
||||||
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
|
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
|
||||||
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
|
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
|
||||||
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
|
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
|
||||||
|
|
||||||
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
|
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
|
||||||
|
|
||||||
input_param = nn.Parameter(
|
input_param = nn.Parameter(
|
||||||
qkv_w
|
qkv_w
|
||||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||||
|
|
||||||
param = local_state[key]
|
param = local_state[key]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
param.copy_(input_param)
|
param.copy_(input_param)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
error_msgs.append(
|
error_msgs.append(
|
||||||
'While copying the parameter named "{}", '
|
'While copying the parameter named "{}", '
|
||||||
"whose dimensions in the model are {} and "
|
"whose dimensions in the model are {} and "
|
||||||
"whose dimensions in the checkpoint are {}, "
|
"whose dimensions in the checkpoint are {}, "
|
||||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||||
)
|
)
|
||||||
|
|
||||||
strict = False # to avoid unexpected_keys
|
strict = False # to avoid unexpected_keys
|
||||||
super()._load_from_state_dict(
|
super()._load_from_state_dict(
|
||||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue