diff --git a/CMakeLists.txt b/CMakeLists.txt index a43c99f90..1bce6ca4c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -156,8 +156,10 @@ file(GLOB SD_LIB_SOURCES "src/*.h" "src/*.cpp" "src/*.hpp" - "src/vocab/*.h" - "src/vocab/*.cpp" + "src/tokenizers/*.h" + "src/tokenizers/*.cpp" + "src/tokenizers/vocab/*.h" + "src/tokenizers/vocab/*.cpp" ) find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) @@ -250,7 +252,7 @@ endif() add_subdirectory(thirdparty) target_link_libraries(${SD_LIB} PUBLIC ggml zip) -target_include_directories(${SD_LIB} PUBLIC . include) +target_include_directories(${SD_LIB} PUBLIC . src include) target_include_directories(${SD_LIB} PUBLIC . thirdparty) target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17) diff --git a/format-code.sh b/format-code.sh index 2e87da414..5c30fb4ff 100644 --- a/format-code.sh +++ b/format-code.sh @@ -1,4 +1,4 @@ -for f in src/*.cpp src/*.h src/*.hpp src/vocab/*.h src/vocab/*.cpp \ +for f in src/*.cpp src/*.h src/*.hpp src/tokenizers/*.h src/tokenizers/*.cpp src/tokenizers/vocab/*.h src/tokenizers/vocab/*.cpp \ examples/cli/*.cpp examples/cli/*.h examples/server/*.cpp \ examples/common/*.hpp examples/common/*.h examples/common/*.cpp; do [[ "$f" == vocab* ]] && continue diff --git a/src/clip.hpp b/src/clip.hpp index 8f2ac0643..8a2070e0b 100644 --- a/src/clip.hpp +++ b/src/clip.hpp @@ -3,455 +3,7 @@ #include "ggml_extend.hpp" #include "model.h" -#include "tokenize_util.h" -#include "vocab/vocab.h" - -/*================================================== CLIPTokenizer ===================================================*/ - -__STATIC_INLINE__ std::vector> bytes_to_unicode() { - std::vector> byte_unicode_pairs; - std::set byte_set; - for (int b = static_cast('!'); b <= static_cast('~'); ++b) { - byte_set.insert(b); - byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(b))); - } - for (int b = 161; b <= 172; ++b) { - byte_set.insert(b); - byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(b))); - } - for (int b = 174; b <= 255; ++b) { - byte_set.insert(b); - byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(b))); - } - int n = 0; - for (int b = 0; b < 256; ++b) { - if (byte_set.find(b) == byte_set.end()) { - byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(n + 256))); - ++n; - } - } - // LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size()); - return byte_unicode_pairs; -} - -// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py - -typedef std::function&)> on_new_token_cb_t; - -class CLIPTokenizer { -private: - std::map byte_encoder; - std::map byte_decoder; - std::map encoder; - std::map decoder; - std::map, int> bpe_ranks; - std::regex pat; - int encoder_len; - int bpe_len; - - std::vector special_tokens; - -public: - const std::string UNK_TOKEN = "<|endoftext|>"; - const std::string BOS_TOKEN = "<|startoftext|>"; - const std::string EOS_TOKEN = "<|endoftext|>"; - const std::string PAD_TOKEN = "<|endoftext|>"; - - const int UNK_TOKEN_ID = 49407; - const int BOS_TOKEN_ID = 49406; - const int EOS_TOKEN_ID = 49407; - const int PAD_TOKEN_ID = 49407; - -private: - static std::string strip(const std::string& str) { - std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); - std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); - - if (start == std::string::npos) { - // String contains only whitespace characters - return ""; - } - - return str.substr(start, end - start + 1); - } - - static std::string whitespace_clean(std::string text) { - text = std::regex_replace(text, std::regex(R"(\s+)"), " "); - text = strip(text); - return text; - } - - static std::set> get_pairs(const std::vector& subwords) { - std::set> pairs; - if (subwords.size() == 0) { - return pairs; - } - std::u32string prev_subword = subwords[0]; - for (int i = 1; i < subwords.size(); i++) { - std::u32string subword = subwords[i]; - std::pair pair(prev_subword, subword); - pairs.insert(pair); - prev_subword = subword; - } - return pairs; - } - - bool is_special_token(const std::string& token) { - for (auto& special_token : special_tokens) { - if (special_token == token) { - return true; - } - } - return false; - } - -public: - CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "") - : PAD_TOKEN_ID(pad_token_id) { - if (merges_utf8_str.size() > 0) { - load_from_merges(merges_utf8_str); - } else { - load_from_merges(load_clip_merges()); - } - add_special_token("<|startoftext|>"); - add_special_token("<|endoftext|>"); - } - - void load_from_merges(const std::string& merges_utf8_str) { - auto byte_unicode_pairs = bytes_to_unicode(); - // printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size()); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); - for (auto& pair : byte_unicode_pairs) { - byte_decoder[pair.second] = pair.first; - } - // for (auto & pair: byte_unicode_pairs) { - // std::cout << pair.first << ": " << pair.second << std::endl; - // } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - // LOG_DEBUG("merges size %llu", merges.size()); - GGML_ASSERT(merges.size() == 48895); - merges = std::vector(merges.begin() + 1, merges.end()); - std::vector> merge_pairs; - for (const auto& merge : merges) { - size_t space_pos = merge.find(' '); - merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), - // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - } - std::vector vocab; - for (const auto& pair : byte_unicode_pairs) { - vocab.push_back(pair.second); - } - for (const auto& pair : byte_unicode_pairs) { - vocab.push_back(pair.second + utf8_to_utf32("")); - } - for (const auto& merge : merge_pairs) { - vocab.push_back(merge.first + merge.second); - } - vocab.push_back(utf8_to_utf32("<|startoftext|>")); - vocab.push_back(utf8_to_utf32("<|endoftext|>")); - LOG_DEBUG("vocab size: %llu", vocab.size()); - int i = 0; - for (const auto& token : vocab) { - encoder[token] = i; - decoder[i] = token; - i++; - } - encoder_len = i; - - auto it = encoder.find(utf8_to_utf32("img")); - if (it != encoder.end()) { - LOG_DEBUG("trigger word img already in vocab"); - } else { - LOG_DEBUG("trigger word img not in vocab yet"); - } - - int rank = 0; - for (const auto& merge : merge_pairs) { - bpe_ranks[merge] = rank++; - } - bpe_len = rank; - }; - - void add_token(const std::string& text) { - std::u32string token = utf8_to_utf32(text); - auto it = encoder.find(token); - if (it != encoder.end()) { - encoder[token] = encoder_len; - decoder[encoder_len] = token; - encoder_len++; - } - } - - void add_special_token(const std::string& token) { - special_tokens.push_back(token); - } - - std::u32string bpe(const std::u32string& token) { - std::vector word; - - for (int i = 0; i < token.size() - 1; i++) { - word.emplace_back(1, token[i]); - } - word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("")); - - std::set> pairs = get_pairs(word); - - if (pairs.empty()) { - return token + utf8_to_utf32(""); - } - - while (true) { - auto min_pair_iter = std::min_element(pairs.begin(), - pairs.end(), - [&](const std::pair& a, - const std::pair& b) { - if (bpe_ranks.find(a) == bpe_ranks.end()) { - return false; - } else if (bpe_ranks.find(b) == bpe_ranks.end()) { - return true; - } - return bpe_ranks.at(a) < bpe_ranks.at(b); - }); - - const std::pair& bigram = *min_pair_iter; - - if (bpe_ranks.find(bigram) == bpe_ranks.end()) { - break; - } - - std::u32string first = bigram.first; - std::u32string second = bigram.second; - std::vector new_word; - int32_t i = 0; - - while (i < word.size()) { - auto it = std::find(word.begin() + i, word.end(), first); - if (it == word.end()) { - new_word.insert(new_word.end(), word.begin() + i, word.end()); - break; - } - new_word.insert(new_word.end(), word.begin() + i, it); - i = static_cast(std::distance(word.begin(), it)); - - if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { - new_word.push_back(first + second); - i += 2; - } else { - new_word.push_back(word[i]); - i += 1; - } - } - - word = new_word; - - if (word.size() == 1) { - break; - } - pairs = get_pairs(word); - } - - std::u32string result; - for (int i = 0; i < word.size(); i++) { - result += word[i]; - if (i != word.size() - 1) { - result += utf8_to_utf32(" "); - } - } - - return result; - } - - std::vector tokenize(std::string text, - on_new_token_cb_t on_new_token_cb, - size_t max_length = 0, - bool padding = false) { - std::vector tokens = encode(text, on_new_token_cb); - - tokens.insert(tokens.begin(), BOS_TOKEN_ID); - if (max_length > 0) { - if (tokens.size() > max_length - 1) { - tokens.resize(max_length - 1); - tokens.push_back(EOS_TOKEN_ID); - } else { - tokens.push_back(EOS_TOKEN_ID); - if (padding) { - tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID); - } - } - } - - return tokens; - } - - void pad_tokens(std::vector& tokens, - std::vector& weights, - size_t max_length = 0, - bool padding = false) { - if (max_length > 0 && padding) { - size_t n = static_cast(std::ceil(tokens.size() * 1.0 / (max_length - 2))); - if (n == 0) { - n = 1; - } - size_t length = max_length * n; - LOG_DEBUG("token length: %llu", length); - std::vector new_tokens; - std::vector new_weights; - new_tokens.push_back(BOS_TOKEN_ID); - new_weights.push_back(1.0); - int token_idx = 0; - for (int i = 1; i < length; i++) { - if (token_idx >= tokens.size()) { - break; - } - if (i % max_length == 0) { - new_tokens.push_back(BOS_TOKEN_ID); - new_weights.push_back(1.0); - } else if (i % max_length == max_length - 1) { - new_tokens.push_back(EOS_TOKEN_ID); - new_weights.push_back(1.0); - } else { - new_tokens.push_back(tokens[token_idx]); - new_weights.push_back(weights[token_idx]); - token_idx++; - } - } - - new_tokens.push_back(EOS_TOKEN_ID); - new_weights.push_back(1.0); - tokens = new_tokens; - weights = new_weights; - - if (padding) { - tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID); - weights.insert(weights.end(), length - weights.size(), 1.0); - } - } - } - - std::string clean_up_tokenization(std::string& text) { - std::regex pattern(R"( ,)"); - // Replace " ," with "," - std::string result = std::regex_replace(text, pattern, ","); - return result; - } - - std::string decode(const std::vector& tokens) { - std::string text = ""; - for (int t : tokens) { - if (t == 49406 || t == 49407) - continue; - std::u32string ts = decoder[t]; - // printf("%d, %s \n", t, utf32_to_utf8(ts).c_str()); - std::string s = utf32_to_utf8(ts); - if (s.length() >= 4) { - if (ends_with(s, "")) { - text += s.replace(s.length() - 4, s.length() - 1, "") + " "; - } else { - text += s; - } - } else { - text += " " + s; - } - } - // std::vector bytes; - // for (auto c : text){ - // bytes.push_back(byte_decoder[c]); - // } - - // std::string s((char *)bytes.data()); - // std::string s = ""; - text = clean_up_tokenization(text); - return trim(text); - } - - std::vector token_split(const std::string& text) { - std::regex pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)", - std::regex::icase); - std::sregex_iterator iter(text.begin(), text.end(), pat); - std::sregex_iterator end; - - std::vector result; - for (; iter != end; ++iter) { - result.emplace_back(iter->str()); - } - - return result; - } - - std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb) { - std::string original_text = text; - std::vector bpe_tokens; - text = whitespace_clean(text); - std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); }); - - std::string str = text; - std::vector token_strs; - - auto splited_texts = split_with_special_tokens(text, special_tokens); - - for (auto& splited_text : splited_texts) { - LOG_DEBUG("token %s", splited_text.c_str()); - if (is_special_token(splited_text)) { - LOG_DEBUG("special %s", splited_text.c_str()); - bool skip = on_new_token_cb(splited_text, bpe_tokens); - if (skip) { - token_strs.push_back(splited_text); - continue; - } - continue; - } - - auto tokens = token_split(splited_text); - for (auto& token : tokens) { - if (on_new_token_cb != nullptr) { - bool skip = on_new_token_cb(token, bpe_tokens); - if (skip) { - token_strs.push_back(token); - continue; - } - } - - std::string token_str = token; - std::u32string utf32_token; - for (int i = 0; i < token_str.length(); i++) { - unsigned char b = token_str[i]; - utf32_token += byte_encoder[b]; - } - auto bpe_strs = bpe(utf32_token); - size_t start = 0; - size_t pos; - while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { - auto bpe_str = bpe_strs.substr(start, pos - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - - start = pos + 1; - } - auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - } - } - // std::stringstream ss; - // ss << "["; - // for (auto token : token_strs) { - // ss << "\"" << token << "\", "; - // } - // ss << "]"; - // LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); - // printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str()); - return bpe_tokens; - } -}; +#include "tokenizers/clip_tokenizer.h" /*================================================ FrozenCLIPEmbedder ================================================*/ diff --git a/src/conditioner.hpp b/src/conditioner.hpp index 5564373eb..a39346cbf 100644 --- a/src/conditioner.hpp +++ b/src/conditioner.hpp @@ -256,15 +256,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { return true; } - std::tuple, std::vector, std::vector> - tokenize_with_trigger_token(std::string text, - int num_input_imgs, - int32_t image_token, - bool padding = false) { - return tokenize_with_trigger_token(text, num_input_imgs, image_token, - text_model->model.n_token, padding); - } - std::vector convert_token_to_id(std::string text) { auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { auto iter = embedding_map.find(str); @@ -288,9 +279,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { std::tuple, std::vector, std::vector> tokenize_with_trigger_token(std::string text, int num_input_imgs, - int32_t image_token, - size_t max_length = 0, - bool padding = false) { + int32_t image_token) { auto parsed_attention = parse_prompt_attention(text); { @@ -377,7 +366,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { // tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID); // weights.insert(weights.begin(), 1.0); - tokenizer.pad_tokens(tokens, weights, max_length, padding); + tokenizer.pad_tokens(tokens, &weights, nullptr, text_model->model.n_token, text_model->model.n_token, true); int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs; for (int i = 0; i < tokens.size(); i++) { // if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs @@ -403,13 +392,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { } std::pair, std::vector> tokenize(std::string text, - bool padding = false) { - return tokenize(text, text_model->model.n_token, padding); - } - - std::pair, std::vector> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { + size_t min_length = 0, + size_t max_length = 0, + bool allow_overflow_expand = true) { auto parsed_attention = parse_prompt_attention(text); { @@ -460,7 +445,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { weights.insert(weights.end(), curr_tokens.size(), curr_weight); } - tokenizer.pad_tokens(tokens, weights, max_length, padding); + tokenizer.pad_tokens(tokens, &weights, nullptr, min_length, max_length, allow_overflow_expand); // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", "; @@ -603,8 +588,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { GGML_ASSERT(image_tokens.size() == 1); auto tokens_and_weights = tokenize_with_trigger_token(conditioner_params.text, conditioner_params.num_input_imgs, - image_tokens[0], - true); + image_tokens[0]); std::vector& tokens = std::get<0>(tokens_and_weights); std::vector& weights = std::get<1>(tokens_and_weights); std::vector& clsm = std::get<2>(tokens_and_weights); @@ -630,7 +614,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { std::string remove_trigger_from_prompt(const std::string& prompt) override { auto image_tokens = convert_token_to_id(trigger_word); GGML_ASSERT(image_tokens.size() == 1); - auto tokens_and_weights = tokenize(prompt, false); + auto tokens_and_weights = tokenize(prompt); std::vector& tokens = tokens_and_weights.first; auto it = std::find(tokens.begin(), tokens.end(), image_tokens[0]); GGML_ASSERT(it != tokens.end()); // prompt must have trigger word @@ -640,7 +624,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { SDCondition get_learned_condition(int n_threads, const ConditionerParams& conditioner_params) override { - auto tokens_and_weights = tokenize(conditioner_params.text, true); + auto tokens_and_weights = tokenize(conditioner_params.text, text_model->model.n_token, text_model->model.n_token, true); std::vector& tokens = tokens_and_weights.first; std::vector& weights = tokens_and_weights.second; return get_learned_condition_common(n_threads, @@ -822,8 +806,9 @@ struct SD3CLIPEmbedder : public Conditioner { } std::vector, std::vector>> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { + size_t min_length = 0, + size_t max_length = 0, + bool allow_overflow_expand = true) { auto parsed_attention = parse_prompt_attention(text); { @@ -860,20 +845,20 @@ struct SD3CLIPEmbedder : public Conditioner { clip_g_weights.insert(clip_g_weights.end(), curr_tokens.size(), curr_weight); } if (t5) { - std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + std::vector curr_tokens = t5_tokenizer.encode(curr_text); t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); } } if (clip_l) { - clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding); + clip_l_tokenizer.pad_tokens(clip_l_tokens, &clip_l_weights, nullptr, min_length, max_length, allow_overflow_expand); } if (clip_g) { - clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding); + clip_g_tokenizer.pad_tokens(clip_g_tokens, &clip_g_weights, nullptr, min_length, max_length, allow_overflow_expand); } if (t5) { - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, nullptr, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, &t5_weights, nullptr, min_length, max_length, true); } // for (int i = 0; i < clip_l_tokens.size(); i++) { @@ -1056,7 +1041,7 @@ struct SD3CLIPEmbedder : public Conditioner { SDCondition get_learned_condition(int n_threads, const ConditionerParams& conditioner_params) override { - auto tokens_and_weights = tokenize(conditioner_params.text, 77, true); + auto tokens_and_weights = tokenize(conditioner_params.text, 77, 77, true); return get_learned_condition_common(n_threads, tokens_and_weights, conditioner_params.clip_skip, @@ -1158,8 +1143,8 @@ struct FluxCLIPEmbedder : public Conditioner { } std::vector, std::vector>> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { + size_t min_length = 0, + size_t max_length = 0) { auto parsed_attention = parse_prompt_attention(text); { @@ -1189,17 +1174,17 @@ struct FluxCLIPEmbedder : public Conditioner { clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight); } if (t5) { - std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + std::vector curr_tokens = t5_tokenizer.encode(curr_text); t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); } } if (clip_l) { - clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding); + clip_l_tokenizer.pad_tokens(clip_l_tokens, &clip_l_weights, nullptr, 77, 77, true); } if (t5) { - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, nullptr, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, &t5_weights, nullptr, min_length, max_length, true); } // for (int i = 0; i < clip_l_tokens.size(); i++) { @@ -1300,7 +1285,7 @@ struct FluxCLIPEmbedder : public Conditioner { SDCondition get_learned_condition(int n_threads, const ConditionerParams& conditioner_params) override { - auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true); + auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, chunk_len); return get_learned_condition_common(n_threads, tokens_and_weights, conditioner_params.clip_skip, @@ -1377,8 +1362,8 @@ struct T5CLIPEmbedder : public Conditioner { } std::tuple, std::vector, std::vector> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { + size_t min_length = 0, + size_t max_length = 0) { auto parsed_attention = parse_prompt_attention(text); { @@ -1403,12 +1388,15 @@ struct T5CLIPEmbedder : public Conditioner { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + std::vector curr_tokens = t5_tokenizer.encode(curr_text); t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); } - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, &t5_weights, &t5_mask, min_length, max_length, true); + for (auto& mask_value : t5_mask) { + mask_value = mask_value > 0.0f ? 0.0f : -HUGE_VALF; + } } return {t5_tokens, t5_weights, t5_mask}; } @@ -1496,7 +1484,7 @@ struct T5CLIPEmbedder : public Conditioner { SDCondition get_learned_condition(int n_threads, const ConditionerParams& conditioner_params) override { - auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true); + auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, chunk_len); return get_learned_condition_common(n_threads, tokens_and_weights, conditioner_params.clip_skip, @@ -1505,14 +1493,14 @@ struct T5CLIPEmbedder : public Conditioner { }; struct AnimaConditioner : public Conditioner { - std::shared_ptr qwen_tokenizer; + std::shared_ptr qwen_tokenizer; T5UniGramTokenizer t5_tokenizer; std::shared_ptr llm; AnimaConditioner(ggml_backend_t backend, bool offload_params_to_cpu, const String2TensorStorage& tensor_storage_map = {}) { - qwen_tokenizer = std::make_shared(); + qwen_tokenizer = std::make_shared(); llm = std::make_shared(LLM::LLMArch::QWEN3, backend, offload_params_to_cpu, @@ -1578,7 +1566,7 @@ struct AnimaConditioner : public Conditioner { for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + std::vector curr_tokens = t5_tokenizer.tokenize(curr_text, nullptr, true); t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); } @@ -1620,7 +1608,7 @@ struct AnimaConditioner : public Conditioner { struct LLMEmbedder : public Conditioner { SDVersion version; - std::shared_ptr tokenizer; + std::shared_ptr tokenizer; std::shared_ptr llm; LLMEmbedder(ggml_backend_t backend, @@ -1637,9 +1625,9 @@ struct LLMEmbedder : public Conditioner { arch = LLM::LLMArch::QWEN3; } if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) { - tokenizer = std::make_shared(); + tokenizer = std::make_shared(); } else { - tokenizer = std::make_shared(); + tokenizer = std::make_shared(); } llm = std::make_shared(arch, backend, @@ -1677,10 +1665,10 @@ struct LLMEmbedder : public Conditioner { } } - std::tuple, std::vector> tokenize(std::string text, - const std::pair& attn_range, - size_t max_length = 0, - bool padding = false) { + std::tuple, std::vector, std::vector> tokenize(std::string text, + const std::pair& attn_range, + size_t min_length = 0, + size_t max_length = 100000000) { std::vector> parsed_attention; if (attn_range.first >= 0 && attn_range.second > 0) { parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); @@ -1710,39 +1698,34 @@ struct LLMEmbedder : public Conditioner { for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = tokenizer->tokenize(curr_text, nullptr); + std::vector curr_tokens = tokenizer->encode(curr_text, nullptr); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); weights.insert(weights.end(), curr_tokens.size(), curr_weight); } - tokenizer->pad_tokens(tokens, weights, max_length, padding); + std::vector mask; + tokenizer->pad_tokens(tokens, &weights, &mask, min_length, max_length); // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl; // } // std::cout << std::endl; - return {tokens, weights}; + return {tokens, weights, mask}; } sd::Tensor encode_prompt(int n_threads, const std::string prompt, const std::pair& prompt_attn_range, - int max_length, int min_length, + int hidden_states_min_length, const std::vector>>& image_embeds, const std::set& out_layers, int prompt_template_encode_start_idx) { - auto tokens_and_weights = tokenize(prompt, prompt_attn_range); - auto& tokens = std::get<0>(tokens_and_weights); - auto& weights = std::get<1>(tokens_and_weights); - std::vector mask; - - if (max_length > 0 && tokens.size() < max_length) { - mask.insert(mask.end(), tokens.size(), 1.f); - mask.insert(mask.end(), max_length - tokens.size(), 0.f); - tokenizer->pad_tokens(tokens, weights, max_length, true); - } + auto tokens_weights_mask = tokenize(prompt, prompt_attn_range, min_length); + auto& tokens = std::get<0>(tokens_weights_mask); + auto& weights = std::get<1>(tokens_weights_mask); + auto& mask = std::get<2>(tokens_weights_mask); sd::Tensor input_ids({static_cast(tokens.size())}, tokens); sd::Tensor attention_mask; @@ -1769,9 +1752,9 @@ struct LLMEmbedder : public Conditioner { GGML_ASSERT(hidden_states.shape()[1] > prompt_template_encode_start_idx); int64_t zero_pad_len = 0; - if (min_length > 0) { - if (hidden_states.shape()[1] - prompt_template_encode_start_idx < min_length) { - zero_pad_len = min_length - hidden_states.shape()[1] + prompt_template_encode_start_idx; + if (hidden_states_min_length > 0) { + if (hidden_states.shape()[1] - prompt_template_encode_start_idx < hidden_states_min_length) { + zero_pad_len = hidden_states_min_length - hidden_states.shape()[1] + prompt_template_encode_start_idx; } } @@ -1798,8 +1781,8 @@ struct LLMEmbedder : public Conditioner { std::vector> extra_prompts_attn_range; std::vector>> image_embeds; int prompt_template_encode_start_idx = 34; - int max_length = 0; // pad tokens - int min_length = 0; // zero pad hidden_states + int min_length = 0; // pad tokens + int hidden_states_min_length = 0; // zero pad hidden_states std::set out_layers; int64_t t0 = ggml_time_ms(); @@ -1874,7 +1857,7 @@ struct LLMEmbedder : public Conditioner { } } else if (version == VERSION_FLUX2) { prompt_template_encode_start_idx = 0; - min_length = 512; + hidden_states_min_length = 512; out_layers = {10, 20, 30}; prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; @@ -1907,7 +1890,7 @@ struct LLMEmbedder : public Conditioner { } } else if (version == VERSION_FLUX2_KLEIN) { prompt_template_encode_start_idx = 0; - max_length = 512; + min_length = 512; out_layers = {9, 18, 27}; prompt = "<|im_start|>user\n"; @@ -1919,7 +1902,7 @@ struct LLMEmbedder : public Conditioner { prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; } else if (version == VERSION_OVIS_IMAGE) { prompt_template_encode_start_idx = 28; - max_length = prompt_template_encode_start_idx + 256; + min_length = prompt_template_encode_start_idx + 256; prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:"; @@ -1935,8 +1918,8 @@ struct LLMEmbedder : public Conditioner { auto hidden_states = encode_prompt(n_threads, prompt, prompt_attn_range, - max_length, min_length, + hidden_states_min_length, image_embeds, out_layers, prompt_template_encode_start_idx); @@ -1945,8 +1928,8 @@ struct LLMEmbedder : public Conditioner { auto extra_hidden_states = encode_prompt(n_threads, extra_prompts[i], extra_prompts_attn_range[i], - max_length, min_length, + hidden_states_min_length, image_embeds, out_layers, prompt_template_encode_start_idx); diff --git a/src/llm.hpp b/src/llm.hpp index c6c296149..9eacdb905 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -14,465 +14,16 @@ #include #include -#include "clip.hpp" #include "ggml_extend.hpp" #include "json.hpp" #include "rope.hpp" -#include "tokenize_util.h" -#include "vocab/vocab.h" +#include "tokenizers/bpe_tokenizer.h" +#include "tokenizers/mistral_tokenizer.h" +#include "tokenizers/qwen2_tokenizer.h" namespace LLM { constexpr int LLM_GRAPH_SIZE = 10240; - class BPETokenizer { - protected: - std::map byte_encoder; - std::map byte_decoder; - std::map encoder; - std::map decoder; - std::map, int> bpe_ranks; - std::regex pat; - int encoder_len; - int bpe_len; - - std::string UNK_TOKEN; - std::string BOS_TOKEN; - std::string EOS_TOKEN; - std::string PAD_TOKEN; - - int UNK_TOKEN_ID; - int BOS_TOKEN_ID; - int EOS_TOKEN_ID; - int PAD_TOKEN_ID; - - std::vector special_tokens; - - bool add_bos_token = false; - - protected: - static std::string strip(const std::string& str) { - std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); - std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); - - if (start == std::string::npos) { - // String contains only whitespace characters - return ""; - } - - return str.substr(start, end - start + 1); - } - - static std::string whitespace_clean(std::string text) { - text = std::regex_replace(text, std::regex(R"(\s+)"), " "); - text = strip(text); - return text; - } - - static std::set> get_pairs(const std::vector& subwords) { - std::set> pairs; - if (subwords.size() == 0) { - return pairs; - } - std::u32string prev_subword = subwords[0]; - for (int i = 1; i < subwords.size(); i++) { - std::u32string subword = subwords[i]; - std::pair pair(prev_subword, subword); - pairs.insert(pair); - prev_subword = subword; - } - return pairs; - } - - bool is_special_token(const std::string& token) { - for (auto& special_token : special_tokens) { - if (special_token == token) { - return true; - } - } - return false; - } - - public: - BPETokenizer() = default; - - std::u32string bpe(const std::u32string& token) { - std::vector word; - - for (int i = 0; i < token.size(); i++) { - word.emplace_back(1, token[i]); - } - - std::set> pairs = get_pairs(word); - - if (pairs.empty()) { - return token; - } - - while (true) { - auto min_pair_iter = std::min_element(pairs.begin(), - pairs.end(), - [&](const std::pair& a, - const std::pair& b) { - if (bpe_ranks.find(a) == bpe_ranks.end()) { - return false; - } else if (bpe_ranks.find(b) == bpe_ranks.end()) { - return true; - } - return bpe_ranks.at(a) < bpe_ranks.at(b); - }); - - const std::pair& bigram = *min_pair_iter; - - if (bpe_ranks.find(bigram) == bpe_ranks.end()) { - break; - } - - std::u32string first = bigram.first; - std::u32string second = bigram.second; - std::vector new_word; - int32_t i = 0; - - while (i < word.size()) { - auto it = std::find(word.begin() + i, word.end(), first); - if (it == word.end()) { - new_word.insert(new_word.end(), word.begin() + i, word.end()); - break; - } - new_word.insert(new_word.end(), word.begin() + i, it); - i = static_cast(std::distance(word.begin(), it)); - - if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { - new_word.push_back(first + second); - i += 2; - } else { - new_word.push_back(word[i]); - i += 1; - } - } - - word = new_word; - - if (word.size() == 1) { - break; - } - pairs = get_pairs(word); - } - - std::u32string result; - for (int i = 0; i < word.size(); i++) { - result += word[i]; - if (i != word.size() - 1) { - result += utf8_to_utf32(" "); - } - } - - return result; - } - - std::vector tokenize(std::string text, - on_new_token_cb_t on_new_token_cb = nullptr, - size_t max_length = 0, - bool padding = false) { - std::vector tokens = encode(text, on_new_token_cb); - - if (max_length > 0) { - if (tokens.size() < max_length) { - tokens.resize(max_length); - } else { - if (padding) { - tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID); - } - } - } - - return tokens; - } - - void pad_tokens(std::vector& tokens, - std::vector& weights, - size_t max_length = 0, - bool padding = false) { - if (add_bos_token) { - tokens.insert(tokens.begin(), BOS_TOKEN_ID); - weights.insert(weights.begin(), 1.f); - } - if (max_length > 0 && padding) { - size_t n = static_cast(std::ceil(tokens.size() * 1.f / max_length)); - if (n == 0) { - n = 1; - } - size_t length = max_length * n; - LOG_DEBUG("token length: %llu", length); - tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID); - weights.insert(weights.end(), length - weights.size(), 1.f); - } - } - - std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) { - std::string original_text = text; - std::vector bpe_tokens; - std::vector token_strs; - - auto splited_texts = split_with_special_tokens(text, special_tokens); - - for (auto& splited_text : splited_texts) { - if (is_special_token(splited_text)) { - bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]); - token_strs.push_back(splited_text); - continue; - } - auto tokens = token_split(splited_text); - for (auto& token : tokens) { - if (on_new_token_cb != nullptr) { - bool skip = on_new_token_cb(token, bpe_tokens); - if (skip) { - continue; - } - } - - std::string token_str = token; - std::u32string utf32_token; - for (int i = 0; i < token_str.length(); i++) { - unsigned char b = token_str[i]; - utf32_token += byte_encoder[b]; - } - auto bpe_strs = bpe(utf32_token); - size_t start = 0; - size_t pos; - while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { - auto bpe_str = bpe_strs.substr(start, pos - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - - start = pos + 1; - } - auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - } - } - - std::stringstream ss; - ss << "["; - for (auto token : token_strs) { - ss << "\"" << token << "\", "; - } - ss << "]"; - LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); - // printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str()); - return bpe_tokens; - } - }; - - class Qwen2Tokenizer : public BPETokenizer { - protected: - void load_from_merges(const std::string& merges_utf8_str) { - auto byte_unicode_pairs = bytes_to_unicode(); - // printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size()); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); - for (auto& pair : byte_unicode_pairs) { - byte_decoder[pair.second] = pair.first; - } - // for (auto & pair: byte_unicode_pairs) { - // std::cout << pair.first << ": " << pair.second << std::endl; - // } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - LOG_DEBUG("merges size %llu", merges.size()); - merges = std::vector(merges.begin(), merges.end()); - std::vector> merge_pairs; - // int print_num = 10; - for (const auto& merge : merges) { - size_t space_pos = merge.find(' '); - merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // if (print_num > 0) { - // print_num--; - // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), - // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - // } - } - - std::vector tokens; - for (const auto& pair : byte_unicode_pairs) { - tokens.push_back(pair.second); - } - for (const auto& merge : merge_pairs) { - tokens.push_back(merge.first + merge.second); - } - for (auto& special_token : special_tokens) { - tokens.push_back(utf8_to_utf32(special_token)); - } - - int i = 0; - for (const auto& token : tokens) { - encoder[token] = i; - decoder[i] = token; - i++; - } - encoder_len = i; - LOG_DEBUG("vocab size: %d", encoder_len); - - int rank = 0; - for (const auto& merge : merge_pairs) { - bpe_ranks[merge] = rank++; - } - bpe_len = rank; - }; - - public: - explicit Qwen2Tokenizer(const std::string& merges_utf8_str = "") { - UNK_TOKEN = "<|endoftext|>"; - EOS_TOKEN = "<|endoftext|>"; - PAD_TOKEN = "<|endoftext|>"; - - UNK_TOKEN_ID = 151643; - EOS_TOKEN_ID = 151643; - PAD_TOKEN_ID = 151643; - - special_tokens = { - "<|endoftext|>", - "<|im_start|>", - "<|im_end|>", - "<|object_ref_start|>", - "<|object_ref_end|>", - "<|box_start|>", - "<|box_end|>", - "<|quad_start|>", - "<|quad_end|>", - "<|vision_start|>", - "<|vision_end|>", - "<|vision_pad|>", - "<|image_pad|>", - "<|video_pad|>", - "", - "", - "<|fim_prefix|>", - "<|fim_middle|>", - "<|fim_suffix|>", - "<|fim_pad|>", - "<|repo_name|>", - "<|file_sep|>", - "", - "", - "", - "", - }; - - if (merges_utf8_str.size() > 0) { - load_from_merges(merges_utf8_str); - } else { - load_from_merges(load_qwen2_merges()); - } - } - }; - - class MistralTokenizer : public BPETokenizer { - protected: - void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) { - nlohmann::json vocab; - - try { - vocab = nlohmann::json::parse(vocab_utf8_str); - } catch (const nlohmann::json::parse_error&) { - GGML_ABORT("invalid vocab json str"); - } - for (const auto& [key, value] : vocab.items()) { - std::u32string token = utf8_to_utf32(key); - int i = value; - encoder[token] = i; - decoder[i] = token; - } - encoder_len = static_cast(vocab.size()); - LOG_DEBUG("vocab size: %d", encoder_len); - - auto byte_unicode_pairs = bytes_to_unicode(); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); - for (auto& pair : byte_unicode_pairs) { - byte_decoder[pair.second] = pair.first; - } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - LOG_DEBUG("merges size %llu", merges.size()); - merges = std::vector(merges.begin(), merges.end()); - std::vector> merge_pairs; - // int print_num = 10; - for (const auto& merge : merges) { - size_t space_pos = merge.find(' '); - merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // if (print_num > 0) { - // print_num--; - // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), - // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - // } - } - - int rank = 0; - for (const auto& merge : merge_pairs) { - bpe_ranks[merge] = rank++; - } - bpe_len = rank; - }; - - public: - explicit MistralTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "") { - add_bos_token = true; - - UNK_TOKEN = ""; - BOS_TOKEN = ""; - EOS_TOKEN = ""; - PAD_TOKEN = ""; - - UNK_TOKEN_ID = 0; - BOS_TOKEN_ID = 1; - EOS_TOKEN_ID = 2; - PAD_TOKEN_ID = 11; - - special_tokens = { - "", - "", - "", - "[INST]", - "[/INST]", - "[AVAILABLE_TOOLS]", - "[/AVAILABLE_TOOLS]", - "[TOOL_RESULTS]", - "[/TOOL_RESULTS]", - "[TOOL_CALLS]", - "[IMG]", - "", - "[IMG_BREAK]", - "[IMG_END]", - "[PREFIX]", - "[MIDDLE]", - "[SUFFIX]", - "[SYSTEM_PROMPT]", - "[/SYSTEM_PROMPT]", - "[TOOL_CONTENT]", - }; - for (int i = 20; i < 1000; i++) { - special_tokens.push_back(""); - } - - if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) { - load_from_merges(merges_utf8_str, vocab_utf8_str); - } else { - load_from_merges(load_mistral_merges(), load_mistral_vocab_json()); - } - } - }; - enum class LLMArch { QWEN2_5_VL, QWEN3, @@ -1479,7 +1030,7 @@ namespace LLM { weights.insert(weights.end(), curr_tokens.size(), curr_weight); } - tokenizer->pad_tokens(tokens, weights, max_length, padding); + tokenizer->pad_tokens(tokens, &weights, nullptr, padding ? max_length : 0, padding ? max_length : 100000000, padding); // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", "; diff --git a/src/t5.hpp b/src/t5.hpp index 60d0c6208..bbd13e498 100644 --- a/src/t5.hpp +++ b/src/t5.hpp @@ -10,452 +10,9 @@ #include #include -#include "darts.h" #include "ggml_extend.hpp" -#include "json.hpp" #include "model.h" -#include "vocab/vocab.h" - -// Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h -// and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h. -// Original License: https://github.com/google/sentencepiece/blob/master/LICENSE -// -// Since tokenization is not the bottleneck in SD, performance was not a major consideration -// during the migration. -class MetaspacePreTokenizer { -private: - std::string replacement; - bool add_prefix_space; - -public: - MetaspacePreTokenizer(const std::string replacement = " ", bool add_prefix_space = true) - : replacement(replacement), add_prefix_space(add_prefix_space) {} - - std::string tokenize(const std::string& input) const { - std::string tokens; - std::stringstream ss(input); - - if (add_prefix_space) { - tokens += replacement; - } - - std::string token; - bool firstToken = true; - while (std::getline(ss, token, ' ')) { - if (!firstToken) - tokens += replacement + token; - else - tokens += token; - - firstToken = false; - } - - return tokens; - } -}; - -using EncodeResult = std::vector>; -class T5UniGramTokenizer { -public: - enum Status { - OK, - NO_PIECES_LOADED, - NO_ENTRY_FOUND, - BUILD_DOUBLE_ARRAY_FAILED, - PIECE_ALREADY_DEFINED, - INVLIAD_JSON - }; - -protected: - MetaspacePreTokenizer pre_tokenizer; - - // all pairs - std::vector> piece_score_pairs; - - float min_score_ = 0.0; - float max_score_ = 0.0; - std::unique_ptr trie_; - - // Maximum size of the return value of Trie, which corresponds - // to the maximum size of shared common prefix in the sentence pieces. - int trie_results_size_; - // unknown id. - int unk_id_ = 2; - std::string eos_token_ = ""; - int eos_id_ = 1; - int pad_id_ = 0; - // status. - Status status_ = OK; - - float kUnkPenalty = 10.0; - - std::string replacement; - bool add_prefix_space = true; - - void InitializePieces(const std::string& json_str) { - nlohmann::json data; - - try { - data = nlohmann::json::parse(json_str); - } catch (const nlohmann::json::parse_error&) { - status_ = INVLIAD_JSON; - return; - } - if (!data.contains("model")) { - status_ = INVLIAD_JSON; - return; - } - nlohmann::json model = data["model"]; - if (!model.contains("vocab")) { - status_ = INVLIAD_JSON; - return; - } - if (model.contains("unk_id")) { - unk_id_ = model["unk_id"]; - } - - replacement = data["pre_tokenizer"]["replacement"]; - add_prefix_space = data["pre_tokenizer"]["add_prefix_space"]; - - pre_tokenizer = MetaspacePreTokenizer(replacement, add_prefix_space); - - for (const auto& item : model["vocab"]) { - if (item.size() != 2 || !item[0].is_string() || !item[1].is_number_float()) { - status_ = INVLIAD_JSON; - return; - } - std::string piece = item[0]; - if (piece.empty()) { - piece = ""; - } - float score = item[1]; - piece_score_pairs.emplace_back(piece, score); - } - } - - // Builds a Trie index. - void BuildTrie(std::vector>* pieces) { - if (status_ != OK) - return; - - if (pieces->empty()) { - status_ = NO_PIECES_LOADED; - return; - } - - // sort by sentencepiece since DoubleArray::build() - // only accepts sorted strings. - sort(pieces->begin(), pieces->end()); - - // Makes key/value set for DoubleArrayTrie. - std::vector key(pieces->size()); - std::vector value(pieces->size()); - for (size_t i = 0; i < pieces->size(); ++i) { - // LOG_DEBUG("%s %d", (*pieces)[i].first.c_str(), (*pieces)[i].second); - key[i] = (*pieces)[i].first.data(); // sorted piece. - value[i] = (*pieces)[i].second; // vocab_id - } - - trie_ = std::unique_ptr(new Darts::DoubleArray()); - if (trie_->build(key.size(), const_cast(&key[0]), nullptr, - &value[0]) != 0) { - status_ = BUILD_DOUBLE_ARRAY_FAILED; - return; - } - - // Computes the maximum number of shared prefixes in the trie. - const int kMaxTrieResultsSize = 1024; - std::vector results( - kMaxTrieResultsSize); - trie_results_size_ = 0; - for (const auto& p : *pieces) { - const size_t num_nodes = trie_->commonPrefixSearch( - p.first.data(), results.data(), results.size(), p.first.size()); - trie_results_size_ = std::max(trie_results_size_, static_cast(num_nodes)); - } - - if (trie_results_size_ == 0) - status_ = NO_ENTRY_FOUND; - } - - // Non-virtual (inlined) implementation for faster execution. - inline float GetScoreInlined(int id) const { - return piece_score_pairs[id].second; - } - - inline bool IsUnusedInlined(int id) const { - return false; // TODO - } - - inline bool IsUserDefinedInlined(int id) const { - return false; // TODO - } - - inline size_t OneCharLen(const char* src) const { - return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4]; - } - - // The optimized Viterbi encode. - // Main differences from the original function: - // 1. Memorizes the best path at each postion so far, - // 2. No need to store the Lattice nodes, - // 3. Works in utf-8 directly, - // 4. Defines a new struct with fewer fields than Lattice, - // 5. Does not depend on `class Lattice` nor call `SetSentence()`, - // `PopulateNodes()`, or `Viterbi()`. It does everything in one function. - // For detailed explanations please see the comments inside the function body. - EncodeResult EncodeOptimized(const std::string& normalized) const { - // An optimized Viterbi algorithm for unigram language models. Benchmarking - // results show that it generates almost identical outputs and achieves 2.1x - // speedup on average for 102 languages compared to the original - // implementation. It's based on the following three ideas: - // - // 1. Because it uses the *unigram* model: - // best_score(x1, x2, ... xt) = best_score(x1, x2, ... x{t-1}) + score(xt) - // Deciding the best path (and score) can be decoupled into two isolated - // terms: (a) the best path ended before the last token `best_score(x1, x2, ...)` - // x{t-1})`, and (b) the last token and its `score(xt)`. The two terms are - // not related to each other at all. - // - // Therefore, we can compute once and store the *best_path ending at - // each character position*. In this way, when we know best_path_ends_at[M], - // we can reuse it to compute all the best_path_ends_at_[...] where the last - // token starts at the same character position M. - // - // This improves the time complexity from O(n*k*k) to O(n*k) because it - // eliminates the extra loop of recomputing the best path ending at the same - // position, where n is the input length and k is the maximum number of tokens - // that can be recognized starting at each position. - // - // 2. Again, because it uses the *unigram* model, we don't need to actually - // store the lattice nodes. We still recognize all the tokens and lattice - // nodes from the input, but along identifying them, we use and discard them - // on the fly. There is no need to actually store them for best path Viterbi - // decoding. The only thing we need to store is the best_path ending at - // each character position. - // - // This improvement reduces the things needed to store in memory from O(n*k) - // to O(n), where n is the input length and k is the maximum number of tokens - // that can be recognized starting at each position. - // - // It also avoids the need of dynamic-size lattice node pool, because the - // number of things to store is fixed as n. - // - // 3. SentencePiece is designed to work with unicode, taking utf-8 encoding - // inputs. In the original implementation, the lattice positions are based on - // unicode positions. A mapping from unicode position to the utf-8 position is - // maintained to recover the utf-8 string piece. - // - // We found that it is sufficient and beneficial to directly work with utf-8 - // positions: - // - // Firstly, it saves the conversion and mapping between unicode positions and - // utf-8 positions. - // - // Secondly, it reduces the number of fields we need to maintain in the - // node/path structure. Specifically, there are 8 fields defined in - // `Lattice::Node` used by the original encoder, but here in the optimized - // encoder we only need to define 3 fields in `BestPathNode`. - - if (status() != OK || normalized.empty()) { - return {}; - } - // Represents the last node of the best path. - struct BestPathNode { - int id = -1; // The vocab id. (maybe -1 for UNK) - float best_path_score = - 0; // The total score of the best path ending at this node. - int starts_at = - -1; // The starting position (in utf-8) of this node. The entire best - // path can be constructed by backtracking along this link. - }; - const int size = static_cast(normalized.size()); - const float unk_score = min_score() - kUnkPenalty; - // The ends are exclusive. - std::vector best_path_ends_at(size + 1); - // Generate lattice on-the-fly (not stored) and update best_path_ends_at. - int starts_at = 0; - while (starts_at < size) { - std::size_t node_pos = 0; - std::size_t key_pos = starts_at; - const auto best_path_score_till_here = - best_path_ends_at[starts_at].best_path_score; - bool has_single_node = false; - const int mblen = - std::min(static_cast(OneCharLen(normalized.data() + starts_at)), - size - starts_at); - while (key_pos < size) { - const int ret = - trie_->traverse(normalized.data(), node_pos, key_pos, key_pos + 1); - if (ret == -2) - break; - if (ret >= 0) { - if (IsUnusedInlined(ret)) - continue; - // Update the best path node. - auto& target_node = best_path_ends_at[key_pos]; - const auto length = (key_pos - starts_at); - // User defined symbol receives extra bonus to always be selected. - const auto score = IsUserDefinedInlined(ret) - ? (length * max_score_ - 0.1) - : GetScoreInlined(ret); - const auto candidate_best_path_score = - score + best_path_score_till_here; - if (target_node.starts_at == -1 || - candidate_best_path_score > target_node.best_path_score) { - target_node.best_path_score = static_cast(candidate_best_path_score); - target_node.starts_at = starts_at; - target_node.id = ret; - } - if (!has_single_node && length == mblen) { - has_single_node = true; - } - } - } - if (!has_single_node) { - auto& target_node = best_path_ends_at[starts_at + mblen]; - const auto candidate_best_path_score = - unk_score + best_path_score_till_here; - if (target_node.starts_at == -1 || - candidate_best_path_score > target_node.best_path_score) { - target_node.best_path_score = candidate_best_path_score; - target_node.starts_at = starts_at; - target_node.id = unk_id_; - } - } - // Move by one unicode character. - starts_at += mblen; - } - // Backtrack to identify the best path. - EncodeResult results; - int ends_at = size; - while (ends_at > 0) { - const auto& node = best_path_ends_at[ends_at]; - results.emplace_back( - normalized.substr(node.starts_at, ends_at - node.starts_at), node.id); - ends_at = node.starts_at; - } - std::reverse(results.begin(), results.end()); - return results; - } - -public: - explicit T5UniGramTokenizer(bool is_umt5 = false) { - if (is_umt5) { - InitializePieces(load_umt5_tokenizer_json()); - } else { - InitializePieces(load_t5_tokenizer_json()); - } - - min_score_ = FLT_MAX; - max_score_ = FLT_MIN; - - std::vector> pieces; - for (int i = 0; i < piece_score_pairs.size(); i++) { - const auto& sp = piece_score_pairs[i]; - - min_score_ = std::min(min_score_, sp.second); - max_score_ = std::max(max_score_, sp.second); - - pieces.emplace_back(sp.first, i); - } - - BuildTrie(&pieces); - } - ~T5UniGramTokenizer(){}; - - std::string Normalize(const std::string& input) const { - // Ref: https://github.com/huggingface/tokenizers/blob/1ff56c0c70b045f0cd82da1af9ac08cd4c7a6f9f/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py#L29 - // TODO: nmt-nfkc - std::string normalized = std::regex_replace(input, std::regex(" {2,}"), " "); - return normalized; - } - - std::vector Encode(const std::string& input, bool append_eos_if_not_present = true) const { - std::string normalized = Normalize(input); - normalized = pre_tokenizer.tokenize(normalized); - EncodeResult result = EncodeOptimized(normalized); - if (result.size() > 0 && append_eos_if_not_present) { - auto item = result[result.size() - 1]; - if (item.first != eos_token_) { - result.emplace_back(eos_token_, eos_id_); - } - } - std::vector tokens; - for (auto item : result) { - tokens.push_back(item.second); - } - return tokens; - } - - void pad_tokens(std::vector& tokens, - std::vector& weights, - std::vector* attention_mask, - size_t max_length = 0, - bool padding = false) { - if (max_length > 0 && padding) { - size_t orig_token_num = tokens.size() - 1; - size_t n = static_cast(std::ceil(orig_token_num * 1.0 / (max_length - 1))); - if (n == 0) { - n = 1; - } - size_t length = max_length * n; - LOG_DEBUG("token length: %llu", length); - std::vector new_tokens; - std::vector new_weights; - std::vector new_attention_mask; - int token_idx = 0; - for (int i = 0; i < length; i++) { - if (token_idx >= orig_token_num) { - break; - } - if (attention_mask != nullptr) { - new_attention_mask.push_back(0.0); - } - if (i % max_length == max_length - 1) { - new_tokens.push_back(eos_id_); - new_weights.push_back(1.0); - } else { - new_tokens.push_back(tokens[token_idx]); - new_weights.push_back(weights[token_idx]); - token_idx++; - } - } - - new_tokens.push_back(eos_id_); - new_weights.push_back(1.0); - if (attention_mask != nullptr) { - new_attention_mask.push_back(0.0); - } - - tokens = new_tokens; - weights = new_weights; - if (attention_mask != nullptr) { - *attention_mask = new_attention_mask; - } - - if (padding) { - int pad_token_id = pad_id_; - tokens.insert(tokens.end(), length - tokens.size(), pad_token_id); - weights.insert(weights.end(), length - weights.size(), 1.0); - if (attention_mask != nullptr) { - // maybe keep some padding tokens unmasked? - attention_mask->insert(attention_mask->end(), length - attention_mask->size(), -HUGE_VALF); - } - } - } - } - - // Returns the minimum score in sentence pieces. - // min_score() - 10 is used for the cost of unknown sentence. - float min_score() const { return min_score_; } - - // Returns the maximum score in sentence pieces. - // max_score() is used for the cost of user defined symbols. - float max_score() const { return max_score_; } - - Status status() const { return status_; } -}; +#include "tokenizers/t5_unigram_tokenizer.h" class T5LayerNorm : public UnaryBlock { protected: @@ -937,18 +494,17 @@ struct T5Embedder { for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = tokenizer.Encode(curr_text, false); + std::vector curr_tokens = tokenizer.encode(curr_text); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); weights.insert(weights.end(), curr_tokens.size(), curr_weight); } - int EOS_TOKEN_ID = 1; - tokens.push_back(EOS_TOKEN_ID); - weights.push_back(1.0); - std::vector attention_mask; - tokenizer.pad_tokens(tokens, weights, &attention_mask, max_length, padding); + tokenizer.pad_tokens(tokens, &weights, &attention_mask, padding ? max_length : 0, padding ? max_length : 100000000, padding); + for (auto& mask_value : attention_mask) { + mask_value = mask_value > 0.0f ? 0.0f : -HUGE_VALF; + } // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", "; diff --git a/src/tokenizers/bpe_tokenizer.cpp b/src/tokenizers/bpe_tokenizer.cpp new file mode 100644 index 000000000..1ad5d9428 --- /dev/null +++ b/src/tokenizers/bpe_tokenizer.cpp @@ -0,0 +1,189 @@ +#include "bpe_tokenizer.h" + +#include +#include + +#include "tokenize_util.h" +#include "util.h" + +std::vector> BPETokenizer::bytes_to_unicode() { + std::vector> byte_unicode_pairs; + std::set byte_set; + for (int b = static_cast('!'); b <= static_cast('~'); ++b) { + byte_set.insert(b); + byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(b))); + } + for (int b = 161; b <= 172; ++b) { + byte_set.insert(b); + byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(b))); + } + for (int b = 174; b <= 255; ++b) { + byte_set.insert(b); + byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(b))); + } + int n = 0; + for (int b = 0; b < 256; ++b) { + if (byte_set.find(b) == byte_set.end()) { + byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(n + 256))); + ++n; + } + } + return byte_unicode_pairs; +} + +std::vector BPETokenizer::token_split(const std::string& text) const { + return ::token_split(text); +} + +std::vector BPETokenizer::split_utf32(const std::string& text, char32_t delimiter) { + std::vector result; + size_t start = 0; + size_t pos = 0; + std::u32string utf32_text = utf8_to_utf32(text); + while ((pos = utf32_text.find(delimiter, start)) != std::u32string::npos) { + result.push_back(utf32_text.substr(start, pos - start)); + start = pos + 1; + } + return result; +} + +static std::set> get_pairs(const std::vector& subwords) { + std::set> pairs; + if (subwords.empty()) { + return pairs; + } + + std::u32string prev_subword = subwords[0]; + for (int i = 1; i < static_cast(subwords.size()); i++) { + std::u32string subword = subwords[i]; + std::pair pair(prev_subword, subword); + pairs.insert(pair); + prev_subword = subword; + } + return pairs; +} + +std::vector BPETokenizer::bpe(const std::u32string& token) const { + std::vector word; + + for (int i = 0; i < static_cast(token.size()) - 1; i++) { + word.emplace_back(1, token[i]); + } + word.push_back(token.substr(token.size() - 1) + utf8_to_utf32(end_of_word_suffix)); + + std::set> pairs = get_pairs(word); + + if (pairs.empty()) { + return {token + utf8_to_utf32(end_of_word_suffix)}; + } + + while (true) { + auto min_pair_iter = std::min_element(pairs.begin(), + pairs.end(), + [&](const std::pair& a, + const std::pair& b) { + if (bpe_ranks.find(a) == bpe_ranks.end()) { + return false; + } else if (bpe_ranks.find(b) == bpe_ranks.end()) { + return true; + } + return bpe_ranks.at(a) < bpe_ranks.at(b); + }); + + const std::pair& bigram = *min_pair_iter; + + if (bpe_ranks.find(bigram) == bpe_ranks.end()) { + break; + } + + std::u32string first = bigram.first; + std::u32string second = bigram.second; + std::vector new_word; + int32_t i = 0; + + while (i < static_cast(word.size())) { + auto it = std::find(word.begin() + i, word.end(), first); + if (it == word.end()) { + new_word.insert(new_word.end(), word.begin() + i, word.end()); + break; + } + new_word.insert(new_word.end(), word.begin() + i, it); + i = static_cast(std::distance(word.begin(), it)); + + if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { + new_word.push_back(first + second); + i += 2; + } else { + new_word.push_back(word[i]); + i += 1; + } + } + + word = new_word; + + if (word.size() == 1) { + break; + } + pairs = get_pairs(word); + } + + return word; +} + +std::vector BPETokenizer::encode(const std::string& text, on_new_token_cb_t on_new_token_cb) { + std::string normalized_text = normalize(text); + std::vector bpe_tokens; + std::vector token_strs; + + auto splited_texts = split_with_special_tokens(normalized_text, special_tokens); + + for (auto& splited_text : splited_texts) { + if (is_special_token(splited_text)) { + if (on_new_token_cb != nullptr) { + bool skip = on_new_token_cb(splited_text, bpe_tokens); + if (skip) { + token_strs.push_back(splited_text); + continue; + } + } + bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]); + token_strs.push_back(splited_text); + continue; + } + auto tokens = token_split(splited_text); + for (auto& token : tokens) { + if (on_new_token_cb != nullptr) { + bool skip = on_new_token_cb(token, bpe_tokens); + if (skip) { + token_strs.push_back(splited_text); + continue; + } + } + + std::string token_str = token; + std::u32string utf32_token; + for (int i = 0; i < static_cast(token_str.length()); i++) { + unsigned char b = token_str[i]; + utf32_token += byte_encoder[b]; + } + auto bpe_strs = bpe(utf32_token); + for (auto bpe_str : bpe_strs) { + bpe_tokens.push_back(encoder[bpe_str]); + token_strs.push_back(utf32_to_utf8(bpe_str)); + } + } + } + + std::stringstream ss; + ss << "["; + for (auto token : token_strs) { + ss << "\"" << token << "\", "; + } + ss << "]"; + LOG_DEBUG("split prompt \"%s\" to tokens %s", text.c_str(), ss.str().c_str()); + return bpe_tokens; +} + +std::string BPETokenizer::decode_token(int token_id) const { + return utf32_to_utf8(decoder.at(token_id)); +} diff --git a/src/tokenizers/bpe_tokenizer.h b/src/tokenizers/bpe_tokenizer.h new file mode 100644 index 000000000..4dca4e97a --- /dev/null +++ b/src/tokenizers/bpe_tokenizer.h @@ -0,0 +1,40 @@ +#ifndef __SD_TOKENIZERS_BPE_TOKENIZER_H__ +#define __SD_TOKENIZERS_BPE_TOKENIZER_H__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tokenizer.h" + +class BPETokenizer : public Tokenizer { +protected: + std::map byte_encoder; + std::map byte_decoder; + std::map encoder; + std::map decoder; + std::map, int> bpe_ranks; + int encoder_len = 0; + int bpe_len = 0; + +protected: + static std::vector> bytes_to_unicode(); + static std::vector split_utf32(const std::string& text, char32_t delimiter = U'\n'); + virtual std::vector token_split(const std::string& text) const; + std::vector bpe(const std::u32string& token) const; + std::string decode_token(int token_id) const override; + +public: + BPETokenizer() = default; + virtual ~BPETokenizer() = default; + + std::vector encode(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr) override; +}; + +#endif // __SD_TOKENIZERS_BPE_TOKENIZER_H__ diff --git a/src/tokenizers/clip_tokenizer.cpp b/src/tokenizers/clip_tokenizer.cpp new file mode 100644 index 000000000..57319306f --- /dev/null +++ b/src/tokenizers/clip_tokenizer.cpp @@ -0,0 +1,116 @@ +#include "clip_tokenizer.h" + +#include +#include +#include +#include +#include + +#include "ggml.h" +#include "tokenize_util.h" +#include "util.h" +#include "vocab/vocab.h" + +CLIPTokenizer::CLIPTokenizer(int pad_token_id, const std::string& merges_utf8_str) { + UNK_TOKEN = "<|endoftext|>"; + BOS_TOKEN = "<|startoftext|>"; + EOS_TOKEN = "<|endoftext|>"; + PAD_TOKEN = "<|endoftext|>"; + + UNK_TOKEN_ID = 49407; + BOS_TOKEN_ID = 49406; + EOS_TOKEN_ID = 49407; + PAD_TOKEN_ID = pad_token_id; + + end_of_word_suffix = ""; + add_bos_token = true; + add_eos_token = true; + + if (merges_utf8_str.size() > 0) { + load_from_merges(merges_utf8_str); + } else { + load_from_merges(load_clip_merges()); + } + add_special_token("<|startoftext|>"); + add_special_token("<|endoftext|>"); +} + +void CLIPTokenizer::load_from_merges(const std::string& merges_utf8_str) { + auto byte_unicode_pairs = bytes_to_unicode(); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + for (auto& pair : byte_unicode_pairs) { + byte_decoder[pair.second] = pair.first; + } + + std::vector merges = split_utf32(merges_utf8_str); + GGML_ASSERT(merges.size() == 48895); + merges = std::vector(merges.begin() + 1, merges.end()); + std::vector> merge_pairs; + for (const auto& merge : merges) { + size_t space_pos = merge.find(' '); + merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); + } + std::vector vocab; + for (const auto& pair : byte_unicode_pairs) { + vocab.push_back(pair.second); + } + for (const auto& pair : byte_unicode_pairs) { + vocab.push_back(pair.second + utf8_to_utf32("")); + } + for (const auto& merge : merge_pairs) { + vocab.push_back(merge.first + merge.second); + } + vocab.push_back(utf8_to_utf32("<|startoftext|>")); + vocab.push_back(utf8_to_utf32("<|endoftext|>")); + LOG_DEBUG("vocab size: %llu", vocab.size()); + int i = 0; + for (const auto& token : vocab) { + encoder[token] = i; + decoder[i] = token; + i++; + } + encoder_len = i; + + int rank = 0; + for (const auto& merge : merge_pairs) { + bpe_ranks[merge] = rank++; + } + bpe_len = rank; +} + +static std::string strip(const std::string& str) { + std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); + std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); + + if (start == std::string::npos) { + return ""; + } + + return str.substr(start, end - start + 1); +} + +static std::string whitespace_clean(const std::string& text) { + auto result = std::regex_replace(text, std::regex(R"(\s+)"), " "); + result = strip(result); + return result; +} + +std::string CLIPTokenizer::normalize(const std::string& text) const { + auto normalized_text = whitespace_clean(text); + std::transform(normalized_text.begin(), normalized_text.end(), normalized_text.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); + return normalized_text; +} + +std::vector CLIPTokenizer::token_split(const std::string& text) const { + std::regex clip_pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)", + std::regex::icase); + std::sregex_iterator iter(text.begin(), text.end(), clip_pat); + std::sregex_iterator end; + + std::vector result; + for (; iter != end; ++iter) { + result.emplace_back(iter->str()); + } + + return result; +} diff --git a/src/tokenizers/clip_tokenizer.h b/src/tokenizers/clip_tokenizer.h new file mode 100644 index 000000000..d4d71ae77 --- /dev/null +++ b/src/tokenizers/clip_tokenizer.h @@ -0,0 +1,20 @@ +#ifndef __SD_TOKENIZERS_CLIP_TOKENIZER_H__ +#define __SD_TOKENIZERS_CLIP_TOKENIZER_H__ + +#include +#include +#include + +#include "bpe_tokenizer.h" + +class CLIPTokenizer : public BPETokenizer { +protected: + void load_from_merges(const std::string& merges_utf8_str); + std::string normalize(const std::string& text) const override; + std::vector token_split(const std::string& text) const override; + +public: + explicit CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = ""); +}; + +#endif // __SD_TOKENIZERS_CLIP_TOKENIZER_H__ diff --git a/src/tokenizers/mistral_tokenizer.cpp b/src/tokenizers/mistral_tokenizer.cpp new file mode 100644 index 000000000..0a56542aa --- /dev/null +++ b/src/tokenizers/mistral_tokenizer.cpp @@ -0,0 +1,89 @@ +#include "mistral_tokenizer.h" + +#include "ggml.h" +#include "json.hpp" +#include "util.h" +#include "vocab/vocab.h" + +void MistralTokenizer::load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) { + nlohmann::json vocab; + + try { + vocab = nlohmann::json::parse(vocab_utf8_str); + } catch (const nlohmann::json::parse_error&) { + GGML_ABORT("invalid vocab json str"); + } + for (const auto& [key, value] : vocab.items()) { + std::u32string token = utf8_to_utf32(key); + int i = value; + encoder[token] = i; + decoder[i] = token; + } + encoder_len = static_cast(vocab.size()); + LOG_DEBUG("vocab size: %d", encoder_len); + + auto byte_unicode_pairs = bytes_to_unicode(); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + for (auto& pair : byte_unicode_pairs) { + byte_decoder[pair.second] = pair.first; + } + std::vector merges = split_utf32(merges_utf8_str); + LOG_DEBUG("merges size %llu", merges.size()); + std::vector> merge_pairs; + for (const auto& merge : merges) { + size_t space_pos = merge.find(' '); + merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); + } + + int rank = 0; + for (const auto& merge : merge_pairs) { + bpe_ranks[merge] = rank++; + } + bpe_len = rank; +} + +MistralTokenizer::MistralTokenizer(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) { + add_bos_token = true; + + UNK_TOKEN = ""; + BOS_TOKEN = ""; + EOS_TOKEN = ""; + PAD_TOKEN = ""; + + UNK_TOKEN_ID = 0; + BOS_TOKEN_ID = 1; + EOS_TOKEN_ID = 2; + PAD_TOKEN_ID = 11; + + special_tokens = { + "", + "", + "", + "[INST]", + "[/INST]", + "[AVAILABLE_TOOLS]", + "[/AVAILABLE_TOOLS]", + "[TOOL_RESULTS]", + "[/TOOL_RESULTS]", + "[TOOL_CALLS]", + "[IMG]", + "", + "[IMG_BREAK]", + "[IMG_END]", + "[PREFIX]", + "[MIDDLE]", + "[SUFFIX]", + "[SYSTEM_PROMPT]", + "[/SYSTEM_PROMPT]", + "[TOOL_CONTENT]", + }; + for (int i = 20; i < 1000; i++) { + special_tokens.push_back(""); + } + + if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) { + load_from_merges(merges_utf8_str, vocab_utf8_str); + } else { + load_from_merges(load_mistral_merges(), load_mistral_vocab_json()); + } +} diff --git a/src/tokenizers/mistral_tokenizer.h b/src/tokenizers/mistral_tokenizer.h new file mode 100644 index 000000000..6749f56f1 --- /dev/null +++ b/src/tokenizers/mistral_tokenizer.h @@ -0,0 +1,16 @@ +#ifndef __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__ +#define __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__ + +#include + +#include "bpe_tokenizer.h" + +class MistralTokenizer : public BPETokenizer { +protected: + void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str); + +public: + explicit MistralTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = ""); +}; + +#endif // __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__ diff --git a/src/tokenizers/qwen2_tokenizer.cpp b/src/tokenizers/qwen2_tokenizer.cpp new file mode 100644 index 000000000..5ddaf4ed1 --- /dev/null +++ b/src/tokenizers/qwen2_tokenizer.cpp @@ -0,0 +1,91 @@ +#include "qwen2_tokenizer.h" + +#include "util.h" +#include "vocab/vocab.h" + +void Qwen2Tokenizer::load_from_merges(const std::string& merges_utf8_str) { + auto byte_unicode_pairs = bytes_to_unicode(); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + for (auto& pair : byte_unicode_pairs) { + byte_decoder[pair.second] = pair.first; + } + + std::vector merges = split_utf32(merges_utf8_str); + LOG_DEBUG("merges size %llu", merges.size()); + std::vector> merge_pairs; + for (const auto& merge : merges) { + size_t space_pos = merge.find(' '); + merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); + } + + std::vector tokens; + for (const auto& pair : byte_unicode_pairs) { + tokens.push_back(pair.second); + } + for (const auto& merge : merge_pairs) { + tokens.push_back(merge.first + merge.second); + } + for (auto& special_token : special_tokens) { + tokens.push_back(utf8_to_utf32(special_token)); + } + + int i = 0; + for (const auto& token : tokens) { + encoder[token] = i; + decoder[i] = token; + i++; + } + encoder_len = i; + LOG_DEBUG("vocab size: %d", encoder_len); + + int rank = 0; + for (const auto& merge : merge_pairs) { + bpe_ranks[merge] = rank++; + } + bpe_len = rank; +} + +Qwen2Tokenizer::Qwen2Tokenizer(const std::string& merges_utf8_str) { + UNK_TOKEN = "<|endoftext|>"; + EOS_TOKEN = "<|endoftext|>"; + PAD_TOKEN = "<|endoftext|>"; + + UNK_TOKEN_ID = 151643; + EOS_TOKEN_ID = 151643; + PAD_TOKEN_ID = 151643; + + special_tokens = { + "<|endoftext|>", + "<|im_start|>", + "<|im_end|>", + "<|object_ref_start|>", + "<|object_ref_end|>", + "<|box_start|>", + "<|box_end|>", + "<|quad_start|>", + "<|quad_end|>", + "<|vision_start|>", + "<|vision_end|>", + "<|vision_pad|>", + "<|image_pad|>", + "<|video_pad|>", + "", + "", + "<|fim_prefix|>", + "<|fim_middle|>", + "<|fim_suffix|>", + "<|fim_pad|>", + "<|repo_name|>", + "<|file_sep|>", + "", + "", + "", + "", + }; + + if (merges_utf8_str.size() > 0) { + load_from_merges(merges_utf8_str); + } else { + load_from_merges(load_qwen2_merges()); + } +} diff --git a/src/tokenizers/qwen2_tokenizer.h b/src/tokenizers/qwen2_tokenizer.h new file mode 100644 index 000000000..04e92c2c3 --- /dev/null +++ b/src/tokenizers/qwen2_tokenizer.h @@ -0,0 +1,16 @@ +#ifndef __SD_TOKENIZERS_QWEN2_TOKENIZER_H__ +#define __SD_TOKENIZERS_QWEN2_TOKENIZER_H__ + +#include + +#include "bpe_tokenizer.h" + +class Qwen2Tokenizer : public BPETokenizer { +protected: + void load_from_merges(const std::string& merges_utf8_str); + +public: + explicit Qwen2Tokenizer(const std::string& merges_utf8_str = ""); +}; + +#endif // __SD_TOKENIZERS_QWEN2_TOKENIZER_H__ diff --git a/src/tokenizers/t5_unigram_tokenizer.cpp b/src/tokenizers/t5_unigram_tokenizer.cpp new file mode 100644 index 000000000..8ed4df539 --- /dev/null +++ b/src/tokenizers/t5_unigram_tokenizer.cpp @@ -0,0 +1,339 @@ +#include "t5_unigram_tokenizer.h" + +#include +#include +#include +#include +#include + +#include "json.hpp" +#include "tokenize_util.h" +#include "util.h" +#include "vocab/vocab.h" + +// Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h +// and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h. +// Original License: https://github.com/google/sentencepiece/blob/master/LICENSE +// +// Since tokenization is not the bottleneck in SD, performance was not a major consideration +// during the migration. + +MetaspacePreTokenizer::MetaspacePreTokenizer(const std::string replacement, bool add_prefix_space) + : replacement(replacement), add_prefix_space(add_prefix_space) {} + +std::string MetaspacePreTokenizer::tokenize(const std::string& input) const { + std::string tokens; + std::stringstream ss(input); + + if (add_prefix_space) { + tokens += replacement; + } + + std::string token; + bool first_token = true; + while (std::getline(ss, token, ' ')) { + if (!first_token) { + tokens += replacement + token; + } else { + tokens += token; + } + + first_token = false; + } + + return tokens; +} + +void T5UniGramTokenizer::InitializePieces(const std::string& json_str) { + nlohmann::json data; + + try { + data = nlohmann::json::parse(json_str); + } catch (const nlohmann::json::parse_error&) { + status_ = INVLIAD_JSON; + return; + } + if (!data.contains("model")) { + status_ = INVLIAD_JSON; + return; + } + nlohmann::json model = data["model"]; + if (!model.contains("vocab")) { + status_ = INVLIAD_JSON; + return; + } + if (model.contains("unk_id")) { + UNK_TOKEN_ID = model["unk_id"]; + } + + replacement = data["pre_tokenizer"]["replacement"]; + add_prefix_space = data["pre_tokenizer"]["add_prefix_space"]; + + pre_tokenizer = MetaspacePreTokenizer(replacement, add_prefix_space); + + for (const auto& item : model["vocab"]) { + if (item.size() != 2 || !item[0].is_string() || !item[1].is_number_float()) { + status_ = INVLIAD_JSON; + return; + } + std::string piece = item[0]; + if (piece.empty()) { + piece = ""; + } + float score = item[1]; + piece_score_pairs.emplace_back(piece, score); + } +} + +void T5UniGramTokenizer::BuildTrie(std::vector>* pieces) { + if (status_ != OK) { + return; + } + + if (pieces->empty()) { + status_ = NO_PIECES_LOADED; + return; + } + + std::sort(pieces->begin(), pieces->end()); + + std::vector key(pieces->size()); + std::vector value(pieces->size()); + for (size_t i = 0; i < pieces->size(); ++i) { + key[i] = (*pieces)[i].first.data(); + value[i] = (*pieces)[i].second; + } + + trie_ = std::unique_ptr(new Darts::DoubleArray()); + if (trie_->build(key.size(), const_cast(&key[0]), nullptr, &value[0]) != 0) { + status_ = BUILD_DOUBLE_ARRAY_FAILED; + return; + } + + const int kMaxTrieResultsSize = 1024; + std::vector results(kMaxTrieResultsSize); + trie_results_size_ = 0; + for (const auto& p : *pieces) { + const size_t num_nodes = trie_->commonPrefixSearch( + p.first.data(), results.data(), results.size(), p.first.size()); + trie_results_size_ = std::max(trie_results_size_, static_cast(num_nodes)); + } + + if (trie_results_size_ == 0) { + status_ = NO_ENTRY_FOUND; + } +} + +float T5UniGramTokenizer::GetScoreInlined(int id) const { + return piece_score_pairs[id].second; +} + +bool T5UniGramTokenizer::IsUnusedInlined(int id) const { + (void)id; + return false; +} + +bool T5UniGramTokenizer::IsUserDefinedInlined(int id) const { + (void)id; + return false; +} + +size_t T5UniGramTokenizer::OneCharLen(const char* src) const { + return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4]; +} + +EncodeResult T5UniGramTokenizer::EncodeOptimized(const std::string& normalized) const { + if (status() != OK || normalized.empty()) { + return {}; + } + + struct BestPathNode { + int id = -1; + float best_path_score = 0; + int starts_at = -1; + }; + + const int size = static_cast(normalized.size()); + const float unk_score = min_score() - kUnkPenalty; + std::vector best_path_ends_at(size + 1); + + int starts_at = 0; + while (starts_at < size) { + std::size_t node_pos = 0; + std::size_t key_pos = starts_at; + const auto best_path_score_till_here = best_path_ends_at[starts_at].best_path_score; + bool has_single_node = false; + const int mblen = std::min(static_cast(OneCharLen(normalized.data() + starts_at)), size - starts_at); + while (key_pos < static_cast(size)) { + const int ret = trie_->traverse(normalized.data(), node_pos, key_pos, key_pos + 1); + if (ret == -2) { + break; + } + if (ret >= 0) { + if (IsUnusedInlined(ret)) { + continue; + } + auto& target_node = best_path_ends_at[key_pos]; + const auto length = static_cast(key_pos - starts_at); + const auto score = IsUserDefinedInlined(ret) ? (length * max_score_ - 0.1f) : GetScoreInlined(ret); + const auto candidate_best_path_score = score + best_path_score_till_here; + if (target_node.starts_at == -1 || candidate_best_path_score > target_node.best_path_score) { + target_node.best_path_score = static_cast(candidate_best_path_score); + target_node.starts_at = starts_at; + target_node.id = ret; + } + if (!has_single_node && length == mblen) { + has_single_node = true; + } + } + } + if (!has_single_node) { + auto& target_node = best_path_ends_at[starts_at + mblen]; + const auto candidate_best_path_score = unk_score + best_path_score_till_here; + if (target_node.starts_at == -1 || candidate_best_path_score > target_node.best_path_score) { + target_node.best_path_score = candidate_best_path_score; + target_node.starts_at = starts_at; + target_node.id = UNK_TOKEN_ID; + } + } + starts_at += mblen; + } + + EncodeResult results; + int ends_at = size; + while (ends_at > 0) { + const auto& node = best_path_ends_at[ends_at]; + results.emplace_back(normalized.substr(node.starts_at, ends_at - node.starts_at), node.id); + ends_at = node.starts_at; + } + std::reverse(results.begin(), results.end()); + return results; +} + +T5UniGramTokenizer::T5UniGramTokenizer(bool is_umt5) { + add_bos_token = false; + add_eos_token = true; + + if (is_umt5) { + PAD_TOKEN_ID = 0; + EOS_TOKEN_ID = 1; + BOS_TOKEN_ID = 2; + UNK_TOKEN_ID = 3; + + PAD_TOKEN = ""; + EOS_TOKEN = ""; + BOS_TOKEN = ""; + UNK_TOKEN = ""; + } else { + PAD_TOKEN_ID = 0; + EOS_TOKEN_ID = 1; + UNK_TOKEN_ID = 2; + + PAD_TOKEN = ""; + EOS_TOKEN = ""; + UNK_TOKEN = ""; + } + + special_tokens = { + "", + "", + "", + }; + + if (is_umt5) { + special_tokens.push_back(""); + } + + if (is_umt5) { + InitializePieces(load_umt5_tokenizer_json()); + } else { + InitializePieces(load_t5_tokenizer_json()); + } + + min_score_ = FLT_MAX; + max_score_ = FLT_MIN; + + std::vector> pieces; + for (int i = 0; i < static_cast(piece_score_pairs.size()); i++) { + const auto& sp = piece_score_pairs[i]; + + min_score_ = std::min(min_score_, sp.second); + max_score_ = std::max(max_score_, sp.second); + + pieces.emplace_back(sp.first, i); + } + + BuildTrie(&pieces); +} + +T5UniGramTokenizer::~T5UniGramTokenizer() = default; + +std::string T5UniGramTokenizer::decode_token(int token_id) const { + if (token_id < 0 || token_id >= static_cast(piece_score_pairs.size())) { + return ""; + } + + const std::string& piece = piece_score_pairs[token_id].first; + if (piece == "") { + return ""; + } + return piece; +} + +std::string T5UniGramTokenizer::normalize(const std::string& input) const { + // Ref: https://github.com/huggingface/tokenizers/blob/1ff56c0c70b045f0cd82da1af9ac08cd4c7a6f9f/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py#L29 + // TODO: nmt-nfkc + std::string normalized = std::regex_replace(input, std::regex(" {2,}"), " "); + return normalized; +} + +std::vector T5UniGramTokenizer::encode(const std::string& input, on_new_token_cb_t on_new_token_cb) { + std::vector tokens; + std::vector token_strs; + std::string normalized = normalize(input); + auto splited_texts = split_with_special_tokens(normalized, special_tokens); + if (splited_texts.empty()) { + splited_texts.push_back(normalized); // for empty string + } + + for (auto& splited_text : splited_texts) { + if (is_special_token(splited_text)) { + if (on_new_token_cb != nullptr) { + bool skip = on_new_token_cb(splited_text, tokens); + if (skip) { + token_strs.push_back(splited_text); + continue; + } + } + + if (splited_text == UNK_TOKEN) { + tokens.push_back(UNK_TOKEN_ID); + token_strs.push_back(UNK_TOKEN); + } else if (splited_text == EOS_TOKEN) { + tokens.push_back(EOS_TOKEN_ID); + token_strs.push_back(EOS_TOKEN); + } else if (splited_text == PAD_TOKEN) { + tokens.push_back(PAD_TOKEN_ID); + token_strs.push_back(PAD_TOKEN); + } + continue; + } + + std::string pretokenized = pre_tokenizer.tokenize(splited_text); + EncodeResult result = EncodeOptimized(pretokenized); + for (const auto& item : result) { + tokens.push_back(item.second); + token_strs.push_back(item.first); + } + } + + std::stringstream ss; + ss << "["; + for (const auto& token_str : token_strs) { + ss << "\"" << token_str << "\", "; + } + ss << "]"; + LOG_DEBUG("split prompt \"%s\" to tokens %s", input.c_str(), ss.str().c_str()); + + return tokens; +} diff --git a/src/tokenizers/t5_unigram_tokenizer.h b/src/tokenizers/t5_unigram_tokenizer.h new file mode 100644 index 000000000..9c9f13f8b --- /dev/null +++ b/src/tokenizers/t5_unigram_tokenizer.h @@ -0,0 +1,70 @@ +#ifndef __SD_TOKENIZERS_T5_UNIGRAM_TOKENIZER_H__ +#define __SD_TOKENIZERS_T5_UNIGRAM_TOKENIZER_H__ + +#include +#include +#include +#include +#include + +#include "darts.h" +#include "tokenizer.h" + +class MetaspacePreTokenizer { +private: + std::string replacement; + bool add_prefix_space; + +public: + MetaspacePreTokenizer(const std::string replacement = " ", bool add_prefix_space = true); + + std::string tokenize(const std::string& input) const; +}; + +using EncodeResult = std::vector>; + +class T5UniGramTokenizer : public Tokenizer { +public: + enum Status { + OK, + NO_PIECES_LOADED, + NO_ENTRY_FOUND, + BUILD_DOUBLE_ARRAY_FAILED, + PIECE_ALREADY_DEFINED, + INVLIAD_JSON + }; + +protected: + MetaspacePreTokenizer pre_tokenizer; + std::vector> piece_score_pairs; + float min_score_ = 0.0f; + float max_score_ = 0.0f; + std::unique_ptr trie_; + int trie_results_size_ = 0; + Status status_ = OK; + float kUnkPenalty = 10.0f; + std::string replacement; + bool add_prefix_space = true; + + void InitializePieces(const std::string& json_str); + void BuildTrie(std::vector>* pieces); + float GetScoreInlined(int id) const; + bool IsUnusedInlined(int id) const; + bool IsUserDefinedInlined(int id) const; + size_t OneCharLen(const char* src) const; + EncodeResult EncodeOptimized(const std::string& normalized) const; + + float min_score() const { return min_score_; } + float max_score() const { return max_score_; } + Status status() const { return status_; } + std::string decode_token(int token_id) const override; + std::string normalize(const std::string& input) const override; + +public: + explicit T5UniGramTokenizer(bool is_umt5 = false); + ~T5UniGramTokenizer(); + + std::vector encode(const std::string& input, on_new_token_cb_t on_new_token_cb = nullptr) override; +}; + +#endif // __SD_TOKENIZERS_T5_UNIGRAM_TOKENIZER_H__ diff --git a/src/tokenize_util.cpp b/src/tokenizers/tokenize_util.cpp similarity index 100% rename from src/tokenize_util.cpp rename to src/tokenizers/tokenize_util.cpp diff --git a/src/tokenize_util.h b/src/tokenizers/tokenize_util.h similarity index 61% rename from src/tokenize_util.h rename to src/tokenizers/tokenize_util.h index e744d7503..efb0a1cc6 100644 --- a/src/tokenize_util.h +++ b/src/tokenizers/tokenize_util.h @@ -1,5 +1,5 @@ -#ifndef __TOKENIZE_UTIL__ -#define __TOKENIZE_UTIL__ +#ifndef __SD_TOKENIZERS_BPE_TOKENIZE_UTIL_H__ +#define __SD_TOKENIZERS_BPE_TOKENIZE_UTIL_H__ #include #include @@ -7,4 +7,4 @@ std::vector token_split(const std::string& text); std::vector split_with_special_tokens(const std::string& text, const std::vector& special_tokens); -#endif // __TOKENIZE_UTIL__ \ No newline at end of file +#endif // __SD_TOKENIZERS_BPE_TOKENIZE_UTIL_H__ \ No newline at end of file diff --git a/src/tokenizers/tokenizer.cpp b/src/tokenizers/tokenizer.cpp new file mode 100644 index 000000000..ebbc1a506 --- /dev/null +++ b/src/tokenizers/tokenizer.cpp @@ -0,0 +1,211 @@ +#include "tokenizer.h" + +#include +#include +#include + +#include "util.h" + +void Tokenizer::add_special_token(const std::string& token) { + special_tokens.push_back(token); +} + +bool Tokenizer::is_special_token(const std::string& token) const { + for (const auto& special_token : special_tokens) { + if (special_token == token) { + return true; + } + } + return false; +} + +std::string Tokenizer::normalize(const std::string& text) const { + return text; +} + +std::vector Tokenizer::tokenize(const std::string& text, + on_new_token_cb_t on_new_token_cb, + bool padding, + size_t min_length, + size_t max_length, + bool allow_overflow_expand) { + std::vector tokens = encode(text, on_new_token_cb); + if (padding) { + pad_tokens(tokens, nullptr, nullptr, min_length, max_length, allow_overflow_expand); + } + return tokens; +} + +void Tokenizer::pad_tokens(std::vector& tokens, + std::vector* weights, + std::vector* mask, + size_t min_length, + size_t max_length, + bool allow_overflow_expand) { + const bool use_weights = weights != nullptr; + const bool use_mask = mask != nullptr; + + if (use_weights && tokens.size() != weights->size()) { + LOG_ERROR("tokens size != weights size"); + return; + } + + const size_t bos_count = add_bos_token ? 1 : 0; + const size_t eos_count = add_eos_token ? 1 : 0; + const size_t special_token_count = bos_count + eos_count; + + auto build_sequence = [&](size_t begin, + size_t count, + size_t target_length, + std::vector& out_tokens, + std::vector& out_weights, + std::vector& out_mask) { + const size_t base_length = count + special_token_count; + const size_t final_length = std::max(target_length, base_length); + + out_tokens.clear(); + out_weights.clear(); + out_mask.clear(); + + out_tokens.reserve(final_length); + if (use_weights) { + out_weights.reserve(final_length); + } + if (use_mask) { + out_mask.reserve(final_length); + } + + if (add_bos_token) { + out_tokens.push_back(BOS_TOKEN_ID); + if (use_weights) { + out_weights.push_back(1.0f); + } + if (use_mask) { + out_mask.push_back(1.0f); + } + } + + for (size_t i = 0; i < count; ++i) { + out_tokens.push_back(tokens[begin + i]); + if (use_weights) { + out_weights.push_back((*weights)[begin + i]); + } + if (use_mask) { + out_mask.push_back(1.0f); + } + } + + if (add_eos_token) { + out_tokens.push_back(EOS_TOKEN_ID); + if (use_weights) { + out_weights.push_back(1.0f); + } + if (use_mask) { + out_mask.push_back(1.0f); + } + } + + if (final_length > out_tokens.size()) { + const size_t pad_count = final_length - out_tokens.size(); + out_tokens.insert(out_tokens.end(), pad_count, PAD_TOKEN_ID); + + if (use_weights) { + out_weights.insert(out_weights.end(), pad_count, 1.0f); + } + if (use_mask) { + out_mask.insert(out_mask.end(), pad_count, 0.0f); + } + } + }; + + const size_t single_length = std::max(min_length, tokens.size() + special_token_count); + const bool exceeds_max_length = max_length > 0 && single_length > max_length; + + std::vector new_tokens; + std::vector new_weights; + std::vector new_mask; + + if (!exceeds_max_length) { + build_sequence(0, tokens.size(), min_length, new_tokens, new_weights, new_mask); + } else if (!allow_overflow_expand) { + build_sequence(0, tokens.size(), 0, new_tokens, new_weights, new_mask); + + new_tokens.resize(max_length); + if (use_weights) { + new_weights.resize(max_length); + } + if (use_mask) { + new_mask.resize(max_length); + } + + if (add_eos_token && !new_tokens.empty()) { + new_tokens.back() = EOS_TOKEN_ID; + if (use_weights) { + new_weights.back() = 1.0f; + } + if (use_mask) { + new_mask.back() = 1.0f; + } + } + } else if (min_length > special_token_count) { + const size_t tokens_per_chunk = min_length - special_token_count; + size_t offset = 0; + + while (offset < tokens.size()) { + const size_t remaining = tokens.size() - offset; + const size_t take = std::min(tokens_per_chunk, remaining); + + std::vector chunk_tokens; + std::vector chunk_weights; + std::vector chunk_mask; + + build_sequence(offset, take, min_length, chunk_tokens, chunk_weights, chunk_mask); + + new_tokens.insert(new_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); + if (use_weights) { + new_weights.insert(new_weights.end(), chunk_weights.begin(), chunk_weights.end()); + } + if (use_mask) { + new_mask.insert(new_mask.end(), chunk_mask.begin(), chunk_mask.end()); + } + + offset += take; + } + } else { + build_sequence(0, tokens.size(), min_length, new_tokens, new_weights, new_mask); + } + + tokens = std::move(new_tokens); + if (use_weights) { + *weights = std::move(new_weights); + } + if (use_mask) { + *mask = std::move(new_mask); + } +} + +static std::string clean_up_tokenization(std::string& text) { + std::regex pattern(R"( ,)"); + return std::regex_replace(text, pattern, ","); +} + +std::string Tokenizer::decode(const std::vector& tokens) const { + std::string text; + + for (int token_id : tokens) { + if (token_id == BOS_TOKEN_ID || token_id == EOS_TOKEN_ID || token_id == PAD_TOKEN_ID) { + continue; + } + + std::string piece = decode_token(token_id); + if (!end_of_word_suffix.empty() && ends_with(piece, end_of_word_suffix)) { + piece.erase(piece.size() - end_of_word_suffix.size()); + text += piece + " "; + } else { + text += piece; + } + } + + text = clean_up_tokenization(text); + return trim(text); +} diff --git a/src/tokenizers/tokenizer.h b/src/tokenizers/tokenizer.h new file mode 100644 index 000000000..21ba067df --- /dev/null +++ b/src/tokenizers/tokenizer.h @@ -0,0 +1,52 @@ +#ifndef __SD_TOKENIZERS_TOKENIZER_H__ +#define __SD_TOKENIZERS_TOKENIZER_H__ + +#include +#include +#include +#include +#include + +using on_new_token_cb_t = std::function&)>; + +class Tokenizer { +protected: + std::vector special_tokens; + bool add_bos_token = false; + bool add_eos_token = false; + std::string end_of_word_suffix; + + virtual std::string decode_token(int token_id) const = 0; + virtual std::string normalize(const std::string& text) const; + +public: + std::string UNK_TOKEN; + std::string BOS_TOKEN; + std::string EOS_TOKEN; + std::string PAD_TOKEN; + int UNK_TOKEN_ID = 0; + int BOS_TOKEN_ID = 0; + int EOS_TOKEN_ID = 0; + int PAD_TOKEN_ID = 0; + + virtual ~Tokenizer() = default; + + void add_special_token(const std::string& token); + bool is_special_token(const std::string& token) const; + virtual std::vector encode(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr) = 0; + std::vector tokenize(const std::string& text, + on_new_token_cb_t on_new_token_cb = nullptr, + bool padding = false, + size_t min_length = 0, + size_t max_length = 100000000, + bool allow_overflow_expand = false); + void pad_tokens(std::vector& tokens, + std::vector* weights, + std::vector* mask, + size_t min_length = 0, + size_t max_length = 100000000, + bool allow_overflow_expand = false); + std::string decode(const std::vector& tokens) const; +}; + +#endif // __SD_TOKENIZERS_TOKENIZER_H__ diff --git a/src/vocab/clip_t5.hpp b/src/tokenizers/vocab/clip_t5.hpp similarity index 100% rename from src/vocab/clip_t5.hpp rename to src/tokenizers/vocab/clip_t5.hpp diff --git a/src/vocab/mistral.hpp b/src/tokenizers/vocab/mistral.hpp similarity index 100% rename from src/vocab/mistral.hpp rename to src/tokenizers/vocab/mistral.hpp diff --git a/src/vocab/qwen.hpp b/src/tokenizers/vocab/qwen.hpp similarity index 100% rename from src/vocab/qwen.hpp rename to src/tokenizers/vocab/qwen.hpp diff --git a/src/vocab/umt5.hpp b/src/tokenizers/vocab/umt5.hpp similarity index 100% rename from src/vocab/umt5.hpp rename to src/tokenizers/vocab/umt5.hpp diff --git a/src/vocab/vocab.cpp b/src/tokenizers/vocab/vocab.cpp similarity index 100% rename from src/vocab/vocab.cpp rename to src/tokenizers/vocab/vocab.cpp diff --git a/src/vocab/vocab.h b/src/tokenizers/vocab/vocab.h similarity index 66% rename from src/vocab/vocab.h rename to src/tokenizers/vocab/vocab.h index cfa033a49..de7a76406 100644 --- a/src/vocab/vocab.h +++ b/src/tokenizers/vocab/vocab.h @@ -1,5 +1,5 @@ -#ifndef __VOCAB_H__ -#define __VOCAB_H__ +#ifndef __SD_TOKENIZERS_VOCAB_VOCAB_H__ +#define __SD_TOKENIZERS_VOCAB_VOCAB_H__ #include @@ -10,4 +10,4 @@ std::string load_mistral_vocab_json(); std::string load_t5_tokenizer_json(); std::string load_umt5_tokenizer_json(); -#endif // __VOCAB_H__ \ No newline at end of file +#endif // __SD_TOKENIZERS_VOCAB_VOCAB_H__ \ No newline at end of file