mirror of https://github.com/InternLM/InternLM
				
				
				
			Merge main to develop (#309)
* fix(chat): fix stream_chat to return generator (#123) * fix(configs/7B_sft.py): model dtype float16 to bfloat16 (#302) * fix(convert2hf.py): fix the rotary_emb.inv_freq KeyError (#299) --------- Co-authored-by: yingtongxiong <974106207@qq.com> Co-authored-by: zhjunqin <zhjunqin@users.noreply.github.com> Co-authored-by: jiangtann <39088437+jiangtann@users.noreply.github.com>pull/310/head
							parent
							
								
									882a07011c
								
							
						
					
					
						commit
						07fc5f674a
					
				|  | @ -127,7 +127,7 @@ model = dict( | ||||||
|     num_layers=NUM_LAYER, |     num_layers=NUM_LAYER, | ||||||
|     mlp_ratio=MLP_RATIO, |     mlp_ratio=MLP_RATIO, | ||||||
|     apply_post_layer_norm=False, |     apply_post_layer_norm=False, | ||||||
|     dtype="torch.float16",  # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" |     dtype="torch.bfloat16",  # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" | ||||||
|     norm_type="rmsnorm", |     norm_type="rmsnorm", | ||||||
|     layer_norm_epsilon=1e-5, |     layer_norm_epsilon=1e-5, | ||||||
|     use_flash_attn=True, |     use_flash_attn=True, | ||||||
|  |  | ||||||
|  | @ -38,7 +38,7 @@ def convert2hf(model_config, states_tp_pps): | ||||||
|         current_states["lm_head.weight"] = states.pop("head.weight") |         current_states["lm_head.weight"] = states.pop("head.weight") | ||||||
| 
 | 
 | ||||||
|         for i in range(model_config["num_layers"]): |         for i in range(model_config["num_layers"]): | ||||||
|             states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq") |             states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq", None) | ||||||
| 
 | 
 | ||||||
|             wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape( |             wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape( | ||||||
|                 3, model_config["num_attention_heads"], -1, model_config["hidden_size"] |                 3, model_config["num_attention_heads"], -1, model_config["hidden_size"] | ||||||
|  |  | ||||||
|  | @ -20,6 +20,7 @@ | ||||||
| """ PyTorch InternLM model.""" | """ PyTorch InternLM model.""" | ||||||
| import math | import math | ||||||
| from typing import List, Optional, Tuple, Union | from typing import List, Optional, Tuple, Union | ||||||
|  | import threading, queue | ||||||
| 
 | 
 | ||||||
| import torch | import torch | ||||||
| import torch.utils.checkpoint | import torch.utils.checkpoint | ||||||
|  | @ -810,23 +811,47 @@ class InternLMForCausalLM(InternLMPreTrainedModel): | ||||||
|                     temperature: float = 0.8, |                     temperature: float = 0.8, | ||||||
|                     top_p: float = 0.8, |                     top_p: float = 0.8, | ||||||
|                     **kwargs): |                     **kwargs): | ||||||
|  |         """ | ||||||
|  |         Return a generator in format: (response, history) | ||||||
|  |         Eg. | ||||||
|  |         ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) | ||||||
|  |         ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         response_queue = queue.Queue(maxsize=20) | ||||||
|  | 
 | ||||||
|         class ChatStreamer(BaseStreamer): |         class ChatStreamer(BaseStreamer): | ||||||
|             def __init__(self, tokenizer) -> None: |             def __init__(self, tokenizer) -> None: | ||||||
|                 super().__init__() |                 super().__init__() | ||||||
|                 self.tokenizer = tokenizer |                 self.tokenizer = tokenizer | ||||||
|  |                 self.queue = response_queue | ||||||
|  |                 self.query = query | ||||||
|  |                 self.history = history | ||||||
|  |                 self.response = "" | ||||||
|  |                 self.received_inputs = False | ||||||
|  |                 self.queue.put((self.response, history + [(self.query, self.response)])) | ||||||
| 
 | 
 | ||||||
|             def put(self, value): |             def put(self, value): | ||||||
|                 if len(value.shape) > 1 and value.shape[0] > 1: |                 if len(value.shape) > 1 and value.shape[0] > 1: | ||||||
|                     raise ValueError("ChatStreamer only supports batch size 1") |                     raise ValueError("ChatStreamer only supports batch size 1") | ||||||
|                 elif len(value.shape) > 1: |                 elif len(value.shape) > 1: | ||||||
|                     value = value[0] |                     value = value[0] | ||||||
|  | 
 | ||||||
|  |                 if not self.received_inputs: | ||||||
|  |                     # The first received value is input_ids, ignore here | ||||||
|  |                     self.received_inputs = True | ||||||
|  |                     return | ||||||
|  | 
 | ||||||
|                 token = self.tokenizer.decode([value[-1]], skip_special_tokens=True) |                 token = self.tokenizer.decode([value[-1]], skip_special_tokens=True) | ||||||
|                 if token.strip() != "<eoa>": |                 if token.strip() != "<eoa>": | ||||||
|                     print(token, end="") |                     self.response = self.response + token | ||||||
|  |                     history = self.history + [(self.query, self.response)] | ||||||
|  |                     self.queue.put((self.response, history)) | ||||||
| 
 | 
 | ||||||
|             def end(self): |             def end(self): | ||||||
|                 print("") |                 self.queue.put(None) | ||||||
| 
 | 
 | ||||||
|  |         def stream_producer(): | ||||||
|             return self.chat( |             return self.chat( | ||||||
|                 tokenizer=tokenizer, |                 tokenizer=tokenizer, | ||||||
|                 query=query, |                 query=query, | ||||||
|  | @ -839,6 +864,17 @@ class InternLMForCausalLM(InternLMPreTrainedModel): | ||||||
|                 **kwargs |                 **kwargs | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|  |         def consumer(): | ||||||
|  |             producer = threading.Thread(target=stream_producer) | ||||||
|  |             producer.start() | ||||||
|  |             while True: | ||||||
|  |                 res = response_queue.get() | ||||||
|  |                 if res is not None: | ||||||
|  |                     return | ||||||
|  |                 yield res | ||||||
|  | 
 | ||||||
|  |         return consumer() | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @add_start_docstrings( | @add_start_docstrings( | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 huangting4201
						huangting4201