|
|
|
@ -2,17 +2,15 @@ import os
|
|
|
|
|
|
|
|
|
|
import pytest |
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
|
from packaging import version |
|
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
from colossalai.inference.tensor_parallel.engine import TPInferEngine |
|
|
|
|
from colossalai.logging import disable_existing_loggers |
|
|
|
|
from colossalai.shardformer import ShardConfig |
|
|
|
|
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig |
|
|
|
|
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration |
|
|
|
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn |
|
|
|
|
from tests.kit.model_zoo.transformers.chatglm2 import infer_config |
|
|
|
|
|
|
|
|
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" |
|
|
|
|
TPSIZE = 1 |
|
|
|
@ -31,28 +29,31 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
|
|
|
|
], |
|
|
|
|
) |
|
|
|
|
def run_chatglm2_test(test_config): |
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) |
|
|
|
|
# pad_token_id = 0 |
|
|
|
|
model_fn = lambda: ChatGLMForConditionalGeneration(infer_config, empty_init=False) |
|
|
|
|
orig_model = model_fn() |
|
|
|
|
orig_model = orig_model.half() |
|
|
|
|
text = ["how is the weather today?"] |
|
|
|
|
input_ids = tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True) |
|
|
|
|
chatglm_config = ChatGLMConfig( |
|
|
|
|
num_layers=2, |
|
|
|
|
vocab_size=1200, |
|
|
|
|
use_cache=True, |
|
|
|
|
multi_query_attention=True, |
|
|
|
|
multi_query_group_num=2, |
|
|
|
|
num_attention_heads=8, |
|
|
|
|
hidden_size=1024, |
|
|
|
|
) |
|
|
|
|
model = ChatGLMForConditionalGeneration(chatglm_config) |
|
|
|
|
model = model.half() |
|
|
|
|
|
|
|
|
|
shard_config = ShardConfig( |
|
|
|
|
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True |
|
|
|
|
) |
|
|
|
|
infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) |
|
|
|
|
|
|
|
|
|
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) |
|
|
|
|
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) |
|
|
|
|
outputs = infer_engine.generate(input_ids, **generate_kwargs) |
|
|
|
|
assert outputs is not None |
|
|
|
|
|
|
|
|
|
# print("outputs.shape: ", outputs[0].shape) |
|
|
|
|
# print("outputs: ", outputs[0]) |
|
|
|
|
if not dist.is_initialized() or dist.get_rank() == 0: |
|
|
|
|
for o in outputs: |
|
|
|
|
output_text = tokenizer.decode(o) |
|
|
|
|
print(output_text) |
|
|
|
|
input_tokens = { |
|
|
|
|
"input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), |
|
|
|
|
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), |
|
|
|
|
} |
|
|
|
|
outputs = infer_engine.generate(input_tokens, **generate_kwargs) |
|
|
|
|
|
|
|
|
|
assert outputs is not None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_chatglm2(rank, world_size, port): |
|
|
|
|