diff --git a/include/chaiscript/language/chaiscript_eval.hpp b/include/chaiscript/language/chaiscript_eval.hpp index f59b34ce..fe84a161 100644 --- a/include/chaiscript/language/chaiscript_eval.hpp +++ b/include/chaiscript/language/chaiscript_eval.hpp @@ -1034,6 +1034,11 @@ namespace chaiscript { {"*", Operators::Opers::product, true}, {"/", Operators::Opers::quotient, true}, {"%", Operators::Opers::remainder, true}, + {"<<", Operators::Opers::shift_left, true}, + {">>", Operators::Opers::shift_right, true}, + {"&", Operators::Opers::bitwise_and, true}, + {"|", Operators::Opers::bitwise_or, true}, + {"^", Operators::Opers::bitwise_xor, true}, {"<", Operators::Opers::less_than, false}, {">", Operators::Opers::greater_than, false}, {"<=", Operators::Opers::less_than_equal, false}, @@ -1042,11 +1047,40 @@ namespace chaiscript { {"!=", Operators::Opers::not_equal, false}, }; + const auto op_exists_for_base_type = [&t_ss, &base_type_name](const char *op_name) { + std::atomic_uint_fast32_t loc{0}; + const auto [func_loc, funcs] = t_ss->get_function(op_name, loc); + if (!funcs || funcs->empty()) { + return false; + } + + std::atomic_uint_fast32_t ctor_loc{0}; + Boxed_Value test_val; + try { + const std::array empty_params{}; + test_val = t_ss->call_function(base_type_name, ctor_loc, + Function_Params(empty_params), t_ss.conversions()); + } catch (...) { + return false; + } + + const std::array test_params{test_val, test_val}; + const Function_Params fp(test_params); + for (const auto &func : *funcs) { + if (func->call_match(fp, t_ss.conversions())) { + return true; + } + } + return false; + }; + for (const auto &op : ops) { - t_ss->add( - chaiscript::make_shared( - new_type_name, std::string(op.name), op.oper, op.rewrap, engine), - op.name); + if (op_exists_for_base_type(op.name)) { + t_ss->add( + chaiscript::make_shared( + new_type_name, std::string(op.name), op.oper, op.rewrap, engine), + op.name); + } } return void_var(); diff --git a/unittests/strong_typedef.chai b/unittests/strong_typedef.chai index c29a0f5d..66a4140d 100644 --- a/unittests/strong_typedef.chai +++ b/unittests/strong_typedef.chai @@ -89,6 +89,21 @@ assert_equal(Meters(5) <= Meters(5), true) assert_equal(Meters(3) >= Meters(3), true) assert_equal(Meters(3) >= Meters(5), false) +// --- Bitwise and shift operators --- +assert_equal(to_underlying(Meters(6) & Meters(3)), 2) +assert_equal(to_underlying(Meters(6) | Meters(3)), 7) +assert_equal(to_underlying(Meters(6) ^ Meters(3)), 5) +assert_equal(to_underlying(Meters(5) << Meters(2)), 20) +assert_equal(to_underlying(Meters(12) >> Meters(1)), 6) + +// Bitwise results are strongly typed +try { + takes_int(Meters(6) & Meters(3)) + assert_equal(true, false) +} catch(e) { + // Expected: result is Meters, not int +} + // --- Strong typedef over string --- using StrongString = string @@ -103,13 +118,31 @@ def takes_strong_string(StrongString ss) { } takes_strong_string(ss_cat) -// StrongString * StrongString -> error (no * for strings) +// Operators not supported by the underlying type are not registered try { var bad = ss1 * ss2 assert_equal(true, false) } catch(e) { // Expected: no * operator for strings } +try { + var bad = ss1 - ss2 + assert_equal(true, false) +} catch(e) { + // Expected: no - operator for strings +} +try { + var bad = ss1 / ss2 + assert_equal(true, false) +} catch(e) { + // Expected: no / operator for strings +} +try { + var bad = ss1 % ss2 + assert_equal(true, false) +} catch(e) { + // Expected: no % operator for strings +} // Comparison on StrongString assert_equal(StrongString("abc") < StrongString("def"), true)