mirror of https://github.com/InternLM/InternLM
Merge main to develop (#309)
* fix(chat): fix stream_chat to return generator (#123) * fix(configs/7B_sft.py): model dtype float16 to bfloat16 (#302) * fix(convert2hf.py): fix the rotary_emb.inv_freq KeyError (#299) --------- Co-authored-by: yingtongxiong <974106207@qq.com> Co-authored-by: zhjunqin <zhjunqin@users.noreply.github.com> Co-authored-by: jiangtann <39088437+jiangtann@users.noreply.github.com>pull/310/head
parent
882a07011c
commit
07fc5f674a
|
@ -127,7 +127,7 @@ model = dict(
|
|||
num_layers=NUM_LAYER,
|
||||
mlp_ratio=MLP_RATIO,
|
||||
apply_post_layer_norm=False,
|
||||
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
||||
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
|
|
|
@ -38,7 +38,7 @@ def convert2hf(model_config, states_tp_pps):
|
|||
current_states["lm_head.weight"] = states.pop("head.weight")
|
||||
|
||||
for i in range(model_config["num_layers"]):
|
||||
states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq")
|
||||
states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq", None)
|
||||
|
||||
wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape(
|
||||
3, model_config["num_attention_heads"], -1, model_config["hidden_size"]
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
""" PyTorch InternLM model."""
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import threading, queue
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
@ -810,35 +811,70 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|||
temperature: float = 0.8,
|
||||
top_p: float = 0.8,
|
||||
**kwargs):
|
||||
"""
|
||||
Return a generator in format: (response, history)
|
||||
Eg.
|
||||
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
|
||||
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
|
||||
"""
|
||||
|
||||
response_queue = queue.Queue(maxsize=20)
|
||||
|
||||
class ChatStreamer(BaseStreamer):
|
||||
def __init__(self, tokenizer) -> None:
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.queue = response_queue
|
||||
self.query = query
|
||||
self.history = history
|
||||
self.response = ""
|
||||
self.received_inputs = False
|
||||
self.queue.put((self.response, history + [(self.query, self.response)]))
|
||||
|
||||
def put(self, value):
|
||||
if len(value.shape) > 1 and value.shape[0] > 1:
|
||||
raise ValueError("ChatStreamer only supports batch size 1")
|
||||
elif len(value.shape) > 1:
|
||||
value = value[0]
|
||||
|
||||
if not self.received_inputs:
|
||||
# The first received value is input_ids, ignore here
|
||||
self.received_inputs = True
|
||||
return
|
||||
|
||||
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
|
||||
if token.strip() != "<eoa>":
|
||||
print(token, end="")
|
||||
|
||||
self.response = self.response + token
|
||||
history = self.history + [(self.query, self.response)]
|
||||
self.queue.put((self.response, history))
|
||||
|
||||
def end(self):
|
||||
print("")
|
||||
|
||||
return self.chat(
|
||||
tokenizer=tokenizer,
|
||||
query=query,
|
||||
streamer=ChatStreamer(tokenizer=tokenizer),
|
||||
history=history,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.queue.put(None)
|
||||
|
||||
def stream_producer():
|
||||
return self.chat(
|
||||
tokenizer=tokenizer,
|
||||
query=query,
|
||||
streamer=ChatStreamer(tokenizer=tokenizer),
|
||||
history=history,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def consumer():
|
||||
producer = threading.Thread(target=stream_producer)
|
||||
producer.start()
|
||||
while True:
|
||||
res = response_queue.get()
|
||||
if res is not None:
|
||||
return
|
||||
yield res
|
||||
|
||||
return consumer()
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue