Added etl::partition and etl::nth_element

This commit is contained in:
John Wellbelove 2024-04-16 08:07:33 +01:00
parent 676b5f330a
commit 55f508b315
2 changed files with 162 additions and 0 deletions

View File

@ -3140,6 +3140,87 @@ namespace etl
return first;
}
//*********************************************************
namespace private_algorithm
{
using ETL_OR_STD::swap;
template <typename TIterator, typename TCompare>
#if (ETL_USING_CPP20 && ETL_USING_STL) || (ETL_USING_CPP14 && ETL_NOT_USING_STL && !defined(ETL_IN_UNIT_TEST))
constexpr
#endif
TIterator nth_partition(TIterator first, TIterator last, TCompare compare)
{
typedef typename etl::iterator_traits<TIterator>::value_type value_type;
TIterator pivot = last; // Maybe find a better pivot choice?
value_type pivot_value = *pivot;
// Swap the pivot with the last, if necessary.
if (pivot != last)
{
swap(*pivot, *last);
}
TIterator i = first;
for (TIterator j = first; j < last; ++j)
{
if (!compare(pivot_value, *j)) // Hack to get '*j <= pivot_value' in terms of 'pivot_value < *j'
{
swap(*i, *j);
++i;
}
}
swap(*i, *last);
return i;
}
}
//*********************************************************
/// nth_element
/// see https://en.cppreference.com/w/cpp/algorithm/nth_element
//*********************************************************
#if ETL_USING_CPP11
template <typename TIterator, typename TCompare = etl::less<typename etl::iterator_traits<TIterator>::value_type> >
#else
template <typename TIterator, typename TCompare>
#endif
#if (ETL_USING_CPP20 && ETL_USING_STL) || (ETL_USING_CPP14 && ETL_NOT_USING_STL && !defined(ETL_IN_UNIT_TEST))
constexpr
#endif
typename etl::enable_if<etl::is_random_access_iterator_concept<TIterator>::value, void>::type
nth_element(TIterator first, TIterator nth, TIterator last, TCompare compare = TCompare())
{
if (first == last)
{
return;
}
// 'last' must point to the actual last value.
--last;
while (first <= last)
{
TIterator p = private_algorithm::nth_partition(first, last, compare);
if (p == nth)
{
return;
}
else if (p > nth)
{
last = p - 1;
}
else
{
first = p + 1;
}
}
}
}
#include "private/minmax_pop.h"

View File

@ -2303,5 +2303,86 @@ namespace
data = initial;
}
}
//*************************************************************************
TEST(nth_element_with_default_less_than_comparison)
{
// 40,320 permutations.
std::array<int, 8> initial = { 0, 1, 2, 3, 4, 5, 6, 7 };
std::array<int, 8> compare = initial;
std::array<int, 8> data = initial;
bool complete = false;
// For each nth position of each permutation.
while (!complete)
{
// Try each nth position.
for (size_t i = 0; i < initial.size(); ++i)
{
std::sort(compare.begin(), compare.end());
etl::nth_element(data.begin(), data.begin() + i, data.end());
CHECK_EQUAL(compare[i], data[i]);
}
complete = !std::next_permutation(initial.begin(), initial.end());
compare = initial;
data = initial;
}
}
#if (ETL_USING_CPP20 && ETL_USING_STL) || (ETL_USING_CPP14 && ETL_NOT_USING_STL && !defined(ETL_IN_UNIT_TEST))
//*************************************************************************
constexpr int MakeNth(int nth_index)
{
std::array<int, 8> data = { 5, 1, 3, 7, 6, 2, 4, 0 };
etl::nth_element(data.begin(), data.begin() + nth_index, data.end());
return data[nth_index];
}
TEST(constexpr_nth_element_with_default_less_than_comparison)
{
std::array<int, 8> compare = { 0, 1, 2, 3, 4, 5, 6, 7 };
constexpr int nth = MakeNth(3);
CHECK_EQUAL(compare[3], nth);
}
#endif
//*************************************************************************
TEST(nth_element_with_custom_comparison)
{
// 40,320 permutations.
std::array<int, 8> initial = { 0, 1, 2, 3, 4, 5, 6, 7 };
std::array<int, 8> compare = initial;
std::array<int, 8> data = initial;
bool complete = false;
// For each nth position of each permutation.
while (!complete)
{
// Try each nth position.
for (size_t i = 0; i < initial.size(); ++i)
{
std::sort(compare.begin(), compare.end(), std::greater<int>());
etl::nth_element(data.begin(), data.begin() + i, data.end(), std::greater<int>());
CHECK_EQUAL(compare[i], data[i]);
}
complete = !std::next_permutation(initial.begin(), initial.end());
compare = initial;
data = initial;
}
}
};
}