avoid frequent file I/O ops

pull/51/head
gaoyang07 2023-07-13 16:26:00 +08:00
parent 6e561e65f6
commit f0d9e56a1a
1 changed files with 8 additions and 9 deletions

View File

@ -13,13 +13,13 @@ model_path = os.path.join(current_dir, "V7_sft.model")
tokenizer = InternLMTokenizer(vocab_file=model_path)
def write_bin(context: str, bin_output_path: str) -> None:
def write_bin(context: str, bin_file) -> None:
"""
Write bin file based on the context.
Args:
context (str): the context of raw file.
bin_output_path (str): the path for output bin file.
bin_file (file handler): the opened bin file.
Example:
>>> write_bin("今天天气晴朗适合出门散步", "out.bin") # the output file format is 'txt'
@ -34,9 +34,8 @@ def write_bin(context: str, bin_output_path: str) -> None:
# encode the data into bytes to save
saved_bin = str.encode(json.dumps(data) + "\n")
# write bytes into bin path
with open(bin_output_path, "ab") as f:
f.write(saved_bin)
# write bytes into bin_file
bin_file.write(saved_bin)
def prepare_meta(bin_output_path: str):
@ -90,14 +89,14 @@ def text2bin(text_input_path: str, bin_output_path: str):
assert file_format in ['txt', 'json', 'jsonl'], \
print("Invalid input file type. Currently support `txt`, `json` and `jsonl`.")
with open(text_input_path, "r") as text_file:
with open(text_input_path, "r") as text_file, open(bin_output_path, "ab") as bin_file:
if file_format == 'txt':
for line in text_file:
# Strip any leading/trailing whitespace
stripped_line = line.strip()
if stripped_line:
# Pass each line to the write_bin function
write_bin(stripped_line, bin_output_path)
write_bin(stripped_line, bin_file)
elif file_format == 'json':
data = json.load(text_file)
@ -106,12 +105,12 @@ def text2bin(text_input_path: str, bin_output_path: str):
# the type of record is dict, transfer the dict into str
context = json.dumps(record)
# encode the str and write into bin
write_bin(context, bin_output_path)
write_bin(context, bin_file)
elif file_format == 'jsonl':
for line in text_file:
# encode the str and write into bin
write_bin(line, bin_output_path)
write_bin(line, bin_file)
def parse_args():