Merge pull request #296 from dalle/dalle/float16

16-bit float support
This commit is contained in:
Daniel Lemire 2025-02-06 19:41:37 -05:00 committed by GitHub
commit 7a5ee5af60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 988 additions and 120 deletions

View File

@ -107,9 +107,9 @@ The library seeks to follow the C++17 (see
[28.2.3.(6.1)](https://eel.is/c++draft/charconv.from.chars#6.1)) specification.
* The `from_chars` function does not skip leading white-space characters (unless
`fast_float::chars_format::chars_format` is set).
`fast_float::chars_format::skip_white_space` is set).
* [A leading `+` sign](https://en.cppreference.com/w/cpp/utility/from_chars) is
forbidden (unless `fast_float::chars_format::skip_white_space` is set).
forbidden (unless `fast_float::chars_format::allow_leading_plus` is set).
* It is generally impossible to represent a decimal value exactly as binary
floating-point number (`float` and `double` types). We seek the nearest value.
We round to an even mantissa when we are in-between two binary floating-point
@ -118,8 +118,8 @@ The library seeks to follow the C++17 (see
Furthermore, we have the following restrictions:
* We support `float` and `double`, but not `long double`. We also support
fixed-width floating-point types such as `std::float32_t` and
`std::float64_t`.
fixed-width floating-point types such as `std::float64_t`, `std::float32_t`,
`std::float16_t`, and `std::bfloat16_t`.
* We only support the decimal format: we do not support hexadecimal strings.
* For values that are either very large or very small (e.g., `1e9999`), we
represent it using the infinity or negative infinity value and the returned
@ -241,7 +241,8 @@ constexpr double constexptest() {
## C++23: Fixed width floating-point types
The library also supports fixed-width floating-point types such as
`std::float32_t` and `std::float64_t`. E.g., you can write:
`std::float64_t`, `std::float32_t`, `std::float16_t`, and `std::bfloat16_t`.
E.g., you can write:
```C++
std::float32_t result;

View File

@ -5,6 +5,7 @@
#include <cstdint>
#include <cassert>
#include <cstring>
#include <limits>
#include <type_traits>
#include <system_error>
#ifdef __has_include
@ -221,15 +222,21 @@ fastfloat_really_inline constexpr bool cpp20_and_in_constexpr() {
template <typename T>
struct is_supported_float_type
: std::integral_constant<bool, std::is_same<T, float>::value ||
std::is_same<T, double>::value
#ifdef __STDCPP_FLOAT32_T__
|| std::is_same<T, std::float32_t>::value
#endif
: std::integral_constant<
bool, std::is_same<T, double>::value || std::is_same<T, float>::value
#ifdef __STDCPP_FLOAT64_T__
|| std::is_same<T, std::float64_t>::value
|| std::is_same<T, std::float64_t>::value
#endif
> {
#ifdef __STDCPP_FLOAT32_T__
|| std::is_same<T, std::float32_t>::value
#endif
#ifdef __STDCPP_FLOAT16_T__
|| std::is_same<T, std::float16_t>::value
#endif
#ifdef __STDCPP_BFLOAT16_T__
|| std::is_same<T, std::bfloat16_t>::value
#endif
> {
};
template <typename T>
@ -431,25 +438,25 @@ template <typename T, typename U = void> struct binary_format_lookup_tables;
template <typename T> struct binary_format : binary_format_lookup_tables<T> {
using equiv_uint = equiv_uint_t<T>;
static inline constexpr int mantissa_explicit_bits();
static inline constexpr int minimum_exponent();
static inline constexpr int infinite_power();
static inline constexpr int sign_index();
static inline constexpr int
static constexpr int mantissa_explicit_bits();
static constexpr int minimum_exponent();
static constexpr int infinite_power();
static constexpr int sign_index();
static constexpr int
min_exponent_fast_path(); // used when fegetround() == FE_TONEAREST
static inline constexpr int max_exponent_fast_path();
static inline constexpr int max_exponent_round_to_even();
static inline constexpr int min_exponent_round_to_even();
static inline constexpr uint64_t max_mantissa_fast_path(int64_t power);
static inline constexpr uint64_t
static constexpr int max_exponent_fast_path();
static constexpr int max_exponent_round_to_even();
static constexpr int min_exponent_round_to_even();
static constexpr uint64_t max_mantissa_fast_path(int64_t power);
static constexpr uint64_t
max_mantissa_fast_path(); // used when fegetround() == FE_TONEAREST
static inline constexpr int largest_power_of_ten();
static inline constexpr int smallest_power_of_ten();
static inline constexpr T exact_power_of_ten(int64_t power);
static inline constexpr size_t max_digits();
static inline constexpr equiv_uint exponent_mask();
static inline constexpr equiv_uint mantissa_mask();
static inline constexpr equiv_uint hidden_bit_mask();
static constexpr int largest_power_of_ten();
static constexpr int smallest_power_of_ten();
static constexpr T exact_power_of_ten(int64_t power);
static constexpr size_t max_digits();
static constexpr equiv_uint exponent_mask();
static constexpr equiv_uint mantissa_mask();
static constexpr equiv_uint hidden_bit_mask();
};
template <typename U> struct binary_format_lookup_tables<double, U> {
@ -622,6 +629,260 @@ inline constexpr uint64_t binary_format<double>::max_mantissa_fast_path() {
return uint64_t(2) << mantissa_explicit_bits();
}
template <>
inline constexpr uint64_t binary_format<float>::max_mantissa_fast_path() {
return uint64_t(2) << mantissa_explicit_bits();
}
// credit: Jakub Jelínek
#ifdef __STDCPP_FLOAT16_T__
template <typename U> struct binary_format_lookup_tables<std::float16_t, U> {
static constexpr std::float16_t powers_of_ten[] = {1e0f16, 1e1f16, 1e2f16,
1e3f16, 1e4f16};
// Largest integer value v so that (5**index * v) <= 1<<11.
// 0x800 == 1<<11
static constexpr uint64_t max_mantissa[] = {0x800,
0x800 / 5,
0x800 / (5 * 5),
0x800 / (5 * 5 * 5),
0x800 / (5 * 5 * 5 * 5),
0x800 / (constant_55555)};
};
#if FASTFLOAT_DETAIL_MUST_DEFINE_CONSTEXPR_VARIABLE
template <typename U>
constexpr std::float16_t
binary_format_lookup_tables<std::float16_t, U>::powers_of_ten[];
template <typename U>
constexpr uint64_t
binary_format_lookup_tables<std::float16_t, U>::max_mantissa[];
#endif
template <>
inline constexpr std::float16_t
binary_format<std::float16_t>::exact_power_of_ten(int64_t power) {
// Work around clang bug https://godbolt.org/z/zedh7rrhc
return (void)powers_of_ten[0], powers_of_ten[power];
}
template <>
inline constexpr binary_format<std::float16_t>::equiv_uint
binary_format<std::float16_t>::exponent_mask() {
return 0x7C00;
}
template <>
inline constexpr binary_format<std::float16_t>::equiv_uint
binary_format<std::float16_t>::mantissa_mask() {
return 0x03FF;
}
template <>
inline constexpr binary_format<std::float16_t>::equiv_uint
binary_format<std::float16_t>::hidden_bit_mask() {
return 0x0400;
}
template <>
inline constexpr int binary_format<std::float16_t>::max_exponent_fast_path() {
return 4;
}
template <>
inline constexpr int binary_format<std::float16_t>::mantissa_explicit_bits() {
return 10;
}
template <>
inline constexpr uint64_t
binary_format<std::float16_t>::max_mantissa_fast_path() {
return uint64_t(2) << mantissa_explicit_bits();
}
template <>
inline constexpr uint64_t
binary_format<std::float16_t>::max_mantissa_fast_path(int64_t power) {
// caller is responsible to ensure that
// power >= 0 && power <= 4
//
// Work around clang bug https://godbolt.org/z/zedh7rrhc
return (void)max_mantissa[0], max_mantissa[power];
}
template <>
inline constexpr int binary_format<std::float16_t>::min_exponent_fast_path() {
return 0;
}
template <>
inline constexpr int
binary_format<std::float16_t>::max_exponent_round_to_even() {
return 5;
}
template <>
inline constexpr int
binary_format<std::float16_t>::min_exponent_round_to_even() {
return -22;
}
template <>
inline constexpr int binary_format<std::float16_t>::minimum_exponent() {
return -15;
}
template <>
inline constexpr int binary_format<std::float16_t>::infinite_power() {
return 0x1F;
}
template <> inline constexpr int binary_format<std::float16_t>::sign_index() {
return 15;
}
template <>
inline constexpr int binary_format<std::float16_t>::largest_power_of_ten() {
return 4;
}
template <>
inline constexpr int binary_format<std::float16_t>::smallest_power_of_ten() {
return -27;
}
template <>
inline constexpr size_t binary_format<std::float16_t>::max_digits() {
return 22;
}
#endif // __STDCPP_FLOAT16_T__
// credit: Jakub Jelínek
#ifdef __STDCPP_BFLOAT16_T__
template <typename U> struct binary_format_lookup_tables<std::bfloat16_t, U> {
static constexpr std::bfloat16_t powers_of_ten[] = {1e0bf16, 1e1bf16, 1e2bf16,
1e3bf16};
// Largest integer value v so that (5**index * v) <= 1<<8.
// 0x100 == 1<<8
static constexpr uint64_t max_mantissa[] = {0x100, 0x100 / 5, 0x100 / (5 * 5),
0x100 / (5 * 5 * 5),
0x100 / (5 * 5 * 5 * 5)};
};
#if FASTFLOAT_DETAIL_MUST_DEFINE_CONSTEXPR_VARIABLE
template <typename U>
constexpr std::bfloat16_t
binary_format_lookup_tables<std::bfloat16_t, U>::powers_of_ten[];
template <typename U>
constexpr uint64_t
binary_format_lookup_tables<std::bfloat16_t, U>::max_mantissa[];
#endif
template <>
inline constexpr std::bfloat16_t
binary_format<std::bfloat16_t>::exact_power_of_ten(int64_t power) {
// Work around clang bug https://godbolt.org/z/zedh7rrhc
return (void)powers_of_ten[0], powers_of_ten[power];
}
template <>
inline constexpr int binary_format<std::bfloat16_t>::max_exponent_fast_path() {
return 3;
}
template <>
inline constexpr binary_format<std::bfloat16_t>::equiv_uint
binary_format<std::bfloat16_t>::exponent_mask() {
return 0x7F80;
}
template <>
inline constexpr binary_format<std::bfloat16_t>::equiv_uint
binary_format<std::bfloat16_t>::mantissa_mask() {
return 0x007F;
}
template <>
inline constexpr binary_format<std::bfloat16_t>::equiv_uint
binary_format<std::bfloat16_t>::hidden_bit_mask() {
return 0x0080;
}
template <>
inline constexpr int binary_format<std::bfloat16_t>::mantissa_explicit_bits() {
return 7;
}
template <>
inline constexpr uint64_t
binary_format<std::bfloat16_t>::max_mantissa_fast_path() {
return uint64_t(2) << mantissa_explicit_bits();
}
template <>
inline constexpr uint64_t
binary_format<std::bfloat16_t>::max_mantissa_fast_path(int64_t power) {
// caller is responsible to ensure that
// power >= 0 && power <= 3
//
// Work around clang bug https://godbolt.org/z/zedh7rrhc
return (void)max_mantissa[0], max_mantissa[power];
}
template <>
inline constexpr int binary_format<std::bfloat16_t>::min_exponent_fast_path() {
return 0;
}
template <>
inline constexpr int
binary_format<std::bfloat16_t>::max_exponent_round_to_even() {
return 3;
}
template <>
inline constexpr int
binary_format<std::bfloat16_t>::min_exponent_round_to_even() {
return -24;
}
template <>
inline constexpr int binary_format<std::bfloat16_t>::minimum_exponent() {
return -127;
}
template <>
inline constexpr int binary_format<std::bfloat16_t>::infinite_power() {
return 0xFF;
}
template <> inline constexpr int binary_format<std::bfloat16_t>::sign_index() {
return 15;
}
template <>
inline constexpr int binary_format<std::bfloat16_t>::largest_power_of_ten() {
return 38;
}
template <>
inline constexpr int binary_format<std::bfloat16_t>::smallest_power_of_ten() {
return -60;
}
template <>
inline constexpr size_t binary_format<std::bfloat16_t>::max_digits() {
return 98;
}
#endif // __STDCPP_BFLOAT16_T__
template <>
inline constexpr uint64_t
binary_format<double>::max_mantissa_fast_path(int64_t power) {
@ -632,11 +893,6 @@ binary_format<double>::max_mantissa_fast_path(int64_t power) {
return (void)max_mantissa[0], max_mantissa[power];
}
template <>
inline constexpr uint64_t binary_format<float>::max_mantissa_fast_path() {
return uint64_t(2) << mantissa_explicit_bits();
}
template <>
inline constexpr uint64_t
binary_format<float>::max_mantissa_fast_path(int64_t power) {
@ -726,8 +982,10 @@ fastfloat_really_inline FASTFLOAT_CONSTEXPR20 void
to_float(bool negative, adjusted_mantissa am, T &value) {
using equiv_uint = equiv_uint_t<T>;
equiv_uint word = equiv_uint(am.mantissa);
word |= equiv_uint(am.power2) << binary_format<T>::mantissa_explicit_bits();
word |= equiv_uint(negative) << binary_format<T>::sign_index();
word = equiv_uint(word | equiv_uint(am.power2)
<< binary_format<T>::mantissa_explicit_bits());
word =
equiv_uint(word | equiv_uint(negative) << binary_format<T>::sign_index());
#if FASTFLOAT_HAS_BIT_CAST
value = std::bit_cast<T>(word);
#else
@ -787,6 +1045,7 @@ template <> constexpr char16_t const *str_const_nan<char16_t>() {
template <> constexpr char32_t const *str_const_nan<char32_t>() {
return U"nan";
}
#ifdef __cpp_char8_t
template <> constexpr char8_t const *str_const_nan<char8_t>() {
return u8"nan";
@ -808,6 +1067,7 @@ template <> constexpr char16_t const *str_const_inf<char16_t>() {
template <> constexpr char32_t const *str_const_inf<char32_t>() {
return U"infinity";
}
#ifdef __cpp_char8_t
template <> constexpr char8_t const *str_const_inf<char8_t>() {
return u8"infinity";
@ -881,18 +1141,47 @@ fastfloat_really_inline constexpr uint64_t min_safe_u64(int base) {
static_assert(std::is_same<equiv_uint_t<double>, uint64_t>::value,
"equiv_uint should be uint64_t for double");
static_assert(std::numeric_limits<double>::is_iec559,
"double must fulfill the requirements of IEC 559 (IEEE 754)");
static_assert(std::is_same<equiv_uint_t<float>, uint32_t>::value,
"equiv_uint should be uint32_t for float");
static_assert(std::numeric_limits<float>::is_iec559,
"float must fulfill the requirements of IEC 559 (IEEE 754)");
#ifdef __STDCPP_FLOAT64_T__
static_assert(std::is_same<equiv_uint_t<std::float64_t>, uint64_t>::value,
"equiv_uint should be uint64_t for std::float64_t");
#endif
static_assert(
std::numeric_limits<std::float64_t>::is_iec559,
"std::float64_t must fulfill the requirements of IEC 559 (IEEE 754)");
#endif // __STDCPP_FLOAT64_T__
#ifdef __STDCPP_FLOAT32_T__
static_assert(std::is_same<equiv_uint_t<std::float32_t>, uint32_t>::value,
"equiv_uint should be uint32_t for std::float32_t");
#endif
static_assert(
std::numeric_limits<std::float32_t>::is_iec559,
"std::float32_t must fulfill the requirements of IEC 559 (IEEE 754)");
#endif // __STDCPP_FLOAT32_T__
#ifdef __STDCPP_FLOAT16_T__
static_assert(
std::is_same<binary_format<std::float16_t>::equiv_uint, uint16_t>::value,
"equiv_uint should be uint16_t for std::float16_t");
static_assert(
std::numeric_limits<std::float16_t>::is_iec559,
"std::float16_t must fulfill the requirements of IEC 559 (IEEE 754)");
#endif // __STDCPP_FLOAT16_T__
#ifdef __STDCPP_BFLOAT16_T__
static_assert(
std::is_same<binary_format<std::bfloat16_t>::equiv_uint, uint16_t>::value,
"equiv_uint should be uint16_t for std::bfloat16_t");
static_assert(
std::numeric_limits<std::bfloat16_t>::is_iec559,
"std::bfloat16_t must fulfill the requirements of IEC 559 (IEEE 754)");
#endif // __STDCPP_BFLOAT16_T__
constexpr chars_format operator~(chars_format rhs) noexcept {
using int_type = std::underlying_type<chars_format>::type;

File diff suppressed because it is too large Load Diff

View File

@ -112,7 +112,7 @@ bool large() {
}
int main() {
std::string const input = "3.1416 xyz ";
std::string input = "3.1416 xyz ";
double result;
auto answer =
fast_float::from_chars(input.data(), input.data() + input.size(), result);
@ -121,6 +121,20 @@ int main() {
return EXIT_FAILURE;
}
std::cout << "parsed the number " << result << std::endl;
#ifdef __STDCPP_FLOAT16_T__
printf("16-bit float\n");
// Parse as 16-bit float
std::float16_t parsed_16{};
input = "10000e-1452";
auto fast_float_r16 = fast_float::from_chars(
input.data(), input.data() + input.size(), parsed_16);
if (fast_float_r16.ec != std::errc() &&
fast_float_r16.ec != std::errc::result_out_of_range) {
std::cerr << "16-bit fast_float parsing failure for: " + input + "\n";
return false;
}
std::cout << "parsed the 16-bit value " << float(parsed_16) << std::endl;
#endif
if (!small()) {
printf("Bug\n");
return EXIT_FAILURE;