Merge main to develop (#309)

* fix(chat): fix stream_chat to return generator (#123)

* fix(configs/7B_sft.py): model dtype float16 to bfloat16 (#302)

* fix(convert2hf.py): fix the rotary_emb.inv_freq KeyError (#299)

---------

Co-authored-by: yingtongxiong <974106207@qq.com>
Co-authored-by: zhjunqin <zhjunqin@users.noreply.github.com>
Co-authored-by: jiangtann <39088437+jiangtann@users.noreply.github.com>
pull/310/head
huangting4201 2023-09-14 16:32:15 +08:00 committed by GitHub
parent 882a07011c
commit 07fc5f674a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 19 deletions

View File

@ -127,7 +127,7 @@ model = dict(
num_layers=NUM_LAYER, num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO, mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False, apply_post_layer_norm=False,
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm", norm_type="rmsnorm",
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
use_flash_attn=True, use_flash_attn=True,

View File

@ -38,7 +38,7 @@ def convert2hf(model_config, states_tp_pps):
current_states["lm_head.weight"] = states.pop("head.weight") current_states["lm_head.weight"] = states.pop("head.weight")
for i in range(model_config["num_layers"]): for i in range(model_config["num_layers"]):
states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq") states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq", None)
wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape( wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape(
3, model_config["num_attention_heads"], -1, model_config["hidden_size"] 3, model_config["num_attention_heads"], -1, model_config["hidden_size"]

View File

@ -20,6 +20,7 @@
""" PyTorch InternLM model.""" """ PyTorch InternLM model."""
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import threading, queue
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -810,35 +811,70 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.8, top_p: float = 0.8,
**kwargs): **kwargs):
"""
Return a generator in format: (response, history)
Eg.
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
"""
response_queue = queue.Queue(maxsize=20)
class ChatStreamer(BaseStreamer): class ChatStreamer(BaseStreamer):
def __init__(self, tokenizer) -> None: def __init__(self, tokenizer) -> None:
super().__init__() super().__init__()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.queue = response_queue
self.query = query
self.history = history
self.response = ""
self.received_inputs = False
self.queue.put((self.response, history + [(self.query, self.response)]))
def put(self, value): def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1: if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("ChatStreamer only supports batch size 1") raise ValueError("ChatStreamer only supports batch size 1")
elif len(value.shape) > 1: elif len(value.shape) > 1:
value = value[0] value = value[0]
if not self.received_inputs:
# The first received value is input_ids, ignore here
self.received_inputs = True
return
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True) token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
if token.strip() != "<eoa>": if token.strip() != "<eoa>":
print(token, end="") self.response = self.response + token
history = self.history + [(self.query, self.response)]
self.queue.put((self.response, history))
def end(self): def end(self):
print("") self.queue.put(None)
return self.chat( def stream_producer():
tokenizer=tokenizer, return self.chat(
query=query, tokenizer=tokenizer,
streamer=ChatStreamer(tokenizer=tokenizer), query=query,
history=history, streamer=ChatStreamer(tokenizer=tokenizer),
max_new_tokens=max_new_tokens, history=history,
do_sample=do_sample, max_new_tokens=max_new_tokens,
temperature=temperature, do_sample=do_sample,
top_p=top_p, temperature=temperature,
**kwargs top_p=top_p,
) **kwargs
)
def consumer():
producer = threading.Thread(target=stream_producer)
producer.start()
while True:
res = response_queue.get()
if res is not None:
return
yield res
return consumer()
@add_start_docstrings( @add_start_docstrings(
""" """