@ -24,7 +24,7 @@ from colossalai.inference.modeling.policy import model_policy_map
from colossalai . inference . sampler import search_tokens
from colossalai . inference . spec import Drafter , GlideInput
from colossalai . inference . struct import Sequence
from colossalai . inference . utils import get_model_size , has_index_file
from colossalai . inference . utils import get_model_size
from colossalai . interface import ModelWrapper
from colossalai . logging import get_dist_logger
from colossalai . pipeline . stage_manager import PipelineStageManager
@ -113,18 +113,15 @@ class InferenceEngine:
model_policy ( Policy ) : the policy to replace the model
"""
casuallm = None
if isinstance ( model_or_path , str ) :
try :
hf_config = AutoConfig . from_pretrained ( model_or_path , trust_remote_code = True )
arch = getattr ( hf_config , " architectures " ) [ 0 ]
if arch in _supported_models . keys ( ) :
casuallm = _supported_models [ arch ] ( hf_config )
if isinstance ( casuallm , AutoModelForCausalLM ) :
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
model = AutoModelForCausalLM . from_pretrained ( model_or_path , trust_remote_code = True ) . half ( )
else :
model = _supported_models [ arch ] ( hf_config )
# NOTE(lry89757) Currently we load the model using transformers-api,
# but we will use lazy tensor and checkpoint io to accelerate
# the model load process in the future.
model = _supported_models [ arch ] . from_pretrained ( model_or_path , trust_remote_code = True )
else :
raise ValueError ( f " Model { arch } is not supported. " )
@ -175,13 +172,14 @@ class InferenceEngine:
f " After the shard, Rank: [ { dist . get_rank ( ) } ], model size: { get_model_size ( self . model ) } GB, model ' s device is: { model . device } "
)
if isinstance ( model_or_path , str ) and not isinstance ( casuallm , AutoModelForCausalLM ) :
from colossalai . inference . core . plugin import InferCheckpoint_io
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
# if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
# from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io ( )
if_has_index_file , model_index_file = has_index_file ( model_or_path )
assert if_has_index_file , " the model path is invalid "
cpt_io . load_model ( self . model , model_index_file )
# cpt_io = InferCheckpoint_io( )
# if_has_index_file, model_index_file = has_index_file(model_or_path )
# assert if_has_index_file, "the model path is invalid "
# cpt_io.load_model(self.model, model_index_file )
free_gpu_memory , total_gpu_memory = torch . cuda . mem_get_info ( )
peak_memory = init_gpu_memory - free_gpu_memory