diff --git a/src/alloc-override-win.c b/src/alloc-override-win.c index 0d675cd8..3ac499b8 100644 --- a/src/alloc-override-win.c +++ b/src/alloc-override-win.c @@ -379,18 +379,19 @@ typedef enum patch_apply_e { PATCH_TARGET_TERM } patch_apply_t; +#define MAX_ENTRIES 4 // maximum number of patched entry points (like `malloc` in ucrtbase and msvcrt) + typedef struct mi_patch_s { - const char* name; // name of the function to patch - int priority; // priority to patch this one (used to prioritize over multiple entries in various dll's) - void* original; // the resolved address of the function (or NULL) - void* target; // the address of the new target (never NULL) - void* target_term;// the address of the target during termination (or NULL) - patch_apply_t applied; // what target has been applied? - mi_jump_t save; // the saved instructions in case it was applied + const char* name; // name of the function to patch + void* target; // the address of the new target (never NULL) + void* target_term; // the address of the target during termination (or NULL) + patch_apply_t applied; // what target has been applied? + void* originals[MAX_ENTRIES]; // the resolved addresses of the function (or NULLs) + mi_jump_t saves[MAX_ENTRIES]; // the saved instructions in case it was applied } mi_patch_t; -#define MI_PATCH_NAME3(name,target,term) { name, 0, NULL, &target, &term, PATCH_NONE } -#define MI_PATCH_NAME2(name,target) { name, 0, NULL, &target, NULL, PATCH_NONE } +#define MI_PATCH_NAME3(name,target,term) { name, &target, &term, PATCH_NONE, {NULL,NULL,NULL,NULL} } +#define MI_PATCH_NAME2(name,target) { name, &target, NULL, PATCH_NONE, {NULL,NULL,NULL,NULL} } #define MI_PATCH3(name,target,term) MI_PATCH_NAME3(#name, target, term) #define MI_PATCH2(name,target) MI_PATCH_NAME2(#name, target) #define MI_PATCH1(name) MI_PATCH2(name,mi_##name) @@ -463,29 +464,34 @@ static mi_patch_t patches[] = { MI_PATCH_NAME3("??_V@YAXPAXABUnothrow_t@std@@@Z", mi_free, mi_free_term), #endif #endif - { NULL, 0, NULL, NULL, NULL, PATCH_NONE } + { NULL, NULL, NULL, PATCH_NONE, {NULL,NULL,NULL,NULL} } }; // Apply a patch static bool mi_patch_apply(mi_patch_t* patch, patch_apply_t apply) { - if (patch->original == NULL) return true; // unresolved + if (patch->originals[0] == NULL) return true; // unresolved if (apply == PATCH_TARGET_TERM && patch->target_term == NULL) apply = PATCH_TARGET; // avoid re-applying non-term variants if (patch->applied == apply) return false; - DWORD protect = PAGE_READWRITE; - if (!VirtualProtect(patch->original, MI_JUMP_SIZE, PAGE_EXECUTE_READWRITE, &protect)) return false; - if (apply == PATCH_NONE) { - mi_jump_restore(patch->original, &patch->save); - } - else { - void* target = (apply == PATCH_TARGET ? patch->target : patch->target_term); - mi_assert_internal(target!=NULL); - if (target != NULL) mi_jump_write(patch->original, target, &patch->save); + for (int i = 0; i < MAX_ENTRIES; i++) { + void* original = patch->originals[i]; + if (original == NULL) break; // no more + + DWORD protect = PAGE_READWRITE; + if (!VirtualProtect(original, MI_JUMP_SIZE, PAGE_EXECUTE_READWRITE, &protect)) return false; + if (apply == PATCH_NONE) { + mi_jump_restore(original, &patch->saves[i]); + } + else { + void* target = (apply == PATCH_TARGET ? patch->target : patch->target_term); + mi_assert_internal(target != NULL); + if (target != NULL) mi_jump_write(original, target, &patch->saves[i]); + } + VirtualProtect(original, MI_JUMP_SIZE, protect, &protect); } patch->applied = apply; - VirtualProtect(patch->original, MI_JUMP_SIZE, protect, &protect); return true; } @@ -542,13 +548,17 @@ static void mi_module_resolve(const char* fname, HMODULE mod, int priority) { // see if any patches apply for (size_t i = 0; patches[i].name != NULL; i++) { mi_patch_t* patch = &patches[i]; - if (!patch->applied && patch->priority < priority) { - void* addr = GetProcAddress(mod, patch->name); - if (addr != NULL) { - // found it! set the address - patch->original = addr; - patch->priority = priority; - _mi_trace_message(" override %s at %s!%p, priority %i\n", patch->name, fname, addr, priority); + if (patch->applied == PATCH_NONE) { + // find an available entry + int i = 0; + while (i < MAX_ENTRIES && patch->originals[i] != NULL) i++; + if (i < MAX_ENTRIES) { + void* addr = GetProcAddress(mod, patch->name); + if (addr != NULL) { + // found it! set the address + patch->originals[i] = addr; + _mi_trace_message(" override %s at %s!%p (entry %i)\n", patch->name, fname, addr, i); + } } } } @@ -595,7 +605,8 @@ static bool mi_patches_resolve(void) { int priority = 0; if (i == 0) priority = 2; // main module to allow static crt linking else if (_strnicmp(basename, "ucrt", 4) == 0) priority = 3; // new ucrtbase.dll in windows 10 - else if (_strnicmp(basename, "msvcr", 5) == 0) priority = 1; // older runtimes + // NOTE: don't override msvcr -- leads to crashes in setlocale (needs more testing) + // else if (_strnicmp(basename, "msvcr", 5) == 0) priority = 1; // older runtimes if (priority > 0) { // probably found a crt module, try to patch it