diff --git a/include/chaiscript/language/chaiscript_eval.hpp b/include/chaiscript/language/chaiscript_eval.hpp index fe84a161..4f1d1d61 100644 --- a/include/chaiscript/language/chaiscript_eval.hpp +++ b/include/chaiscript/language/chaiscript_eval.hpp @@ -192,6 +192,85 @@ namespace chaiscript { mutable std::atomic_uint_fast32_t m_loc{0}; }; + class Strong_Typedef_Compound_Assign_Op final : public dispatch::Proxy_Function_Base { + public: + Strong_Typedef_Compound_Assign_Op( + std::string t_type_name, + std::string t_op_name, + Operators::Opers t_base_oper, + std::string t_base_op_name, + chaiscript::detail::Dispatch_Engine &t_engine) + : Proxy_Function_Base( + {user_type(), + user_type(), + user_type()}, + 2) + , m_type_name(std::move(t_type_name)) + , m_op_name(std::move(t_op_name)) + , m_base_oper(t_base_oper) + , m_base_op_name(std::move(t_base_op_name)) + , m_engine(t_engine) { + } + + bool operator==(const Proxy_Function_Base &f) const noexcept override { + if (const auto *other = dynamic_cast(&f)) { + return m_type_name == other->m_type_name && m_op_name == other->m_op_name; + } + return false; + } + + bool call_match(const Function_Params &vals, const Type_Conversions_State &t_conversions) const noexcept override { + return vals.size() == 2 + && type_matches(vals[0], t_conversions) + && type_matches(vals[1], t_conversions); + } + + protected: + Boxed_Value do_call(const Function_Params ¶ms, const Type_Conversions_State &t_conversions) const override { + if (!call_match(params, t_conversions)) { + throw chaiscript::exception::guard_error(); + } + + auto &lhs = boxed_cast(params[0], &t_conversions); + const auto &rhs = boxed_cast(params[1], &t_conversions); + const auto lhs_val = lhs.get_attr("__value"); + const auto rhs_val = rhs.get_attr("__value"); + + Boxed_Value result; + if (m_base_oper != Operators::Opers::invalid + && lhs_val.get_type_info().is_arithmetic() + && rhs_val.get_type_info().is_arithmetic()) { + result = Boxed_Number::do_oper(m_base_oper, lhs_val, rhs_val); + } else { + std::array underlying_params{lhs_val, rhs_val}; + result = m_engine.call_function(m_base_op_name, m_loc, Function_Params(underlying_params), t_conversions); + } + + lhs.get_attr("__value") = result; + return params[0]; + } + + private: + bool type_matches(const Boxed_Value &bv, const Type_Conversions_State &t_conversions) const noexcept { + if (!bv.get_type_info().bare_equal(user_type())) { + return false; + } + try { + const auto &d = boxed_cast(bv, &t_conversions); + return d.get_type_name() == m_type_name; + } catch (...) { + return false; + } + } + + std::string m_type_name; + std::string m_op_name; + Operators::Opers m_base_oper; + std::string m_base_op_name; + chaiscript::detail::Dispatch_Engine &m_engine; + mutable std::atomic_uint_fast32_t m_loc{0}; + }; + } // namespace detail template @@ -1047,40 +1126,37 @@ namespace chaiscript { {"!=", Operators::Opers::not_equal, false}, }; - const auto op_exists_for_base_type = [&t_ss, &base_type_name](const char *op_name) { - std::atomic_uint_fast32_t loc{0}; - const auto [func_loc, funcs] = t_ss->get_function(op_name, loc); - if (!funcs || funcs->empty()) { - return false; - } + for (const auto &op : ops) { + t_ss->add( + chaiscript::make_shared( + new_type_name, std::string(op.name), op.oper, op.rewrap, engine), + op.name); + } - std::atomic_uint_fast32_t ctor_loc{0}; - Boxed_Value test_val; - try { - const std::array empty_params{}; - test_val = t_ss->call_function(base_type_name, ctor_loc, - Function_Params(empty_params), t_ss.conversions()); - } catch (...) { - return false; - } - - const std::array test_params{test_val, test_val}; - const Function_Params fp(test_params); - for (const auto &func : *funcs) { - if (func->call_match(fp, t_ss.conversions())) { - return true; - } - } - return false; + struct Compound_Op_Entry { + const char *name; + Operators::Opers base_oper; + const char *base_op_name; }; - for (const auto &op : ops) { - if (op_exists_for_base_type(op.name)) { - t_ss->add( - chaiscript::make_shared( - new_type_name, std::string(op.name), op.oper, op.rewrap, engine), - op.name); - } + static constexpr Compound_Op_Entry compound_ops[] = { + {"+=", Operators::Opers::sum, "+"}, + {"-=", Operators::Opers::difference, "-"}, + {"*=", Operators::Opers::product, "*"}, + {"/=", Operators::Opers::quotient, "/"}, + {"%=", Operators::Opers::remainder, "%"}, + {"<<=", Operators::Opers::shift_left, "<<"}, + {">>=", Operators::Opers::shift_right, ">>"}, + {"&=", Operators::Opers::bitwise_and, "&"}, + {"|=", Operators::Opers::bitwise_or, "|"}, + {"^=", Operators::Opers::bitwise_xor, "^"}, + }; + + for (const auto &op : compound_ops) { + t_ss->add( + chaiscript::make_shared( + new_type_name, std::string(op.name), op.base_oper, std::string(op.base_op_name), engine), + op.name); } return void_var(); diff --git a/unittests/strong_typedef.chai b/unittests/strong_typedef.chai index 66a4140d..73aaa4a9 100644 --- a/unittests/strong_typedef.chai +++ b/unittests/strong_typedef.chai @@ -118,30 +118,30 @@ def takes_strong_string(StrongString ss) { } takes_strong_string(ss_cat) -// Operators not supported by the underlying type are not registered +// Operators not supported by the underlying type error at call time try { var bad = ss1 * ss2 assert_equal(true, false) } catch(e) { - // Expected: no * operator for strings + // Expected: underlying string has no * operator } try { var bad = ss1 - ss2 assert_equal(true, false) } catch(e) { - // Expected: no - operator for strings + // Expected: underlying string has no - operator } try { var bad = ss1 / ss2 assert_equal(true, false) } catch(e) { - // Expected: no / operator for strings + // Expected: underlying string has no / operator } try { var bad = ss1 % ss2 assert_equal(true, false) } catch(e) { - // Expected: no % operator for strings + // Expected: underlying string has no % operator } // Comparison on StrongString @@ -168,3 +168,33 @@ def `[]`(StrongString ss, int offset) { return to_string(to_underlying(ss)[offset]) } assert_equal(StrongString("hello")[1], "e") + +// --- Compound assignment operators --- +var m3 = Meters(10) +m3 += Meters(5) +assert_equal(to_underlying(m3), 15) +measure(m3) + +m3 -= Meters(3) +assert_equal(to_underlying(m3), 12) + +m3 *= Meters(2) +assert_equal(to_underlying(m3), 24) + +m3 /= Meters(4) +assert_equal(to_underlying(m3), 6) + +m3 %= Meters(4) +assert_equal(to_underlying(m3), 2) + +// Compound assignment result is still the strong typedef +var m4 = Meters(10) +m4 += Meters(5) +assert_equal(to_underlying(m4), 15) +measure(m4) + +// Compound assignment on StrongString +var ss3 = StrongString("hello") +ss3 += StrongString(" world") +assert_equal(to_underlying(ss3), "hello world") +takes_strong_string(ss3)