diff --git a/include/rcp.h b/include/rcp.h index eaacf74..bc79cf4 100644 --- a/include/rcp.h +++ b/include/rcp.h @@ -33,12 +33,9 @@ class rcp private: T * ptr = nullptr; - explicit rcp(T * p) : ptr(p) + void init_ptr(T * p) { - if (p) - { - ptr->rcp_inc(); - } + ptr = p; } friend T; @@ -54,6 +51,16 @@ public: } } + rcp & operator=(const rcp & other) + { + ptr = other.ptr; + if (ptr) + { + ptr->rcp_inc(); + } + return *this; + } + ~rcp() { if (ptr) @@ -93,12 +100,12 @@ class root private: mutable std::atomic ref_count{0}; - rcp root_rcp; + mutable rcp root_rcp; protected: root() { - root_rcp = rcp(this); + root_rcp.init_ptr(this); } virtual ~root() = default; @@ -111,8 +118,9 @@ public: void rcp_dec() const { - if (ref_count.fetch_sub(1, std::memory_order_acq_rel) == 0) + if (ref_count.fetch_sub(1, std::memory_order_acq_rel) == 1) { + root_rcp.init_ptr(nullptr); delete this; } } diff --git a/test/tests.cpp b/test/tests.cpp index 04645bc..6c978d1 100644 --- a/test/tests.cpp +++ b/test/tests.cpp @@ -1,9 +1,22 @@ #include +static int mybase_construct; +static int mybase_destroy; +static int myderived_construct; +static int myderived_destroy; + class MyBase : public root { protected: - MyBase(int x, int y) {} + MyBase(int x, int y) + { + mybase_construct++; + } + + virtual ~MyBase() + { + mybase_destroy++; + } public: rcp_managed(MyBase); @@ -12,13 +25,28 @@ public: class MyDerived : public MyBase { protected: - MyDerived(double v) : MyBase(1, 2) {} + MyDerived(double v) : MyBase(1, 2) + { + myderived_construct++; + } + + virtual ~MyDerived() + { + myderived_destroy++; + } public: rcp_managed(MyDerived) }; +void t1() +{ + rcp mybase = MyBase::create(4, 5); + rcp myderived = MyDerived::create(42.5); +} + int main(int argc, char * argv[]) { + t1(); return 0; }