import multiprocessing
import os
import re
from tqdm import tqdm
from typing import List
import json
import time
import argparse
import functools

def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
    """
    Args:
        document:
        flag: Type:str, "all" 中英文标点分句,"zh" 中文标点分句,"en" 英文标点分句
        limit: 默认单句最大长度为510个字符
    Returns: Type:list
    """
    sent_list = []
    try:
        if flag == "zh":
            document = re.sub('(?P<quotation_mark>([。?!…](?![”’"\'])))', r'\g<quotation_mark>\n', document)  # 单字符断句符
            document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document)  # 特殊引号
        elif flag == "en":
            document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)  # 英文单字符断句符
            document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n', document)  # 特殊引号
        else:
            document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)  # 单字符断句符
            
            document = re.sub('(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’"\']))', r'\g<quotation_mark>\n',
                            document)  # 特殊引号

        sent_list_ori = document.splitlines()
        for sent in sent_list_ori:
            sent = sent.strip()
            if not sent:
                continue
            elif len(sent) <= 2:
                continue
            else:
                while len(sent) > limit:
                    temp = sent[0:limit]
                    sent_list.append(temp)
                    sent = sent[limit:]
                sent_list.append(sent)
    except:
        sent_list.clear()
        sent_list.append(document)
    return sent_list


def get_sent(output_path,
            input_path,
            fin_list=[], host=-1, seq_len=512) -> None:

    workers = 32

    if input_path[-1] == '/':
        input_path = input_path[:-1]
    
    cur_path = os.path.join(output_path, str(host) + '.txt')
    new_split_sentence = functools.partial(split_sentence, limit=seq_len-2)
    with open(cur_path, 'w', encoding='utf-8') as f:
        for fi, fin_path in enumerate(fin_list):
            if not os.path.exists(os.path.join(input_path, fin_path[0])):
                continue
            if '.json' not in fin_path[0]:
                continue

            print("Processing ", fin_path[0], " ", fi)
            
            with open(os.path.join(input_path, fin_path[0]), 'r') as fin:
                f_data = [l['content'] for l in json.load(fin)]

                pool = multiprocessing.Pool(workers)
                all_sent = pool.imap_unordered(new_split_sentence, f_data, 32)
                pool.close()
            print('finished..')

            cnt = 0
            for d in tqdm(all_sent):
                for i in d:
                    f.write(i.strip() + '\n')
                f.write(']]' + '\n')
                cnt += 1
                # if cnt >= 2:
                #     exit()


def getFileSize(filepath, shard):
    all_data = []
    for i in os.listdir(filepath):
        all_data.append(os.path.join(filepath, i))
    all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data])
    ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data]
    ans = sorted(ans, key=lambda x: x[1], reverse=True)
    per_size = all_size / shard
    real_shard = []
    temp = []
    accu_size = 0
    for i in ans:
        accu_size += i[1]
        temp.append(i)
        if accu_size > per_size:
            real_shard.append(temp)
            accu_size = 0
            temp = []
            
    if len(temp) > 0:
        real_shard.append(temp)
    
    return real_shard


def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'):
    import socket
    host = int(socket.gethostname().split(server_name)[-1])
    
    fin_list = real_shard[server_num * base + host - 1]
    print(fin_list)
    print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}')
    return fin_list, host


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--server_num', type=int, default=10, help='number of servers')
    parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
    parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100')
    parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus')
    parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence')
    args = parser.parse_args()

    server_num = args.server_num 
    seq_len = args.seq_len
    shard = args.shard 
    input_path = args.input_path
    output_path = args.output_path 

    real_shard = getFileSize(input_path, shard)

    start = time.time()
    for index, shard in enumerate(real_shard):
        get_sent(output_path,
                input_path,
                fin_list=shard, 
                host=index,
                seq_len=seq_len)
    print(f'cost {str(time.time() - start)}')

    # if you have multiple server, you can use code below or modify code to openmpi
    
    # for i in range(len(real_shard) // server_num + 1):
    #     fin_list, host = get_start_end(real_shard, i)
        
    #     start = time.time()
    #     get_sent(output_path,
    #             input_path,
    #             fin_list=fin_list, host= 10 * i + host - 1)

    #     print(f'cost {str(time.time() - start)}')