diff --git a/src/platform/tls_pointer_win.cpp b/src/platform/tls_pointer_win.cpp index 655a9f3..3c4c341 100644 --- a/src/platform/tls_pointer_win.cpp +++ b/src/platform/tls_pointer_win.cpp @@ -3,7 +3,6 @@ #include // ::Tls... #include -#include // std::unordered_set namespace ipc { @@ -26,59 +25,110 @@ namespace { struct tls_data { using destructor_t = void(*)(void*); + unsigned index_; DWORD win_key_; destructor_t destructor_; - void destruct(void* data) { - if ((destructor_ != nullptr) && (data != nullptr)) { - destructor_(data); + bool valid() const noexcept { + return win_key_ != TLS_OUT_OF_INDEXES; + } + + void* get() const { + return ::TlsGetValue(win_key_); + } + + bool set(void* p) { + return TRUE == ::TlsSetValue(win_key_, static_cast(p)); + } + + void destruct() { + void* data = valid() ? get() : nullptr; + if (data != nullptr) { + if (destructor_ != nullptr) destructor_(data); + set(nullptr); } } + + void clear_self() { + if (valid()) { + destruct(); + ::TlsFree(win_key_); + } + delete this; + } }; -using rec_t = std::unordered_set; +struct tls_recs { + tls_data* recs_[TLS_MINIMUM_AVAILABLE] {}; + unsigned index_ = 0; -DWORD& record_key() { - - struct key_gen { - DWORD rec_key_; - key_gen() : rec_key_(::TlsAlloc()) { - if (rec_key_ == TLS_OUT_OF_INDEXES) { - ipc::error("[record_key] TlsAlloc failed[%lu].\n", ::GetLastError()); - } + bool insert(tls_data* data) noexcept { + if (index_ >= TLS_MINIMUM_AVAILABLE) { + ipc::error("[tls_recs] insert tls_data failed[index_ >= TLS_MINIMUM_AVAILABLE].\n"); + return false; } - ~key_gen() { ::TlsFree(rec_key_); } - }; + recs_[(data->index_ = index_++)] = data; + return true; + } - static key_gen gen; - return gen.rec_key_; + bool erase(tls_data* data) noexcept { + if (data->index_ >= TLS_MINIMUM_AVAILABLE) return false; + recs_[data->index_] = nullptr; + return true; + } + + tls_data* * begin() noexcept { return &recs_[0]; } + tls_data* const * begin() const noexcept { return &recs_[0]; } + tls_data* * end () noexcept { return &recs_[index_]; } + tls_data* const * end () const noexcept { return &recs_[index_]; } +}; + +struct key_gen { + DWORD rec_key_; + + key_gen() : rec_key_(::TlsAlloc()) { + if (rec_key_ == TLS_OUT_OF_INDEXES) { + ipc::error("[record_key] TlsAlloc failed[%lu].\n", ::GetLastError()); + } + } + + ~key_gen() { + ::TlsFree(rec_key_); + rec_key_ = TLS_OUT_OF_INDEXES; + } +} gen__; + +DWORD& record_key() noexcept { + return gen__.rec_key_; } -bool record(tls_data* tls) { - auto rec = static_cast(::TlsGetValue(record_key())); +bool record(tls_data* tls_dat) { + if (record_key() == TLS_OUT_OF_INDEXES) return false; + auto rec = static_cast(::TlsGetValue(record_key())); if (rec == nullptr) { - if (FALSE == ::TlsSetValue(record_key(), static_cast(rec = new rec_t))) { + if (FALSE == ::TlsSetValue(record_key(), static_cast(rec = new tls_recs))) { ipc::error("[record] TlsSetValue failed[%lu].\n", ::GetLastError()); return false; } } - rec->insert(tls); - return true; + return rec->insert(tls_dat); } -static void erase_record(tls_data* tls) { - auto rec = static_cast(::TlsGetValue(record_key())); +void erase_record(tls_data* tls_dat) { + if (tls_dat == nullptr) return; + if (record_key() == TLS_OUT_OF_INDEXES) return; + auto rec = static_cast(::TlsGetValue(record_key())); if (rec == nullptr) return; - rec->erase(tls); + rec->erase(tls_dat); + tls_dat->clear_self(); } -static void clear_all_records() { - auto rec = static_cast(::TlsGetValue(record_key())); +void clear_all_records() { + if (record_key() == TLS_OUT_OF_INDEXES) return; + auto rec = static_cast(::TlsGetValue(record_key())); if (rec == nullptr) return; - for (auto tls : *rec) { - if (tls != nullptr) { - tls->destruct(::TlsGetValue(tls->win_key_)); - } + for (auto tls_dat : *rec) { + if (tls_dat != nullptr) tls_dat->destruct(); } delete rec; ::TlsSetValue(record_key(), static_cast(nullptr)); @@ -90,11 +140,11 @@ namespace tls { key_t create(destructor_t destructor) { record_key(); // gen record-key - auto tls_dat = new tls_data { ::TlsAlloc(), destructor }; + auto tls_dat = new tls_data { unsigned(-1), ::TlsAlloc(), destructor }; std::atomic_thread_fence(std::memory_order_seq_cst); - if (tls_dat->win_key_ == TLS_OUT_OF_INDEXES) { + if (!tls_dat->valid()) { ipc::error("[tls::create] TlsAlloc failed[%lu].\n", ::GetLastError()); - delete tls_dat; + tls_dat->clear_self(); return invalid_value; } return reinterpret_cast(tls_dat); @@ -111,8 +161,6 @@ void release(key_t tls_key) { return; } erase_record(tls_dat); - ::TlsFree(tls_dat->win_key_); - delete tls_dat; } bool set(key_t tls_key, void* ptr) { @@ -125,7 +173,7 @@ bool set(key_t tls_key, void* ptr) { ipc::error("[tls::set] tls_dat is nullptr.\n"); return false; } - if (FALSE == ::TlsSetValue(tls_dat->win_key_, static_cast(ptr))) { + if (!tls_dat->set(ptr)) { ipc::error("[tls::set] TlsSetValue failed[%lu].\n", ::GetLastError()); return false; } @@ -143,19 +191,15 @@ void* get(key_t tls_key) { ipc::error("[tls::get] tls_dat is nullptr.\n"); return nullptr; } - return ::TlsGetValue(tls_dat->win_key_); + return tls_dat->get(); } } // namespace tls namespace { -void OnThreadExit() { - clear_all_records(); -} - void NTAPI OnTlsCallback(PVOID, DWORD dwReason, PVOID) { - if (dwReason == DLL_THREAD_DETACH) OnThreadExit(); + if (dwReason == DLL_THREAD_DETACH) clear_all_records(); } } // internal-linkage