Browse Source

[Fix/Example] Fix Llama Inference Loading Data Type (#5763)

* [fix/example] fix llama inference loading dtype

* revise loading dtype of benchmark llama3
pull/5758/head
Yuanheng Zhao 6 months ago committed by GitHub
parent
commit
677cbfacf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 12
      examples/inference/llama/benchmark_llama3.py
  2. 9
      examples/inference/llama/llama_generation.py

12
examples/inference/llama/benchmark_llama3.py

@ -17,6 +17,13 @@ GIGABYTE = 1024**3
MEGABYTE = 1024**2 MEGABYTE = 1024**2
N_WARMUP_STEPS = 2 N_WARMUP_STEPS = 2
TORCH_DTYPE_MAP = {
"fp16": torch.float16,
"fp32": torch.float32,
"bf16": torch.bfloat16,
}
CONFIG_MAP = { CONFIG_MAP = {
"toy": transformers.LlamaConfig(num_hidden_layers=4), "toy": transformers.LlamaConfig(num_hidden_layers=4),
"llama-7b": transformers.LlamaConfig( "llama-7b": transformers.LlamaConfig(
@ -104,10 +111,13 @@ def print_details_info(model_config, whole_end2end, total_token_num, dtype, coor
def benchmark_inference(args): def benchmark_inference(args):
coordinator = DistCoordinator() coordinator = DistCoordinator()
torch_dtype = TORCH_DTYPE_MAP.get(args.dtype, None)
config = CONFIG_MAP[args.model] config = CONFIG_MAP[args.model]
config.torch_dtype = torch_dtype
config.pad_token_id = config.eos_token_id config.pad_token_id = config.eos_token_id
if args.model_path is not None: if args.model_path is not None:
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path) model = transformers.LlamaForCausalLM.from_pretrained(args.model_path, torch_dtype=torch_dtype)
tokenizer = AutoTokenizer.from_pretrained(args.model_path) tokenizer = AutoTokenizer.from_pretrained(args.model_path)
else: else:
# Random weights # Random weights

9
examples/inference/llama/llama_generation.py

@ -1,5 +1,6 @@
import argparse import argparse
from torch import bfloat16, float16, float32
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import colossalai import colossalai
@ -12,6 +13,12 @@ from colossalai.inference.modeling.policy.nopadding_llama import NoPaddingLlamaM
MODEL_CLS = AutoModelForCausalLM MODEL_CLS = AutoModelForCausalLM
POLICY_CLS = NoPaddingLlamaModelInferPolicy POLICY_CLS = NoPaddingLlamaModelInferPolicy
TORCH_DTYPE_MAP = {
"fp16": float16,
"fp32": float32,
"bf16": bfloat16,
}
def infer(args): def infer(args):
# ============================== # ==============================
@ -24,7 +31,7 @@ def infer(args):
# Load model and tokenizer # Load model and tokenizer
# ============================== # ==============================
model_path_or_name = args.model model_path_or_name = args.model
model = MODEL_CLS.from_pretrained(model_path_or_name) model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name) tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# coordinator.print_on_master(f"Model Config:\n{model.config}") # coordinator.print_on_master(f"Model Config:\n{model.config}")

Loading…
Cancel
Save