#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace py = pybind11; const int32_t LONG_SENTENCE_LEN = 512; struct MaskedLMInstance { int index; std::string label; MaskedLMInstance(int index, std::string label) { this->index = index; this->label = label; } }; auto get_new_segment(std::vector segment, std::vector segment_jieba, const std::vector chinese_vocab) { // const std::unordered_set &chinese_vocab std::unordered_set seq_cws_dict; for (auto word : segment_jieba) { seq_cws_dict.insert(word); } int i = 0; std::vector new_segment; int segment_size = segment.size(); while (i < segment_size) { if (!chinese_vocab[i]) { //chinese_vocab.find(segment[i]) == chinese_vocab.end() new_segment.emplace_back(segment[i]); i += 1; continue; } bool has_add = false; for (int length = 3; length >= 1; length--) { if (i + length > segment_size) { continue; } std::string chinese_word = ""; for (int j = i; j < i + length; j++) { chinese_word += segment[j]; } if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) { new_segment.emplace_back(segment[i]); for (int j = i + 1; j < i + length; j++) { new_segment.emplace_back("##" + segment[j]); } i += length; has_add = true; break; } } if (!has_add) { new_segment.emplace_back(segment[i]); i += 1; } } return new_segment; } bool startsWith(const std::string& s, const std::string& sub) { return s.find(sub) == 0 ? true : false; } auto create_whole_masked_lm_predictions(std::vector &tokens, const std::vector &original_tokens, const std::vector &vocab_words, std::map &vocab, const int max_predictions_per_seq, const double masked_lm_prob) { // for (auto item : vocab) { // std::cout << "key=" << std::string(py::str(item.first)) << ", " // << "value=" << std::string(py::str(item.second)) << std::endl; // } std::vector > cand_indexes; std::vector cand_temp; int tokens_size = tokens.size(); std::string prefix = "##"; bool do_whole_masked = true; for (int i = 0; i < tokens_size; i++) { if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") { continue; } if (do_whole_masked && (cand_indexes.size() > 0) && (tokens[i].rfind(prefix, 0) == 0)) { cand_temp.emplace_back(i); } else { if (cand_temp.size() > 0) { cand_indexes.emplace_back(cand_temp); } cand_temp.clear(); cand_temp.emplace_back(i); } } auto seed = std::chrono::system_clock::now().time_since_epoch().count(); std::shuffle(cand_indexes.begin(), cand_indexes.end(), std::default_random_engine(seed)); // for (auto i : cand_indexes) { // for (auto j : i) { // std::cout << tokens[j] << " "; // } // std::cout << std::endl; // } // for (auto i : output_tokens) { // std::cout << i; // } // std::cout << std::endl; int num_to_predict = std::min(max_predictions_per_seq, std::max(1, int(tokens_size * masked_lm_prob))); // std::cout << num_to_predict << std::endl; std::set covered_indexes; std::vector masked_lm_output(tokens_size, -1); int vocab_words_len = vocab_words.size(); std::default_random_engine e(seed); std::uniform_real_distribution u1(0.0, 1.0); std::uniform_int_distribution u2(0, vocab_words_len - 1); int mask_cnt = 0; std::vector output_tokens; output_tokens = original_tokens; for (auto index_set : cand_indexes) { if (mask_cnt > num_to_predict) { break; } int index_set_size = index_set.size(); if (mask_cnt + index_set_size > num_to_predict) { continue; } bool is_any_index_covered = false; for (auto index : index_set) { if (covered_indexes.find(index) != covered_indexes.end()) { is_any_index_covered = true; break; } } if (is_any_index_covered) { continue; } for (auto index : index_set) { covered_indexes.insert(index); std::string masked_token; if (u1(e) < 0.8) { masked_token = "[MASK]"; } else { if (u1(e) < 0.5) { masked_token = output_tokens[index]; } else { int random_index = u2(e); masked_token = vocab_words[random_index]; } } // masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index])); masked_lm_output[index] = vocab[output_tokens[index]]; output_tokens[index] = masked_token; mask_cnt++; } } // for (auto p : masked_lms) { // masked_lm_output[p.index] = vocab[p.label]; // } return std::make_tuple(output_tokens, masked_lm_output); } PYBIND11_MODULE(mask, m) { m.def("create_whole_masked_lm_predictions", &create_whole_masked_lm_predictions); m.def("get_new_segment", &get_new_segment); }