mirror of https://github.com/hpcaitech/ColossalAI
[Fix/Example] Fix Llama Inference Loading Data Type (#5763)
* [fix/example] fix llama inference loading dtype * revise loading dtype of benchmark llama3pull/5758/head
parent
023ea13cb5
commit
677cbfacf8
|
@ -17,6 +17,13 @@ GIGABYTE = 1024**3
|
|||
MEGABYTE = 1024**2
|
||||
N_WARMUP_STEPS = 2
|
||||
|
||||
TORCH_DTYPE_MAP = {
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
CONFIG_MAP = {
|
||||
"toy": transformers.LlamaConfig(num_hidden_layers=4),
|
||||
"llama-7b": transformers.LlamaConfig(
|
||||
|
@ -104,10 +111,13 @@ def print_details_info(model_config, whole_end2end, total_token_num, dtype, coor
|
|||
def benchmark_inference(args):
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
torch_dtype = TORCH_DTYPE_MAP.get(args.dtype, None)
|
||||
config = CONFIG_MAP[args.model]
|
||||
config.torch_dtype = torch_dtype
|
||||
config.pad_token_id = config.eos_token_id
|
||||
|
||||
if args.model_path is not None:
|
||||
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path)
|
||||
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path, torch_dtype=torch_dtype)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
else:
|
||||
# Random weights
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
|
||||
from torch import bfloat16, float16, float32
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
import colossalai
|
||||
|
@ -12,6 +13,12 @@ from colossalai.inference.modeling.policy.nopadding_llama import NoPaddingLlamaM
|
|||
MODEL_CLS = AutoModelForCausalLM
|
||||
POLICY_CLS = NoPaddingLlamaModelInferPolicy
|
||||
|
||||
TORCH_DTYPE_MAP = {
|
||||
"fp16": float16,
|
||||
"fp32": float32,
|
||||
"bf16": bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def infer(args):
|
||||
# ==============================
|
||||
|
@ -24,7 +31,7 @@ def infer(args):
|
|||
# Load model and tokenizer
|
||||
# ==============================
|
||||
model_path_or_name = args.model
|
||||
model = MODEL_CLS.from_pretrained(model_path_or_name)
|
||||
model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
# coordinator.print_on_master(f"Model Config:\n{model.config}")
|
||||
|
|
Loading…
Reference in New Issue