From a5f89326f1917b049de5e842b927b49b3744acd8 Mon Sep 17 00:00:00 2001 From: leftibot Date: Tue, 14 Apr 2026 13:00:15 -0600 Subject: [PATCH] Address review: register all operators unconditionally and add compound assignment operators Remove conditional operator registration (op_exists_for_base_type check) since users could add underlying operators later, and the runtime check was expensive. Operators that fail on the underlying type now error at call time instead of being absent. Add compound assignment operators (*=, +=, -=, /=, %=, <<=, >>=, &=, |=, ^=) via Strong_Typedef_Compound_Assign_Op which computes the base operation and stores the result back in __value. Requested by @lefticus in PR #680 review. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../chaiscript/language/chaiscript_eval.hpp | 138 ++++++++++++++---- unittests/strong_typedef.chai | 40 ++++- 2 files changed, 142 insertions(+), 36 deletions(-) diff --git a/include/chaiscript/language/chaiscript_eval.hpp b/include/chaiscript/language/chaiscript_eval.hpp index fe84a161..4f1d1d61 100644 --- a/include/chaiscript/language/chaiscript_eval.hpp +++ b/include/chaiscript/language/chaiscript_eval.hpp @@ -192,6 +192,85 @@ namespace chaiscript { mutable std::atomic_uint_fast32_t m_loc{0}; }; + class Strong_Typedef_Compound_Assign_Op final : public dispatch::Proxy_Function_Base { + public: + Strong_Typedef_Compound_Assign_Op( + std::string t_type_name, + std::string t_op_name, + Operators::Opers t_base_oper, + std::string t_base_op_name, + chaiscript::detail::Dispatch_Engine &t_engine) + : Proxy_Function_Base( + {user_type(), + user_type(), + user_type()}, + 2) + , m_type_name(std::move(t_type_name)) + , m_op_name(std::move(t_op_name)) + , m_base_oper(t_base_oper) + , m_base_op_name(std::move(t_base_op_name)) + , m_engine(t_engine) { + } + + bool operator==(const Proxy_Function_Base &f) const noexcept override { + if (const auto *other = dynamic_cast(&f)) { + return m_type_name == other->m_type_name && m_op_name == other->m_op_name; + } + return false; + } + + bool call_match(const Function_Params &vals, const Type_Conversions_State &t_conversions) const noexcept override { + return vals.size() == 2 + && type_matches(vals[0], t_conversions) + && type_matches(vals[1], t_conversions); + } + + protected: + Boxed_Value do_call(const Function_Params ¶ms, const Type_Conversions_State &t_conversions) const override { + if (!call_match(params, t_conversions)) { + throw chaiscript::exception::guard_error(); + } + + auto &lhs = boxed_cast(params[0], &t_conversions); + const auto &rhs = boxed_cast(params[1], &t_conversions); + const auto lhs_val = lhs.get_attr("__value"); + const auto rhs_val = rhs.get_attr("__value"); + + Boxed_Value result; + if (m_base_oper != Operators::Opers::invalid + && lhs_val.get_type_info().is_arithmetic() + && rhs_val.get_type_info().is_arithmetic()) { + result = Boxed_Number::do_oper(m_base_oper, lhs_val, rhs_val); + } else { + std::array underlying_params{lhs_val, rhs_val}; + result = m_engine.call_function(m_base_op_name, m_loc, Function_Params(underlying_params), t_conversions); + } + + lhs.get_attr("__value") = result; + return params[0]; + } + + private: + bool type_matches(const Boxed_Value &bv, const Type_Conversions_State &t_conversions) const noexcept { + if (!bv.get_type_info().bare_equal(user_type())) { + return false; + } + try { + const auto &d = boxed_cast(bv, &t_conversions); + return d.get_type_name() == m_type_name; + } catch (...) { + return false; + } + } + + std::string m_type_name; + std::string m_op_name; + Operators::Opers m_base_oper; + std::string m_base_op_name; + chaiscript::detail::Dispatch_Engine &m_engine; + mutable std::atomic_uint_fast32_t m_loc{0}; + }; + } // namespace detail template @@ -1047,40 +1126,37 @@ namespace chaiscript { {"!=", Operators::Opers::not_equal, false}, }; - const auto op_exists_for_base_type = [&t_ss, &base_type_name](const char *op_name) { - std::atomic_uint_fast32_t loc{0}; - const auto [func_loc, funcs] = t_ss->get_function(op_name, loc); - if (!funcs || funcs->empty()) { - return false; - } + for (const auto &op : ops) { + t_ss->add( + chaiscript::make_shared( + new_type_name, std::string(op.name), op.oper, op.rewrap, engine), + op.name); + } - std::atomic_uint_fast32_t ctor_loc{0}; - Boxed_Value test_val; - try { - const std::array empty_params{}; - test_val = t_ss->call_function(base_type_name, ctor_loc, - Function_Params(empty_params), t_ss.conversions()); - } catch (...) { - return false; - } - - const std::array test_params{test_val, test_val}; - const Function_Params fp(test_params); - for (const auto &func : *funcs) { - if (func->call_match(fp, t_ss.conversions())) { - return true; - } - } - return false; + struct Compound_Op_Entry { + const char *name; + Operators::Opers base_oper; + const char *base_op_name; }; - for (const auto &op : ops) { - if (op_exists_for_base_type(op.name)) { - t_ss->add( - chaiscript::make_shared( - new_type_name, std::string(op.name), op.oper, op.rewrap, engine), - op.name); - } + static constexpr Compound_Op_Entry compound_ops[] = { + {"+=", Operators::Opers::sum, "+"}, + {"-=", Operators::Opers::difference, "-"}, + {"*=", Operators::Opers::product, "*"}, + {"/=", Operators::Opers::quotient, "/"}, + {"%=", Operators::Opers::remainder, "%"}, + {"<<=", Operators::Opers::shift_left, "<<"}, + {">>=", Operators::Opers::shift_right, ">>"}, + {"&=", Operators::Opers::bitwise_and, "&"}, + {"|=", Operators::Opers::bitwise_or, "|"}, + {"^=", Operators::Opers::bitwise_xor, "^"}, + }; + + for (const auto &op : compound_ops) { + t_ss->add( + chaiscript::make_shared( + new_type_name, std::string(op.name), op.base_oper, std::string(op.base_op_name), engine), + op.name); } return void_var(); diff --git a/unittests/strong_typedef.chai b/unittests/strong_typedef.chai index 66a4140d..73aaa4a9 100644 --- a/unittests/strong_typedef.chai +++ b/unittests/strong_typedef.chai @@ -118,30 +118,30 @@ def takes_strong_string(StrongString ss) { } takes_strong_string(ss_cat) -// Operators not supported by the underlying type are not registered +// Operators not supported by the underlying type error at call time try { var bad = ss1 * ss2 assert_equal(true, false) } catch(e) { - // Expected: no * operator for strings + // Expected: underlying string has no * operator } try { var bad = ss1 - ss2 assert_equal(true, false) } catch(e) { - // Expected: no - operator for strings + // Expected: underlying string has no - operator } try { var bad = ss1 / ss2 assert_equal(true, false) } catch(e) { - // Expected: no / operator for strings + // Expected: underlying string has no / operator } try { var bad = ss1 % ss2 assert_equal(true, false) } catch(e) { - // Expected: no % operator for strings + // Expected: underlying string has no % operator } // Comparison on StrongString @@ -168,3 +168,33 @@ def `[]`(StrongString ss, int offset) { return to_string(to_underlying(ss)[offset]) } assert_equal(StrongString("hello")[1], "e") + +// --- Compound assignment operators --- +var m3 = Meters(10) +m3 += Meters(5) +assert_equal(to_underlying(m3), 15) +measure(m3) + +m3 -= Meters(3) +assert_equal(to_underlying(m3), 12) + +m3 *= Meters(2) +assert_equal(to_underlying(m3), 24) + +m3 /= Meters(4) +assert_equal(to_underlying(m3), 6) + +m3 %= Meters(4) +assert_equal(to_underlying(m3), 2) + +// Compound assignment result is still the strong typedef +var m4 = Meters(10) +m4 += Meters(5) +assert_equal(to_underlying(m4), 15) +measure(m4) + +// Compound assignment on StrongString +var ss3 = StrongString("hello") +ss3 += StrongString(" world") +assert_equal(to_underlying(ss3), "hello world") +takes_strong_string(ss3)