mirror of https://github.com/THUDM/ChatGLM-6B
bugfix: linux多卡部署时weight,input不在同一device上,导致RuntimeError
parent
8101d75ab8
commit
6a5267aef7
28
utils.py
28
utils.py
|
@ -1,8 +1,10 @@
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Tuple, Union
|
from typing import Dict, Tuple, Union, Optional
|
||||||
|
|
||||||
from accelerate import load_checkpoint_and_dispatch
|
from accelerate import load_checkpoint_and_dispatch
|
||||||
|
from torch.nn import Module
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
||||||
|
@ -13,10 +15,16 @@ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
||||||
num_trans_layers = 28
|
num_trans_layers = 28
|
||||||
per_gpu_layers = 30 / num_gpus
|
per_gpu_layers = 30 / num_gpus
|
||||||
|
|
||||||
|
# bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
|
||||||
|
# windows下 model.device 会被设置成 transformer.word_embeddings.device
|
||||||
|
# linux下 model.device 会被设置成 lm_head.device
|
||||||
|
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
||||||
|
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
||||||
|
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
||||||
device_map = {'transformer.word_embeddings': 0,
|
device_map = {'transformer.word_embeddings': 0,
|
||||||
'transformer.final_layernorm': num_gpus - 1, 'lm_head': num_gpus - 1}
|
'transformer.final_layernorm': 0, 'lm_head': 0}
|
||||||
|
|
||||||
used = 1
|
used = 2
|
||||||
gpu_target = 0
|
gpu_target = 0
|
||||||
for i in range(num_trans_layers):
|
for i in range(num_trans_layers):
|
||||||
if used >= per_gpu_layers:
|
if used >= per_gpu_layers:
|
||||||
|
@ -29,9 +37,9 @@ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
||||||
return device_map
|
return device_map
|
||||||
|
|
||||||
|
|
||||||
def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike],
|
def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
|
||||||
multi_gpu_model_cache_dir: Union[str, os.PathLike] = "./temp_model_dir",
|
multi_gpu_model_cache_dir: Union[str, os.PathLike] = "./temp_model_dir",
|
||||||
num_gpus: int = 2, **kwargs):
|
tokenizer: Optional[PreTrainedTokenizer] = None, **kwargs) -> Module:
|
||||||
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs)
|
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
|
@ -49,18 +57,22 @@ def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike],
|
||||||
model, multi_gpu_model_cache_dir, device_map=device_map,
|
model, multi_gpu_model_cache_dir, device_map=device_map,
|
||||||
offload_folder="offload", offload_state_dict=True).half()
|
offload_folder="offload", offload_state_dict=True).half()
|
||||||
|
|
||||||
|
if tokenizer is not None:
|
||||||
|
tokenizer.save_pretrained(multi_gpu_model_cache_dir)
|
||||||
print(f"loading model successfully, you should use checkpoint_path={multi_gpu_model_cache_dir} next time")
|
print(f"loading model successfully, you should use checkpoint_path={multi_gpu_model_cache_dir} next time")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(checkpoint_path: Union[str, os.PathLike],
|
def load_model_and_tokenizer(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 1,
|
||||||
multi_gpu_model_cache_dir: Union[str, os.PathLike] = "./temp_model_dir",
|
multi_gpu_model_cache_dir: Union[str, os.PathLike] = "./temp_model_dir",
|
||||||
num_gpus: int = 1, **kwargs) -> Tuple[AutoModel, AutoTokenizer]:
|
**kwargs) -> Tuple[Module, PreTrainedTokenizer]:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs)
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs)
|
||||||
if num_gpus < 2:
|
if num_gpus < 2:
|
||||||
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
|
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
else:
|
else:
|
||||||
model = load_model_on_gpus(checkpoint_path, multi_gpu_model_cache_dir, num_gpus, **kwargs)
|
model = load_model_on_gpus(checkpoint_path, num_gpus=num_gpus,
|
||||||
|
multi_gpu_model_cache_dir=multi_gpu_model_cache_dir,
|
||||||
|
tokenizer=tokenizer, **kwargs)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from utils import load_model_and_tokenizer
|
from utils import load_model_and_tokenizer
|
||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)
|
model, tokenizer = load_model_and_tokenizer("THUDM/chatglm-6b", num_gpus=1)
|
||||||
|
|
Loading…
Reference in New Issue