Added etl::standard_deviation + corrected covariance

This commit is contained in:
John Wellbelove 2021-04-05 20:31:20 +01:00
parent 3b367ce2b5
commit 203f082c91
8 changed files with 373 additions and 16 deletions

View File

@ -239,8 +239,8 @@ namespace etl
double mean1 = sum1 / n;
double mean2 = sum2 / n;
double stddev1_squared = sum_of_squares1 - (n * mean1 * mean1);
double stddev2_squared = sum_of_squares2 - (n * mean2 * mean2);
double stddev1_squared = (sum_of_squares1 / n) - (mean1 * mean1);
double stddev2_squared = (sum_of_squares2 / n) - (mean2 * mean2);
double stddev1 = 0.0;
double stddev2 = 0.0;
@ -255,7 +255,7 @@ namespace etl
stddev2 = sqrt(stddev2_squared);
}
covariance_value = inner_product - (n * mean1 * mean2);
covariance_value = (inner_product / n) - (mean1 * mean2);
if ((stddev1 > 0.0) && (stddev2 > 0.0))
{

View File

@ -202,7 +202,7 @@ namespace etl
double mean1 = sum1 / n;
double mean2 = sum2 / n;
covariance_value = inner_product - (n * mean1 * mean2);
covariance_value = (inner_product / n) - (mean1 * mean2);
recalulate = false;
}

View File

@ -0,0 +1,234 @@
///\file
/******************************************************************************
The MIT License(MIT)
Embedded Template Library.
https://github.com/ETLCPP/etl
https://www.etlcpp.com
Copyright(c) 2021 jwellbelove
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files(the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions :
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
******************************************************************************/
#ifndef ETL_STANDARD_DEVIATION_INCLUDED
#define ETL_STANDARD_DEVIATION_INCLUDED
#include "functional.h"
#include "type_traits.h"
#include <math.h>
#include <stdint.h>
namespace etl
{
namespace private_standard_deviation
{
//***************************************************************************
/// Types for generic covariance.
//***************************************************************************
template <typename TInput, typename TCalc>
struct standard_deviation_types
{
TCalc sum_of_squares;
TCalc sum;
uint32_t counter;
//*********************************
/// Clear the histogram.
//*********************************
void clear()
{
sum_of_squares = TCalc(0);
sum = TCalc(0);
counter = 0U;
}
};
//***************************************************************************
/// Types for float covariance.
//***************************************************************************
template <typename TCalc>
struct standard_deviation_types<float, TCalc>
{
float sum_of_squares;
float sum;
uint32_t counter;
//*********************************
/// Clear the histogram.
//*********************************
void clear()
{
sum_of_squares = float(0);
sum = float(0);
counter = 0U;
}
};
//***************************************************************************
/// Types for double covariance.
//***************************************************************************
template <typename TCalc>
struct standard_deviation_types<double, TCalc>
{
double sum_of_squares;
double sum;
uint32_t counter;
//*********************************
/// Clear the histogram.
//*********************************
void clear()
{
sum_of_squares = double(0);
sum = double(0);
counter = 0U;
}
};
}
//***************************************************************************
/// Constructor.
//***************************************************************************
template <typename TInput, typename TCalc = TInput>
class standard_deviation
: public private_standard_deviation::standard_deviation_types<TInput, TCalc>
, public etl::binary_function<TInput, TInput, void>
{
public:
//*********************************
/// Constructor.
//*********************************
standard_deviation()
: recalulate(true)
{
this->clear();
}
//*********************************
/// Constructor.
//*********************************
template <typename TIterator>
standard_deviation(TIterator first, TIterator last)
: recalulate(true)
{
this->clear();
add(first, last);
}
//*********************************
/// Add a pair of values.
//*********************************
void add(TInput value)
{
sum_of_squares += TCalc(value * value);
sum += TCalc(value);
++counter;
recalulate = true;
}
//*********************************
/// Add a range.
//*********************************
template <typename TIterator>
typename etl::enable_if<!etl::is_same<TIterator, TInput>::value, void>::type
add(TIterator first, TIterator last)
{
while (first != last)
{
add(*first++);
}
}
//*********************************
/// operator ()
/// Add a pair of values.
//*********************************
void operator ()(TInput value)
{
add(value);
}
//*********************************
/// operator ()
/// Add a range.
//*********************************
template <typename TIterator>
typename etl::enable_if<!etl::is_same<TIterator, TInput>::value, void>::type
operator ()(TIterator first, TIterator last)
{
add(first, last);
}
//*********************************
/// Get the standard_deviation.
//*********************************
double get_standard_deviation()
{
if (recalulate)
{
standard_deviation_value = 0.0;
if (counter != 0)
{
double n = double(counter);
double mean = sum / n;
double stddev_squared = (sum_of_squares / n) - (mean * mean);
if (stddev_squared > 0)
{
standard_deviation_value = sqrt(stddev_squared);
}
}
recalulate = false;
}
return standard_deviation_value;
}
//*********************************
/// Get the standard_deviation.
//*********************************
operator double()
{
return get_standard_deviation();
}
//*********************************
/// Get the total number added entries.
//*********************************
size_t count() const
{
return size_t(counter);
}
private:
double standard_deviation_value;
bool recalulate;
};
}
#endif

View File

@ -124,7 +124,7 @@ namespace
correlation_result = correlation1;
CHECK_CLOSE(-1.0, correlation_result, 0.1);
covariance_result = correlation1.get_covariance();
CHECK_CLOSE(-82.5, covariance_result, 0.1);
CHECK_CLOSE(-8.25, covariance_result, 0.1);
// Zero correlation
etl::correlation<char, int32_t> correlation2(input_c.begin(), input_c.end(), input_c_flat.begin());
@ -138,7 +138,7 @@ namespace
correlation_result = correlation3;
CHECK_CLOSE(1.0, correlation_result, 0.1);
covariance_result = correlation3.get_covariance();
CHECK_CLOSE(82.5, covariance_result, 0.1);
CHECK_CLOSE(8.25, covariance_result, 0.1);
}
//*************************************************************************
@ -152,7 +152,7 @@ namespace
correlation_result = correlation1;
CHECK_CLOSE(-1.0, correlation_result, 0.1);
covariance_result = correlation1.get_covariance();
CHECK_CLOSE(-82.5, covariance_result, 0.1);
CHECK_CLOSE(-8.25, covariance_result, 0.1);
// Zero correlation
etl::correlation<float> correlation2(input_f.begin(), input_f.end(), input_f_flat.begin());
@ -166,7 +166,7 @@ namespace
correlation_result = correlation3;
CHECK_CLOSE(1.0, correlation_result, 0.1);
covariance_result = correlation3.get_covariance();
CHECK_CLOSE(82.5, covariance_result, 0.1);
CHECK_CLOSE(8.25, covariance_result, 0.1);
}
//*************************************************************************
@ -180,7 +180,7 @@ namespace
correlation_result = correlation1;
CHECK_CLOSE(-1.0, correlation_result, 0.1);
covariance_result = correlation1.get_covariance();
CHECK_CLOSE(-82.5, covariance_result, 0.1);
CHECK_CLOSE(-8.25, covariance_result, 0.1);
// Zero correlation
etl::correlation<double> correlation2(input_d.begin(), input_d.end(), input_d_flat.begin());
@ -194,7 +194,7 @@ namespace
correlation_result = correlation3;
CHECK_CLOSE(1.0, correlation_result, 0.1);
covariance_result = correlation3.get_covariance();
CHECK_CLOSE(82.5, covariance_result, 0.1);
CHECK_CLOSE(8.25, covariance_result, 0.1);
}
};
}

View File

@ -121,7 +121,7 @@ namespace
// Negative covariance.
etl::covariance<char, int32_t> covariance1(input_c.begin(), input_c.end(), input_c_inv.begin());
covariance_result = covariance1.get_covariance();
CHECK_CLOSE(-82.5, covariance_result, 0.1);
CHECK_CLOSE(-8.25, covariance_result, 0.1);
// Zero covariance
etl::covariance<char, int32_t> covariance2(input_c.begin(), input_c.end(), input_c_flat.begin());
@ -131,7 +131,7 @@ namespace
// Positive covariance.
etl::covariance<char, int32_t> covariance3(input_c.begin(), input_c.end(), input_c.begin());
covariance_result = covariance3.get_covariance();
CHECK_CLOSE(82.5, covariance_result, 0.1);
CHECK_CLOSE(8.25, covariance_result, 0.1);
}
//*************************************************************************
@ -142,7 +142,7 @@ namespace
// Negative covariance.
etl::covariance<float> covariance1(input_f.begin(), input_f.end(), input_f_inv.begin());
covariance_result = covariance1.get_covariance();
CHECK_CLOSE(-82.5, covariance_result, 0.1);
CHECK_CLOSE(-8.25, covariance_result, 0.1);
// Zero covariance
etl::covariance<float> covariance2(input_f.begin(), input_f.end(), input_f_flat.begin());
@ -152,7 +152,7 @@ namespace
// Positive covariance.
etl::covariance<float> covariance3(input_f.begin(), input_f.end(), input_f.begin());
covariance_result = covariance3.get_covariance();
CHECK_CLOSE(82.5, covariance_result, 0.1);
CHECK_CLOSE(8.25, covariance_result, 0.1);
}
//*************************************************************************
@ -163,7 +163,7 @@ namespace
// Negative covariance.
etl::covariance<double> covariance1(input_d.begin(), input_d.end(), input_d_inv.begin());
covariance_result = covariance1.get_covariance();
CHECK_CLOSE(-82.5, covariance_result, 0.1);
CHECK_CLOSE(-8.25, covariance_result, 0.1);
// Zero covariance
etl::covariance<double> covariance2(input_d.begin(), input_d.end(), input_d_flat.begin());
@ -173,7 +173,7 @@ namespace
// Positive covariance.
etl::covariance<double> covariance3(input_d.begin(), input_d.end(), input_d.begin());
covariance_result = covariance3.get_covariance();
CHECK_CLOSE(82.5, covariance_result, 0.1);
CHECK_CLOSE(8.25, covariance_result, 0.1);
}
};
}

View File

@ -0,0 +1,119 @@
/******************************************************************************
The MIT License(MIT)
Embedded Template Library.
https://github.com/ETLCPP/etl
https://www.etlcpp.com
Copyright(c) 2021 jwellbelove
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files(the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions :
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
******************************************************************************/
#include "unit_test_framework.h"
#include "etl/standard_deviation.h"
#include <array>
namespace
{
std::array<char, 10> input_c
{
0, 1, 2, 3, 4, 5, 6, 7, 8, 9
};
//*********************************
std::array<float, 10> input_f
{
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f
};
//*********************************
std::array<double, 10> input_d
{
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0
};
SUITE(test_standard_deviation)
{
//*************************************************************************
TEST(test_char_standard_deviation_default_constuctor)
{
etl::standard_deviation<char, int32_t> standard_deviation;
double standard_deviation_result = standard_deviation;
CHECK_EQUAL(0.0, standard_deviation_result);
}
//*************************************************************************
TEST(test_float_standard_deviation_default_constuctor)
{
etl::standard_deviation<float> standard_deviation;
double standard_deviation_result = standard_deviation;
CHECK_EQUAL(0.0, standard_deviation_result);
}
//*************************************************************************
TEST(test_double_standard_deviation_default_constuctor)
{
etl::standard_deviation<double> standard_deviation;
double standard_deviation_result = standard_deviation;
CHECK_EQUAL(0.0, standard_deviation_result);
}
//*************************************************************************
TEST(test_char_standard_deviation_constuctor)
{
double standard_deviation_result;
// Negative standard_deviation.
etl::standard_deviation<char, int32_t> standard_deviation1(input_c.begin(), input_c.end());
standard_deviation_result = standard_deviation1.get_standard_deviation();
CHECK_CLOSE(2.872281323269, standard_deviation_result, 0.1);
}
//*************************************************************************
TEST(test_float_standard_deviation_constuctor)
{
double standard_deviation_result;
// Negative standard_deviation.
etl::standard_deviation<float> standard_deviation1(input_f.begin(), input_f.end());
standard_deviation_result = standard_deviation1.get_standard_deviation();
CHECK_CLOSE(2.872281323269, standard_deviation_result, 0.1);
}
//*************************************************************************
TEST(test_double_standard_deviation_constuctor)
{
double standard_deviation_result;
// Negative standard_deviation.
etl::standard_deviation<double> standard_deviation1(input_d.begin(), input_d.end());
standard_deviation_result = standard_deviation1.get_standard_deviation();
CHECK_CLOSE(2.872281323269, standard_deviation_result, 0.1);
}
};
}

View File

@ -4613,6 +4613,7 @@
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</ClCompile>
<ClCompile Include="..\test_span.cpp" />
<ClCompile Include="..\test_standard_deviation.cpp" />
<ClCompile Include="..\test_state_chart.cpp" />
<ClCompile Include="..\test_smallest.cpp" />
<ClCompile Include="..\test_stack.cpp" />

View File

@ -2414,6 +2414,9 @@
<ClCompile Include="..\test_covariance.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="..\test_standard_deviation.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<None Include="..\..\library.properties">