update tls_pointer_win.cpp

This commit is contained in:
mutouyun 2020-03-28 13:55:34 +08:00
parent e87d516b1d
commit 91cc1b7767

View File

@ -3,7 +3,6 @@
#include <Windows.h> // ::Tls... #include <Windows.h> // ::Tls...
#include <atomic> #include <atomic>
#include <unordered_set> // std::unordered_set
namespace ipc { namespace ipc {
@ -26,59 +25,110 @@ namespace {
struct tls_data { struct tls_data {
using destructor_t = void(*)(void*); using destructor_t = void(*)(void*);
unsigned index_;
DWORD win_key_; DWORD win_key_;
destructor_t destructor_; destructor_t destructor_;
void destruct(void* data) { bool valid() const noexcept {
if ((destructor_ != nullptr) && (data != nullptr)) { return win_key_ != TLS_OUT_OF_INDEXES;
destructor_(data); }
void* get() const {
return ::TlsGetValue(win_key_);
}
bool set(void* p) {
return TRUE == ::TlsSetValue(win_key_, static_cast<LPVOID>(p));
}
void destruct() {
void* data = valid() ? get() : nullptr;
if (data != nullptr) {
if (destructor_ != nullptr) destructor_(data);
set(nullptr);
} }
} }
void clear_self() {
if (valid()) {
destruct();
::TlsFree(win_key_);
}
delete this;
}
}; };
using rec_t = std::unordered_set<tls_data*>; struct tls_recs {
tls_data* recs_[TLS_MINIMUM_AVAILABLE] {};
unsigned index_ = 0;
DWORD& record_key() { bool insert(tls_data* data) noexcept {
if (index_ >= TLS_MINIMUM_AVAILABLE) {
struct key_gen { ipc::error("[tls_recs] insert tls_data failed[index_ >= TLS_MINIMUM_AVAILABLE].\n");
DWORD rec_key_; return false;
key_gen() : rec_key_(::TlsAlloc()) {
if (rec_key_ == TLS_OUT_OF_INDEXES) {
ipc::error("[record_key] TlsAlloc failed[%lu].\n", ::GetLastError());
}
} }
~key_gen() { ::TlsFree(rec_key_); } recs_[(data->index_ = index_++)] = data;
}; return true;
}
static key_gen gen; bool erase(tls_data* data) noexcept {
return gen.rec_key_; if (data->index_ >= TLS_MINIMUM_AVAILABLE) return false;
recs_[data->index_] = nullptr;
return true;
}
tls_data* * begin() noexcept { return &recs_[0]; }
tls_data* const * begin() const noexcept { return &recs_[0]; }
tls_data* * end () noexcept { return &recs_[index_]; }
tls_data* const * end () const noexcept { return &recs_[index_]; }
};
struct key_gen {
DWORD rec_key_;
key_gen() : rec_key_(::TlsAlloc()) {
if (rec_key_ == TLS_OUT_OF_INDEXES) {
ipc::error("[record_key] TlsAlloc failed[%lu].\n", ::GetLastError());
}
}
~key_gen() {
::TlsFree(rec_key_);
rec_key_ = TLS_OUT_OF_INDEXES;
}
} gen__;
DWORD& record_key() noexcept {
return gen__.rec_key_;
} }
bool record(tls_data* tls) { bool record(tls_data* tls_dat) {
auto rec = static_cast<rec_t*>(::TlsGetValue(record_key())); if (record_key() == TLS_OUT_OF_INDEXES) return false;
auto rec = static_cast<tls_recs*>(::TlsGetValue(record_key()));
if (rec == nullptr) { if (rec == nullptr) {
if (FALSE == ::TlsSetValue(record_key(), static_cast<LPVOID>(rec = new rec_t))) { if (FALSE == ::TlsSetValue(record_key(), static_cast<LPVOID>(rec = new tls_recs))) {
ipc::error("[record] TlsSetValue failed[%lu].\n", ::GetLastError()); ipc::error("[record] TlsSetValue failed[%lu].\n", ::GetLastError());
return false; return false;
} }
} }
rec->insert(tls); return rec->insert(tls_dat);
return true;
} }
static void erase_record(tls_data* tls) { void erase_record(tls_data* tls_dat) {
auto rec = static_cast<rec_t*>(::TlsGetValue(record_key())); if (tls_dat == nullptr) return;
if (record_key() == TLS_OUT_OF_INDEXES) return;
auto rec = static_cast<tls_recs*>(::TlsGetValue(record_key()));
if (rec == nullptr) return; if (rec == nullptr) return;
rec->erase(tls); rec->erase(tls_dat);
tls_dat->clear_self();
} }
static void clear_all_records() { void clear_all_records() {
auto rec = static_cast<rec_t*>(::TlsGetValue(record_key())); if (record_key() == TLS_OUT_OF_INDEXES) return;
auto rec = static_cast<tls_recs*>(::TlsGetValue(record_key()));
if (rec == nullptr) return; if (rec == nullptr) return;
for (auto tls : *rec) { for (auto tls_dat : *rec) {
if (tls != nullptr) { if (tls_dat != nullptr) tls_dat->destruct();
tls->destruct(::TlsGetValue(tls->win_key_));
}
} }
delete rec; delete rec;
::TlsSetValue(record_key(), static_cast<LPVOID>(nullptr)); ::TlsSetValue(record_key(), static_cast<LPVOID>(nullptr));
@ -90,11 +140,11 @@ namespace tls {
key_t create(destructor_t destructor) { key_t create(destructor_t destructor) {
record_key(); // gen record-key record_key(); // gen record-key
auto tls_dat = new tls_data { ::TlsAlloc(), destructor }; auto tls_dat = new tls_data { unsigned(-1), ::TlsAlloc(), destructor };
std::atomic_thread_fence(std::memory_order_seq_cst); std::atomic_thread_fence(std::memory_order_seq_cst);
if (tls_dat->win_key_ == TLS_OUT_OF_INDEXES) { if (!tls_dat->valid()) {
ipc::error("[tls::create] TlsAlloc failed[%lu].\n", ::GetLastError()); ipc::error("[tls::create] TlsAlloc failed[%lu].\n", ::GetLastError());
delete tls_dat; tls_dat->clear_self();
return invalid_value; return invalid_value;
} }
return reinterpret_cast<key_t>(tls_dat); return reinterpret_cast<key_t>(tls_dat);
@ -111,8 +161,6 @@ void release(key_t tls_key) {
return; return;
} }
erase_record(tls_dat); erase_record(tls_dat);
::TlsFree(tls_dat->win_key_);
delete tls_dat;
} }
bool set(key_t tls_key, void* ptr) { bool set(key_t tls_key, void* ptr) {
@ -125,7 +173,7 @@ bool set(key_t tls_key, void* ptr) {
ipc::error("[tls::set] tls_dat is nullptr.\n"); ipc::error("[tls::set] tls_dat is nullptr.\n");
return false; return false;
} }
if (FALSE == ::TlsSetValue(tls_dat->win_key_, static_cast<LPVOID>(ptr))) { if (!tls_dat->set(ptr)) {
ipc::error("[tls::set] TlsSetValue failed[%lu].\n", ::GetLastError()); ipc::error("[tls::set] TlsSetValue failed[%lu].\n", ::GetLastError());
return false; return false;
} }
@ -143,19 +191,15 @@ void* get(key_t tls_key) {
ipc::error("[tls::get] tls_dat is nullptr.\n"); ipc::error("[tls::get] tls_dat is nullptr.\n");
return nullptr; return nullptr;
} }
return ::TlsGetValue(tls_dat->win_key_); return tls_dat->get();
} }
} // namespace tls } // namespace tls
namespace { namespace {
void OnThreadExit() {
clear_all_records();
}
void NTAPI OnTlsCallback(PVOID, DWORD dwReason, PVOID) { void NTAPI OnTlsCallback(PVOID, DWORD dwReason, PVOID) {
if (dwReason == DLL_THREAD_DETACH) OnThreadExit(); if (dwReason == DLL_THREAD_DETACH) clear_all_records();
} }
} // internal-linkage } // internal-linkage