diff --git a/include/rcp.h b/include/rcp.h index 3fe3a0d..96af0d7 100644 --- a/include/rcp.h +++ b/include/rcp.h @@ -54,10 +54,18 @@ public: rcp & operator=(const rcp & other) { - ptr = other.ptr; - if (ptr) + if (ptr != other.ptr) { - ptr->rcp_inc(); + T * old_ptr = ptr; + ptr = other.ptr; + if (ptr) + { + ptr->rcp_inc(); + } + if (old_ptr) + { + old_ptr->rcp_dec(); + } } return *this; } diff --git a/test/tests.cpp b/test/tests.cpp index 556db5f..c24e664 100644 --- a/test/tests.cpp +++ b/test/tests.cpp @@ -28,6 +28,7 @@ public: int x; int y; }; +typedef rcp MyB; class MyDerived : public MyBase { @@ -85,7 +86,6 @@ void test_booleans() void test_create() { - typedef rcp MyB; MyB myb = MyB::create(8, 9); assert(myb->x == 8); } @@ -124,6 +124,14 @@ void test_multi_construct_from_raw_pointers() } } +void test_copy_assignment_decrements_previous_reference() +{ + MyB myb = MyB::create(12, 13); + MyB myb2 = MyB::create(14, 15); + myb = myb2; + assert(myb->x == 14); +} + int main(int argc, char * argv[]) { test_class_hierarchy(); @@ -131,5 +139,6 @@ int main(int argc, char * argv[]) test_booleans(); test_create(); test_multi_construct_from_raw_pointers(); + test_copy_assignment_decrements_previous_reference(); return 0; }