Merge main to develop (#309)

* fix(chat): fix stream_chat to return generator (#123)

* fix(configs/7B_sft.py): model dtype float16 to bfloat16 (#302)

* fix(convert2hf.py): fix the rotary_emb.inv_freq KeyError (#299)

---------

Co-authored-by: yingtongxiong <974106207@qq.com>
Co-authored-by: zhjunqin <zhjunqin@users.noreply.github.com>
Co-authored-by: jiangtann <39088437+jiangtann@users.noreply.github.com>
pull/310/head
huangting4201 2023-09-14 16:32:15 +08:00 committed by GitHub
parent 882a07011c
commit 07fc5f674a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 19 deletions

View File

@ -127,7 +127,7 @@ model = dict(
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,

View File

@ -38,7 +38,7 @@ def convert2hf(model_config, states_tp_pps):
current_states["lm_head.weight"] = states.pop("head.weight")
for i in range(model_config["num_layers"]):
states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq")
states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq", None)
wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape(
3, model_config["num_attention_heads"], -1, model_config["hidden_size"]

View File

@ -20,6 +20,7 @@
""" PyTorch InternLM model."""
import math
from typing import List, Optional, Tuple, Union
import threading, queue
import torch
import torch.utils.checkpoint
@ -810,23 +811,47 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
temperature: float = 0.8,
top_p: float = 0.8,
**kwargs):
"""
Return a generator in format: (response, history)
Eg.
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
"""
response_queue = queue.Queue(maxsize=20)
class ChatStreamer(BaseStreamer):
def __init__(self, tokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
self.queue = response_queue
self.query = query
self.history = history
self.response = ""
self.received_inputs = False
self.queue.put((self.response, history + [(self.query, self.response)]))
def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("ChatStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if not self.received_inputs:
# The first received value is input_ids, ignore here
self.received_inputs = True
return
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
if token.strip() != "<eoa>":
print(token, end="")
self.response = self.response + token
history = self.history + [(self.query, self.response)]
self.queue.put((self.response, history))
def end(self):
print("")
self.queue.put(None)
def stream_producer():
return self.chat(
tokenizer=tokenizer,
query=query,
@ -839,6 +864,17 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
**kwargs
)
def consumer():
producer = threading.Thread(target=stream_producer)
producer.start()
while True:
res = response_queue.get()
if res is not None:
return
yield res
return consumer()
@add_start_docstrings(
"""