mirror of https://github.com/hpcaitech/ColossalAI
precision alignment
parent
62968588d1
commit
9489dc64d8
|
@ -230,11 +230,8 @@ class InferenceEngine:
|
|||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
print("finished_sequences: ", finished_sequences)
|
||||
|
||||
# Decode completed sentences.
|
||||
for seq in finished_sequences:
|
||||
print("seq.output_token_id: ", seq.output_token_id)
|
||||
if seq.prompt:
|
||||
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(seq.prompt + output_str)
|
||||
|
@ -242,6 +239,4 @@ class InferenceEngine:
|
|||
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(output_str)
|
||||
|
||||
print("len(output_list): ", len(output_list))
|
||||
|
||||
return output_list
|
||||
|
|
|
@ -67,19 +67,8 @@ def llama_model_forward(
|
|||
block_tables = batch.get_block_table_tensor()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
seq_length = input_ids.shape[1]
|
||||
device = input_ids.device
|
||||
|
||||
if batch.is_prompts:
|
||||
past_key_values_length = 0
|
||||
else:
|
||||
past_key_values_length = sequence_lengths[0].item() - 1
|
||||
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
|
||||
# Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer.
|
||||
position_ids = generate_padding_position_id(input_ids)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
|
@ -142,7 +131,7 @@ def llama_attn_forward(
|
|||
k_cache: torch.Tensor = None,
|
||||
v_cache: torch.Tensor = None,
|
||||
is_prompts: bool = True,
|
||||
sequence_lengths: int = None,
|
||||
sequence_lengths: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
|
@ -150,7 +139,9 @@ def llama_attn_forward(
|
|||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2] + block_tables.shape[1]
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if not is_prompts:
|
||||
kv_seq_len = kv_seq_len + sequence_lengths[0].item()
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
@ -166,10 +157,8 @@ def llama_attn_forward(
|
|||
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
k_cache.shape[-1]
|
||||
|
||||
# TODO: The code below will be uncommented after the development of attention-related kernel is completed.
|
||||
# memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths)
|
||||
|
||||
# if is_prompts:
|
||||
# attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||
# else:
|
||||
|
@ -177,10 +166,16 @@ def llama_attn_forward(
|
|||
# decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size)
|
||||
|
||||
attn_output = query_states
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def generate_padding_position_id(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
padding_id = 2
|
||||
attention_mask = input_ids.ne(padding_id).long()
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
return position_ids
|
||||
|
|
|
@ -21,8 +21,8 @@ def multinomial_sample(
|
|||
"""
|
||||
Sample tokens in a random phase.
|
||||
"""
|
||||
max_best_of = generation_config.best_of
|
||||
random_results = torch.multinomial(probs, num_samples=max_best_of, replacement=True).cpu()
|
||||
# max_best_of = generation_config.best_of
|
||||
random_results = torch.multinomial(probs, num_samples=1, replacement=True).cpu()
|
||||
return random_results
|
||||
|
||||
|
||||
|
@ -44,7 +44,8 @@ def beam_search_sample(
|
|||
# NOTE: this beam search sample function is wrong now.
|
||||
"""
|
||||
|
||||
beam_width = generation_config.best_of
|
||||
# beam_width = generation_config.best_of
|
||||
beam_width = 1
|
||||
results = []
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
|
|
|
@ -308,7 +308,7 @@ class BatchInfo:
|
|||
input_len_list.append(1)
|
||||
|
||||
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor(
|
||||
input_len_list, dtype=torch.int, device=device
|
||||
input_len_list, dtype=torch.int, device=self.device
|
||||
)
|
||||
|
||||
def get_sequence_lengths(self):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
import colossalai
|
||||
|
@ -8,38 +7,46 @@ from colossalai.inference.core.engine import InferenceEngine
|
|||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def check_inference_engine():
|
||||
def check_inference_engine(test_cai=False):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = transformers.LlamaForCausalLM(
|
||||
transformers.LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
||||
)
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
inference_config = InferenceConfig(max_output_len=5)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
|
||||
inputs = [
|
||||
"介绍一下今天的北京",
|
||||
"介绍一下武汉",
|
||||
]
|
||||
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True)
|
||||
outputs = inference_engine.generate(generation_config)
|
||||
|
||||
print("len(outputs): ", len(outputs))
|
||||
print("outputs: ", outputs)
|
||||
|
||||
# Engine still gets some bug
|
||||
|
||||
# for s1, s2 in zip(inputs, outputs):
|
||||
# assert s1 == s2
|
||||
if test_cai:
|
||||
inference_config = InferenceConfig(max_output_len=1)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True)
|
||||
outputs = inference_engine.generate(generation_config)
|
||||
else:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||
generation_config = GenerationConfig(
|
||||
top_k=2, top_p=0.8, do_sample=True, pad_token_id=tokenizer.pad_token_id, max_new_tokens=1
|
||||
)
|
||||
outputs = model.generate(inputs, generation_config=generation_config)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
return outputs
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_inference_engine()
|
||||
check_inference_engine(True)
|
||||
check_inference_engine(False)
|
||||
|
||||
# TODO: There are some in sampler
|
||||
# for s1, s2 in zip(cai_outputs, transformer_outputs):
|
||||
# assert s1 == s2
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue