diff --git a/chatglm_parallel.py b/chatglm_parallel.py index e500d64..da3f776 100644 --- a/chatglm_parallel.py +++ b/chatglm_parallel.py @@ -3,28 +3,38 @@ Author: lichuang Date: 2023-03-23 09:18:13 Description: 将模型加载到多张GPU卡中,根据gpu的数量自动分配平均的显存占用 ''' +from typing import Dict -from transformers import AutoModel, AutoTokenizer from accelerate import load_checkpoint_and_dispatch +from transformers import AutoModel + + +def auto_configure_device_map(num_gpus) -> Dict[str, int]: + # transformer.word_embeddings 占用1层 + # transformer.final_layernorm 和 lm_head 占用1层 + # transformer.layers 占用 28 层 + # 总共30层分配到num_gpus张卡上 + num_trans_layers = 28 + per_gpu_layers = 30 / num_gpus + + device_map = {'transformer.word_embeddings': 0, + 'transformer.final_layernorm': num_gpus - 1, 'lm_head': num_gpus - 1} + + used = 1 + gpu_target = 0 + for i in range(num_trans_layers): + if used >= per_gpu_layers: + gpu_target += 1 + used = 0 + assert gpu_target < num_gpus + device_map[f'transformer.layers.{i}'] = gpu_target + used += 1 + + return device_map def load_model_on_gpus(checkpoint_path, num_gpus=2): - # 总共占用13GB显存,28层transformer每层0.39GB左右 - # 第一层 word_embeddings和最后一层 lm_head 层各占用1.2GB左右 - num_trans_layers = 28 - vram_per_layer = 0.39 - average = 13/num_gpus - used = 1.2 - device_map = {'transformer.word_embeddings': 0, - 'transformer.final_layernorm': num_gpus-1, 'lm_head': num_gpus-1} - gpu_target = 0 - for i in range(num_trans_layers): - if used > average-vram_per_layer/2 and gpu_target < num_gpus: - gpu_target += 1 - used = 0 - else: - used += vram_per_layer - device_map['transformer.layers.%d' % i] = gpu_target + device_map = auto_configure_device_map(num_gpus) model = AutoModel.from_pretrained( checkpoint_path, trust_remote_code=True) diff --git a/requirements.txt b/requirements.txt index 2948480..f311c85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ icetk cpm_kernels torch>=1.10 gradio +accelerate \ No newline at end of file diff --git a/web_demo.py b/web_demo.py index 07ddc33..6f4f34d 100644 --- a/web_demo.py +++ b/web_demo.py @@ -1,4 +1,4 @@ -from transformers import AutoModel, AutoTokenizer +from transformers import AutoTokenizer import gradio as gr from chatglm_parallel import load_model_on_gpus