mirror of https://github.com/hpcaitech/ColossalAI
[Fix] Fix spec-dec Glide LlamaModel for compatibility with transformers (#5837)
* fix glide llama model * revisepull/5833/head
parent
fd1dc417d8
commit
7b249c76e5
|
@ -466,6 +466,7 @@ class InferenceEngine:
|
||||||
self.k_cache[-1], # use kv cahces of the last layer
|
self.k_cache[-1], # use kv cahces of the last layer
|
||||||
self.v_cache[-1],
|
self.v_cache[-1],
|
||||||
batch.get_sequence_lengths(),
|
batch.get_sequence_lengths(),
|
||||||
|
n_spec_tokens=self.n_spec_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
drafter_out = self.drafter.speculate(
|
drafter_out = self.drafter.speculate(
|
||||||
|
|
|
@ -319,7 +319,8 @@ class LlamaCrossAttention(nn.Module):
|
||||||
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# for RoPE
|
# for RoPE
|
||||||
cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32)
|
position_ids = position_ids + glide_input.n_spec_tokens
|
||||||
|
cos, sin = self.rotary_emb(query_states, position_ids)
|
||||||
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
|
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
|
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
|
||||||
|
|
|
@ -46,6 +46,7 @@ class GlideInput:
|
||||||
large_k_cache: torch.Tensor = None
|
large_k_cache: torch.Tensor = None
|
||||||
large_v_cache: torch.Tensor = None
|
large_v_cache: torch.Tensor = None
|
||||||
sequence_lengths: torch.Tensor = None
|
sequence_lengths: torch.Tensor = None
|
||||||
|
n_spec_tokens: int = 5
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def glimpse_ready(self):
|
def glimpse_ready(self):
|
||||||
|
|
|
@ -43,5 +43,8 @@ colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --drafter_mo
|
||||||
|
|
||||||
If you want to try the GLIDE model (glide-vicuna7b) as the drafter model with vicuna-7B, you could provide the GLIDE model path or model card as drafter model and enable the feature by
|
If you want to try the GLIDE model (glide-vicuna7b) as the drafter model with vicuna-7B, you could provide the GLIDE model path or model card as drafter model and enable the feature by
|
||||||
```python
|
```python
|
||||||
|
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM
|
||||||
|
drafter_model = GlideLlamaForCausalLM.from_pretrained(drafter_model_path_or_name)
|
||||||
|
...
|
||||||
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
|
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
|
||||||
```
|
```
|
||||||
|
|
Loading…
Reference in New Issue