mirror of https://github.com/hpcaitech/ColossalAI
add context_attention_unpadded
parent
07b5283b6a
commit
02c1bf8b2a
|
@ -232,11 +232,7 @@ class InferenceEngine:
|
|||
|
||||
# Decode completed sentences.
|
||||
for seq in finished_sequences:
|
||||
if seq.prompt:
|
||||
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(seq.prompt + output_str)
|
||||
else:
|
||||
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(output_str)
|
||||
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(output_str)
|
||||
|
||||
return output_list
|
||||
|
|
|
@ -156,9 +156,9 @@ class RequestHandler:
|
|||
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config):
|
||||
if generation_config.num_beams == 1:
|
||||
if generation_config.do_sample:
|
||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
||||
else:
|
||||
sample_tokens = multinomial_sample(generation_config, probs)
|
||||
else:
|
||||
sample_tokens = greedy_sample(generation_config, logprobs)
|
||||
else:
|
||||
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
|
||||
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
from colossalai.kernel.triton import context_attention_unpadded
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
|
@ -53,7 +54,6 @@ def llama_causal_lm_forward(
|
|||
v_caches=v_caches,
|
||||
)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
|
@ -157,15 +157,17 @@ def llama_attn_forward(
|
|||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
# TODO: The code below will be uncommented after the development of attention-related kernel is completed.
|
||||
# memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths)
|
||||
# if is_prompts:
|
||||
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||
# else:
|
||||
# attn_output = torch.empty(bsz, self.num_heads, self.head_dim)
|
||||
# decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size)
|
||||
_, _, _, block_size = k_cache.shape
|
||||
|
||||
# NOTE: context_attention_unpadded is unsed for testing accuracy and we can only use aligned inputs.
|
||||
# The code below will be uncommented after the development of attention-related kernel is completed.
|
||||
if is_prompts:
|
||||
attn_output = context_attention_unpadded(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
)
|
||||
# else:
|
||||
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||
|
||||
attn_output = query_states
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
|
|
@ -21,7 +21,6 @@ def multinomial_sample(
|
|||
"""
|
||||
Sample tokens in a random phase.
|
||||
"""
|
||||
# max_best_of = generation_config.best_of
|
||||
random_results = torch.multinomial(probs, num_samples=1, replacement=True).cpu()
|
||||
return random_results
|
||||
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
import colossalai
|
||||
|
@ -7,7 +12,15 @@ from colossalai.inference.core.engine import InferenceEngine
|
|||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def check_inference_engine(test_cai=False):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = transformers.LlamaForCausalLM(
|
||||
transformers.LlamaConfig(
|
||||
|
@ -16,8 +29,8 @@ def check_inference_engine(test_cai=False):
|
|||
)
|
||||
|
||||
inputs = [
|
||||
"介绍一下今天的北京",
|
||||
"介绍一下武汉",
|
||||
"介绍一下北京,",
|
||||
"介绍一下武汉,",
|
||||
]
|
||||
|
||||
if test_cai:
|
||||
|
@ -25,28 +38,26 @@ def check_inference_engine(test_cai=False):
|
|||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True)
|
||||
generation_config = GenerationConfig(do_sample=False)
|
||||
outputs = inference_engine.generate(generation_config)
|
||||
else:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||
generation_config = GenerationConfig(
|
||||
top_k=2, top_p=0.8, do_sample=True, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1
|
||||
)
|
||||
generation_config = GenerationConfig(do_sample=False, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1)
|
||||
outputs = model.generate(inputs, generation_config=generation_config)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_inference_engine(True)
|
||||
check_inference_engine(False)
|
||||
cai_outputs = check_inference_engine(True)
|
||||
transformer_outputs = check_inference_engine(False)
|
||||
|
||||
# TODO: There are some bugs in sampler.
|
||||
# for s1, s2 in zip(cai_outputs, transformer_outputs):
|
||||
# assert s1 == s2
|
||||
for s1, s2 in zip(cai_outputs, transformer_outputs):
|
||||
assert s1 == s2
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue