ColossalAI/examples/inference/colossal_llama2_demo.py

82 lines
2.9 KiB
Python
Raw Normal View History

import os
import warnings
import torch
import torch.distributed as dist
import argparse
from packaging import version
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from transformers import AutoModelForCausalLM, AutoTokenizer
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
TPSIZE = 1
BATCH_SIZE = 4
MAX_INPUT_LEN = 32
MAX_OUTPUT_LEN = 128
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')
@parameterize('test_config', [{
'tp_size': TPSIZE,
}])
def run_llama_test(test_config, args):
model_path = args.path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.unk_token_id
model = AutoModelForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
model = model.half()
text = ["Introduce London.", "What is the genus of Poodle?"]
input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True)
print(input_ids)
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
extra_kwargs={"inference_only": True})
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
outputs = infer_engine.generate(input_ids, **generate_kwargs)
assert outputs is not None
if not dist.is_initialized() or dist.get_rank() == 0:
for o in outputs:
output_text = tokenizer.decode(o)
print(output_text)
def check_llama(rank, world_size, port, args):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test(args=args)
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama(args):
spawn(check_llama, args.tp_size, args=args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--path", type=str, default = "hpcai-tech/Colossal-LLaMA-2-7b-base", help="Model path")
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("-b", "--batch_size", type=int, default=32, help="Maximum batch size")
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
parser.add_argument(
"--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]
)
args = parser.parse_args()
test_llama(args)