From 55f508b315932b362fcdd7056b991f10ffcfefe9 Mon Sep 17 00:00:00 2001 From: John Wellbelove Date: Tue, 16 Apr 2024 08:07:33 +0100 Subject: [PATCH] Added etl::partition and etl::nth_element --- include/etl/algorithm.h | 81 +++++++++++++++++++++++++++++++++++++++++ test/test_algorithm.cpp | 81 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+) diff --git a/include/etl/algorithm.h b/include/etl/algorithm.h index 4d88910c..313a2c87 100644 --- a/include/etl/algorithm.h +++ b/include/etl/algorithm.h @@ -3140,6 +3140,87 @@ namespace etl return first; } + + //********************************************************* + namespace private_algorithm + { + using ETL_OR_STD::swap; + + template +#if (ETL_USING_CPP20 && ETL_USING_STL) || (ETL_USING_CPP14 && ETL_NOT_USING_STL && !defined(ETL_IN_UNIT_TEST)) + constexpr +#endif + TIterator nth_partition(TIterator first, TIterator last, TCompare compare) + { + typedef typename etl::iterator_traits::value_type value_type; + + TIterator pivot = last; // Maybe find a better pivot choice? + value_type pivot_value = *pivot; + + // Swap the pivot with the last, if necessary. + if (pivot != last) + { + swap(*pivot, *last); + } + + TIterator i = first; + + for (TIterator j = first; j < last; ++j) + { + if (!compare(pivot_value, *j)) // Hack to get '*j <= pivot_value' in terms of 'pivot_value < *j' + { + swap(*i, *j); + ++i; + } + } + + swap(*i, *last); + + return i; + } + } + + //********************************************************* + /// nth_element + /// see https://en.cppreference.com/w/cpp/algorithm/nth_element + //********************************************************* +#if ETL_USING_CPP11 + template ::value_type> > +#else + template +#endif +#if (ETL_USING_CPP20 && ETL_USING_STL) || (ETL_USING_CPP14 && ETL_NOT_USING_STL && !defined(ETL_IN_UNIT_TEST)) + constexpr +#endif + typename etl::enable_if::value, void>::type + nth_element(TIterator first, TIterator nth, TIterator last, TCompare compare = TCompare()) + { + if (first == last) + { + return; + } + + // 'last' must point to the actual last value. + --last; + + while (first <= last) + { + TIterator p = private_algorithm::nth_partition(first, last, compare); + + if (p == nth) + { + return; + } + else if (p > nth) + { + last = p - 1; + } + else + { + first = p + 1; + } + } + } } #include "private/minmax_pop.h" diff --git a/test/test_algorithm.cpp b/test/test_algorithm.cpp index b7e8e663..875884c6 100644 --- a/test/test_algorithm.cpp +++ b/test/test_algorithm.cpp @@ -2303,5 +2303,86 @@ namespace data = initial; } } + + //************************************************************************* + TEST(nth_element_with_default_less_than_comparison) + { + // 40,320 permutations. + std::array initial = { 0, 1, 2, 3, 4, 5, 6, 7 }; + + std::array compare = initial; + std::array data = initial; + + bool complete = false; + + // For each nth position of each permutation. + while (!complete) + { + // Try each nth position. + for (size_t i = 0; i < initial.size(); ++i) + { + std::sort(compare.begin(), compare.end()); + etl::nth_element(data.begin(), data.begin() + i, data.end()); + + CHECK_EQUAL(compare[i], data[i]); + } + + complete = !std::next_permutation(initial.begin(), initial.end()); + + compare = initial; + data = initial; + } + } + +#if (ETL_USING_CPP20 && ETL_USING_STL) || (ETL_USING_CPP14 && ETL_NOT_USING_STL && !defined(ETL_IN_UNIT_TEST)) + //************************************************************************* + constexpr int MakeNth(int nth_index) + { + std::array data = { 5, 1, 3, 7, 6, 2, 4, 0 }; + + etl::nth_element(data.begin(), data.begin() + nth_index, data.end()); + + return data[nth_index]; + } + + TEST(constexpr_nth_element_with_default_less_than_comparison) + { + std::array compare = { 0, 1, 2, 3, 4, 5, 6, 7 }; + + constexpr int nth = MakeNth(3); + + CHECK_EQUAL(compare[3], nth); + } +#endif + + //************************************************************************* + TEST(nth_element_with_custom_comparison) + { + // 40,320 permutations. + std::array initial = { 0, 1, 2, 3, 4, 5, 6, 7 }; + + std::array compare = initial; + std::array data = initial; + + bool complete = false; + + // For each nth position of each permutation. + while (!complete) + { + // Try each nth position. + for (size_t i = 0; i < initial.size(); ++i) + { + std::sort(compare.begin(), compare.end(), std::greater()); + etl::nth_element(data.begin(), data.begin() + i, data.end(), std::greater()); + + CHECK_EQUAL(compare[i], data[i]); + } + + complete = !std::next_permutation(initial.begin(), initial.end()); + + compare = initial; + data = initial; + } + } }; }