Browse Source

[Fix] Fix spec-dec Glide LlamaModel for compatibility with transformers (#5837)

* fix glide llama model

* revise
pull/5833/head
Yuanheng Zhao 5 months ago committed by GitHub
parent
commit
7b249c76e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      colossalai/inference/core/engine.py
  2. 3
      colossalai/inference/modeling/models/glide_llama.py
  3. 1
      colossalai/inference/spec/struct.py
  4. 3
      examples/inference/llama/README.md

1
colossalai/inference/core/engine.py

@ -466,6 +466,7 @@ class InferenceEngine:
self.k_cache[-1], # use kv cahces of the last layer
self.v_cache[-1],
batch.get_sequence_lengths(),
n_spec_tokens=self.n_spec_tokens,
)
drafter_out = self.drafter.speculate(

3
colossalai/inference/modeling/models/glide_llama.py

@ -319,7 +319,8 @@ class LlamaCrossAttention(nn.Module):
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
# for RoPE
cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32)
position_ids = position_ids + glide_input.n_spec_tokens
cos, sin = self.rotary_emb(query_states, position_ids)
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)

1
colossalai/inference/spec/struct.py

@ -46,6 +46,7 @@ class GlideInput:
large_k_cache: torch.Tensor = None
large_v_cache: torch.Tensor = None
sequence_lengths: torch.Tensor = None
n_spec_tokens: int = 5
@property
def glimpse_ready(self):

3
examples/inference/llama/README.md

@ -43,5 +43,8 @@ colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --drafter_mo
If you want to try the GLIDE model (glide-vicuna7b) as the drafter model with vicuna-7B, you could provide the GLIDE model path or model card as drafter model and enable the feature by
```python
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM
drafter_model = GlideLlamaForCausalLM.from_pretrained(drafter_model_path_or_name)
...
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
```

Loading…
Cancel
Save