mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			321 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			321 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
# This file is modified from:
 | 
						|
# hhttps://github.com/reasoning-machines/pal/blob/main/pal/core/interface.py
 | 
						|
#
 | 
						|
# Copyright 2022 PAL Authors. All rights reserved.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
 | 
						|
import argparse
 | 
						|
import copy
 | 
						|
import json
 | 
						|
import os
 | 
						|
from dataclasses import asdict
 | 
						|
from typing import Any, Dict, List
 | 
						|
 | 
						|
import torch
 | 
						|
import tqdm
 | 
						|
from datasets import load_dataset
 | 
						|
from transformers import AutoModelForCausalLM, AutoTokenizer
 | 
						|
 | 
						|
from tools.transformers.interface import GenerationConfig, generate_interactive
 | 
						|
from internlm.utils.timeout import Timeout
 | 
						|
 | 
						|
 | 
						|
def parse_args():
 | 
						|
    parser = argparse.ArgumentParser(description="PAL Inference")
 | 
						|
    parser.add_argument("model", type=str, help="Path to the pre-trained LLM used for inference.")
 | 
						|
    parser.add_argument(
 | 
						|
        "out_dir", type=str, help="Name of the output folder where generated code snippets will be saved."
 | 
						|
    )
 | 
						|
    parser.add_argument("--dataset", default="gsm8k", type=str, help="Name of the dataset used for code generation.")
 | 
						|
    parser.add_argument(
 | 
						|
        "--max_length",
 | 
						|
        default=2048,
 | 
						|
        type=int,
 | 
						|
        help="Maximum input token length for the natural language description.",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--top_p",
 | 
						|
        default=0.8,
 | 
						|
        type=float,
 | 
						|
        help="Probability threshold to choose sample tokens during generation.",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--eoh",
 | 
						|
        default="",
 | 
						|
        type=str,
 | 
						|
        help="End of human (user) token.",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--eoa",
 | 
						|
        default="",
 | 
						|
        type=str,
 | 
						|
        help="End of assistant (bot) token.",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--eos",
 | 
						|
        default="",
 | 
						|
        type=str,
 | 
						|
        help="End of system token.",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--temperature", "-t", default=1.0, type=float, help="Temperature of token sampling during generation."
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--time_out", default=100, type=float, help="Maximum time allowed for executing generated code."
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--verbose",
 | 
						|
        "-v",
 | 
						|
        action="store_true",
 | 
						|
        help="Print code error information when executing generated code (optional).",
 | 
						|
    )
 | 
						|
    parser.add_argument("--append", "-a", action="store_true", help="Append output to the history results (optional).")
 | 
						|
    args = parser.parse_args()
 | 
						|
    return args
 | 
						|
 | 
						|
 | 
						|
class GenericRuntime:
 | 
						|
    """Adapted from https://github.com/reasoning-machines/pal"""
 | 
						|
 | 
						|
    GLOBAL_DICT: dict = {}
 | 
						|
    LOCAL_DICT = None
 | 
						|
    HEADERS: List = []
 | 
						|
 | 
						|
    def __init__(self):
 | 
						|
        self._global_vars = copy.copy(self.GLOBAL_DICT)
 | 
						|
        self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
 | 
						|
 | 
						|
        for c in self.HEADERS:
 | 
						|
            self.exec_code(c)
 | 
						|
 | 
						|
    def exec_code(self, code_piece: str) -> None:
 | 
						|
        exec(code_piece, self._global_vars)
 | 
						|
 | 
						|
    def eval_code(self, expr: str) -> Any:
 | 
						|
        return eval(expr, self._global_vars)
 | 
						|
 | 
						|
    def inject(self, var_dict: Dict[str, Any]) -> None:
 | 
						|
        for k, v in var_dict.items():
 | 
						|
            self._global_vars[k] = v
 | 
						|
 | 
						|
    @property
 | 
						|
    def answer(self):
 | 
						|
        return self._global_vars["answer"]
 | 
						|
 | 
						|
 | 
						|
class PALInterface:
 | 
						|
    """PAL interface wrap fun:`generate_interactive` to extract and execute
 | 
						|
    generated code.
 | 
						|
 | 
						|
    Adapted from https://github.com/reasoning-machines/pal
 | 
						|
 | 
						|
    Args:
 | 
						|
        model (AutoModelForCausalLM)
 | 
						|
        tokenizer (AutoTokenizer)
 | 
						|
        generation_config (GenerationConfig): Decode strategies
 | 
						|
        additional_eos_token_id (int): End of sentence token id, default: 103028
 | 
						|
        get_answer_expr (str): The function name of generated code, default: "solution()"
 | 
						|
        verbose (bool): Print error information
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        model: AutoModelForCausalLM,
 | 
						|
        tokenizer: AutoTokenizer,
 | 
						|
        generation_config: GenerationConfig,
 | 
						|
        additional_eos_token_id: int = 103028,
 | 
						|
        get_answer_expr: str = "solution()",
 | 
						|
        verbose: bool = False,
 | 
						|
    ):
 | 
						|
        self.runtime = GenericRuntime()
 | 
						|
        self.history: List = []
 | 
						|
        self.model = model
 | 
						|
        self.tokenizer = tokenizer
 | 
						|
        self.generation_config = generation_config
 | 
						|
        self.additional_eos_token_id = additional_eos_token_id
 | 
						|
        self.answer_expr = get_answer_expr
 | 
						|
        self.verbose = verbose
 | 
						|
 | 
						|
    def generate(self, prompt):
 | 
						|
        # The api will generate response word by word
 | 
						|
        # we only need the last generation as the final results
 | 
						|
        for cur_gen in generate_interactive(
 | 
						|
            model=self.model,
 | 
						|
            tokenizer=self.tokenizer,
 | 
						|
            prompt=prompt,
 | 
						|
            additional_eos_token_id=self.additional_eos_token_id,
 | 
						|
            **asdict(self.generation_config),
 | 
						|
        ):
 | 
						|
            continue
 | 
						|
        # Get final response
 | 
						|
        self.history.append(cur_gen)
 | 
						|
        # Extract code block
 | 
						|
        code = self.process_generation_to_code(cur_gen)
 | 
						|
        return code
 | 
						|
 | 
						|
    def process_generation_to_code(self, gens: str):
 | 
						|
        if "```python" in gens:
 | 
						|
            gens = gens.split("```python")[1].split("```")[0]
 | 
						|
        elif "```" in gens:
 | 
						|
            gens = gens.split("```")[1].split("```")[0]
 | 
						|
        code = gens.split("\n")
 | 
						|
        return code
 | 
						|
 | 
						|
    def run(self, prompt, time_out: float = 100):
 | 
						|
        code = self.generate(prompt)
 | 
						|
        with Timeout(time_out):
 | 
						|
            try:
 | 
						|
                exec_result = self.execute(code)
 | 
						|
            except Exception as e:
 | 
						|
                if self.verbose:
 | 
						|
                    print(e)
 | 
						|
        return exec_result
 | 
						|
 | 
						|
    def execute(self, code: List[str]):
 | 
						|
        self.runtime.exec_code("\n".join(code))
 | 
						|
        return self.runtime.eval_code(self.answer_expr)
 | 
						|
 | 
						|
    def clear_history(self):
 | 
						|
        self.history = []
 | 
						|
 | 
						|
 | 
						|
def load_model(args):
 | 
						|
    model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=True).to(torch.bfloat16).cuda()
 | 
						|
    tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
 | 
						|
    return model, tokenizer
 | 
						|
 | 
						|
 | 
						|
def load_data(args):
 | 
						|
    # Load data from huggingface dataset
 | 
						|
    if args.dataset == "gsm8k":
 | 
						|
        gsm8k = load_dataset(path=args.dataset, name="main")
 | 
						|
        test_set = gsm8k["test"]
 | 
						|
        input_data = []
 | 
						|
        for data in test_set:
 | 
						|
            question = data["question"]
 | 
						|
            target = float(data["answer"].split("#")[-1].replace(",", ""))
 | 
						|
            input_data.append({"question": question, "target": target})
 | 
						|
    else:
 | 
						|
        raise NotImplementedError
 | 
						|
    return input_data
 | 
						|
 | 
						|
 | 
						|
PROMPT = """<|System|>:You are a helpful assistant which use tools to solve mathematical reasoning questions. The tools you can use are:
 | 
						|
PythonExecutor: It can execute Python code. The code must be a function, and the function name must be 'solution'. The example format is as follows:
 | 
						|
```python
 | 
						|
def solution():
 | 
						|
    variable_names_with_real_meaning = func(variable)
 | 
						|
    return variable_names_with_real_meaning
 | 
						|
```{eos}
 | 
						|
<|User|>:Olivia has $23. She bought five bagels for $3 each. How much money does she have left?{eoh}
 | 
						|
<|Bot|>:
 | 
						|
```python
 | 
						|
def solution():
 | 
						|
    money_initial = 23
 | 
						|
    bagels = 5
 | 
						|
    bagel_cost = 3
 | 
						|
    money_spent = bagels * bagel_cost
 | 
						|
    money_left = money_initial - money_spent
 | 
						|
    result = money_left
 | 
						|
    return result
 | 
						|
```{eoa}
 | 
						|
<|User|>:Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?{eoh}
 | 
						|
<|Bot|>:
 | 
						|
```python
 | 
						|
def solution():
 | 
						|
    golf_balls_initial = 58
 | 
						|
    golf_balls_lost_tuesday = 23
 | 
						|
    golf_balls_lost_wednesday = 2
 | 
						|
    golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday
 | 
						|
    result = golf_balls_left
 | 
						|
    return result
 | 
						|
```{eoa}
 | 
						|
<|User|>:There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?{eoh}
 | 
						|
<|Bot|>:
 | 
						|
```python
 | 
						|
def solution():
 | 
						|
    computers_initial = 9
 | 
						|
    computers_per_day = 5
 | 
						|
    num_days = 4  # 4 days between monday and thursday
 | 
						|
    computers_added = computers_per_day * num_days
 | 
						|
    computers_total = computers_initial + computers_added
 | 
						|
    result = computers_total
 | 
						|
    return result
 | 
						|
```{eoa}
 | 
						|
<|System|>:How about this question?{eos}
 | 
						|
<|User|>:{question}{eoh}
 | 
						|
<|Bot|>:""".strip()
 | 
						|
 | 
						|
 | 
						|
def main():
 | 
						|
 | 
						|
    args = parse_args()
 | 
						|
 | 
						|
    print("load model begin.")
 | 
						|
    model, tokenizer = load_model(args)
 | 
						|
    print("load model end.")
 | 
						|
 | 
						|
    generation_config = GenerationConfig(max_length=args.max_length, top_p=args.top_p, temperature=args.temperature)
 | 
						|
 | 
						|
    verbose = args.verbose
 | 
						|
    interface = PALInterface(model=model, tokenizer=tokenizer, generation_config=generation_config, verbose=verbose)
 | 
						|
 | 
						|
    if not os.path.exists(args.out_dir):
 | 
						|
        os.makedirs(args.out_dir)
 | 
						|
    savepath = os.path.join(args.out_dir, args.dataset + ".json")
 | 
						|
 | 
						|
    # Load from history results
 | 
						|
    if args.append and os.path.exists(savepath):
 | 
						|
        lines = open(savepath).readlines()
 | 
						|
        num_skip_exps = len(lines)
 | 
						|
        scores = [x["score"] for x in map(json.loads, lines)]
 | 
						|
    else:
 | 
						|
        num_skip_exps = 0
 | 
						|
        scores = []
 | 
						|
 | 
						|
    examples = load_data(args)
 | 
						|
    with open(savepath, "a" if args.append else "w") as f:
 | 
						|
        pbar = tqdm.tqdm(examples[num_skip_exps:], initial=num_skip_exps, total=len(examples))
 | 
						|
        for x in pbar:
 | 
						|
            question = x["question"]
 | 
						|
            result = copy.copy(x)
 | 
						|
 | 
						|
            try:
 | 
						|
                answer = interface.run(
 | 
						|
                    prompt=PROMPT.format(question=question, eoh=args.eoh, eoa=args.eoa, eos=args.eos),
 | 
						|
                    time_out=args.time_out,
 | 
						|
                )
 | 
						|
                answer = float(answer)
 | 
						|
                score = 1 if abs(answer - x["target"]) < 1e-3 else 0
 | 
						|
            except Exception as e:
 | 
						|
                if verbose:
 | 
						|
                    print(e)
 | 
						|
                answer = ""
 | 
						|
                score = 0
 | 
						|
            scores.append(score)
 | 
						|
            result["answer"] = answer
 | 
						|
            result["score"] = score
 | 
						|
            result["generation"] = interface.history
 | 
						|
            f.write(json.dumps(result) + "\n")
 | 
						|
 | 
						|
            interface.clear_history()
 | 
						|
            f.flush()
 | 
						|
 | 
						|
    print(f"{args.model}: Accuracy - {sum(scores) / len(scores)}")
 | 
						|
    torch.cuda.empty_cache()
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    main()
 |