From 015003fa033eb98fb703acc166b37ce33252d62d Mon Sep 17 00:00:00 2001 From: John Wellbelove Date: Fri, 26 Nov 2021 19:35:18 +0000 Subject: [PATCH] contains() for etl::map --- include/etl/functional.h | 121 +++++++++++++++++++++++++++++++++++++++ include/etl/map.h | 109 ++++++++++++++++++++++++++++++++++- test/test_map.cpp | 32 +++++++++++ 3 files changed, 259 insertions(+), 3 deletions(-) diff --git a/include/etl/functional.h b/include/etl/functional.h index 5f274af1..04eed104 100644 --- a/include/etl/functional.h +++ b/include/etl/functional.h @@ -118,6 +118,26 @@ namespace etl } }; + template <> + struct less + { + typedef int is_transparent; + +#if ETL_CPP11_SUPPORTED + template + constexpr auto operator()(T1&& lhs, T2&& rhs) const -> decltype(static_cast(lhs) < static_cast(rhs)) + { + return static_cast(lhs) < static_cast(rhs); + } +#else + template + bool operator()(T1&& lhs, T2&& rhs) const + { + return static_cast(lhs) < static_cast(rhs); + } +#endif + }; + //*************************************************************************** template struct less_equal @@ -130,6 +150,26 @@ namespace etl } }; + template <> + struct less_equal + { + typedef int is_transparent; + +#if ETL_CPP11_SUPPORTED + template + constexpr auto operator()(T1&& lhs, T2&& rhs) const -> decltype(static_cast(lhs) < static_cast(rhs)) + { + return !(static_cast(lhs) < static_cast(rhs)); + } +#else + template + bool operator()(T1&& lhs, T2&& rhs) const + { + return !(static_cast(lhs) < static_cast(rhs)); + } +#endif + }; + //*************************************************************************** template struct greater @@ -142,6 +182,26 @@ namespace etl } }; + template <> + struct greater + { + typedef int is_transparent; + +#if ETL_CPP11_SUPPORTED + template + constexpr auto operator()(T1&& lhs, T2&& rhs) const -> decltype(static_cast(lhs) < static_cast(rhs)) + { + return static_cast(rhs) < static_cast(lhs); + } +#else + template + bool operator()(T1&& lhs, T2&& rhs) const + { + return static_cast(rhs) < static_cast(lhs); + } +#endif + }; + //*************************************************************************** template struct greater_equal @@ -154,6 +214,26 @@ namespace etl } }; + template <> + struct greater_equal + { + typedef int is_transparent; + +#if ETL_CPP11_SUPPORTED + template + constexpr auto operator()(T1&& lhs, T2&& rhs) const -> decltype(static_cast(lhs) < static_cast(rhs)) + { + return static_cast(rhs) < static_cast(lhs); + } +#else + template + bool operator()(T1&& lhs, T2&& rhs) const + { + return !(static_cast(rhs) < static_cast(lhs)); + } +#endif + }; + //*************************************************************************** template struct equal_to @@ -166,6 +246,27 @@ namespace etl } }; + template <> + struct equal_to + { + typedef void value_type; + typedef int is_transparent; + +#if ETL_CPP11_SUPPORTED + template + constexpr auto operator()(T1&& lhs, T2&& rhs) const -> decltype(static_cast(lhs) < static_cast(rhs)) + { + return static_cast(lhs) == static_cast(rhs); + } +#else + template + bool operator()(T1&& lhs, T2&& rhs) const + { + return !(static_cast(lhs) < static_cast(rhs)); + } +#endif + }; + //*************************************************************************** template struct not_equal_to @@ -178,6 +279,26 @@ namespace etl } }; + template <> + struct not_equal_to + { + typedef int is_transparent; + +#if ETL_CPP11_SUPPORTED + template + constexpr auto operator()(T1&& lhs, T2&& rhs) const -> decltype(static_cast(lhs) < static_cast(rhs)) + { + return !(static_cast(lhs) == static_cast(rhs)); + } +#else + template + bool operator()(T1&& lhs, T2&& rhs) const + { + return !(static_cast(lhs) == static_cast(rhs)); + } +#endif + }; + //*************************************************************************** template diff --git a/include/etl/map.h b/include/etl/map.h index 01a01ae5..f75e2bbf 100644 --- a/include/etl/map.h +++ b/include/etl/map.h @@ -510,7 +510,7 @@ namespace etl }; /// Defines the key value parameter type - typedef typename etl::parameter_type::type key_parameter_t; + typedef typename TKey key_parameter_t; //************************************************************************* /// How to compare node elements. @@ -519,15 +519,29 @@ namespace etl { return kcompare(node1.value.first, node2.value.first); } + bool node_comp(const Data_Node& node, key_parameter_t key) const { return kcompare(node.value.first, key); } + bool node_comp(key_parameter_t key, const Data_Node& node) const { return kcompare(key, node.value.first); } + template + bool node_comp(const Data_Node& node, const K& key) const + { + return kcompare(node.value.first, key); + } + + template + bool node_comp(const K& key, const Data_Node& node) const + { + return kcompare(key, node.value.first); + } + private: /// The pool of data nodes used in the map. @@ -1034,21 +1048,37 @@ namespace etl ///\param key The key to search for. ///\return An iterator pointing to the element or end() if not found. //********************************************************************* - iterator find(key_parameter_t key) + iterator find(const key_parameter_t& key) { return iterator(*this, find_node(root_node, key)); } + template + iterator find(const K& k) + { + Node* pn = find_node(root_node, k); + + return iterator(*this, pn); + } + //********************************************************************* /// Finds an element. ///\param key The key to search for. ///\return An iterator pointing to the element or end() if not found. //********************************************************************* - const_iterator find(key_parameter_t key) const + const_iterator find(const key_parameter_t& key) const { return const_iterator(*this, find_node(root_node, key)); } + template + const_iterator find(const K& k) const + { + const Node* pn = find_node(root_node, k); + + return const_iterator(*this, pn); + } + //********************************************************************* /// Inserts a value to the map. /// If asserts or exceptions are enabled, emits map_full if the map is already full. @@ -1260,6 +1290,17 @@ namespace etl return vcompare; } + bool contains(const TKey& key) const + { + return find(key) != end(); + } + + template + bool contains(const K& k) const + { + return find(k) != end(); + } + protected: //************************************************************************* @@ -1362,6 +1403,37 @@ namespace etl return found; } + template + Node* find_node(Node* position, const K& key) + { + Node* found = position; + while (found) + { + // Downcast found to Data_Node class for comparison and other operations + Data_Node& found_data_node = imap::data_cast(*found); + + // Compare the node value to the current position value + if (node_comp(key, found_data_node)) + { + // Keep searching for the node on the left + found = found->children[kLeft]; + } + else if (node_comp(found_data_node, key)) + { + // Keep searching for the node on the right + found = found->children[kRight]; + } + else + { + // Node that matches the key provided was found, exit loop + break; + } + } + + // Return the node found (might be ETL_NULLPTR) + return found; + } + //************************************************************************* /// Find the value matching the node provided //************************************************************************* @@ -1395,6 +1467,37 @@ namespace etl return found; } + template + const Node* find_node(const Node* position, const K& key) const + { + const Node* found = position; + while (found) + { + // Downcast found to Data_Node class for comparison and other operations + const Data_Node& found_data_node = imap::data_cast(*found); + + // Compare the node value to the current position value + if (node_comp(key, found_data_node)) + { + // Keep searching for the node on the left + found = found->children[kLeft]; + } + else if (node_comp(found_data_node, key)) + { + // Keep searching for the node on the right + found = found->children[kRight]; + } + else + { + // Node that matches the key provided was found, exit loop + break; + } + } + + // Return the node found (might be ETL_NULLPTR) + return found; + } + //************************************************************************* /// Find the reference node matching the node provided //************************************************************************* diff --git a/test/test_map.cpp b/test/test_map.cpp index 981e42a1..1d0f537e 100644 --- a/test/test_map.cpp +++ b/test/test_map.cpp @@ -61,6 +61,26 @@ using Data_const_iterator = Data::const_iterator; using Compare_Data_iterator = Compare_Data::iterator; using Compare_Data_const_iterator = Compare_Data::const_iterator; +struct Key +{ + Key(const char* k_) + : k(k_) + { + } + + std::string k; +}; + +bool operator <(const Key& lhs, const std::string& rhs) +{ + return (lhs.k < rhs); +} + +bool operator <(const std::string& lhs, const Key& rhs) +{ + return (lhs < rhs.k); +} + namespace { SUITE(test_map) @@ -181,6 +201,18 @@ namespace } }; + //************************************************************************* + TEST_FIXTURE(SetupFixture, test_contains) + { + etl::map> data(initial_data.begin(), initial_data.end()); + + CHECK(data.contains(std::string("1"))); + CHECK(data.contains(Key("1"))); + + CHECK(!data.contains(std::string("99"))); + CHECK(!data.contains(Key("99"))); + } + //************************************************************************* TEST_FIXTURE(SetupFixture, test_default_constructor) {