From d03823c895ec509247b2125fcc4ac43628b16fea Mon Sep 17 00:00:00 2001 From: Yedidya Feldblum Date: Wed, 8 Nov 2017 09:22:17 -0800 Subject: [PATCH] Heterogeneous lookups for sorted_vector types Summary: [Folly] Heterogeneous lookups for `sorted_vector` types. When the `Compare` type has member type or alias `is_transparent`, enable template overloads of `count`, `find`, `lower_bound`, `upper_bound`, and `equal_range` on both `sorted_vector_set` and `sorted_vector_map`. This is the protocol found in the equivalent `std::set` and `std::map` member functions. > This overload only participates in overload resolution if the qualified-id `Compare::is_transparent` is valid and denotes a type. They allow calling this function without constructing an instance of `Key`. > > http://en.cppreference.com/w/cpp/container/set/count (same wording in all 10 cases) Reviewed By: nbronson Differential Revision: D6256989 fbshipit-source-id: a40a181453a019564e8f7674e1e07e241d5ab068 --- folly/sorted_vector_types.h | 243 +++++++++++++++++++++++------- folly/test/sorted_vector_test.cpp | 149 ++++++++++++++++++ 2 files changed, 335 insertions(+), 57 deletions(-) diff --git a/folly/sorted_vector_types.h b/folly/sorted_vector_types.h index d011a83d..2cff300f 100644 --- a/folly/sorted_vector_types.h +++ b/folly/sorted_vector_types.h @@ -68,6 +68,8 @@ #include #include + +#include #include namespace folly { @@ -76,6 +78,18 @@ namespace folly { namespace detail { +template +struct sorted_vector_enable_if_is_transparent {}; + +template +struct sorted_vector_enable_if_is_transparent< + void_t, + Compare, + Key, + T> { + using type = T; +}; + // This wrapper goes around a GrowthPolicy and provides iterator // preservation semantics, but only if the growth policy is not the // default (i.e. nothing). @@ -212,6 +226,10 @@ class sorted_vector_set detail::growth_policy_wrapper& get_growth_policy() { return *this; } + template + using if_is_transparent = + _t>; + public: typedef T value_type; typedef T key_type; @@ -343,25 +361,32 @@ class sorted_vector_set } iterator find(const key_type& key) { - iterator it = lower_bound(key); - if (it == end() || !key_comp()(key, *it)) { - return it; - } - return end(); + return find(*this, key); } const_iterator find(const key_type& key) const { - const_iterator it = lower_bound(key); - if (it == end() || !key_comp()(key, *it)) { - return it; - } - return end(); + return find(*this, key); + } + + template + if_is_transparent find(const K& key) { + return find(*this, key); + } + + template + if_is_transparent find(const K& key) const { + return find(*this, key); } size_type count(const key_type& key) const { return find(key) == end() ? 0 : 1; } + template + if_is_transparent count(const K& key) const { + return find(key) == end() ? 0 : 1; + } + iterator lower_bound(const key_type& key) { return std::lower_bound(begin(), end(), key, key_comp()); } @@ -370,6 +395,16 @@ class sorted_vector_set return std::lower_bound(begin(), end(), key, key_comp()); } + template + if_is_transparent lower_bound(const K& key) { + return std::lower_bound(begin(), end(), key, key_comp()); + } + + template + if_is_transparent lower_bound(const K& key) const { + return std::lower_bound(begin(), end(), key, key_comp()); + } + iterator upper_bound(const key_type& key) { return std::upper_bound(begin(), end(), key, key_comp()); } @@ -378,12 +413,34 @@ class sorted_vector_set return std::upper_bound(begin(), end(), key, key_comp()); } - std::pair equal_range(const key_type& key) { + template + if_is_transparent upper_bound(const K& key) { + return std::upper_bound(begin(), end(), key, key_comp()); + } + + template + if_is_transparent upper_bound(const K& key) const { + return std::upper_bound(begin(), end(), key, key_comp()); + } + + std::pair equal_range(const key_type& key) { + return std::equal_range(begin(), end(), key, key_comp()); + } + + std::pair equal_range( + const key_type& key) const { + return std::equal_range(begin(), end(), key, key_comp()); + } + + template + if_is_transparent> equal_range( + const K& key) { return std::equal_range(begin(), end(), key, key_comp()); } - std::pair - equal_range(const key_type& key) const { + template + if_is_transparent> equal_range( + const K& key) const { return std::equal_range(begin(), end(), key, key_comp()); } @@ -423,6 +480,20 @@ class sorted_vector_set {} ContainerT cont_; } m_; + + template + using self_iterator_t = _t< + std::conditional::value, const_iterator, iterator>>; + + template + static self_iterator_t find(Self& self, K const& key) { + auto end = self.end(); + auto it = self.lower_bound(key); + if (it == end || !self.key_comp()(key, *it)) { + return it; + } + return end; + } }; // Swap function that can be found using ADL. @@ -465,6 +536,10 @@ class sorted_vector_map detail::growth_policy_wrapper& get_growth_policy() { return *this; } + template + using if_is_transparent = + _t>; + public: typedef Key key_type; typedef Value mapped_type; @@ -599,19 +674,21 @@ class sorted_vector_map } iterator find(const key_type& key) { - iterator it = lower_bound(key); - if (it == end() || !key_comp()(key, it->first)) { - return it; - } - return end(); + return find(*this, key); } const_iterator find(const key_type& key) const { - const_iterator it = lower_bound(key); - if (it == end() || !key_comp()(key, it->first)) { - return it; - } - return end(); + return find(*this, key); + } + + template + if_is_transparent find(const K& key) { + return find(*this, key); + } + + template + if_is_transparent find(const K& key) const { + return find(*this, key); } mapped_type& at(const key_type& key) { @@ -634,54 +711,66 @@ class sorted_vector_map return find(key) == end() ? 0 : 1; } + template + if_is_transparent count(const K& key) const { + return find(key) == end() ? 0 : 1; + } + iterator lower_bound(const key_type& key) { - auto c = key_comp(); - auto f = [&](const value_type& a, const key_type& b) { - return c(a.first, b); - }; - return std::lower_bound(begin(), end(), key, f); + return lower_bound(*this, key); } const_iterator lower_bound(const key_type& key) const { - auto c = key_comp(); - auto f = [&](const value_type& a, const key_type& b) { - return c(a.first, b); - }; - return std::lower_bound(begin(), end(), key, f); + return lower_bound(*this, key); + } + + template + if_is_transparent lower_bound(const K& key) { + return lower_bound(*this, key); + } + + template + if_is_transparent lower_bound(const K& key) const { + return lower_bound(*this, key); } iterator upper_bound(const key_type& key) { - auto c = key_comp(); - auto f = [&](const key_type& a, const value_type& b) { - return c(a, b.first); - }; - return std::upper_bound(begin(), end(), key, f); + return upper_bound(*this, key); } const_iterator upper_bound(const key_type& key) const { - auto c = key_comp(); - auto f = [&](const key_type& a, const value_type& b) { - return c(a, b.first); - }; - return std::upper_bound(begin(), end(), key, f); + return upper_bound(*this, key); } - std::pair equal_range(const key_type& key) { - // Note: std::equal_range can't be passed a functor that takes - // argument types different from the iterator value_type, so we - // have to do this. - iterator low = lower_bound(key); - auto c = key_comp(); - auto f = [&](const key_type& a, const value_type& b) { - return c(a, b.first); - }; - iterator high = std::upper_bound(low, end(), key, f); - return std::make_pair(low, high); + template + if_is_transparent upper_bound(const K& key) { + return upper_bound(*this, key); + } + + template + if_is_transparent upper_bound(const K& key) const { + return upper_bound(*this, key); } - std::pair - equal_range(const key_type& key) const { - return const_cast(this)->equal_range(key); + std::pair equal_range(const key_type& key) { + return equal_range(*this, key); + } + + std::pair equal_range( + const key_type& key) const { + return equal_range(*this, key); + } + + template + if_is_transparent> equal_range( + const K& key) { + return equal_range(*this, key); + } + + template + if_is_transparent> equal_range( + const K& key) const { + return equal_range(*this, key); } // Nothrow as long as swap() on the Compare type is nothrow. @@ -719,6 +808,46 @@ class sorted_vector_map {} ContainerT cont_; } m_; + + template + using self_iterator_t = _t< + std::conditional::value, const_iterator, iterator>>; + + template + static self_iterator_t find(Self& self, K const& key) { + auto end = self.end(); + auto it = self.lower_bound(key); + if (it == end || !self.key_comp()(key, it->first)) { + return it; + } + return end; + } + + template + static self_iterator_t lower_bound(Self& self, K const& key) { + auto f = [c = self.key_comp()](value_type const& a, K const& b) { + return c(a.first, b); + }; + return std::lower_bound(self.begin(), self.end(), key, f); + } + + template + static self_iterator_t upper_bound(Self& self, K const& key) { + auto f = [c = self.key_comp()](K const& a, value_type const& b) { + return c(a, b.first); + }; + return std::upper_bound(self.begin(), self.end(), key, f); + } + + template + static std::pair, self_iterator_t> equal_range( + Self& self, + K const& key) { + // Note: std::equal_range can't be passed a functor that takes + // argument types different from the iterator value_type, so we + // have to do this. + return {lower_bound(self, key), upper_bound(self, key)}; + } }; // Swap function that can be found using ADL. diff --git a/folly/test/sorted_vector_test.cpp b/folly/test/sorted_vector_test.cpp index fa9c977e..7efa4c8f 100644 --- a/folly/test/sorted_vector_test.cpp +++ b/folly/test/sorted_vector_test.cpp @@ -76,6 +76,27 @@ struct CountCopyCtor { int count_; }; +struct Opaque { + int value; + friend bool operator==(Opaque a, Opaque b) { + return a.value == b.value; + } + friend bool operator<(Opaque a, Opaque b) { + return a.value < b.value; + } + struct Compare : std::less, std::less { + using is_transparent = void; + using std::less::operator(); + using std::less::operator(); + bool operator()(int a, Opaque b) const { + return std::less::operator()(a, b.value); + } + bool operator()(Opaque a, int b) const { + return std::less::operator()(a.value, b); + } + }; +}; + } // namespace TEST(SortedVectorTypes, SimpleSetTest) { @@ -145,6 +166,73 @@ TEST(SortedVectorTypes, SimpleSetTest) { EXPECT_TRUE(cpy2 == cpy); } +TEST(SortedVectorTypes, TransparentSetTest) { + sorted_vector_set s; + EXPECT_TRUE(s.empty()); + for (int i = 0; i < 1000; ++i) { + s.insert(Opaque{rand() % 100000}); + } + EXPECT_FALSE(s.empty()); + check_invariant(s); + + sorted_vector_set s2; + s2.insert(s.begin(), s.end()); + check_invariant(s2); + EXPECT_TRUE(s == s2); + + auto it = s2.lower_bound(32); + if (it->value == 32) { + s2.erase(it); + it = s2.lower_bound(32); + } + check_invariant(s2); + auto oldSz = s2.size(); + s2.insert(it, Opaque{32}); + EXPECT_TRUE(s2.size() == oldSz + 1); + check_invariant(s2); + + const sorted_vector_set& cs2 = s2; + auto range = cs2.equal_range(32); + auto lbound = cs2.lower_bound(32); + auto ubound = cs2.upper_bound(32); + EXPECT_TRUE(range.first == lbound); + EXPECT_TRUE(range.second == ubound); + EXPECT_TRUE(range.first != cs2.end()); + EXPECT_TRUE(range.second != cs2.end()); + EXPECT_TRUE(cs2.count(32) == 1); + EXPECT_FALSE(cs2.find(32) == cs2.end()); + + // Bad insert hint. + s2.insert(s2.begin() + 3, Opaque{33}); + EXPECT_TRUE(s2.find(33) != s2.begin()); + EXPECT_TRUE(s2.find(33) != s2.end()); + check_invariant(s2); + s2.erase(Opaque{33}); + check_invariant(s2); + + it = s2.find(32); + EXPECT_FALSE(it == s2.end()); + s2.erase(it); + EXPECT_TRUE(s2.size() == oldSz); + check_invariant(s2); + + sorted_vector_set cpy(s); + check_invariant(cpy); + EXPECT_TRUE(cpy == s); + sorted_vector_set cpy2(s); + cpy2.insert(Opaque{100001}); + EXPECT_TRUE(cpy2 != cpy); + EXPECT_TRUE(cpy2 != s); + check_invariant(cpy2); + EXPECT_TRUE(cpy2.count(100001) == 1); + s.swap(cpy2); + check_invariant(cpy2); + check_invariant(s); + EXPECT_TRUE(s != cpy); + EXPECT_TRUE(s != cpy2); + EXPECT_TRUE(cpy2 == cpy); +} + TEST(SortedVectorTypes, BadHints) { for (int toInsert = -1; toInsert <= 7; ++toInsert) { for (int hintPos = 0; hintPos <= 4; ++hintPos) { @@ -221,6 +309,67 @@ TEST(SortedVectorTypes, SimpleMapTest) { check_invariant(m); } +TEST(SortedVectorTypes, TransparentMapTest) { + sorted_vector_map m; + for (int i = 0; i < 1000; ++i) { + m[Opaque{i}] = i / 1000.0; + } + check_invariant(m); + + m[Opaque{32}] = 100.0; + check_invariant(m); + EXPECT_TRUE(m.count(32) == 1); + EXPECT_DOUBLE_EQ(100.0, m.at(Opaque{32})); + EXPECT_FALSE(m.find(32) == m.end()); + m.erase(Opaque{32}); + EXPECT_TRUE(m.find(32) == m.end()); + check_invariant(m); + EXPECT_THROW(m.at(Opaque{32}), std::out_of_range); + + sorted_vector_map m2 = m; + EXPECT_TRUE(m2 == m); + EXPECT_FALSE(m2 != m); + auto it = m2.lower_bound(1 << 20); + EXPECT_TRUE(it == m2.end()); + m2.insert(it, std::make_pair(Opaque{1 << 20}, 10.0f)); + check_invariant(m2); + EXPECT_TRUE(m2.count(1 << 20) == 1); + EXPECT_TRUE(m < m2); + EXPECT_TRUE(m <= m2); + + const sorted_vector_map& cm = m; + auto range = cm.equal_range(42); + auto lbound = cm.lower_bound(42); + auto ubound = cm.upper_bound(42); + EXPECT_TRUE(range.first == lbound); + EXPECT_TRUE(range.second == ubound); + EXPECT_FALSE(range.first == cm.end()); + EXPECT_FALSE(range.second == cm.end()); + m.erase(m.lower_bound(42)); + check_invariant(m); + + sorted_vector_map m3; + m3.insert(m2.begin(), m2.end()); + check_invariant(m3); + EXPECT_TRUE(m3 == m2); + EXPECT_FALSE(m3 == m); + + EXPECT_TRUE(m != m2); + EXPECT_TRUE(m2 == m3); + EXPECT_TRUE(m3 != m); + m.swap(m3); + check_invariant(m); + check_invariant(m2); + check_invariant(m3); + EXPECT_TRUE(m3 != m2); + EXPECT_TRUE(m3 != m); + EXPECT_TRUE(m == m2); + + // Bad insert hint. + m.insert(m.begin() + 3, std::make_pair(Opaque{1 << 15}, 1.0f)); + check_invariant(m); +} + TEST(SortedVectorTypes, Sizes) { EXPECT_EQ(sizeof(sorted_vector_set), sizeof(std::vector)); -- 2.34.1