mirror of
https://github.com/ETLCPP/etl.git
synced 2026-06-16 00:46:03 +08:00
Added etl::partition and etl::nth_element
This commit is contained in:
parent
676b5f330a
commit
55f508b315
@ -3140,6 +3140,87 @@ namespace etl
|
||||
|
||||
return first;
|
||||
}
|
||||
|
||||
//*********************************************************
|
||||
namespace private_algorithm
|
||||
{
|
||||
using ETL_OR_STD::swap;
|
||||
|
||||
template <typename TIterator, typename TCompare>
|
||||
#if (ETL_USING_CPP20 && ETL_USING_STL) || (ETL_USING_CPP14 && ETL_NOT_USING_STL && !defined(ETL_IN_UNIT_TEST))
|
||||
constexpr
|
||||
#endif
|
||||
TIterator nth_partition(TIterator first, TIterator last, TCompare compare)
|
||||
{
|
||||
typedef typename etl::iterator_traits<TIterator>::value_type value_type;
|
||||
|
||||
TIterator pivot = last; // Maybe find a better pivot choice?
|
||||
value_type pivot_value = *pivot;
|
||||
|
||||
// Swap the pivot with the last, if necessary.
|
||||
if (pivot != last)
|
||||
{
|
||||
swap(*pivot, *last);
|
||||
}
|
||||
|
||||
TIterator i = first;
|
||||
|
||||
for (TIterator j = first; j < last; ++j)
|
||||
{
|
||||
if (!compare(pivot_value, *j)) // Hack to get '*j <= pivot_value' in terms of 'pivot_value < *j'
|
||||
{
|
||||
swap(*i, *j);
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
swap(*i, *last);
|
||||
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
//*********************************************************
|
||||
/// nth_element
|
||||
/// see https://en.cppreference.com/w/cpp/algorithm/nth_element
|
||||
//*********************************************************
|
||||
#if ETL_USING_CPP11
|
||||
template <typename TIterator, typename TCompare = etl::less<typename etl::iterator_traits<TIterator>::value_type> >
|
||||
#else
|
||||
template <typename TIterator, typename TCompare>
|
||||
#endif
|
||||
#if (ETL_USING_CPP20 && ETL_USING_STL) || (ETL_USING_CPP14 && ETL_NOT_USING_STL && !defined(ETL_IN_UNIT_TEST))
|
||||
constexpr
|
||||
#endif
|
||||
typename etl::enable_if<etl::is_random_access_iterator_concept<TIterator>::value, void>::type
|
||||
nth_element(TIterator first, TIterator nth, TIterator last, TCompare compare = TCompare())
|
||||
{
|
||||
if (first == last)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// 'last' must point to the actual last value.
|
||||
--last;
|
||||
|
||||
while (first <= last)
|
||||
{
|
||||
TIterator p = private_algorithm::nth_partition(first, last, compare);
|
||||
|
||||
if (p == nth)
|
||||
{
|
||||
return;
|
||||
}
|
||||
else if (p > nth)
|
||||
{
|
||||
last = p - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
first = p + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#include "private/minmax_pop.h"
|
||||
|
||||
@ -2303,5 +2303,86 @@ namespace
|
||||
data = initial;
|
||||
}
|
||||
}
|
||||
|
||||
//*************************************************************************
|
||||
TEST(nth_element_with_default_less_than_comparison)
|
||||
{
|
||||
// 40,320 permutations.
|
||||
std::array<int, 8> initial = { 0, 1, 2, 3, 4, 5, 6, 7 };
|
||||
|
||||
std::array<int, 8> compare = initial;
|
||||
std::array<int, 8> data = initial;
|
||||
|
||||
bool complete = false;
|
||||
|
||||
// For each nth position of each permutation.
|
||||
while (!complete)
|
||||
{
|
||||
// Try each nth position.
|
||||
for (size_t i = 0; i < initial.size(); ++i)
|
||||
{
|
||||
std::sort(compare.begin(), compare.end());
|
||||
etl::nth_element(data.begin(), data.begin() + i, data.end());
|
||||
|
||||
CHECK_EQUAL(compare[i], data[i]);
|
||||
}
|
||||
|
||||
complete = !std::next_permutation(initial.begin(), initial.end());
|
||||
|
||||
compare = initial;
|
||||
data = initial;
|
||||
}
|
||||
}
|
||||
|
||||
#if (ETL_USING_CPP20 && ETL_USING_STL) || (ETL_USING_CPP14 && ETL_NOT_USING_STL && !defined(ETL_IN_UNIT_TEST))
|
||||
//*************************************************************************
|
||||
constexpr int MakeNth(int nth_index)
|
||||
{
|
||||
std::array<int, 8> data = { 5, 1, 3, 7, 6, 2, 4, 0 };
|
||||
|
||||
etl::nth_element(data.begin(), data.begin() + nth_index, data.end());
|
||||
|
||||
return data[nth_index];
|
||||
}
|
||||
|
||||
TEST(constexpr_nth_element_with_default_less_than_comparison)
|
||||
{
|
||||
std::array<int, 8> compare = { 0, 1, 2, 3, 4, 5, 6, 7 };
|
||||
|
||||
constexpr int nth = MakeNth(3);
|
||||
|
||||
CHECK_EQUAL(compare[3], nth);
|
||||
}
|
||||
#endif
|
||||
|
||||
//*************************************************************************
|
||||
TEST(nth_element_with_custom_comparison)
|
||||
{
|
||||
// 40,320 permutations.
|
||||
std::array<int, 8> initial = { 0, 1, 2, 3, 4, 5, 6, 7 };
|
||||
|
||||
std::array<int, 8> compare = initial;
|
||||
std::array<int, 8> data = initial;
|
||||
|
||||
bool complete = false;
|
||||
|
||||
// For each nth position of each permutation.
|
||||
while (!complete)
|
||||
{
|
||||
// Try each nth position.
|
||||
for (size_t i = 0; i < initial.size(); ++i)
|
||||
{
|
||||
std::sort(compare.begin(), compare.end(), std::greater<int>());
|
||||
etl::nth_element(data.begin(), data.begin() + i, data.end(), std::greater<int>());
|
||||
|
||||
CHECK_EQUAL(compare[i], data[i]);
|
||||
}
|
||||
|
||||
complete = !std::next_permutation(initial.begin(), initial.end());
|
||||
|
||||
compare = initial;
|
||||
data = initial;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user