Add meta instruction in chat

pull/536/head
x54-729 2024-01-09 15:37:07 +08:00
parent 695d76eb31
commit 09a2b5ba50
1 changed files with 12 additions and 6 deletions

View File

@ -226,7 +226,6 @@ def rotate_half(x):
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
if position_ids.size(1) == 1:
q_cos = cos[position_ids].unsqueeze(1).expand(q.shape)
@ -879,12 +878,16 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
prompt = ""
if meta_instruction:
prompt += f"""<s><|System|>:{meta_instruction}\n"""
else:
prompt += "<s>"
for record in history:
prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
prompt += f"""<|User|>:{record[0]}\n<|Bot|>:{record[1]}<eoa>\n"""
prompt += f"""<|User|>:{query}\n<|Bot|>:"""
return tokenizer([prompt], return_tensors="pt")
@torch.no_grad()
@ -898,9 +901,12 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
do_sample: bool = True,
temperature: float = 0.8,
top_p: float = 0.8,
meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
**kwargs,
):
inputs = self.build_inputs(tokenizer, query, history)
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
outputs = self.generate(
**inputs,