mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
191 lines
5.5 KiB
191 lines
5.5 KiB
2 years ago
|
#include <math.h>
|
||
|
#include <pybind11/numpy.h>
|
||
|
#include <pybind11/pybind11.h>
|
||
|
#include <pybind11/stl.h>
|
||
|
|
||
|
#include <algorithm>
|
||
|
#include <chrono>
|
||
|
#include <iostream>
|
||
|
#include <limits>
|
||
|
#include <random>
|
||
|
#include <stdexcept>
|
||
|
#include <string>
|
||
|
#include <tuple>
|
||
|
#include <unordered_map>
|
||
|
#include <unordered_set>
|
||
|
#include <vector>
|
||
|
|
||
|
namespace py = pybind11;
|
||
|
|
||
|
const int32_t LONG_SENTENCE_LEN = 512;
|
||
|
|
||
|
struct MaskedLMInstance {
|
||
|
int index;
|
||
|
std::string label;
|
||
|
MaskedLMInstance(int index, std::string label) {
|
||
|
this->index = index;
|
||
|
this->label = label;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
auto get_new_segment(
|
||
|
std::vector<std::string> segment, std::vector<std::string> segment_jieba,
|
||
|
const std::vector<bool> chinese_vocab) { // const
|
||
|
// std::unordered_set<std::string>
|
||
|
// &chinese_vocab
|
||
|
std::unordered_set<std::string> seq_cws_dict;
|
||
|
for (auto word : segment_jieba) {
|
||
|
seq_cws_dict.insert(word);
|
||
|
}
|
||
|
int i = 0;
|
||
|
std::vector<std::string> new_segment;
|
||
|
int segment_size = segment.size();
|
||
|
while (i < segment_size) {
|
||
|
if (!chinese_vocab[i]) { // chinese_vocab.find(segment[i]) ==
|
||
|
// chinese_vocab.end()
|
||
|
new_segment.emplace_back(segment[i]);
|
||
|
i += 1;
|
||
|
continue;
|
||
|
}
|
||
|
bool has_add = false;
|
||
|
for (int length = 3; length >= 1; length--) {
|
||
|
if (i + length > segment_size) {
|
||
|
continue;
|
||
|
}
|
||
|
std::string chinese_word = "";
|
||
|
for (int j = i; j < i + length; j++) {
|
||
|
chinese_word += segment[j];
|
||
|
}
|
||
|
if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) {
|
||
|
new_segment.emplace_back(segment[i]);
|
||
|
for (int j = i + 1; j < i + length; j++) {
|
||
|
new_segment.emplace_back("##" + segment[j]);
|
||
|
}
|
||
|
i += length;
|
||
|
has_add = true;
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
if (!has_add) {
|
||
|
new_segment.emplace_back(segment[i]);
|
||
|
i += 1;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return new_segment;
|
||
|
}
|
||
|
|
||
|
bool startsWith(const std::string &s, const std::string &sub) {
|
||
|
return s.find(sub) == 0 ? true : false;
|
||
|
}
|
||
|
|
||
|
auto create_whole_masked_lm_predictions(
|
||
|
std::vector<std::string> &tokens,
|
||
|
const std::vector<std::string> &original_tokens,
|
||
|
const std::vector<std::string> &vocab_words,
|
||
|
std::map<std::string, int> &vocab, const int max_predictions_per_seq,
|
||
|
const double masked_lm_prob) {
|
||
|
// for (auto item : vocab) {
|
||
|
// std::cout << "key=" << std::string(py::str(item.first)) << ", "
|
||
|
// << "value=" << std::string(py::str(item.second)) <<
|
||
|
// std::endl;
|
||
|
// }
|
||
|
std::vector<std::vector<int> > cand_indexes;
|
||
|
std::vector<int> cand_temp;
|
||
|
int tokens_size = tokens.size();
|
||
|
std::string prefix = "##";
|
||
|
bool do_whole_masked = true;
|
||
|
|
||
|
for (int i = 0; i < tokens_size; i++) {
|
||
|
if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") {
|
||
|
continue;
|
||
|
}
|
||
|
if (do_whole_masked && (cand_indexes.size() > 0) &&
|
||
|
(tokens[i].rfind(prefix, 0) == 0)) {
|
||
|
cand_temp.emplace_back(i);
|
||
|
} else {
|
||
|
if (cand_temp.size() > 0) {
|
||
|
cand_indexes.emplace_back(cand_temp);
|
||
|
}
|
||
|
cand_temp.clear();
|
||
|
cand_temp.emplace_back(i);
|
||
|
}
|
||
|
}
|
||
|
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
|
||
|
std::shuffle(cand_indexes.begin(), cand_indexes.end(),
|
||
|
std::default_random_engine(seed));
|
||
|
// for (auto i : cand_indexes) {
|
||
|
// for (auto j : i) {
|
||
|
// std::cout << tokens[j] << " ";
|
||
|
// }
|
||
|
// std::cout << std::endl;
|
||
|
// }
|
||
|
// for (auto i : output_tokens) {
|
||
|
// std::cout << i;
|
||
|
// }
|
||
|
// std::cout << std::endl;
|
||
|
|
||
|
int num_to_predict = std::min(max_predictions_per_seq,
|
||
|
std::max(1, int(tokens_size * masked_lm_prob)));
|
||
|
// std::cout << num_to_predict << std::endl;
|
||
|
|
||
|
std::set<int> covered_indexes;
|
||
|
std::vector<int> masked_lm_output(tokens_size, -1);
|
||
|
int vocab_words_len = vocab_words.size();
|
||
|
std::default_random_engine e(seed);
|
||
|
std::uniform_real_distribution<double> u1(0.0, 1.0);
|
||
|
std::uniform_int_distribution<unsigned> u2(0, vocab_words_len - 1);
|
||
|
int mask_cnt = 0;
|
||
|
std::vector<std::string> output_tokens;
|
||
|
output_tokens = original_tokens;
|
||
|
|
||
|
for (auto index_set : cand_indexes) {
|
||
|
if (mask_cnt > num_to_predict) {
|
||
|
break;
|
||
|
}
|
||
|
int index_set_size = index_set.size();
|
||
|
if (mask_cnt + index_set_size > num_to_predict) {
|
||
|
continue;
|
||
|
}
|
||
|
bool is_any_index_covered = false;
|
||
|
for (auto index : index_set) {
|
||
|
if (covered_indexes.find(index) != covered_indexes.end()) {
|
||
|
is_any_index_covered = true;
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
if (is_any_index_covered) {
|
||
|
continue;
|
||
|
}
|
||
|
for (auto index : index_set) {
|
||
|
covered_indexes.insert(index);
|
||
|
std::string masked_token;
|
||
|
if (u1(e) < 0.8) {
|
||
|
masked_token = "[MASK]";
|
||
|
} else {
|
||
|
if (u1(e) < 0.5) {
|
||
|
masked_token = output_tokens[index];
|
||
|
} else {
|
||
|
int random_index = u2(e);
|
||
|
masked_token = vocab_words[random_index];
|
||
|
}
|
||
|
}
|
||
|
// masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index]));
|
||
|
masked_lm_output[index] = vocab[output_tokens[index]];
|
||
|
output_tokens[index] = masked_token;
|
||
|
mask_cnt++;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// for (auto p : masked_lms) {
|
||
|
// masked_lm_output[p.index] = vocab[p.label];
|
||
|
// }
|
||
|
return std::make_tuple(output_tokens, masked_lm_output);
|
||
|
}
|
||
|
|
||
|
PYBIND11_MODULE(mask, m) {
|
||
|
m.def("create_whole_masked_lm_predictions",
|
||
|
&create_whole_masked_lm_predictions);
|
||
|
m.def("get_new_segment", &get_new_segment);
|
||
|
}
|