1*51c0b2f7Stbbdev /* 2*51c0b2f7Stbbdev Copyright (c) 2005-2020 Intel Corporation 3*51c0b2f7Stbbdev 4*51c0b2f7Stbbdev Licensed under the Apache License, Version 2.0 (the "License"); 5*51c0b2f7Stbbdev you may not use this file except in compliance with the License. 6*51c0b2f7Stbbdev You may obtain a copy of the License at 7*51c0b2f7Stbbdev 8*51c0b2f7Stbbdev http://www.apache.org/licenses/LICENSE-2.0 9*51c0b2f7Stbbdev 10*51c0b2f7Stbbdev Unless required by applicable law or agreed to in writing, software 11*51c0b2f7Stbbdev distributed under the License is distributed on an "AS IS" BASIS, 12*51c0b2f7Stbbdev WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*51c0b2f7Stbbdev See the License for the specific language governing permissions and 14*51c0b2f7Stbbdev limitations under the License. 15*51c0b2f7Stbbdev */ 16*51c0b2f7Stbbdev 17*51c0b2f7Stbbdev #include "tbb/detail/_config.h" 18*51c0b2f7Stbbdev #include "tbb/detail/_assert.h" 19*51c0b2f7Stbbdev #include "../tbb/assert_impl.h" 20*51c0b2f7Stbbdev 21*51c0b2f7Stbbdev #if !__TBB_WIN8UI_SUPPORT && defined(_WIN32) 22*51c0b2f7Stbbdev 23*51c0b2f7Stbbdev #ifndef _CRT_SECURE_NO_DEPRECATE 24*51c0b2f7Stbbdev #define _CRT_SECURE_NO_DEPRECATE 1 25*51c0b2f7Stbbdev #endif 26*51c0b2f7Stbbdev 27*51c0b2f7Stbbdev // no standard-conforming implementation of snprintf prior to VS 2015 28*51c0b2f7Stbbdev #if !defined(_MSC_VER) || _MSC_VER>=1900 29*51c0b2f7Stbbdev #define LOG_PRINT(s, n, format, ...) snprintf(s, n, format, __VA_ARGS__) 30*51c0b2f7Stbbdev #else 31*51c0b2f7Stbbdev #define LOG_PRINT(s, n, format, ...) _snprintf_s(s, n, _TRUNCATE, format, __VA_ARGS__) 32*51c0b2f7Stbbdev #endif 33*51c0b2f7Stbbdev 34*51c0b2f7Stbbdev #include <windows.h> 35*51c0b2f7Stbbdev #include <new> 36*51c0b2f7Stbbdev #include <stdio.h> 37*51c0b2f7Stbbdev #include <string.h> 38*51c0b2f7Stbbdev 39*51c0b2f7Stbbdev #include "function_replacement.h" 40*51c0b2f7Stbbdev 41*51c0b2f7Stbbdev // The information about a standard memory allocation function for the replacement log 42*51c0b2f7Stbbdev struct FunctionInfo { 43*51c0b2f7Stbbdev const char* funcName; 44*51c0b2f7Stbbdev const char* dllName; 45*51c0b2f7Stbbdev }; 46*51c0b2f7Stbbdev 47*51c0b2f7Stbbdev // Namespace that processes and manages the output of records to the Log journal 48*51c0b2f7Stbbdev // that will be provided to user by TBB_malloc_replacement_log() 49*51c0b2f7Stbbdev namespace Log { 50*51c0b2f7Stbbdev // Value of RECORDS_COUNT is set due to the fact that we maximally 51*51c0b2f7Stbbdev // scan 8 modules, and in every module we can swap 6 opcodes. (rounded to 8) 52*51c0b2f7Stbbdev static const unsigned RECORDS_COUNT = 8 * 8; 53*51c0b2f7Stbbdev static const unsigned RECORD_LENGTH = MAX_PATH; 54*51c0b2f7Stbbdev 55*51c0b2f7Stbbdev // Need to add 1 to count of records, because last record must be always NULL 56*51c0b2f7Stbbdev static char *records[RECORDS_COUNT + 1]; 57*51c0b2f7Stbbdev static bool replacement_status = true; 58*51c0b2f7Stbbdev 59*51c0b2f7Stbbdev // Internal counter that contains number of next string for record 60*51c0b2f7Stbbdev static unsigned record_number = 0; 61*51c0b2f7Stbbdev 62*51c0b2f7Stbbdev // Function that writes info about (not)found opcodes to the Log journal 63*51c0b2f7Stbbdev // functionInfo - information about a standard memory allocation function for the replacement log 64*51c0b2f7Stbbdev // opcodeString - string, that contain byte code of this function 65*51c0b2f7Stbbdev // status - information about function replacement status 66*51c0b2f7Stbbdev static void record(FunctionInfo functionInfo, const char * opcodeString, bool status) { 67*51c0b2f7Stbbdev __TBB_ASSERT(functionInfo.dllName, "Empty DLL name value"); 68*51c0b2f7Stbbdev __TBB_ASSERT(functionInfo.funcName, "Empty function name value"); 69*51c0b2f7Stbbdev __TBB_ASSERT(opcodeString, "Empty opcode"); 70*51c0b2f7Stbbdev __TBB_ASSERT(record_number <= RECORDS_COUNT, "Incorrect record number"); 71*51c0b2f7Stbbdev 72*51c0b2f7Stbbdev //If some replacement failed -> set status to false 73*51c0b2f7Stbbdev replacement_status &= status; 74*51c0b2f7Stbbdev 75*51c0b2f7Stbbdev // If we reach the end of the log, write this message to the last line 76*51c0b2f7Stbbdev if (record_number == RECORDS_COUNT) { 77*51c0b2f7Stbbdev // %s - workaround to fix empty variable argument parsing behavior in GCC 78*51c0b2f7Stbbdev LOG_PRINT(records[RECORDS_COUNT - 1], RECORD_LENGTH, "%s", "Log was truncated."); 79*51c0b2f7Stbbdev return; 80*51c0b2f7Stbbdev } 81*51c0b2f7Stbbdev 82*51c0b2f7Stbbdev char* entry = (char*)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, RECORD_LENGTH); 83*51c0b2f7Stbbdev __TBB_ASSERT(entry, "Invalid memory was returned"); 84*51c0b2f7Stbbdev 85*51c0b2f7Stbbdev LOG_PRINT(entry, RECORD_LENGTH, "%s: %s (%s), byte pattern: <%s>", 86*51c0b2f7Stbbdev status ? "Success" : "Fail", functionInfo.funcName, functionInfo.dllName, opcodeString); 87*51c0b2f7Stbbdev 88*51c0b2f7Stbbdev records[record_number++] = entry; 89*51c0b2f7Stbbdev } 90*51c0b2f7Stbbdev }; 91*51c0b2f7Stbbdev 92*51c0b2f7Stbbdev inline UINT_PTR Ptr2Addrint(LPVOID ptr) 93*51c0b2f7Stbbdev { 94*51c0b2f7Stbbdev Int2Ptr i2p; 95*51c0b2f7Stbbdev i2p.lpv = ptr; 96*51c0b2f7Stbbdev return i2p.uip; 97*51c0b2f7Stbbdev } 98*51c0b2f7Stbbdev 99*51c0b2f7Stbbdev inline LPVOID Addrint2Ptr(UINT_PTR ptr) 100*51c0b2f7Stbbdev { 101*51c0b2f7Stbbdev Int2Ptr i2p; 102*51c0b2f7Stbbdev i2p.uip = ptr; 103*51c0b2f7Stbbdev return i2p.lpv; 104*51c0b2f7Stbbdev } 105*51c0b2f7Stbbdev 106*51c0b2f7Stbbdev // Is the distance between addr1 and addr2 smaller than dist 107*51c0b2f7Stbbdev inline bool IsInDistance(UINT_PTR addr1, UINT_PTR addr2, __int64 dist) 108*51c0b2f7Stbbdev { 109*51c0b2f7Stbbdev __int64 diff = addr1>addr2 ? addr1-addr2 : addr2-addr1; 110*51c0b2f7Stbbdev return diff<dist; 111*51c0b2f7Stbbdev } 112*51c0b2f7Stbbdev 113*51c0b2f7Stbbdev /* 114*51c0b2f7Stbbdev * When inserting a probe in 64 bits process the distance between the insertion 115*51c0b2f7Stbbdev * point and the target may be bigger than 2^32. In this case we are using 116*51c0b2f7Stbbdev * indirect jump through memory where the offset to this memory location 117*51c0b2f7Stbbdev * is smaller than 2^32 and it contains the absolute address (8 bytes). 118*51c0b2f7Stbbdev * 119*51c0b2f7Stbbdev * This class is used to hold the pages used for the above trampolines. 120*51c0b2f7Stbbdev * Since this utility will be used to replace malloc functions this implementation 121*51c0b2f7Stbbdev * doesn't allocate memory dynamically. 122*51c0b2f7Stbbdev * 123*51c0b2f7Stbbdev * The struct MemoryBuffer holds the data about a page in the memory used for 124*51c0b2f7Stbbdev * replacing functions in 64-bit code where the target is too far to be replaced 125*51c0b2f7Stbbdev * with a short jump. All the calculations of m_base and m_next are in a multiple 126*51c0b2f7Stbbdev * of SIZE_OF_ADDRESS (which is 8 in Win64). 127*51c0b2f7Stbbdev */ 128*51c0b2f7Stbbdev class MemoryProvider { 129*51c0b2f7Stbbdev private: 130*51c0b2f7Stbbdev struct MemoryBuffer { 131*51c0b2f7Stbbdev UINT_PTR m_base; // base address of the buffer 132*51c0b2f7Stbbdev UINT_PTR m_next; // next free location in the buffer 133*51c0b2f7Stbbdev DWORD m_size; // size of buffer 134*51c0b2f7Stbbdev 135*51c0b2f7Stbbdev // Default constructor 136*51c0b2f7Stbbdev MemoryBuffer() : m_base(0), m_next(0), m_size(0) {} 137*51c0b2f7Stbbdev 138*51c0b2f7Stbbdev // Constructor 139*51c0b2f7Stbbdev MemoryBuffer(void *base, DWORD size) 140*51c0b2f7Stbbdev { 141*51c0b2f7Stbbdev m_base = Ptr2Addrint(base); 142*51c0b2f7Stbbdev m_next = m_base; 143*51c0b2f7Stbbdev m_size = size; 144*51c0b2f7Stbbdev } 145*51c0b2f7Stbbdev }; 146*51c0b2f7Stbbdev 147*51c0b2f7Stbbdev MemoryBuffer *CreateBuffer(UINT_PTR addr) 148*51c0b2f7Stbbdev { 149*51c0b2f7Stbbdev // No more room in the pages database 150*51c0b2f7Stbbdev if (m_lastBuffer - m_pages == MAX_NUM_BUFFERS) 151*51c0b2f7Stbbdev return 0; 152*51c0b2f7Stbbdev 153*51c0b2f7Stbbdev void *newAddr = Addrint2Ptr(addr); 154*51c0b2f7Stbbdev // Get information for the region which the given address belongs to 155*51c0b2f7Stbbdev MEMORY_BASIC_INFORMATION memInfo; 156*51c0b2f7Stbbdev if (VirtualQuery(newAddr, &memInfo, sizeof(memInfo)) != sizeof(memInfo)) 157*51c0b2f7Stbbdev return 0; 158*51c0b2f7Stbbdev 159*51c0b2f7Stbbdev for(;;) { 160*51c0b2f7Stbbdev // The new address to check is beyond the current region and aligned to allocation size 161*51c0b2f7Stbbdev newAddr = Addrint2Ptr( (Ptr2Addrint(memInfo.BaseAddress) + memInfo.RegionSize + m_allocSize) & ~(UINT_PTR)(m_allocSize-1) ); 162*51c0b2f7Stbbdev 163*51c0b2f7Stbbdev // Check that the address is in the right distance. 164*51c0b2f7Stbbdev // VirtualAlloc can only round the address down; so it will remain in the right distance 165*51c0b2f7Stbbdev if (!IsInDistance(addr, Ptr2Addrint(newAddr), MAX_DISTANCE)) 166*51c0b2f7Stbbdev break; 167*51c0b2f7Stbbdev 168*51c0b2f7Stbbdev if (VirtualQuery(newAddr, &memInfo, sizeof(memInfo)) != sizeof(memInfo)) 169*51c0b2f7Stbbdev break; 170*51c0b2f7Stbbdev 171*51c0b2f7Stbbdev if (memInfo.State == MEM_FREE && memInfo.RegionSize >= m_allocSize) 172*51c0b2f7Stbbdev { 173*51c0b2f7Stbbdev // Found a free region, try to allocate a page in this region 174*51c0b2f7Stbbdev void *newPage = VirtualAlloc(newAddr, m_allocSize, MEM_COMMIT|MEM_RESERVE, PAGE_READWRITE); 175*51c0b2f7Stbbdev if (!newPage) 176*51c0b2f7Stbbdev break; 177*51c0b2f7Stbbdev 178*51c0b2f7Stbbdev // Add the new page to the pages database 179*51c0b2f7Stbbdev MemoryBuffer *pBuff = new (m_lastBuffer) MemoryBuffer(newPage, m_allocSize); 180*51c0b2f7Stbbdev ++m_lastBuffer; 181*51c0b2f7Stbbdev return pBuff; 182*51c0b2f7Stbbdev } 183*51c0b2f7Stbbdev } 184*51c0b2f7Stbbdev 185*51c0b2f7Stbbdev // Failed to find a buffer in the distance 186*51c0b2f7Stbbdev return 0; 187*51c0b2f7Stbbdev } 188*51c0b2f7Stbbdev 189*51c0b2f7Stbbdev public: 190*51c0b2f7Stbbdev MemoryProvider() 191*51c0b2f7Stbbdev { 192*51c0b2f7Stbbdev SYSTEM_INFO sysInfo; 193*51c0b2f7Stbbdev GetSystemInfo(&sysInfo); 194*51c0b2f7Stbbdev m_allocSize = sysInfo.dwAllocationGranularity; 195*51c0b2f7Stbbdev m_lastBuffer = &m_pages[0]; 196*51c0b2f7Stbbdev } 197*51c0b2f7Stbbdev 198*51c0b2f7Stbbdev // We can't free the pages in the destructor because the trampolines 199*51c0b2f7Stbbdev // are using these memory locations and a replaced function might be called 200*51c0b2f7Stbbdev // after the destructor was called. 201*51c0b2f7Stbbdev ~MemoryProvider() 202*51c0b2f7Stbbdev { 203*51c0b2f7Stbbdev } 204*51c0b2f7Stbbdev 205*51c0b2f7Stbbdev // Return a memory location in distance less than 2^31 from input address 206*51c0b2f7Stbbdev UINT_PTR GetLocation(UINT_PTR addr) 207*51c0b2f7Stbbdev { 208*51c0b2f7Stbbdev MemoryBuffer *pBuff = m_pages; 209*51c0b2f7Stbbdev for (; pBuff<m_lastBuffer && IsInDistance(pBuff->m_next, addr, MAX_DISTANCE); ++pBuff) 210*51c0b2f7Stbbdev { 211*51c0b2f7Stbbdev if (pBuff->m_next < pBuff->m_base + pBuff->m_size) 212*51c0b2f7Stbbdev { 213*51c0b2f7Stbbdev UINT_PTR loc = pBuff->m_next; 214*51c0b2f7Stbbdev pBuff->m_next += MAX_PROBE_SIZE; 215*51c0b2f7Stbbdev return loc; 216*51c0b2f7Stbbdev } 217*51c0b2f7Stbbdev } 218*51c0b2f7Stbbdev 219*51c0b2f7Stbbdev pBuff = CreateBuffer(addr); 220*51c0b2f7Stbbdev if(!pBuff) 221*51c0b2f7Stbbdev return 0; 222*51c0b2f7Stbbdev 223*51c0b2f7Stbbdev UINT_PTR loc = pBuff->m_next; 224*51c0b2f7Stbbdev pBuff->m_next += MAX_PROBE_SIZE; 225*51c0b2f7Stbbdev return loc; 226*51c0b2f7Stbbdev } 227*51c0b2f7Stbbdev 228*51c0b2f7Stbbdev private: 229*51c0b2f7Stbbdev MemoryBuffer m_pages[MAX_NUM_BUFFERS]; 230*51c0b2f7Stbbdev MemoryBuffer *m_lastBuffer; 231*51c0b2f7Stbbdev DWORD m_allocSize; 232*51c0b2f7Stbbdev }; 233*51c0b2f7Stbbdev 234*51c0b2f7Stbbdev static MemoryProvider memProvider; 235*51c0b2f7Stbbdev 236*51c0b2f7Stbbdev // Compare opcodes from dictionary (str1) and opcodes from code (str2) 237*51c0b2f7Stbbdev // str1 might contain '*' to mask addresses 238*51c0b2f7Stbbdev // RETURN: 0 if opcodes did not match, 1 on success 239*51c0b2f7Stbbdev size_t compareStrings( const char *str1, const char *str2 ) 240*51c0b2f7Stbbdev { 241*51c0b2f7Stbbdev for (size_t i=0; str1[i]!=0; i++){ 242*51c0b2f7Stbbdev if( str1[i]!='*' && str1[i]!='#' && str1[i]!=str2[i] ) return 0; 243*51c0b2f7Stbbdev } 244*51c0b2f7Stbbdev return 1; 245*51c0b2f7Stbbdev } 246*51c0b2f7Stbbdev 247*51c0b2f7Stbbdev // Check function prologue with known prologues from the dictionary 248*51c0b2f7Stbbdev // opcodes - dictionary 249*51c0b2f7Stbbdev // inpAddr - pointer to function prologue 250*51c0b2f7Stbbdev // Dictionary contains opcodes for several full asm instructions 251*51c0b2f7Stbbdev // + one opcode byte for the next asm instruction for safe address processing 252*51c0b2f7Stbbdev // RETURN: 1 + the index of the matched pattern, or 0 if no match found. 253*51c0b2f7Stbbdev static UINT CheckOpcodes( const char ** opcodes, void *inpAddr, bool abortOnError, const FunctionInfo* functionInfo = NULL) 254*51c0b2f7Stbbdev { 255*51c0b2f7Stbbdev static size_t opcodesStringsCount = 0; 256*51c0b2f7Stbbdev static size_t maxOpcodesLength = 0; 257*51c0b2f7Stbbdev static size_t opcodes_pointer = (size_t)opcodes; 258*51c0b2f7Stbbdev char opcodeString[2*MAX_PATTERN_SIZE+1]; 259*51c0b2f7Stbbdev size_t i; 260*51c0b2f7Stbbdev size_t result = 0; 261*51c0b2f7Stbbdev 262*51c0b2f7Stbbdev // Get the values for static variables 263*51c0b2f7Stbbdev // max length and number of patterns 264*51c0b2f7Stbbdev if( !opcodesStringsCount || opcodes_pointer != (size_t)opcodes ){ 265*51c0b2f7Stbbdev while( *(opcodes + opcodesStringsCount)!= NULL ){ 266*51c0b2f7Stbbdev if( (i=strlen(*(opcodes + opcodesStringsCount))) > maxOpcodesLength ) 267*51c0b2f7Stbbdev maxOpcodesLength = i; 268*51c0b2f7Stbbdev opcodesStringsCount++; 269*51c0b2f7Stbbdev } 270*51c0b2f7Stbbdev opcodes_pointer = (size_t)opcodes; 271*51c0b2f7Stbbdev __TBB_ASSERT( maxOpcodesLength/2 <= MAX_PATTERN_SIZE, "Pattern exceeded the limit of 28 opcodes/56 symbols" ); 272*51c0b2f7Stbbdev } 273*51c0b2f7Stbbdev 274*51c0b2f7Stbbdev // Translate prologue opcodes to string format to compare 275*51c0b2f7Stbbdev for( i=0; i<maxOpcodesLength/2 && i<MAX_PATTERN_SIZE; ++i ){ 276*51c0b2f7Stbbdev sprintf( opcodeString + 2*i, "%.2X", *((unsigned char*)inpAddr+i) ); 277*51c0b2f7Stbbdev } 278*51c0b2f7Stbbdev opcodeString[2*i] = 0; 279*51c0b2f7Stbbdev 280*51c0b2f7Stbbdev // Compare translated opcodes with patterns 281*51c0b2f7Stbbdev for( UINT idx=0; idx<opcodesStringsCount; ++idx ){ 282*51c0b2f7Stbbdev result = compareStrings( opcodes[idx],opcodeString ); 283*51c0b2f7Stbbdev if( result ) { 284*51c0b2f7Stbbdev if (functionInfo) { 285*51c0b2f7Stbbdev Log::record(*functionInfo, opcodeString, /*status*/ true); 286*51c0b2f7Stbbdev } 287*51c0b2f7Stbbdev return idx + 1; // avoid 0 which indicates a failure 288*51c0b2f7Stbbdev } 289*51c0b2f7Stbbdev } 290*51c0b2f7Stbbdev if (functionInfo) { 291*51c0b2f7Stbbdev Log::record(*functionInfo, opcodeString, /*status*/ false); 292*51c0b2f7Stbbdev } 293*51c0b2f7Stbbdev if (abortOnError) { 294*51c0b2f7Stbbdev // Impossibility to find opcodes in the dictionary is a serious issue, 295*51c0b2f7Stbbdev // as if we unable to call original function, leak or crash is expected result. 296*51c0b2f7Stbbdev __TBB_ASSERT_RELEASE( false, "CheckOpcodes failed" ); 297*51c0b2f7Stbbdev } 298*51c0b2f7Stbbdev return 0; 299*51c0b2f7Stbbdev } 300*51c0b2f7Stbbdev 301*51c0b2f7Stbbdev // Modify offsets in original code after moving it to a trampoline. 302*51c0b2f7Stbbdev // We do not have more than one offset to correct in existing opcode patterns. 303*51c0b2f7Stbbdev static void CorrectOffset( UINT_PTR address, const char* pattern, UINT distance ) 304*51c0b2f7Stbbdev { 305*51c0b2f7Stbbdev const char* pos = strstr(pattern, "#*******"); 306*51c0b2f7Stbbdev if( pos ) { 307*51c0b2f7Stbbdev address += (pos - pattern)/2; // compute the offset position 308*51c0b2f7Stbbdev UINT value; 309*51c0b2f7Stbbdev // UINT assignment is not used to avoid potential alignment issues 310*51c0b2f7Stbbdev memcpy(&value, Addrint2Ptr(address), sizeof(value)); 311*51c0b2f7Stbbdev value += distance; 312*51c0b2f7Stbbdev memcpy(Addrint2Ptr(address), &value, sizeof(value)); 313*51c0b2f7Stbbdev } 314*51c0b2f7Stbbdev } 315*51c0b2f7Stbbdev 316*51c0b2f7Stbbdev // Insert jump relative instruction to the input address 317*51c0b2f7Stbbdev // RETURN: the size of the trampoline or 0 on failure 318*51c0b2f7Stbbdev static DWORD InsertTrampoline32(void *inpAddr, void *targetAddr, const char* pattern, void** storedAddr) 319*51c0b2f7Stbbdev { 320*51c0b2f7Stbbdev size_t bytesToMove = SIZE_OF_RELJUMP; 321*51c0b2f7Stbbdev UINT_PTR srcAddr = Ptr2Addrint(inpAddr); 322*51c0b2f7Stbbdev UINT_PTR tgtAddr = Ptr2Addrint(targetAddr); 323*51c0b2f7Stbbdev // Check that the target fits in 32 bits 324*51c0b2f7Stbbdev if (!IsInDistance(srcAddr, tgtAddr, MAX_DISTANCE)) 325*51c0b2f7Stbbdev return 0; 326*51c0b2f7Stbbdev 327*51c0b2f7Stbbdev UINT_PTR offset; 328*51c0b2f7Stbbdev UINT offset32; 329*51c0b2f7Stbbdev UCHAR *codePtr = (UCHAR *)inpAddr; 330*51c0b2f7Stbbdev 331*51c0b2f7Stbbdev if ( storedAddr ){ // If requested, store original function code 332*51c0b2f7Stbbdev bytesToMove = strlen(pattern)/2-1; // The last byte matching the pattern must not be copied 333*51c0b2f7Stbbdev __TBB_ASSERT_RELEASE( bytesToMove >= SIZE_OF_RELJUMP, "Incorrect bytecode pattern?" ); 334*51c0b2f7Stbbdev UINT_PTR trampAddr = memProvider.GetLocation(srcAddr); 335*51c0b2f7Stbbdev if (!trampAddr) 336*51c0b2f7Stbbdev return 0; 337*51c0b2f7Stbbdev *storedAddr = Addrint2Ptr(trampAddr); 338*51c0b2f7Stbbdev // Set 'executable' flag for original instructions in the new place 339*51c0b2f7Stbbdev DWORD pageFlags = PAGE_EXECUTE_READWRITE; 340*51c0b2f7Stbbdev if (!VirtualProtect(*storedAddr, MAX_PROBE_SIZE, pageFlags, &pageFlags)) return 0; 341*51c0b2f7Stbbdev // Copy original instructions to the new place 342*51c0b2f7Stbbdev memcpy(*storedAddr, codePtr, bytesToMove); 343*51c0b2f7Stbbdev offset = srcAddr - trampAddr; 344*51c0b2f7Stbbdev offset32 = (UINT)(offset & 0xFFFFFFFF); 345*51c0b2f7Stbbdev CorrectOffset( trampAddr, pattern, offset32 ); 346*51c0b2f7Stbbdev // Set jump to the code after replacement 347*51c0b2f7Stbbdev offset32 -= SIZE_OF_RELJUMP; 348*51c0b2f7Stbbdev *(UCHAR*)(trampAddr+bytesToMove) = 0xE9; 349*51c0b2f7Stbbdev memcpy((UCHAR*)(trampAddr+bytesToMove+1), &offset32, sizeof(offset32)); 350*51c0b2f7Stbbdev } 351*51c0b2f7Stbbdev 352*51c0b2f7Stbbdev // The following will work correctly even if srcAddr>tgtAddr, as long as 353*51c0b2f7Stbbdev // address difference is less than 2^31, which is guaranteed by IsInDistance. 354*51c0b2f7Stbbdev offset = tgtAddr - srcAddr - SIZE_OF_RELJUMP; 355*51c0b2f7Stbbdev offset32 = (UINT)(offset & 0xFFFFFFFF); 356*51c0b2f7Stbbdev // Insert the jump to the new code 357*51c0b2f7Stbbdev *codePtr = 0xE9; 358*51c0b2f7Stbbdev memcpy(codePtr+1, &offset32, sizeof(offset32)); 359*51c0b2f7Stbbdev 360*51c0b2f7Stbbdev // Fill the rest with NOPs to correctly see disassembler of old code in debugger. 361*51c0b2f7Stbbdev for( unsigned i=SIZE_OF_RELJUMP; i<bytesToMove; i++ ){ 362*51c0b2f7Stbbdev *(codePtr+i) = 0x90; 363*51c0b2f7Stbbdev } 364*51c0b2f7Stbbdev 365*51c0b2f7Stbbdev return SIZE_OF_RELJUMP; 366*51c0b2f7Stbbdev } 367*51c0b2f7Stbbdev 368*51c0b2f7Stbbdev // This function is called when the offset doesn't fit in 32 bits 369*51c0b2f7Stbbdev // 1 Find and allocate a page in the small distance (<2^31) from input address 370*51c0b2f7Stbbdev // 2 Put jump RIP relative indirect through the address in the close page 371*51c0b2f7Stbbdev // 3 Put the absolute address of the target in the allocated location 372*51c0b2f7Stbbdev // RETURN: the size of the trampoline or 0 on failure 373*51c0b2f7Stbbdev static DWORD InsertTrampoline64(void *inpAddr, void *targetAddr, const char* pattern, void** storedAddr) 374*51c0b2f7Stbbdev { 375*51c0b2f7Stbbdev size_t bytesToMove = SIZE_OF_INDJUMP; 376*51c0b2f7Stbbdev 377*51c0b2f7Stbbdev UINT_PTR srcAddr = Ptr2Addrint(inpAddr); 378*51c0b2f7Stbbdev UINT_PTR tgtAddr = Ptr2Addrint(targetAddr); 379*51c0b2f7Stbbdev 380*51c0b2f7Stbbdev // Get a location close to the source address 381*51c0b2f7Stbbdev UINT_PTR location = memProvider.GetLocation(srcAddr); 382*51c0b2f7Stbbdev if (!location) 383*51c0b2f7Stbbdev return 0; 384*51c0b2f7Stbbdev 385*51c0b2f7Stbbdev UINT_PTR offset; 386*51c0b2f7Stbbdev UINT offset32; 387*51c0b2f7Stbbdev UCHAR *codePtr = (UCHAR *)inpAddr; 388*51c0b2f7Stbbdev 389*51c0b2f7Stbbdev // Fill the location 390*51c0b2f7Stbbdev UINT_PTR *locPtr = (UINT_PTR *)Addrint2Ptr(location); 391*51c0b2f7Stbbdev *locPtr = tgtAddr; 392*51c0b2f7Stbbdev 393*51c0b2f7Stbbdev if ( storedAddr ){ // If requested, store original function code 394*51c0b2f7Stbbdev bytesToMove = strlen(pattern)/2-1; // The last byte matching the pattern must not be copied 395*51c0b2f7Stbbdev __TBB_ASSERT_RELEASE( bytesToMove >= SIZE_OF_INDJUMP, "Incorrect bytecode pattern?" ); 396*51c0b2f7Stbbdev UINT_PTR trampAddr = memProvider.GetLocation(srcAddr); 397*51c0b2f7Stbbdev if (!trampAddr) 398*51c0b2f7Stbbdev return 0; 399*51c0b2f7Stbbdev *storedAddr = Addrint2Ptr(trampAddr); 400*51c0b2f7Stbbdev // Set 'executable' flag for original instructions in the new place 401*51c0b2f7Stbbdev DWORD pageFlags = PAGE_EXECUTE_READWRITE; 402*51c0b2f7Stbbdev if (!VirtualProtect(*storedAddr, MAX_PROBE_SIZE, pageFlags, &pageFlags)) return 0; 403*51c0b2f7Stbbdev // Copy original instructions to the new place 404*51c0b2f7Stbbdev memcpy(*storedAddr, codePtr, bytesToMove); 405*51c0b2f7Stbbdev offset = srcAddr - trampAddr; 406*51c0b2f7Stbbdev offset32 = (UINT)(offset & 0xFFFFFFFF); 407*51c0b2f7Stbbdev CorrectOffset( trampAddr, pattern, offset32 ); 408*51c0b2f7Stbbdev // Set jump to the code after replacement. It is within the distance of relative jump! 409*51c0b2f7Stbbdev offset32 -= SIZE_OF_RELJUMP; 410*51c0b2f7Stbbdev *(UCHAR*)(trampAddr+bytesToMove) = 0xE9; 411*51c0b2f7Stbbdev memcpy((UCHAR*)(trampAddr+bytesToMove+1), &offset32, sizeof(offset32)); 412*51c0b2f7Stbbdev } 413*51c0b2f7Stbbdev 414*51c0b2f7Stbbdev // Fill the buffer 415*51c0b2f7Stbbdev offset = location - srcAddr - SIZE_OF_INDJUMP; 416*51c0b2f7Stbbdev offset32 = (UINT)(offset & 0xFFFFFFFF); 417*51c0b2f7Stbbdev *(codePtr) = 0xFF; 418*51c0b2f7Stbbdev *(codePtr+1) = 0x25; 419*51c0b2f7Stbbdev memcpy(codePtr+2, &offset32, sizeof(offset32)); 420*51c0b2f7Stbbdev 421*51c0b2f7Stbbdev // Fill the rest with NOPs to correctly see disassembler of old code in debugger. 422*51c0b2f7Stbbdev for( unsigned i=SIZE_OF_INDJUMP; i<bytesToMove; i++ ){ 423*51c0b2f7Stbbdev *(codePtr+i) = 0x90; 424*51c0b2f7Stbbdev } 425*51c0b2f7Stbbdev 426*51c0b2f7Stbbdev return SIZE_OF_INDJUMP; 427*51c0b2f7Stbbdev } 428*51c0b2f7Stbbdev 429*51c0b2f7Stbbdev // Insert a jump instruction in the inpAddr to the targetAddr 430*51c0b2f7Stbbdev // 1. Get the memory protection of the page containing the input address 431*51c0b2f7Stbbdev // 2. Change the memory protection to writable 432*51c0b2f7Stbbdev // 3. Call InsertTrampoline32 or InsertTrampoline64 433*51c0b2f7Stbbdev // 4. Restore memory protection 434*51c0b2f7Stbbdev // RETURN: FALSE on failure, TRUE on success 435*51c0b2f7Stbbdev static bool InsertTrampoline(void *inpAddr, void *targetAddr, const char ** opcodes, void** origFunc) 436*51c0b2f7Stbbdev { 437*51c0b2f7Stbbdev DWORD probeSize; 438*51c0b2f7Stbbdev // Change page protection to EXECUTE+WRITE 439*51c0b2f7Stbbdev DWORD origProt = 0; 440*51c0b2f7Stbbdev if (!VirtualProtect(inpAddr, MAX_PROBE_SIZE, PAGE_EXECUTE_WRITECOPY, &origProt)) 441*51c0b2f7Stbbdev return FALSE; 442*51c0b2f7Stbbdev 443*51c0b2f7Stbbdev const char* pattern = NULL; 444*51c0b2f7Stbbdev if ( origFunc ){ // Need to store original function code 445*51c0b2f7Stbbdev UCHAR * const codePtr = (UCHAR *)inpAddr; 446*51c0b2f7Stbbdev if ( *codePtr == 0xE9 ){ // JMP relative instruction 447*51c0b2f7Stbbdev // For the special case when a system function consists of a single near jump, 448*51c0b2f7Stbbdev // instead of moving it somewhere we use the target of the jump as the original function. 449*51c0b2f7Stbbdev unsigned offsetInJmp = *(unsigned*)(codePtr + 1); 450*51c0b2f7Stbbdev *origFunc = (void*)(Ptr2Addrint(inpAddr) + offsetInJmp + SIZE_OF_RELJUMP); 451*51c0b2f7Stbbdev origFunc = NULL; // now it must be ignored by InsertTrampoline32/64 452*51c0b2f7Stbbdev } else { 453*51c0b2f7Stbbdev // find the right opcode pattern 454*51c0b2f7Stbbdev UINT opcodeIdx = CheckOpcodes( opcodes, inpAddr, /*abortOnError=*/true ); 455*51c0b2f7Stbbdev __TBB_ASSERT( opcodeIdx > 0, "abortOnError ignored in CheckOpcodes?" ); 456*51c0b2f7Stbbdev pattern = opcodes[opcodeIdx-1]; // -1 compensates for +1 in CheckOpcodes 457*51c0b2f7Stbbdev } 458*51c0b2f7Stbbdev } 459*51c0b2f7Stbbdev 460*51c0b2f7Stbbdev probeSize = InsertTrampoline32(inpAddr, targetAddr, pattern, origFunc); 461*51c0b2f7Stbbdev if (!probeSize) 462*51c0b2f7Stbbdev probeSize = InsertTrampoline64(inpAddr, targetAddr, pattern, origFunc); 463*51c0b2f7Stbbdev 464*51c0b2f7Stbbdev // Restore original protection 465*51c0b2f7Stbbdev VirtualProtect(inpAddr, MAX_PROBE_SIZE, origProt, &origProt); 466*51c0b2f7Stbbdev 467*51c0b2f7Stbbdev if (!probeSize) 468*51c0b2f7Stbbdev return FALSE; 469*51c0b2f7Stbbdev 470*51c0b2f7Stbbdev FlushInstructionCache(GetCurrentProcess(), inpAddr, probeSize); 471*51c0b2f7Stbbdev FlushInstructionCache(GetCurrentProcess(), origFunc, probeSize); 472*51c0b2f7Stbbdev 473*51c0b2f7Stbbdev return TRUE; 474*51c0b2f7Stbbdev } 475*51c0b2f7Stbbdev 476*51c0b2f7Stbbdev // Routine to replace the functions 477*51c0b2f7Stbbdev // TODO: replace opcodesNumber with opcodes and opcodes number to check if we replace right code. 478*51c0b2f7Stbbdev FRR_TYPE ReplaceFunctionA(const char *dllName, const char *funcName, FUNCPTR newFunc, const char ** opcodes, FUNCPTR* origFunc) 479*51c0b2f7Stbbdev { 480*51c0b2f7Stbbdev // Cache the results of the last search for the module 481*51c0b2f7Stbbdev // Assume that there was no DLL unload between 482*51c0b2f7Stbbdev static char cachedName[MAX_PATH+1]; 483*51c0b2f7Stbbdev static HMODULE cachedHM = 0; 484*51c0b2f7Stbbdev 485*51c0b2f7Stbbdev if (!dllName || !*dllName) 486*51c0b2f7Stbbdev return FRR_NODLL; 487*51c0b2f7Stbbdev 488*51c0b2f7Stbbdev if (!cachedHM || strncmp(dllName, cachedName, MAX_PATH) != 0) 489*51c0b2f7Stbbdev { 490*51c0b2f7Stbbdev // Find the module handle for the input dll 491*51c0b2f7Stbbdev HMODULE hModule = GetModuleHandleA(dllName); 492*51c0b2f7Stbbdev if (hModule == 0) 493*51c0b2f7Stbbdev { 494*51c0b2f7Stbbdev // Couldn't find the module with the input name 495*51c0b2f7Stbbdev cachedHM = 0; 496*51c0b2f7Stbbdev return FRR_NODLL; 497*51c0b2f7Stbbdev } 498*51c0b2f7Stbbdev 499*51c0b2f7Stbbdev cachedHM = hModule; 500*51c0b2f7Stbbdev strncpy(cachedName, dllName, MAX_PATH); 501*51c0b2f7Stbbdev } 502*51c0b2f7Stbbdev 503*51c0b2f7Stbbdev FARPROC inpFunc = GetProcAddress(cachedHM, funcName); 504*51c0b2f7Stbbdev if (inpFunc == 0) 505*51c0b2f7Stbbdev { 506*51c0b2f7Stbbdev // Function was not found 507*51c0b2f7Stbbdev return FRR_NOFUNC; 508*51c0b2f7Stbbdev } 509*51c0b2f7Stbbdev 510*51c0b2f7Stbbdev if (!InsertTrampoline((void*)inpFunc, (void*)newFunc, opcodes, (void**)origFunc)){ 511*51c0b2f7Stbbdev // Failed to insert the trampoline to the target address 512*51c0b2f7Stbbdev return FRR_FAILED; 513*51c0b2f7Stbbdev } 514*51c0b2f7Stbbdev 515*51c0b2f7Stbbdev return FRR_OK; 516*51c0b2f7Stbbdev } 517*51c0b2f7Stbbdev 518*51c0b2f7Stbbdev FRR_TYPE ReplaceFunctionW(const wchar_t *dllName, const char *funcName, FUNCPTR newFunc, const char ** opcodes, FUNCPTR* origFunc) 519*51c0b2f7Stbbdev { 520*51c0b2f7Stbbdev // Cache the results of the last search for the module 521*51c0b2f7Stbbdev // Assume that there was no DLL unload between 522*51c0b2f7Stbbdev static wchar_t cachedName[MAX_PATH+1]; 523*51c0b2f7Stbbdev static HMODULE cachedHM = 0; 524*51c0b2f7Stbbdev 525*51c0b2f7Stbbdev if (!dllName || !*dllName) 526*51c0b2f7Stbbdev return FRR_NODLL; 527*51c0b2f7Stbbdev 528*51c0b2f7Stbbdev if (!cachedHM || wcsncmp(dllName, cachedName, MAX_PATH) != 0) 529*51c0b2f7Stbbdev { 530*51c0b2f7Stbbdev // Find the module handle for the input dll 531*51c0b2f7Stbbdev HMODULE hModule = GetModuleHandleW(dllName); 532*51c0b2f7Stbbdev if (hModule == 0) 533*51c0b2f7Stbbdev { 534*51c0b2f7Stbbdev // Couldn't find the module with the input name 535*51c0b2f7Stbbdev cachedHM = 0; 536*51c0b2f7Stbbdev return FRR_NODLL; 537*51c0b2f7Stbbdev } 538*51c0b2f7Stbbdev 539*51c0b2f7Stbbdev cachedHM = hModule; 540*51c0b2f7Stbbdev wcsncpy(cachedName, dllName, MAX_PATH); 541*51c0b2f7Stbbdev } 542*51c0b2f7Stbbdev 543*51c0b2f7Stbbdev FARPROC inpFunc = GetProcAddress(cachedHM, funcName); 544*51c0b2f7Stbbdev if (inpFunc == 0) 545*51c0b2f7Stbbdev { 546*51c0b2f7Stbbdev // Function was not found 547*51c0b2f7Stbbdev return FRR_NOFUNC; 548*51c0b2f7Stbbdev } 549*51c0b2f7Stbbdev 550*51c0b2f7Stbbdev if (!InsertTrampoline((void*)inpFunc, (void*)newFunc, opcodes, (void**)origFunc)){ 551*51c0b2f7Stbbdev // Failed to insert the trampoline to the target address 552*51c0b2f7Stbbdev return FRR_FAILED; 553*51c0b2f7Stbbdev } 554*51c0b2f7Stbbdev 555*51c0b2f7Stbbdev return FRR_OK; 556*51c0b2f7Stbbdev } 557*51c0b2f7Stbbdev 558*51c0b2f7Stbbdev bool IsPrologueKnown(const char* dllName, const char *funcName, const char **opcodes, HMODULE module) 559*51c0b2f7Stbbdev { 560*51c0b2f7Stbbdev FARPROC inpFunc = GetProcAddress(module, funcName); 561*51c0b2f7Stbbdev FunctionInfo functionInfo = { funcName, dllName }; 562*51c0b2f7Stbbdev 563*51c0b2f7Stbbdev if (!inpFunc) { 564*51c0b2f7Stbbdev Log::record(functionInfo, "unknown", /*status*/ false); 565*51c0b2f7Stbbdev return false; 566*51c0b2f7Stbbdev } 567*51c0b2f7Stbbdev 568*51c0b2f7Stbbdev return CheckOpcodes( opcodes, (void*)inpFunc, /*abortOnError=*/false, &functionInfo) != 0; 569*51c0b2f7Stbbdev } 570*51c0b2f7Stbbdev 571*51c0b2f7Stbbdev // Public Windows API 572*51c0b2f7Stbbdev extern "C" __declspec(dllexport) int TBB_malloc_replacement_log(char *** function_replacement_log_ptr) 573*51c0b2f7Stbbdev { 574*51c0b2f7Stbbdev if (function_replacement_log_ptr != NULL) { 575*51c0b2f7Stbbdev *function_replacement_log_ptr = Log::records; 576*51c0b2f7Stbbdev } 577*51c0b2f7Stbbdev 578*51c0b2f7Stbbdev // If we have no logs -> return false status 579*51c0b2f7Stbbdev return Log::replacement_status && Log::records[0] != NULL ? 0 : -1; 580*51c0b2f7Stbbdev } 581*51c0b2f7Stbbdev 582*51c0b2f7Stbbdev #endif /* !__TBB_WIN8UI_SUPPORT && defined(_WIN32) */ 583