avoid frequent file I/O ops

pull/51/head
gaoyang07 2023-07-13 16:26:00 +08:00
parent 6e561e65f6
commit f0d9e56a1a
1 changed files with 8 additions and 9 deletions

View File

@ -13,13 +13,13 @@ model_path = os.path.join(current_dir, "V7_sft.model")
tokenizer = InternLMTokenizer(vocab_file=model_path) tokenizer = InternLMTokenizer(vocab_file=model_path)
def write_bin(context: str, bin_output_path: str) -> None: def write_bin(context: str, bin_file) -> None:
""" """
Write bin file based on the context. Write bin file based on the context.
Args: Args:
context (str): the context of raw file. context (str): the context of raw file.
bin_output_path (str): the path for output bin file. bin_file (file handler): the opened bin file.
Example: Example:
>>> write_bin("今天天气晴朗适合出门散步", "out.bin") # the output file format is 'txt' >>> write_bin("今天天气晴朗适合出门散步", "out.bin") # the output file format is 'txt'
@ -34,9 +34,8 @@ def write_bin(context: str, bin_output_path: str) -> None:
# encode the data into bytes to save # encode the data into bytes to save
saved_bin = str.encode(json.dumps(data) + "\n") saved_bin = str.encode(json.dumps(data) + "\n")
# write bytes into bin path # write bytes into bin_file
with open(bin_output_path, "ab") as f: bin_file.write(saved_bin)
f.write(saved_bin)
def prepare_meta(bin_output_path: str): def prepare_meta(bin_output_path: str):
@ -90,14 +89,14 @@ def text2bin(text_input_path: str, bin_output_path: str):
assert file_format in ['txt', 'json', 'jsonl'], \ assert file_format in ['txt', 'json', 'jsonl'], \
print("Invalid input file type. Currently support `txt`, `json` and `jsonl`.") print("Invalid input file type. Currently support `txt`, `json` and `jsonl`.")
with open(text_input_path, "r") as text_file: with open(text_input_path, "r") as text_file, open(bin_output_path, "ab") as bin_file:
if file_format == 'txt': if file_format == 'txt':
for line in text_file: for line in text_file:
# Strip any leading/trailing whitespace # Strip any leading/trailing whitespace
stripped_line = line.strip() stripped_line = line.strip()
if stripped_line: if stripped_line:
# Pass each line to the write_bin function # Pass each line to the write_bin function
write_bin(stripped_line, bin_output_path) write_bin(stripped_line, bin_file)
elif file_format == 'json': elif file_format == 'json':
data = json.load(text_file) data = json.load(text_file)
@ -106,12 +105,12 @@ def text2bin(text_input_path: str, bin_output_path: str):
# the type of record is dict, transfer the dict into str # the type of record is dict, transfer the dict into str
context = json.dumps(record) context = json.dumps(record)
# encode the str and write into bin # encode the str and write into bin
write_bin(context, bin_output_path) write_bin(context, bin_file)
elif file_format == 'jsonl': elif file_format == 'jsonl':
for line in text_file: for line in text_file:
# encode the str and write into bin # encode the str and write into bin
write_bin(line, bin_output_path) write_bin(line, bin_file)
def parse_args(): def parse_args():