diff --git a/lib/trie.hpp b/lib/trie.hpp index 2310811..46670bd 100644 --- a/lib/trie.hpp +++ b/lib/trie.hpp @@ -1,45 +1,37 @@ #pragma once -#include - -#include -#include +#include +#include #include #include /// -/// \brief A Trie datastructure used to remove prefixes in a set of words -/// -/// The datastructure only works for words over integral unsigned types. In principle the symbols -/// can be unbounded, however having very large symbols degrades the performance a lot. Some random -/// testing shows that for symbols <= 50 the performance is similar to std::set (which is solving a -/// different problem). +/// \brief A Trie datastructure used to remove prefixes in a set of words. +/// Insert-only. Iteration over the structure only uses longest matches. /// /// Tests : 1M words, avg words length 4 (geometric dist.), alphabet 50 symbols -/// trie reduction 58% in 1.15s -/// set reduction 49% in 0.92s +/// trie reduction 58% in 0.4s +/// set reduction 49% in 1.1s /// /// I did not implement any iterators, as those are quite hard to get right. /// There are, however, "internal iterators" exposed as a for_each() member /// function (if only we had coroutines already...) /// +/// TODO: implement `bool member(...)` +/// template struct trie { - static_assert(std::is_integral::value && std::is_unsigned::value, ""); - /// \brief Inserts a word (given by iterators \p begin and \p end) /// \returns true if the element was inserted, false if already there template bool insert(Iterator && begin, Iterator && end) { - if (begin == end) return false; + if (!node) { + node.reset(new trie_node()); - size_t i = *begin++; - if (i >= branches.size()) branches.resize(i + 1); + if (begin == end) { + return true; + } + } - auto & b = branches[i]; - if (b) return b->insert(begin, end); - - b = trie(); - b->insert(begin, end); - return true; + return node->insert(begin, end); } /// \brief Inserts a word given as range \p r @@ -48,34 +40,67 @@ template struct trie { /// \brief Applies \p function to all word (not to the prefixes) template void for_each(Fun && function) const { - std::vector word; - return for_each_impl(std::forward(function), word); + if (node) { + node->for_each(std::forward(function)); + } else { + // empty set, so we don't call the function + } } /// \brief Empties the complete set - void clear() { branches.clear(); } + void clear() { node.reset(nullptr); } private: - template void for_each_impl(Fun && function, std::vector & word) const { - size_t count = 0; - for (T i = 0; i < branches.size(); ++i) { - auto const & b = branches[i]; - if (b) { - ++count; - word.push_back(i); - b->for_each_impl(function, word); + struct trie_node; + std::unique_ptr node = nullptr; + + // A node always contains the empty word + struct trie_node { + template bool insert(Iterator && begin, Iterator && end) { + if (begin == end) return false; + + T i = *begin++; + auto it = find(i); + + if (it != data.end() && it->first == i) { + return it->second.insert(begin, end); + } + + // else, does not yet exist + it = data.emplace(it, i, trie_node()); + it->second.insert(begin, end); + return true; + } + + template void for_each(Fun && function) const { + std::vector word; + return for_each_impl(std::forward(function), word); + } + + private: + template void for_each_impl(Fun && function, std::vector & word) const { + if (data.empty()) { + // we don't want function to modify word + const auto & cword = word; + function(cword); + } + + for (auto const & kv : data) { + // for each letter, we extend the word, recurse and remove extension. + word.push_back(kv.first); + kv.second.for_each_impl(function, word); word.resize(word.size() - 1); } } - if (count == 0) { - const auto & cword = word; - function(cword); // we don't want function to modify word - return; + typename std::vector>::iterator find(T const & key) { + return std::lower_bound( + data.begin(), data.end(), key, + [](std::pair const & kv, T const & k) { return kv.first < k; }); } - } - std::vector> branches; + std::vector> data; + }; }; /// \brief Flattens a trie \p t @@ -90,7 +115,7 @@ template std::vector> flatten(trie const & t) { template std::pair total_size(trie const & t) { size_t count = 0; size_t total_count = 0; - t.for_each([&count, &total_count](std::vector const & w) { + t.for_each([&count, &total_count](std::vector const & w) { ++count; total_count += w.size(); });