mirror of https://github.com/InternLM/InternLM
Merge main to develop (#309)
* fix(chat): fix stream_chat to return generator (#123) * fix(configs/7B_sft.py): model dtype float16 to bfloat16 (#302) * fix(convert2hf.py): fix the rotary_emb.inv_freq KeyError (#299) --------- Co-authored-by: yingtongxiong <974106207@qq.com> Co-authored-by: zhjunqin <zhjunqin@users.noreply.github.com> Co-authored-by: jiangtann <39088437+jiangtann@users.noreply.github.com>pull/310/head
parent
882a07011c
commit
07fc5f674a
|
@ -127,7 +127,7 @@ model = dict(
|
||||||
num_layers=NUM_LAYER,
|
num_layers=NUM_LAYER,
|
||||||
mlp_ratio=MLP_RATIO,
|
mlp_ratio=MLP_RATIO,
|
||||||
apply_post_layer_norm=False,
|
apply_post_layer_norm=False,
|
||||||
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
||||||
norm_type="rmsnorm",
|
norm_type="rmsnorm",
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
use_flash_attn=True,
|
use_flash_attn=True,
|
||||||
|
|
|
@ -38,7 +38,7 @@ def convert2hf(model_config, states_tp_pps):
|
||||||
current_states["lm_head.weight"] = states.pop("head.weight")
|
current_states["lm_head.weight"] = states.pop("head.weight")
|
||||||
|
|
||||||
for i in range(model_config["num_layers"]):
|
for i in range(model_config["num_layers"]):
|
||||||
states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq")
|
states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq", None)
|
||||||
|
|
||||||
wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape(
|
wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape(
|
||||||
3, model_config["num_attention_heads"], -1, model_config["hidden_size"]
|
3, model_config["num_attention_heads"], -1, model_config["hidden_size"]
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
""" PyTorch InternLM model."""
|
""" PyTorch InternLM model."""
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
import threading, queue
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
@ -810,35 +811,70 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
top_p: float = 0.8,
|
top_p: float = 0.8,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Return a generator in format: (response, history)
|
||||||
|
Eg.
|
||||||
|
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
|
||||||
|
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
|
||||||
|
"""
|
||||||
|
|
||||||
|
response_queue = queue.Queue(maxsize=20)
|
||||||
|
|
||||||
class ChatStreamer(BaseStreamer):
|
class ChatStreamer(BaseStreamer):
|
||||||
def __init__(self, tokenizer) -> None:
|
def __init__(self, tokenizer) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
self.queue = response_queue
|
||||||
|
self.query = query
|
||||||
|
self.history = history
|
||||||
|
self.response = ""
|
||||||
|
self.received_inputs = False
|
||||||
|
self.queue.put((self.response, history + [(self.query, self.response)]))
|
||||||
|
|
||||||
def put(self, value):
|
def put(self, value):
|
||||||
if len(value.shape) > 1 and value.shape[0] > 1:
|
if len(value.shape) > 1 and value.shape[0] > 1:
|
||||||
raise ValueError("ChatStreamer only supports batch size 1")
|
raise ValueError("ChatStreamer only supports batch size 1")
|
||||||
elif len(value.shape) > 1:
|
elif len(value.shape) > 1:
|
||||||
value = value[0]
|
value = value[0]
|
||||||
|
|
||||||
|
if not self.received_inputs:
|
||||||
|
# The first received value is input_ids, ignore here
|
||||||
|
self.received_inputs = True
|
||||||
|
return
|
||||||
|
|
||||||
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
|
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
|
||||||
if token.strip() != "<eoa>":
|
if token.strip() != "<eoa>":
|
||||||
print(token, end="")
|
self.response = self.response + token
|
||||||
|
history = self.history + [(self.query, self.response)]
|
||||||
|
self.queue.put((self.response, history))
|
||||||
|
|
||||||
def end(self):
|
def end(self):
|
||||||
print("")
|
self.queue.put(None)
|
||||||
|
|
||||||
return self.chat(
|
def stream_producer():
|
||||||
tokenizer=tokenizer,
|
return self.chat(
|
||||||
query=query,
|
tokenizer=tokenizer,
|
||||||
streamer=ChatStreamer(tokenizer=tokenizer),
|
query=query,
|
||||||
history=history,
|
streamer=ChatStreamer(tokenizer=tokenizer),
|
||||||
max_new_tokens=max_new_tokens,
|
history=history,
|
||||||
do_sample=do_sample,
|
max_new_tokens=max_new_tokens,
|
||||||
temperature=temperature,
|
do_sample=do_sample,
|
||||||
top_p=top_p,
|
temperature=temperature,
|
||||||
**kwargs
|
top_p=top_p,
|
||||||
)
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def consumer():
|
||||||
|
producer = threading.Thread(target=stream_producer)
|
||||||
|
producer.start()
|
||||||
|
while True:
|
||||||
|
res = response_queue.get()
|
||||||
|
if res is not None:
|
||||||
|
return
|
||||||
|
yield res
|
||||||
|
|
||||||
|
return consumer()
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue