mirror of https://github.com/InternLM/InternLM
				
				
				
			Merge main to develop (#309)
* fix(chat): fix stream_chat to return generator (#123) * fix(configs/7B_sft.py): model dtype float16 to bfloat16 (#302) * fix(convert2hf.py): fix the rotary_emb.inv_freq KeyError (#299) --------- Co-authored-by: yingtongxiong <974106207@qq.com> Co-authored-by: zhjunqin <zhjunqin@users.noreply.github.com> Co-authored-by: jiangtann <39088437+jiangtann@users.noreply.github.com>pull/310/head
							parent
							
								
									882a07011c
								
							
						
					
					
						commit
						07fc5f674a
					
				|  | @ -127,7 +127,7 @@ model = dict( | |||
|     num_layers=NUM_LAYER, | ||||
|     mlp_ratio=MLP_RATIO, | ||||
|     apply_post_layer_norm=False, | ||||
|     dtype="torch.float16",  # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" | ||||
|     dtype="torch.bfloat16",  # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" | ||||
|     norm_type="rmsnorm", | ||||
|     layer_norm_epsilon=1e-5, | ||||
|     use_flash_attn=True, | ||||
|  |  | |||
|  | @ -38,7 +38,7 @@ def convert2hf(model_config, states_tp_pps): | |||
|         current_states["lm_head.weight"] = states.pop("head.weight") | ||||
| 
 | ||||
|         for i in range(model_config["num_layers"]): | ||||
|             states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq") | ||||
|             states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq", None) | ||||
| 
 | ||||
|             wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape( | ||||
|                 3, model_config["num_attention_heads"], -1, model_config["hidden_size"] | ||||
|  |  | |||
|  | @ -20,6 +20,7 @@ | |||
| """ PyTorch InternLM model.""" | ||||
| import math | ||||
| from typing import List, Optional, Tuple, Union | ||||
| import threading, queue | ||||
| 
 | ||||
| import torch | ||||
| import torch.utils.checkpoint | ||||
|  | @ -810,35 +811,70 @@ class InternLMForCausalLM(InternLMPreTrainedModel): | |||
|                     temperature: float = 0.8, | ||||
|                     top_p: float = 0.8, | ||||
|                     **kwargs): | ||||
|         """ | ||||
|         Return a generator in format: (response, history) | ||||
|         Eg. | ||||
|         ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) | ||||
|         ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) | ||||
|         """ | ||||
| 
 | ||||
|         response_queue = queue.Queue(maxsize=20) | ||||
| 
 | ||||
|         class ChatStreamer(BaseStreamer): | ||||
|             def __init__(self, tokenizer) -> None: | ||||
|                 super().__init__() | ||||
|                 self.tokenizer = tokenizer | ||||
|                  | ||||
|                 self.queue = response_queue | ||||
|                 self.query = query | ||||
|                 self.history = history | ||||
|                 self.response = "" | ||||
|                 self.received_inputs = False | ||||
|                 self.queue.put((self.response, history + [(self.query, self.response)])) | ||||
| 
 | ||||
|             def put(self, value): | ||||
|                 if len(value.shape) > 1 and value.shape[0] > 1: | ||||
|                     raise ValueError("ChatStreamer only supports batch size 1") | ||||
|                 elif len(value.shape) > 1: | ||||
|                     value = value[0] | ||||
| 
 | ||||
|                 if not self.received_inputs: | ||||
|                     # The first received value is input_ids, ignore here | ||||
|                     self.received_inputs = True | ||||
|                     return | ||||
| 
 | ||||
|                 token = self.tokenizer.decode([value[-1]], skip_special_tokens=True) | ||||
|                 if token.strip() != "<eoa>": | ||||
|                     print(token, end="") | ||||
|                  | ||||
|                     self.response = self.response + token | ||||
|                     history = self.history + [(self.query, self.response)] | ||||
|                     self.queue.put((self.response, history)) | ||||
| 
 | ||||
|             def end(self): | ||||
|                 print("") | ||||
|              | ||||
|         return self.chat( | ||||
|             tokenizer=tokenizer, | ||||
|             query=query, | ||||
|             streamer=ChatStreamer(tokenizer=tokenizer), | ||||
|             history=history,  | ||||
|             max_new_tokens=max_new_tokens, | ||||
|             do_sample=do_sample, | ||||
|             temperature=temperature, | ||||
|             top_p=top_p, | ||||
|             **kwargs | ||||
|         ) | ||||
|                  | ||||
|                 self.queue.put(None) | ||||
| 
 | ||||
|         def stream_producer(): | ||||
|             return self.chat( | ||||
|                 tokenizer=tokenizer, | ||||
|                 query=query, | ||||
|                 streamer=ChatStreamer(tokenizer=tokenizer), | ||||
|                 history=history,  | ||||
|                 max_new_tokens=max_new_tokens, | ||||
|                 do_sample=do_sample, | ||||
|                 temperature=temperature, | ||||
|                 top_p=top_p, | ||||
|                 **kwargs | ||||
|             ) | ||||
| 
 | ||||
|         def consumer(): | ||||
|             producer = threading.Thread(target=stream_producer) | ||||
|             producer.start() | ||||
|             while True: | ||||
|                 res = response_queue.get() | ||||
|                 if res is not None: | ||||
|                     return | ||||
|                 yield res | ||||
| 
 | ||||
|         return consumer() | ||||
| 
 | ||||
| 
 | ||||
| @add_start_docstrings( | ||||
|     """ | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 huangting4201
						huangting4201