update the import and fix lints

pull/51/head
gaoyang07 2023-07-13 16:52:17 +08:00
parent f0d9e56a1a
commit 2969032439
1 changed files with 12 additions and 13 deletions

View File

@ -2,14 +2,14 @@ import argparse
import json
import os
import sys
import warnings
import numpy as np
sys.path.append("tools/transformers")
from tokenization_internlm import InternLMTokenizer
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, "V7_sft.model")
sys.path.append(os.path.join(current_dir, "transformers"))
from tokenization_internlm import InternLMTokenizer
tokenizer = InternLMTokenizer(vocab_file=model_path)
@ -82,15 +82,15 @@ def text2bin(text_input_path: str, bin_output_path: str):
"""
# Check if the txt file exists
if not os.path.isfile(text_input_path):
warnings.warn(f"{text_input_path} does not exist.")
return
raise FileNotFoundError(f"{text_input_path} does not exist.")
file_format = text_input_path.split(".")[-1]
assert file_format in ['txt', 'json', 'jsonl'], \
print("Invalid input file type. Currently support `txt`, `json` and `jsonl`.")
assert file_format in ["txt", "json", "jsonl"], print(
"Invalid input file type. Currently support `txt`, `json` and `jsonl`."
)
with open(text_input_path, "r") as text_file, open(bin_output_path, "ab") as bin_file:
if file_format == 'txt':
if file_format == "txt":
for line in text_file:
# Strip any leading/trailing whitespace
stripped_line = line.strip()
@ -98,7 +98,7 @@ def text2bin(text_input_path: str, bin_output_path: str):
# Pass each line to the write_bin function
write_bin(stripped_line, bin_file)
elif file_format == 'json':
elif file_format == "json":
data = json.load(text_file)
# assuming data is a list of dictionaries
for record in data:
@ -106,8 +106,8 @@ def text2bin(text_input_path: str, bin_output_path: str):
context = json.dumps(record)
# encode the str and write into bin
write_bin(context, bin_file)
elif file_format == 'jsonl':
elif file_format == "jsonl":
for line in text_file:
# encode the str and write into bin
write_bin(line, bin_file)
@ -121,8 +121,7 @@ def parse_args():
required=True,
help="Path to the input text file.",
)
parser.add_argument(
"--bin_output_path", type=str, required=True, help="Path to the output bin file.")
parser.add_argument("--bin_output_path", type=str, required=True, help="Path to the output bin file.")
return parser.parse_args()