diff --git a/include/mimalloc.h b/include/mimalloc.h index c752ac24..85311c11 100644 --- a/include/mimalloc.h +++ b/include/mimalloc.h @@ -389,6 +389,9 @@ mi_decl_nodiscard mi_decl_export mi_decl_restrict void* mi_new_n(size_t count, s mi_decl_nodiscard mi_decl_export void* mi_new_realloc(void* p, size_t newsize) mi_attr_alloc_size(2); mi_decl_nodiscard mi_decl_export void* mi_new_reallocn(void* p, size_t newcount, size_t size) mi_attr_alloc_size2(2, 3); +mi_decl_nodiscard mi_decl_export mi_decl_restrict void* mi_heap_new_(size_t size, mi_heap_t *heap) mi_attr_malloc mi_attr_alloc_size(1); +mi_decl_nodiscard mi_decl_export mi_decl_restrict void* mi_heap_new_n(size_t count, size_t size, mi_heap_t *heap) mi_attr_malloc mi_attr_alloc_size2(1, 2); + #ifdef __cplusplus } #endif @@ -448,6 +451,77 @@ template struct mi_stl_allocator { template bool operator==(const mi_stl_allocator& , const mi_stl_allocator& ) mi_attr_noexcept { return true; } template bool operator!=(const mi_stl_allocator& , const mi_stl_allocator& ) mi_attr_noexcept { return false; } + +#if (__cplusplus >= 201103L) || (_MSC_VER > 1900) // C++11 +#include + +template struct mi_heap_stl_allocator { + typedef T value_type; + typedef std::size_t size_type; + typedef std::ptrdiff_t difference_type; + typedef value_type& reference; + typedef value_type const& const_reference; + typedef value_type* pointer; + typedef value_type const* const_pointer; + template struct rebind { typedef mi_heap_stl_allocator other; }; + + mi_heap_stl_allocator() { + mi_heap_t *heap = mi_heap_new(); + this->_heap.reset(new(static_cast(mi_heap_new_(sizeof(managed_heap), heap))) managed_heap(heap), managed_heap::destroy); + } + mi_heap_stl_allocator(const mi_heap_stl_allocator&) mi_attr_noexcept = default; + template mi_heap_stl_allocator(const mi_heap_stl_allocator& other) mi_attr_noexcept : _heap(std::reinterpret_pointer_cast::managed_heap>(other._heap)) { } + mi_heap_stl_allocator select_on_container_copy_construction() const { return *this; } + void deallocate(T* p, size_type) { if (_heap->free_enabled) mi_free(p); } + + #if (__cplusplus >= 201703L) // C++17 + mi_decl_nodiscard T* allocate(size_type count) { return static_cast(mi_heap_new_n(count, sizeof(T), _heap->heap)); } + mi_decl_nodiscard T* allocate(size_type count, const void*) { return allocate(count); } + #else + mi_decl_nodiscard pointer allocate(size_type count, const void* = 0) { return static_cast(mi_heap_new_n(count, sizeof(value_type), _heap->heap)); } + #endif + + using propagate_on_container_copy_assignment = std::true_type; + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; + using is_always_equal = std::true_type; + template void construct(U* p, Args&& ...args) { ::new(p) U(std::forward(args)...); } + template void destroy(U* p) mi_attr_noexcept { p->~U(); } + + size_type max_size() const mi_attr_noexcept { return (PTRDIFF_MAX/sizeof(value_type)); } + pointer address(reference x) const { return &x; } + const_pointer address(const_reference x) const { return &x; } + + void enable_free() mi_attr_noexcept { this->_heap->free_enabled = true; } + void disable_free() mi_attr_noexcept { this->_heap->free_enabled = false; } + void collect(bool force = false) mi_attr_noexcept { mi_heap_collect(_heap->heap, force); } + + protected: + struct managed_heap { + managed_heap(mi_heap_t *heap): heap(heap), free_enabled(true) { } + managed_heap(const managed_heap&) = delete; + managed_heap& operator=(managed_heap const&) = delete; + ~managed_heap() = delete; + static void destroy(managed_heap *ptr) { mi_heap_destroy(ptr->heap); } + + mi_heap_t *heap; + bool free_enabled; + }; + + std::shared_ptr _heap; + + template + friend struct mi_heap_stl_allocator; + template + friend bool operator==(const mi_heap_stl_allocator& first, const mi_heap_stl_allocator& second) mi_attr_noexcept; + template + friend bool operator!=(const mi_heap_stl_allocator& first, const mi_heap_stl_allocator& second) mi_attr_noexcept; +}; + +template bool operator==(const mi_heap_stl_allocator& first, const mi_heap_stl_allocator& second) mi_attr_noexcept { return first._heap == second._heap; } +template bool operator!=(const mi_heap_stl_allocator& first, const mi_heap_stl_allocator& second) mi_attr_noexcept { return first._heap != second._heap; } +#endif // C++11 + #endif // __cplusplus #endif diff --git a/src/alloc.c b/src/alloc.c index 1a36b5da..5d330b6c 100644 --- a/src/alloc.c +++ b/src/alloc.c @@ -932,3 +932,20 @@ void* mi_new_reallocn(void* p, size_t newcount, size_t size) { return mi_new_realloc(p, total); } } + +mi_decl_restrict void* mi_heap_new_(size_t size, mi_heap_t *heap) { + void* p = mi_heap_malloc(heap, size); + if (mi_unlikely(p == NULL)) return mi_try_new(size,false); + return p; +} + +mi_decl_restrict void* mi_heap_new_n(size_t count, size_t size, mi_heap_t *heap) { + size_t total; + if (mi_unlikely(mi_count_size_overflow(count, size, &total))) { + mi_try_new_handler(false); // on overflow we invoke the try_new_handler once to potentially throw std::bad_alloc + return NULL; + } + else { + return mi_heap_new_(total, heap); + } +} diff --git a/test/main-override.cpp b/test/main-override.cpp index e0dba5a3..784618ec 100644 --- a/test/main-override.cpp +++ b/test/main-override.cpp @@ -127,6 +127,32 @@ static bool test_stl_allocator2() { return vec.size() == 0; } +static bool test_heap_stl_allocator1() { +#if (__cplusplus >= 201103L) || (_MSC_VER > 1900) + mi_heap_stl_allocator alloc; + std::vector > vec(alloc); + vec.push_back(1); + vec.pop_back(); + return vec.size() == 0; +#else + return true; +#endif +} + +static bool test_heap_stl_allocator2() { +#if (__cplusplus >= 201103L) || (_MSC_VER > 1900) + mi_heap_stl_allocator alloc; + std::vector > vec(alloc); + alloc.disable_free(); + vec.push_back(some_struct()); + vec.pop_back(); + alloc.enable_free(); + return vec.size() == 0; +#else + return true; +#endif +} + // issue 445 static void strdup_test() { #ifdef _MSC_VER diff --git a/test/test-api.c b/test/test-api.c index 0302464e..bf3bad02 100644 --- a/test/test-api.c +++ b/test/test-api.c @@ -45,6 +45,10 @@ bool test_heap1(void); bool test_heap2(void); bool test_stl_allocator1(void); bool test_stl_allocator2(void); +bool test_heap_stl_allocator1(void); +bool test_heap_stl_allocator2(void); +bool test_heap_stl_allocator3(void); +bool test_heap_stl_allocator4(void); // --------------------------------------------------------------------------- // Main testing @@ -193,6 +197,11 @@ int main(void) { CHECK("stl_allocator1", test_stl_allocator1()); CHECK("stl_allocator2", test_stl_allocator2()); + + CHECK("heap_stl_allocator1", test_heap_stl_allocator1()); + CHECK("heap_stl_allocator2", test_heap_stl_allocator2()); + CHECK("heap_stl_allocator3", test_heap_stl_allocator3()); + CHECK("heap_stl_allocator3", test_heap_stl_allocator4()); // --------------------------------------------------- // Done @@ -247,3 +256,60 @@ bool test_stl_allocator2() { return true; #endif } + +bool test_heap_stl_allocator1() { +#if (__cplusplus >= 201103L) || (_MSC_VER > 1900) + mi_heap_stl_allocator alloc; + std::vector > vec(alloc); + vec.push_back(1); + vec.pop_back(); + return vec.size() == 0; +#else + return true; +#endif +} + +bool test_heap_stl_allocator2() { +#if (__cplusplus >= 201103L) || (_MSC_VER > 1900) + mi_heap_stl_allocator alloc; + std::vector > vec(alloc); + vec.push_back(some_struct()); + vec.pop_back(); + return vec.size() == 0; +#else + return true; +#endif +} + +bool test_heap_stl_allocator3() { +#if (__cplusplus >= 201103L) || (_MSC_VER > 1900) + mi_heap_stl_allocator alloc; + alloc.disable_free(); + std::vector > vec(alloc); + for (int i = 0; i < 1000; i++) { + vec.push_back(i); + } + return vec.size() == 1000; +#else + return true; +#endif +} + +bool test_heap_stl_allocator4() { +#if (__cplusplus >= 201103L) || (_MSC_VER > 1900) + mi_heap_stl_allocator alloc; + alloc.disable_free(); + std::vector > vec(alloc); + for (int i = 0; i < 100; i++) { + vec.push_back(i); + } + alloc.enable_free(); + for (int i = 0; i < 1000; i++) { + vec.push_back(i); + } + alloc.collect(); + return vec.size() == 1100; +#else + return true; +#endif +}