Browse Source

[Fix/Example] Fix Llama Inference Loading Data Type (#5763)

* [fix/example] fix llama inference loading dtype

* revise loading dtype of benchmark llama3
pull/5758/head
Yuanheng Zhao 6 months ago committed by GitHub
parent
commit
677cbfacf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 12
      examples/inference/llama/benchmark_llama3.py
  2. 9
      examples/inference/llama/llama_generation.py

12
examples/inference/llama/benchmark_llama3.py

@ -17,6 +17,13 @@ GIGABYTE = 1024**3
MEGABYTE = 1024**2
N_WARMUP_STEPS = 2
TORCH_DTYPE_MAP = {
"fp16": torch.float16,
"fp32": torch.float32,
"bf16": torch.bfloat16,
}
CONFIG_MAP = {
"toy": transformers.LlamaConfig(num_hidden_layers=4),
"llama-7b": transformers.LlamaConfig(
@ -104,10 +111,13 @@ def print_details_info(model_config, whole_end2end, total_token_num, dtype, coor
def benchmark_inference(args):
coordinator = DistCoordinator()
torch_dtype = TORCH_DTYPE_MAP.get(args.dtype, None)
config = CONFIG_MAP[args.model]
config.torch_dtype = torch_dtype
config.pad_token_id = config.eos_token_id
if args.model_path is not None:
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path)
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path, torch_dtype=torch_dtype)
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
else:
# Random weights

9
examples/inference/llama/llama_generation.py

@ -1,5 +1,6 @@
import argparse
from torch import bfloat16, float16, float32
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import colossalai
@ -12,6 +13,12 @@ from colossalai.inference.modeling.policy.nopadding_llama import NoPaddingLlamaM
MODEL_CLS = AutoModelForCausalLM
POLICY_CLS = NoPaddingLlamaModelInferPolicy
TORCH_DTYPE_MAP = {
"fp16": float16,
"fp32": float32,
"bf16": bfloat16,
}
def infer(args):
# ==============================
@ -24,7 +31,7 @@ def infer(args):
# Load model and tokenizer
# ==============================
model_path_or_name = args.model
model = MODEL_CLS.from_pretrained(model_path_or_name)
model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
tokenizer.pad_token = tokenizer.eos_token
# coordinator.print_on_master(f"Model Config:\n{model.config}")

Loading…
Cancel
Save