|
|
@ -1,5 +1,6 @@ |
|
|
|
import argparse |
|
|
|
import argparse |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch import bfloat16, float16, float32 |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
|
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
import colossalai |
|
|
@ -12,6 +13,12 @@ from colossalai.inference.modeling.policy.nopadding_llama import NoPaddingLlamaM |
|
|
|
MODEL_CLS = AutoModelForCausalLM |
|
|
|
MODEL_CLS = AutoModelForCausalLM |
|
|
|
POLICY_CLS = NoPaddingLlamaModelInferPolicy |
|
|
|
POLICY_CLS = NoPaddingLlamaModelInferPolicy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_DTYPE_MAP = { |
|
|
|
|
|
|
|
"fp16": float16, |
|
|
|
|
|
|
|
"fp32": float32, |
|
|
|
|
|
|
|
"bf16": bfloat16, |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer(args): |
|
|
|
def infer(args): |
|
|
|
# ============================== |
|
|
|
# ============================== |
|
|
@ -24,7 +31,7 @@ def infer(args): |
|
|
|
# Load model and tokenizer |
|
|
|
# Load model and tokenizer |
|
|
|
# ============================== |
|
|
|
# ============================== |
|
|
|
model_path_or_name = args.model |
|
|
|
model_path_or_name = args.model |
|
|
|
model = MODEL_CLS.from_pretrained(model_path_or_name) |
|
|
|
model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None)) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name) |
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
# coordinator.print_on_master(f"Model Config:\n{model.config}") |
|
|
|
# coordinator.print_on_master(f"Model Config:\n{model.config}") |
|
|
|