mirror of https://github.com/InternLM/InternLM
Add meta instruction in chat
parent
695d76eb31
commit
09a2b5ba50
|
@ -226,7 +226,6 @@ def rotate_half(x):
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
if position_ids.size(1) == 1:
|
if position_ids.size(1) == 1:
|
||||||
q_cos = cos[position_ids].unsqueeze(1).expand(q.shape)
|
q_cos = cos[position_ids].unsqueeze(1).expand(q.shape)
|
||||||
|
@ -880,11 +879,15 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
||||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
|
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
|
||||||
prompt = ""
|
prompt = ""
|
||||||
|
if meta_instruction:
|
||||||
|
prompt += f"""<s><|System|>:{meta_instruction}\n"""
|
||||||
|
else:
|
||||||
|
prompt += "<s>"
|
||||||
for record in history:
|
for record in history:
|
||||||
prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
|
prompt += f"""<|User|>:{record[0]}\n<|Bot|>:{record[1]}<eoa>\n"""
|
||||||
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
|
prompt += f"""<|User|>:{query}\n<|Bot|>:"""
|
||||||
return tokenizer([prompt], return_tensors="pt")
|
return tokenizer([prompt], return_tensors="pt")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -898,9 +901,12 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
||||||
do_sample: bool = True,
|
do_sample: bool = True,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
top_p: float = 0.8,
|
top_p: float = 0.8,
|
||||||
|
meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
||||||
|
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
|
||||||
|
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
inputs = self.build_inputs(tokenizer, query, history)
|
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
|
||||||
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
||||||
outputs = self.generate(
|
outputs = self.generate(
|
||||||
**inputs,
|
**inputs,
|
||||||
|
|
Loading…
Reference in New Issue