diff --git a/tools/tokenizer.py b/tools/tokenizer.py index bc92ea4..9a969ad 100644 --- a/tools/tokenizer.py +++ b/tools/tokenizer.py @@ -13,13 +13,13 @@ model_path = os.path.join(current_dir, "V7_sft.model") tokenizer = InternLMTokenizer(vocab_file=model_path) -def write_bin(context: str, bin_output_path: str) -> None: +def write_bin(context: str, bin_file) -> None: """ Write bin file based on the context. Args: context (str): the context of raw file. - bin_output_path (str): the path for output bin file. + bin_file (file handler): the opened bin file. Example: >>> write_bin("今天天气晴朗适合出门散步", "out.bin") # the output file format is 'txt' @@ -34,9 +34,8 @@ def write_bin(context: str, bin_output_path: str) -> None: # encode the data into bytes to save saved_bin = str.encode(json.dumps(data) + "\n") - # write bytes into bin path - with open(bin_output_path, "ab") as f: - f.write(saved_bin) + # write bytes into bin_file + bin_file.write(saved_bin) def prepare_meta(bin_output_path: str): @@ -90,14 +89,14 @@ def text2bin(text_input_path: str, bin_output_path: str): assert file_format in ['txt', 'json', 'jsonl'], \ print("Invalid input file type. Currently support `txt`, `json` and `jsonl`.") - with open(text_input_path, "r") as text_file: + with open(text_input_path, "r") as text_file, open(bin_output_path, "ab") as bin_file: if file_format == 'txt': for line in text_file: # Strip any leading/trailing whitespace stripped_line = line.strip() if stripped_line: # Pass each line to the write_bin function - write_bin(stripped_line, bin_output_path) + write_bin(stripped_line, bin_file) elif file_format == 'json': data = json.load(text_file) @@ -106,12 +105,12 @@ def text2bin(text_input_path: str, bin_output_path: str): # the type of record is dict, transfer the dict into str context = json.dumps(record) # encode the str and write into bin - write_bin(context, bin_output_path) + write_bin(context, bin_file) elif file_format == 'jsonl': for line in text_file: # encode the str and write into bin - write_bin(line, bin_output_path) + write_bin(line, bin_file) def parse_args():