diff --git a/include/libconcur/concurrent.h b/include/libconcur/concurrent.h index 33d8692..dff57fd 100644 --- a/include/libconcur/concurrent.h +++ b/include/libconcur/concurrent.h @@ -325,8 +325,8 @@ struct producer { hdr.w_idx.fetch_add(1, std::memory_order_release); // Set data & flag. elem.set_flag(w_idx | state::enqueue_mask); - elem.set_data(std::forward(src)); // Here should not be interrupted. - elem.set_flag(w_idx | state::commit_mask); + elem.set_data(std::forward(src)); + elem.set_flag(w_idx); return true; } }; @@ -336,14 +336,14 @@ template <> struct producer { struct header_impl { - std::atomic w_flags {0}; ///< write flags, combined current and starting index. - private: padding ___; + std::atomic w_contexts {0}; ///< write contexts, combined current and starting index. + private: padding ___; public: void get(index_t &idx, index_t &beg) const noexcept { - auto w_flags = this->w_flags.load(std::memory_order_relaxed); - idx = get_index(w_flags); - beg = get_begin(w_flags); + auto w_contexts = this->w_contexts.load(std::memory_order_relaxed); + idx = get_index(w_contexts); + beg = get_begin(w_contexts); } }; @@ -351,28 +351,29 @@ struct producer { verify_elems_header = true, convertible = true> static bool enqueue(::LIBIMP::span> elems, H &hdr, C &/*ctx*/, U &&src) noexcept { - auto w_flags = hdr.w_flags.load(std::memory_order_acquire); + auto w_contexts = hdr.w_contexts.load(std::memory_order_acquire); index_t w_idx; for (;;) { - w_idx = get_index(w_flags); - auto w_beg = get_begin(w_flags); + w_idx = get_index(w_contexts); + auto w_beg = get_begin(w_contexts); // Move the queue head index. if (w_beg + hdr.circ_size <= w_idx) { w_beg += 1; } - // Update flags. - auto n_flags = make_flags(w_idx + 1/*iterate backwards*/, w_beg); - if (hdr.w_flags.compare_exchange_weak(w_flags, n_flags, std::memory_order_acq_rel)) { + // Update write contexts. + auto n_contexts = make_w_contexts(w_idx + 1/*iterate backwards*/, w_beg); + if (hdr.w_contexts.compare_exchange_weak(w_contexts, n_contexts, std::memory_order_acq_rel)) { break; } } // Get element. auto w_cur = trunc_index(hdr, w_idx); auto &elem = elems[w_cur]; - // Set data & flag. + // Set data & flag. Dirty write is not considered here. + // By default, when dirty write occurs, the previous writer must no longer exist. elem.set_flag(w_idx | state::enqueue_mask); - elem.set_data(std::forward(src)); // Here should not be interrupted. - elem.set_flag(w_idx | state::commit_mask); + elem.set_data(std::forward(src)); + elem.set_flag(w_idx); return true; } @@ -387,8 +388,8 @@ private: return index_t(flags >> (sizeof(index_t) * CHAR_BIT)); } - static constexpr state::flag_t make_flags(index_t idx, index_t beg) noexcept { - return state::flag_t(idx) | (state::flag_t(beg) << (sizeof(index_t) * CHAR_BIT)); + static constexpr std::uint64_t make_w_contexts(index_t idx, index_t beg) noexcept { + return std::uint64_t(idx) | (std::uint64_t(beg) << (sizeof(index_t) * CHAR_BIT)); } }; @@ -426,7 +427,7 @@ struct consumer { } // Try getting data. for (;;) { - if ((f_ct & state::enqueue_mask) == state::enqueue_mask) { + if (f_ct & state::enqueue_mask) { return false; // unreadable } des = LIBCONCUR::get(elem); @@ -434,8 +435,8 @@ struct consumer { // the elem data is not modified during the getting process. if (elem.cas_flag(f_ct, f_ct)) break; } - ctx.w_lst = (f_ct & ~state::enqueue_mask) + 1; // Get a valid index and iterate backwards. + ctx.w_lst = index_t(f_ct) + 1; ctx.r_idx += 1; return true; } diff --git a/include/libconcur/element.h b/include/libconcur/element.h index b50c7cf..c0f06d2 100644 --- a/include/libconcur/element.h +++ b/include/libconcur/element.h @@ -27,7 +27,6 @@ using flag_t = std::uint64_t; enum : flag_t { invalid_value = ~flag_t(0), enqueue_mask = invalid_value << 32, - commit_mask = ~flag_t(1) << 32, }; } // namespace state diff --git a/test/concur/test_concur_concurrent.cpp b/test/concur/test_concur_concurrent.cpp index f8e1837..0362de8 100644 --- a/test/concur/test_concur_concurrent.cpp +++ b/test/concur/test_concur_concurrent.cpp @@ -13,6 +13,8 @@ #include "libimp/log.h" #include "libimp/nameof.h" +#include "test_util.h" + TEST(concurrent, cache_line_size) { std::cout << concur::cache_line_size << "\n"; EXPECT_TRUE(concur::cache_line_size >= alignof(std::max_align_t)); @@ -311,4 +313,61 @@ TEST(concurrent, broadcast) { /// \brief 8-8 test_broadcast>(8, 8); -} \ No newline at end of file +} + +TEST(concurrent, broadcast_multi_dirtywrite) { + using namespace concur; + + struct data { + std::uint64_t n{}; + + data &operator=(test::latch &l) noexcept { + l.arrive_and_wait(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + n = 1; + return *this; + } + + data &operator=(data const &rhs) noexcept { + n = rhs.n; + return *this; + } + }; + + element circ[2] {}; + prod_cons pc; + typename traits::header hdr {imp::make_span(circ)}; + + auto push_one = [&, ctx = typename concur::traits::context{}](auto &i) mutable { + return pc.enqueue(imp::make_span(circ), hdr, ctx, i); + }; + auto pop_one = [&, ctx = typename concur::traits::context{}]() mutable { + data i; + if (pc.dequeue(imp::make_span(circ), hdr, ctx, i)) { + return i; + } + return data{}; + }; + + test::latch l(2); + std::thread t[2]; + t[0] = std::thread([&] { + push_one(l); // 1 + }); + t[1] = std::thread([&] { + l.arrive_and_wait(); + push_one(data{3}); + push_one(data{2}); // dirty write + }); + + for (int i = 0; i < 2; ++i) { + t[i].join(); + } + std::set s{1, 2, 3}; + for (int i = 0; i < 2; ++i) { + auto d = pop_one(); + EXPECT_TRUE(s.find(d.n) != s.end()); + s.erase(d.n); + } + EXPECT_TRUE(s.find(3) == s.end()); +} diff --git a/test/test_util.h b/test/test_util.h index a55a128..da22ea8 100644 --- a/test/test_util.h +++ b/test/test_util.h @@ -1,13 +1,22 @@ #pragma once -#include -#include +#include "libimp/detect_plat.h" +#ifndef LIBIMP_OS_WIN +# include +# include +#else +# define pid_t int +#endif + +#include +#include namespace test { template pid_t subproc(Fn&& fn) { +#ifndef LIBIMP_OS_WIN pid_t pid = fork(); if (pid == -1) { return pid; @@ -21,13 +30,49 @@ pid_t subproc(Fn&& fn) { exit(0); } return pid; +#else + return -1; +#endif } inline void join_subproc(pid_t pid) { +#ifndef LIBIMP_OS_WIN int ret_code; waitpid(pid, &ret_code, 0); +#endif } +/// \brief A simple latch implementation. +class latch { +public: + explicit latch(int count) : count_(count) {} + + void count_down() { + std::unique_lock lock(mutex_); + if (count_ > 0) { + --count_; + if (count_ == 0) { + cv_.notify_all(); + } + } + } + + void wait() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return count_ == 0; }); + } + + void arrive_and_wait() { + count_down(); + wait(); + } + +private: + std::mutex mutex_; + std::condition_variable cv_; + int count_; +}; + } // namespace test #define REQUIRE_EXIT(...) \