mirror of
https://github.com/ETLCPP/etl.git
synced 2026-04-30 19:09:10 +08:00
Added etl::standard_deviation + corrected covariance
This commit is contained in:
parent
3b367ce2b5
commit
203f082c91
@ -239,8 +239,8 @@ namespace etl
|
||||
double mean1 = sum1 / n;
|
||||
double mean2 = sum2 / n;
|
||||
|
||||
double stddev1_squared = sum_of_squares1 - (n * mean1 * mean1);
|
||||
double stddev2_squared = sum_of_squares2 - (n * mean2 * mean2);
|
||||
double stddev1_squared = (sum_of_squares1 / n) - (mean1 * mean1);
|
||||
double stddev2_squared = (sum_of_squares2 / n) - (mean2 * mean2);
|
||||
|
||||
double stddev1 = 0.0;
|
||||
double stddev2 = 0.0;
|
||||
@ -255,7 +255,7 @@ namespace etl
|
||||
stddev2 = sqrt(stddev2_squared);
|
||||
}
|
||||
|
||||
covariance_value = inner_product - (n * mean1 * mean2);
|
||||
covariance_value = (inner_product / n) - (mean1 * mean2);
|
||||
|
||||
if ((stddev1 > 0.0) && (stddev2 > 0.0))
|
||||
{
|
||||
|
||||
@ -202,7 +202,7 @@ namespace etl
|
||||
double mean1 = sum1 / n;
|
||||
double mean2 = sum2 / n;
|
||||
|
||||
covariance_value = inner_product - (n * mean1 * mean2);
|
||||
covariance_value = (inner_product / n) - (mean1 * mean2);
|
||||
|
||||
recalulate = false;
|
||||
}
|
||||
|
||||
234
include/etl/standard_deviation.h
Normal file
234
include/etl/standard_deviation.h
Normal file
@ -0,0 +1,234 @@
|
||||
///\file
|
||||
|
||||
/******************************************************************************
|
||||
The MIT License(MIT)
|
||||
|
||||
Embedded Template Library.
|
||||
https://github.com/ETLCPP/etl
|
||||
https://www.etlcpp.com
|
||||
|
||||
Copyright(c) 2021 jwellbelove
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files(the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions :
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
******************************************************************************/
|
||||
|
||||
#ifndef ETL_STANDARD_DEVIATION_INCLUDED
|
||||
#define ETL_STANDARD_DEVIATION_INCLUDED
|
||||
|
||||
#include "functional.h"
|
||||
#include "type_traits.h"
|
||||
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace etl
|
||||
{
|
||||
namespace private_standard_deviation
|
||||
{
|
||||
//***************************************************************************
|
||||
/// Types for generic covariance.
|
||||
//***************************************************************************
|
||||
template <typename TInput, typename TCalc>
|
||||
struct standard_deviation_types
|
||||
{
|
||||
TCalc sum_of_squares;
|
||||
TCalc sum;
|
||||
uint32_t counter;
|
||||
|
||||
//*********************************
|
||||
/// Clear the histogram.
|
||||
//*********************************
|
||||
void clear()
|
||||
{
|
||||
sum_of_squares = TCalc(0);
|
||||
sum = TCalc(0);
|
||||
counter = 0U;
|
||||
}
|
||||
};
|
||||
|
||||
//***************************************************************************
|
||||
/// Types for float covariance.
|
||||
//***************************************************************************
|
||||
template <typename TCalc>
|
||||
struct standard_deviation_types<float, TCalc>
|
||||
{
|
||||
float sum_of_squares;
|
||||
float sum;
|
||||
uint32_t counter;
|
||||
|
||||
//*********************************
|
||||
/// Clear the histogram.
|
||||
//*********************************
|
||||
void clear()
|
||||
{
|
||||
sum_of_squares = float(0);
|
||||
sum = float(0);
|
||||
counter = 0U;
|
||||
}
|
||||
};
|
||||
|
||||
//***************************************************************************
|
||||
/// Types for double covariance.
|
||||
//***************************************************************************
|
||||
template <typename TCalc>
|
||||
struct standard_deviation_types<double, TCalc>
|
||||
{
|
||||
double sum_of_squares;
|
||||
double sum;
|
||||
uint32_t counter;
|
||||
|
||||
//*********************************
|
||||
/// Clear the histogram.
|
||||
//*********************************
|
||||
void clear()
|
||||
{
|
||||
sum_of_squares = double(0);
|
||||
sum = double(0);
|
||||
counter = 0U;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
//***************************************************************************
|
||||
/// Constructor.
|
||||
//***************************************************************************
|
||||
template <typename TInput, typename TCalc = TInput>
|
||||
class standard_deviation
|
||||
: public private_standard_deviation::standard_deviation_types<TInput, TCalc>
|
||||
, public etl::binary_function<TInput, TInput, void>
|
||||
{
|
||||
public:
|
||||
|
||||
//*********************************
|
||||
/// Constructor.
|
||||
//*********************************
|
||||
standard_deviation()
|
||||
: recalulate(true)
|
||||
{
|
||||
this->clear();
|
||||
}
|
||||
|
||||
//*********************************
|
||||
/// Constructor.
|
||||
//*********************************
|
||||
template <typename TIterator>
|
||||
standard_deviation(TIterator first, TIterator last)
|
||||
: recalulate(true)
|
||||
{
|
||||
this->clear();
|
||||
add(first, last);
|
||||
}
|
||||
|
||||
//*********************************
|
||||
/// Add a pair of values.
|
||||
//*********************************
|
||||
void add(TInput value)
|
||||
{
|
||||
sum_of_squares += TCalc(value * value);
|
||||
sum += TCalc(value);
|
||||
++counter;
|
||||
recalulate = true;
|
||||
}
|
||||
|
||||
//*********************************
|
||||
/// Add a range.
|
||||
//*********************************
|
||||
template <typename TIterator>
|
||||
typename etl::enable_if<!etl::is_same<TIterator, TInput>::value, void>::type
|
||||
add(TIterator first, TIterator last)
|
||||
{
|
||||
while (first != last)
|
||||
{
|
||||
add(*first++);
|
||||
}
|
||||
}
|
||||
|
||||
//*********************************
|
||||
/// operator ()
|
||||
/// Add a pair of values.
|
||||
//*********************************
|
||||
void operator ()(TInput value)
|
||||
{
|
||||
add(value);
|
||||
}
|
||||
|
||||
//*********************************
|
||||
/// operator ()
|
||||
/// Add a range.
|
||||
//*********************************
|
||||
template <typename TIterator>
|
||||
typename etl::enable_if<!etl::is_same<TIterator, TInput>::value, void>::type
|
||||
operator ()(TIterator first, TIterator last)
|
||||
{
|
||||
add(first, last);
|
||||
}
|
||||
|
||||
//*********************************
|
||||
/// Get the standard_deviation.
|
||||
//*********************************
|
||||
double get_standard_deviation()
|
||||
{
|
||||
if (recalulate)
|
||||
{
|
||||
standard_deviation_value = 0.0;
|
||||
|
||||
if (counter != 0)
|
||||
{
|
||||
double n = double(counter);
|
||||
|
||||
double mean = sum / n;
|
||||
|
||||
double stddev_squared = (sum_of_squares / n) - (mean * mean);
|
||||
|
||||
if (stddev_squared > 0)
|
||||
{
|
||||
standard_deviation_value = sqrt(stddev_squared);
|
||||
}
|
||||
}
|
||||
|
||||
recalulate = false;
|
||||
}
|
||||
|
||||
return standard_deviation_value;
|
||||
}
|
||||
|
||||
//*********************************
|
||||
/// Get the standard_deviation.
|
||||
//*********************************
|
||||
operator double()
|
||||
{
|
||||
return get_standard_deviation();
|
||||
}
|
||||
|
||||
//*********************************
|
||||
/// Get the total number added entries.
|
||||
//*********************************
|
||||
size_t count() const
|
||||
{
|
||||
return size_t(counter);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
double standard_deviation_value;
|
||||
bool recalulate;
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -124,7 +124,7 @@ namespace
|
||||
correlation_result = correlation1;
|
||||
CHECK_CLOSE(-1.0, correlation_result, 0.1);
|
||||
covariance_result = correlation1.get_covariance();
|
||||
CHECK_CLOSE(-82.5, covariance_result, 0.1);
|
||||
CHECK_CLOSE(-8.25, covariance_result, 0.1);
|
||||
|
||||
// Zero correlation
|
||||
etl::correlation<char, int32_t> correlation2(input_c.begin(), input_c.end(), input_c_flat.begin());
|
||||
@ -138,7 +138,7 @@ namespace
|
||||
correlation_result = correlation3;
|
||||
CHECK_CLOSE(1.0, correlation_result, 0.1);
|
||||
covariance_result = correlation3.get_covariance();
|
||||
CHECK_CLOSE(82.5, covariance_result, 0.1);
|
||||
CHECK_CLOSE(8.25, covariance_result, 0.1);
|
||||
}
|
||||
|
||||
//*************************************************************************
|
||||
@ -152,7 +152,7 @@ namespace
|
||||
correlation_result = correlation1;
|
||||
CHECK_CLOSE(-1.0, correlation_result, 0.1);
|
||||
covariance_result = correlation1.get_covariance();
|
||||
CHECK_CLOSE(-82.5, covariance_result, 0.1);
|
||||
CHECK_CLOSE(-8.25, covariance_result, 0.1);
|
||||
|
||||
// Zero correlation
|
||||
etl::correlation<float> correlation2(input_f.begin(), input_f.end(), input_f_flat.begin());
|
||||
@ -166,7 +166,7 @@ namespace
|
||||
correlation_result = correlation3;
|
||||
CHECK_CLOSE(1.0, correlation_result, 0.1);
|
||||
covariance_result = correlation3.get_covariance();
|
||||
CHECK_CLOSE(82.5, covariance_result, 0.1);
|
||||
CHECK_CLOSE(8.25, covariance_result, 0.1);
|
||||
}
|
||||
|
||||
//*************************************************************************
|
||||
@ -180,7 +180,7 @@ namespace
|
||||
correlation_result = correlation1;
|
||||
CHECK_CLOSE(-1.0, correlation_result, 0.1);
|
||||
covariance_result = correlation1.get_covariance();
|
||||
CHECK_CLOSE(-82.5, covariance_result, 0.1);
|
||||
CHECK_CLOSE(-8.25, covariance_result, 0.1);
|
||||
|
||||
// Zero correlation
|
||||
etl::correlation<double> correlation2(input_d.begin(), input_d.end(), input_d_flat.begin());
|
||||
@ -194,7 +194,7 @@ namespace
|
||||
correlation_result = correlation3;
|
||||
CHECK_CLOSE(1.0, correlation_result, 0.1);
|
||||
covariance_result = correlation3.get_covariance();
|
||||
CHECK_CLOSE(82.5, covariance_result, 0.1);
|
||||
CHECK_CLOSE(8.25, covariance_result, 0.1);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@ -121,7 +121,7 @@ namespace
|
||||
// Negative covariance.
|
||||
etl::covariance<char, int32_t> covariance1(input_c.begin(), input_c.end(), input_c_inv.begin());
|
||||
covariance_result = covariance1.get_covariance();
|
||||
CHECK_CLOSE(-82.5, covariance_result, 0.1);
|
||||
CHECK_CLOSE(-8.25, covariance_result, 0.1);
|
||||
|
||||
// Zero covariance
|
||||
etl::covariance<char, int32_t> covariance2(input_c.begin(), input_c.end(), input_c_flat.begin());
|
||||
@ -131,7 +131,7 @@ namespace
|
||||
// Positive covariance.
|
||||
etl::covariance<char, int32_t> covariance3(input_c.begin(), input_c.end(), input_c.begin());
|
||||
covariance_result = covariance3.get_covariance();
|
||||
CHECK_CLOSE(82.5, covariance_result, 0.1);
|
||||
CHECK_CLOSE(8.25, covariance_result, 0.1);
|
||||
}
|
||||
|
||||
//*************************************************************************
|
||||
@ -142,7 +142,7 @@ namespace
|
||||
// Negative covariance.
|
||||
etl::covariance<float> covariance1(input_f.begin(), input_f.end(), input_f_inv.begin());
|
||||
covariance_result = covariance1.get_covariance();
|
||||
CHECK_CLOSE(-82.5, covariance_result, 0.1);
|
||||
CHECK_CLOSE(-8.25, covariance_result, 0.1);
|
||||
|
||||
// Zero covariance
|
||||
etl::covariance<float> covariance2(input_f.begin(), input_f.end(), input_f_flat.begin());
|
||||
@ -152,7 +152,7 @@ namespace
|
||||
// Positive covariance.
|
||||
etl::covariance<float> covariance3(input_f.begin(), input_f.end(), input_f.begin());
|
||||
covariance_result = covariance3.get_covariance();
|
||||
CHECK_CLOSE(82.5, covariance_result, 0.1);
|
||||
CHECK_CLOSE(8.25, covariance_result, 0.1);
|
||||
}
|
||||
|
||||
//*************************************************************************
|
||||
@ -163,7 +163,7 @@ namespace
|
||||
// Negative covariance.
|
||||
etl::covariance<double> covariance1(input_d.begin(), input_d.end(), input_d_inv.begin());
|
||||
covariance_result = covariance1.get_covariance();
|
||||
CHECK_CLOSE(-82.5, covariance_result, 0.1);
|
||||
CHECK_CLOSE(-8.25, covariance_result, 0.1);
|
||||
|
||||
// Zero covariance
|
||||
etl::covariance<double> covariance2(input_d.begin(), input_d.end(), input_d_flat.begin());
|
||||
@ -173,7 +173,7 @@ namespace
|
||||
// Positive covariance.
|
||||
etl::covariance<double> covariance3(input_d.begin(), input_d.end(), input_d.begin());
|
||||
covariance_result = covariance3.get_covariance();
|
||||
CHECK_CLOSE(82.5, covariance_result, 0.1);
|
||||
CHECK_CLOSE(8.25, covariance_result, 0.1);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
119
test/test_standard_deviation.cpp
Normal file
119
test/test_standard_deviation.cpp
Normal file
@ -0,0 +1,119 @@
|
||||
/******************************************************************************
|
||||
The MIT License(MIT)
|
||||
|
||||
Embedded Template Library.
|
||||
https://github.com/ETLCPP/etl
|
||||
https://www.etlcpp.com
|
||||
|
||||
Copyright(c) 2021 jwellbelove
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files(the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions :
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
******************************************************************************/
|
||||
|
||||
#include "unit_test_framework.h"
|
||||
|
||||
#include "etl/standard_deviation.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
namespace
|
||||
{
|
||||
std::array<char, 10> input_c
|
||||
{
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9
|
||||
};
|
||||
|
||||
//*********************************
|
||||
std::array<float, 10> input_f
|
||||
{
|
||||
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f
|
||||
};
|
||||
|
||||
//*********************************
|
||||
std::array<double, 10> input_d
|
||||
{
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0
|
||||
};
|
||||
|
||||
SUITE(test_standard_deviation)
|
||||
{
|
||||
//*************************************************************************
|
||||
TEST(test_char_standard_deviation_default_constuctor)
|
||||
{
|
||||
etl::standard_deviation<char, int32_t> standard_deviation;
|
||||
|
||||
double standard_deviation_result = standard_deviation;
|
||||
|
||||
CHECK_EQUAL(0.0, standard_deviation_result);
|
||||
}
|
||||
|
||||
//*************************************************************************
|
||||
TEST(test_float_standard_deviation_default_constuctor)
|
||||
{
|
||||
etl::standard_deviation<float> standard_deviation;
|
||||
|
||||
double standard_deviation_result = standard_deviation;
|
||||
|
||||
CHECK_EQUAL(0.0, standard_deviation_result);
|
||||
}
|
||||
|
||||
//*************************************************************************
|
||||
TEST(test_double_standard_deviation_default_constuctor)
|
||||
{
|
||||
etl::standard_deviation<double> standard_deviation;
|
||||
|
||||
double standard_deviation_result = standard_deviation;
|
||||
|
||||
CHECK_EQUAL(0.0, standard_deviation_result);
|
||||
}
|
||||
|
||||
//*************************************************************************
|
||||
TEST(test_char_standard_deviation_constuctor)
|
||||
{
|
||||
double standard_deviation_result;
|
||||
|
||||
// Negative standard_deviation.
|
||||
etl::standard_deviation<char, int32_t> standard_deviation1(input_c.begin(), input_c.end());
|
||||
standard_deviation_result = standard_deviation1.get_standard_deviation();
|
||||
CHECK_CLOSE(2.872281323269, standard_deviation_result, 0.1);
|
||||
}
|
||||
|
||||
//*************************************************************************
|
||||
TEST(test_float_standard_deviation_constuctor)
|
||||
{
|
||||
double standard_deviation_result;
|
||||
|
||||
// Negative standard_deviation.
|
||||
etl::standard_deviation<float> standard_deviation1(input_f.begin(), input_f.end());
|
||||
standard_deviation_result = standard_deviation1.get_standard_deviation();
|
||||
CHECK_CLOSE(2.872281323269, standard_deviation_result, 0.1);
|
||||
}
|
||||
|
||||
//*************************************************************************
|
||||
TEST(test_double_standard_deviation_constuctor)
|
||||
{
|
||||
double standard_deviation_result;
|
||||
|
||||
// Negative standard_deviation.
|
||||
etl::standard_deviation<double> standard_deviation1(input_d.begin(), input_d.end());
|
||||
standard_deviation_result = standard_deviation1.get_standard_deviation();
|
||||
CHECK_CLOSE(2.872281323269, standard_deviation_result, 0.1);
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -4613,6 +4613,7 @@
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\test_span.cpp" />
|
||||
<ClCompile Include="..\test_standard_deviation.cpp" />
|
||||
<ClCompile Include="..\test_state_chart.cpp" />
|
||||
<ClCompile Include="..\test_smallest.cpp" />
|
||||
<ClCompile Include="..\test_stack.cpp" />
|
||||
|
||||
@ -2414,6 +2414,9 @@
|
||||
<ClCompile Include="..\test_covariance.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\test_standard_deviation.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="..\..\library.properties">
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user