Browse Source

[Fix] Llama3 Load/Omit CheckpointIO Temporarily (#5717)

* Fix Llama3 Load error
* Omit Checkpoint IO Temporarily
pull/5723/head
Runyu Lu 6 months ago committed by GitHub
parent
commit
74c47921fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 26
      colossalai/inference/core/engine.py
  2. 30
      colossalai/inference/executor/rpc_worker.py
  3. 1
      colossalai/inference/modeling/models/nopadding_llama.py

26
colossalai/inference/core/engine.py

@ -24,7 +24,7 @@ from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file from colossalai.inference.utils import get_model_size
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -113,18 +113,15 @@ class InferenceEngine:
model_policy (Policy): the policy to replace the model model_policy (Policy): the policy to replace the model
""" """
casuallm = None
if isinstance(model_or_path, str): if isinstance(model_or_path, str):
try: try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
arch = getattr(hf_config, "architectures")[0] arch = getattr(hf_config, "architectures")[0]
if arch in _supported_models.keys(): if arch in _supported_models.keys():
casuallm = _supported_models[arch](hf_config) # NOTE(lry89757) Currently we load the model using transformers-api,
if isinstance(casuallm, AutoModelForCausalLM): # but we will use lazy tensor and checkpoint io to accelerate
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory. # the model load process in the future.
model = AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half() model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
else:
model = _supported_models[arch](hf_config)
else: else:
raise ValueError(f"Model {arch} is not supported.") raise ValueError(f"Model {arch} is not supported.")
@ -175,13 +172,14 @@ class InferenceEngine:
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
) )
if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM): # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
from colossalai.inference.core.plugin import InferCheckpoint_io # if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
# from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io() # cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(model_or_path) # if_has_index_file, model_index_file = has_index_file(model_or_path)
assert if_has_index_file, "the model path is invalid" # assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file) # cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory peak_memory = init_gpu_memory - free_gpu_memory

30
colossalai/inference/executor/rpc_worker.py

@ -1,4 +1,3 @@
import os
from typing import List, Tuple, Union from typing import List, Tuple, Union
import rpyc import rpyc
@ -19,7 +18,7 @@ from colossalai.inference.modeling.policy import (
model_policy_map, model_policy_map,
) )
from colossalai.inference.sampler import search_tokens from colossalai.inference.sampler import search_tokens
from colossalai.inference.utils import get_model_size, has_index_file from colossalai.inference.utils import get_model_size
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -178,15 +177,19 @@ class rpcWorkerService(rpyc.Service):
""" """
if isinstance(model_or_path, str): if isinstance(model_or_path, str):
is_local = os.path.isdir(model_or_path) # is_local = os.path.isdir(model_or_path)
try: try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
arch = getattr(hf_config, "architectures")[0] arch = getattr(hf_config, "architectures")[0]
if is_local: # NOTE(lry89757) Currently we load the model using transformers-api,
model = _SUPPORTED_MODELS[arch](hf_config) # but we will use lazy tensor and checkpoint io to accelerate
else: # the model load process in the future.
# load the real checkpoint
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
# if is_local:
# model = _SUPPORTED_MODELS[arch](hf_config)
# else:
# # load the real checkpoint
# model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
@ -235,13 +238,14 @@ class rpcWorkerService(rpyc.Service):
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
) )
if isinstance(model_or_path, str) and is_local: # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
from colossalai.inference.core.plugin import InferCheckpoint_io # if isinstance(model_or_path, str) and is_local:
# from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io() # cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(model_or_path) # if_has_index_file, model_index_file = has_index_file(model_or_path)
assert if_has_index_file, "the model path is invalid" # assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file) # cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory peak_memory = init_gpu_memory - free_gpu_memory

1
colossalai/inference/modeling/models/nopadding_llama.py

@ -646,6 +646,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
def _load_from_state_dict( def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
): ):
if self.num_heads == self.num_key_value_heads:
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight) # NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
for hook in self._load_state_dict_pre_hooks.values(): for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

Loading…
Cancel
Save