From 7cfea534e7c7191d80a3b6db61fa335bc229ff7e Mon Sep 17 00:00:00 2001 From: Shaoyuan Xie <66255889+Daniel-xsy@users.noreply.github.com> Date: Thu, 10 Aug 2023 17:53:46 +0800 Subject: [PATCH] [feat]: add pal reasoning script (#163) * [Feat] Add PAL inference script * Update README.md * Update tools/README.md Co-authored-by: BigDong * Update tools/pal_inference.py Co-authored-by: BigDong * Update pal script * Update README.md * restore .ore-commit-config.yaml * Update tools/README.md Co-authored-by: BigDong * Update tools/README.md Co-authored-by: BigDong * Update pal inference script * Update READMD.md * Update internlm/utils/interface.py Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> * Update pal script * Update pal script * Update script * Add docstring * Update format * Update script * Update script * Update script --------- Co-authored-by: BigDong Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> --- internlm/utils/interface.py | 138 ++++++++++ internlm/utils/simple_memory_profiler.py | 8 +- internlm/utils/timeout.py | 26 ++ tools/README.md | 61 ++++- tools/README_EN.md | 62 ++++- tools/pal_inference.py | 320 +++++++++++++++++++++++ 6 files changed, 607 insertions(+), 8 deletions(-) create mode 100644 internlm/utils/interface.py create mode 100644 internlm/utils/timeout.py create mode 100644 tools/pal_inference.py diff --git a/internlm/utils/interface.py b/internlm/utils/interface.py new file mode 100644 index 0000000..22b3743 --- /dev/null +++ b/internlm/utils/interface.py @@ -0,0 +1,138 @@ +import copy +import warnings +from dataclasses import dataclass +from typing import Callable, List, Optional + +import torch +from transformers import AutoModel, AutoTokenizer +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +@dataclass +class GenerationConfig: + max_length: Optional[int] = None + top_p: Optional[float] = None + temperature: Optional[float] = None + do_sample: Optional[bool] = True + repetition_penalty: Optional[float] = 1.0 + + +@torch.inference_mode() +def generation_iterator( + model: AutoModel, + tokenizer: AutoTokenizer, + prompt: str, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + additional_eos_token_id: Optional[int] = None, + **kwargs, +): + inputs = tokenizer([prompt], padding=True, return_tensors="pt") + input_length = len(inputs["input_ids"][0]) + for k, v in inputs.items(): + inputs[k] = v.cuda() + input_ids = inputs["input_ids"] + input_ids_seq_length = input_ids.shape[-1] + if generation_config is None: + generation_config = model.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if additional_eos_token_id is not None: + eos_token_id.append(additional_eos_token_id) + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warning( + "Both `max_new_tokens` (={%s}) and `max_length`(=" + "{%s}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + generation_config.max_new_tokens, + generation_config.max_length, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "input_ids" + logger.warning( + "Input length of {%s} is {%s}, but `max_length` is set to" + " {%s}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`.", + input_ids_string, + input_ids_seq_length, + generation_config.max_length, + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = model._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = model._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = model._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = model( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = next_token_scores.softmax(dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) + unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long()) + + output_token_ids = input_ids[0].cpu().tolist() + output_token_ids = output_token_ids[input_length:] + for each_eos_token_id in eos_token_id: + if output_token_ids[-1] == each_eos_token_id: + output_token_ids = output_token_ids[:-1] + response = tokenizer.decode(output_token_ids) + + yield response + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break diff --git a/internlm/utils/simple_memory_profiler.py b/internlm/utils/simple_memory_profiler.py index 7fd8b28..4ca6679 100644 --- a/internlm/utils/simple_memory_profiler.py +++ b/internlm/utils/simple_memory_profiler.py @@ -218,9 +218,7 @@ class SimpleMemoryProfiler: # Calculate static optimizer state cuda memory self._os_params_mem_state = SimpleMemState("os_params_mem") self._os_state_mem_state = SimpleMemState("os_state_mem") - self._calc_tensor_group_memory( - self._os_params_mem_state, [(k, v) for k, v in enumerate(self._optimizer.param_groups)] - ) + self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups))) # Generate the first memory record self.point(create=True) @@ -302,9 +300,7 @@ class SimpleMemoryProfiler: # Update os state memory usage self._os_state_mem_state = SimpleMemState("os_state_mem") - self._calc_tensor_group_memory( - self._os_state_mem_state, [(k, v) for k, v in self._optimizer.state_dict()["state"].items()] - ) + self._calc_tensor_group_memory(self._os_state_mem_state, list(self._optimizer.state_dict()["state"].items())) if not self._stoped: # Do we need to print os_state_layout every time? Is it always constant? diff --git a/internlm/utils/timeout.py b/internlm/utils/timeout.py new file mode 100644 index 0000000..07a0911 --- /dev/null +++ b/internlm/utils/timeout.py @@ -0,0 +1,26 @@ +import signal + + +class Timeout: + """Timer to execute code + + Adapted from https://github.com/reasoning-machines/pal + + Args: + seconds (float): The maximum seconds to execute code + error_message (str) + """ + + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def timeout_handler(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.timeout_handler) + signal.alarm(self.seconds) + + def __exit__(self, error_type, value, traceback): + signal.alarm(0) diff --git a/tools/README.md b/tools/README.md index f3ba385..ae7d196 100644 --- a/tools/README.md +++ b/tools/README.md @@ -1,4 +1,5 @@ 本目录提供辅助模型训练的一些工具,文件结构如下所示: + ```bash ├── transformers # 适配hugging face的transformers的一些工具 │ ├── configuration_internlm.py # config适配工具 @@ -9,9 +10,11 @@ ``` # tokenizer.py + 生成原始数据的`bin`和`meta`文件需要使用`tokenizer`,我们通过在`tools/tokenizer.py`中指定模型参数路径的方式来导入tokenizer模型。目前我们提供了`V7_sft.model`来生成tokens。若想使用不同的模型,可直接修改`tokernizer.py`中的模型参数路径。 可以运行以下命令生成原始数据对应的`bin`和`meta`文件,其中参数`text_input_path`表示原始文本数据路径,目前支持`txt`、`json`和`jsonl`三种输入格式,`bin_output_path`表示生成的`bin`文件的保存路径。 + ```bash $ python tools/tokenizer.py --text_input_path your_input_text_path --bin_output_path your_output_bin_path ``` @@ -19,6 +22,7 @@ $ python tools/tokenizer.py --text_input_path your_input_text_path --bin_output_ 下面是一个数据处理的例子: 给定一个包含原始数据集的文件`raw_data.txt`,原始数据集如下所示: + ```bash 感恩生活中的每一个细节,才能真正体会到幸福的滋味。 梦想是人生的动力源泉,努力追逐,才能实现自己的目标。 @@ -26,6 +30,7 @@ $ python tools/tokenizer.py --text_input_path your_input_text_path --bin_output_ ``` 可以通过运行以下命令来生成`bin`和`meta`文件: + ```bash $ python tools/tokenizer.py --text_input_path raw_data.txt --bin_output_path cn/output.bin ``` @@ -35,19 +40,73 @@ $ python tools/tokenizer.py --text_input_path raw_data.txt --bin_output_path cn/ 其中,`cn`表示中文数据集;`en`表示英文数据集;`code`表示代码数据集;`ja`表示日语数据集;`ar`表示阿拉伯语数据集;`kaoshi`表示考试数据集。 生成的bin文件的格式如下: + ```python {"tokens": [73075, 75302, 69522, 69022, 98899, 67713, 68015, 81269, 74637, 75445, 99157]} {"tokens": [69469, 60355, 73026, 68524, 60846, 61844, 98899, 67775, 79241, 98899, 67713, 67800, 67453, 67838, 99157]} {"tokens": [68057, 79017, 60378, 68014, 98899, 67713, 67990, 68015, 70381, 67428, 61003, 67622, 99157]} ``` + `bin`文件中的每一行均对应原始数据集中的每一个句子,表示每个句子的`token`(下文将用sequence指定)。 生成的`meta`文件的格式如下: + ```bash (0, 11), (90, 15), (208, 13) ``` + 在`meta`文件中,每个元组对应着`bin`文件中每一个`sequence`的元信息。其中,元组的第一个元素表示每个`sequence`在所有`sequence`中的`starting index`,第二个元素表示每个`sequence`中有多少个`tokens`。 例如,对于第一个`sequence`,`starting index`为 0,有 11 个`tokens`;对于第二个`sequence`,由于第一个`sequence`转换为`string`后的长度为`89`,因此它的`starting index`为 90,有 15 个`tokens`。 -`json`和`jsonl`类型的文件的`bin`和`meta`文件格式和`txt`一致,此处不再赘叙。 \ No newline at end of file +`json`和`jsonl`类型的文件的`bin`和`meta`文件格式和`txt`一致,此处不再赘叙。 + +# pal_inference.py + +在 [GSM8K](https://huggingface.co/datasets/gsm8k) 数据集上使用 [PAL](https://github.com/reasoning-machines/pal) 范式推理,使模型编写代码并通过 Python 解释器执行来解决数学问题。其用法如下: + +```python +# 用法: +python pal_inference.py [--dataset ] [--max_length ] [--top_p ] [--eoh ] [--eoa ] [--eos ] [--temperature ] [--time_out