1*51c0b2f7Stbbdev /*
2*51c0b2f7Stbbdev     Copyright (c) 2005-2020 Intel Corporation
3*51c0b2f7Stbbdev 
4*51c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
5*51c0b2f7Stbbdev     you may not use this file except in compliance with the License.
6*51c0b2f7Stbbdev     You may obtain a copy of the License at
7*51c0b2f7Stbbdev 
8*51c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
9*51c0b2f7Stbbdev 
10*51c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
11*51c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
12*51c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*51c0b2f7Stbbdev     See the License for the specific language governing permissions and
14*51c0b2f7Stbbdev     limitations under the License.
15*51c0b2f7Stbbdev */
16*51c0b2f7Stbbdev 
17*51c0b2f7Stbbdev #include "tbb/detail/_config.h"
18*51c0b2f7Stbbdev #include "tbb/detail/_assert.h"
19*51c0b2f7Stbbdev #include "../tbb/assert_impl.h"
20*51c0b2f7Stbbdev 
21*51c0b2f7Stbbdev #if !__TBB_WIN8UI_SUPPORT && defined(_WIN32)
22*51c0b2f7Stbbdev 
23*51c0b2f7Stbbdev #ifndef _CRT_SECURE_NO_DEPRECATE
24*51c0b2f7Stbbdev #define _CRT_SECURE_NO_DEPRECATE 1
25*51c0b2f7Stbbdev #endif
26*51c0b2f7Stbbdev 
27*51c0b2f7Stbbdev // no standard-conforming implementation of snprintf prior to VS 2015
28*51c0b2f7Stbbdev #if !defined(_MSC_VER) || _MSC_VER>=1900
29*51c0b2f7Stbbdev #define LOG_PRINT(s, n, format, ...) snprintf(s, n, format, __VA_ARGS__)
30*51c0b2f7Stbbdev #else
31*51c0b2f7Stbbdev #define LOG_PRINT(s, n, format, ...) _snprintf_s(s, n, _TRUNCATE, format, __VA_ARGS__)
32*51c0b2f7Stbbdev #endif
33*51c0b2f7Stbbdev 
34*51c0b2f7Stbbdev #include <windows.h>
35*51c0b2f7Stbbdev #include <new>
36*51c0b2f7Stbbdev #include <stdio.h>
37*51c0b2f7Stbbdev #include <string.h>
38*51c0b2f7Stbbdev 
39*51c0b2f7Stbbdev #include "function_replacement.h"
40*51c0b2f7Stbbdev 
41*51c0b2f7Stbbdev // The information about a standard memory allocation function for the replacement log
42*51c0b2f7Stbbdev struct FunctionInfo {
43*51c0b2f7Stbbdev     const char* funcName;
44*51c0b2f7Stbbdev     const char* dllName;
45*51c0b2f7Stbbdev };
46*51c0b2f7Stbbdev 
47*51c0b2f7Stbbdev // Namespace that processes and manages the output of records to the Log journal
48*51c0b2f7Stbbdev // that will be provided to user by TBB_malloc_replacement_log()
49*51c0b2f7Stbbdev namespace Log {
50*51c0b2f7Stbbdev     // Value of RECORDS_COUNT is set due to the fact that we maximally
51*51c0b2f7Stbbdev     // scan 8 modules, and in every module we can swap 6 opcodes. (rounded to 8)
52*51c0b2f7Stbbdev     static const unsigned RECORDS_COUNT = 8 * 8;
53*51c0b2f7Stbbdev     static const unsigned RECORD_LENGTH = MAX_PATH;
54*51c0b2f7Stbbdev 
55*51c0b2f7Stbbdev     // Need to add 1 to count of records, because last record must be always NULL
56*51c0b2f7Stbbdev     static char *records[RECORDS_COUNT + 1];
57*51c0b2f7Stbbdev     static bool replacement_status = true;
58*51c0b2f7Stbbdev 
59*51c0b2f7Stbbdev     // Internal counter that contains number of next string for record
60*51c0b2f7Stbbdev     static unsigned record_number = 0;
61*51c0b2f7Stbbdev 
62*51c0b2f7Stbbdev     // Function that writes info about (not)found opcodes to the Log journal
63*51c0b2f7Stbbdev     // functionInfo - information about a standard memory allocation function for the replacement log
64*51c0b2f7Stbbdev     // opcodeString - string, that contain byte code of this function
65*51c0b2f7Stbbdev     // status - information about function replacement status
66*51c0b2f7Stbbdev     static void record(FunctionInfo functionInfo, const char * opcodeString, bool status) {
67*51c0b2f7Stbbdev         __TBB_ASSERT(functionInfo.dllName, "Empty DLL name value");
68*51c0b2f7Stbbdev         __TBB_ASSERT(functionInfo.funcName, "Empty function name value");
69*51c0b2f7Stbbdev         __TBB_ASSERT(opcodeString, "Empty opcode");
70*51c0b2f7Stbbdev         __TBB_ASSERT(record_number <= RECORDS_COUNT, "Incorrect record number");
71*51c0b2f7Stbbdev 
72*51c0b2f7Stbbdev         //If some replacement failed -> set status to false
73*51c0b2f7Stbbdev         replacement_status &= status;
74*51c0b2f7Stbbdev 
75*51c0b2f7Stbbdev         // If we reach the end of the log, write this message to the last line
76*51c0b2f7Stbbdev         if (record_number == RECORDS_COUNT) {
77*51c0b2f7Stbbdev             // %s - workaround to fix empty variable argument parsing behavior in GCC
78*51c0b2f7Stbbdev             LOG_PRINT(records[RECORDS_COUNT - 1], RECORD_LENGTH, "%s", "Log was truncated.");
79*51c0b2f7Stbbdev             return;
80*51c0b2f7Stbbdev         }
81*51c0b2f7Stbbdev 
82*51c0b2f7Stbbdev         char* entry = (char*)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, RECORD_LENGTH);
83*51c0b2f7Stbbdev         __TBB_ASSERT(entry, "Invalid memory was returned");
84*51c0b2f7Stbbdev 
85*51c0b2f7Stbbdev         LOG_PRINT(entry, RECORD_LENGTH, "%s: %s (%s), byte pattern: <%s>",
86*51c0b2f7Stbbdev             status ? "Success" : "Fail", functionInfo.funcName, functionInfo.dllName, opcodeString);
87*51c0b2f7Stbbdev 
88*51c0b2f7Stbbdev         records[record_number++] = entry;
89*51c0b2f7Stbbdev     }
90*51c0b2f7Stbbdev };
91*51c0b2f7Stbbdev 
92*51c0b2f7Stbbdev inline UINT_PTR Ptr2Addrint(LPVOID ptr)
93*51c0b2f7Stbbdev {
94*51c0b2f7Stbbdev     Int2Ptr i2p;
95*51c0b2f7Stbbdev     i2p.lpv = ptr;
96*51c0b2f7Stbbdev     return i2p.uip;
97*51c0b2f7Stbbdev }
98*51c0b2f7Stbbdev 
99*51c0b2f7Stbbdev inline LPVOID Addrint2Ptr(UINT_PTR ptr)
100*51c0b2f7Stbbdev {
101*51c0b2f7Stbbdev     Int2Ptr i2p;
102*51c0b2f7Stbbdev     i2p.uip = ptr;
103*51c0b2f7Stbbdev     return i2p.lpv;
104*51c0b2f7Stbbdev }
105*51c0b2f7Stbbdev 
106*51c0b2f7Stbbdev // Is the distance between addr1 and addr2 smaller than dist
107*51c0b2f7Stbbdev inline bool IsInDistance(UINT_PTR addr1, UINT_PTR addr2, __int64 dist)
108*51c0b2f7Stbbdev {
109*51c0b2f7Stbbdev     __int64 diff = addr1>addr2 ? addr1-addr2 : addr2-addr1;
110*51c0b2f7Stbbdev     return diff<dist;
111*51c0b2f7Stbbdev }
112*51c0b2f7Stbbdev 
113*51c0b2f7Stbbdev /*
114*51c0b2f7Stbbdev  * When inserting a probe in 64 bits process the distance between the insertion
115*51c0b2f7Stbbdev  * point and the target may be bigger than 2^32. In this case we are using
116*51c0b2f7Stbbdev  * indirect jump through memory where the offset to this memory location
117*51c0b2f7Stbbdev  * is smaller than 2^32 and it contains the absolute address (8 bytes).
118*51c0b2f7Stbbdev  *
119*51c0b2f7Stbbdev  * This class is used to hold the pages used for the above trampolines.
120*51c0b2f7Stbbdev  * Since this utility will be used to replace malloc functions this implementation
121*51c0b2f7Stbbdev  * doesn't allocate memory dynamically.
122*51c0b2f7Stbbdev  *
123*51c0b2f7Stbbdev  * The struct MemoryBuffer holds the data about a page in the memory used for
124*51c0b2f7Stbbdev  * replacing functions in 64-bit code where the target is too far to be replaced
125*51c0b2f7Stbbdev  * with a short jump. All the calculations of m_base and m_next are in a multiple
126*51c0b2f7Stbbdev  * of SIZE_OF_ADDRESS (which is 8 in Win64).
127*51c0b2f7Stbbdev  */
128*51c0b2f7Stbbdev class MemoryProvider {
129*51c0b2f7Stbbdev private:
130*51c0b2f7Stbbdev     struct MemoryBuffer {
131*51c0b2f7Stbbdev         UINT_PTR m_base;    // base address of the buffer
132*51c0b2f7Stbbdev         UINT_PTR m_next;    // next free location in the buffer
133*51c0b2f7Stbbdev         DWORD    m_size;    // size of buffer
134*51c0b2f7Stbbdev 
135*51c0b2f7Stbbdev         // Default constructor
136*51c0b2f7Stbbdev         MemoryBuffer() : m_base(0), m_next(0), m_size(0) {}
137*51c0b2f7Stbbdev 
138*51c0b2f7Stbbdev         // Constructor
139*51c0b2f7Stbbdev         MemoryBuffer(void *base, DWORD size)
140*51c0b2f7Stbbdev         {
141*51c0b2f7Stbbdev             m_base = Ptr2Addrint(base);
142*51c0b2f7Stbbdev             m_next = m_base;
143*51c0b2f7Stbbdev             m_size = size;
144*51c0b2f7Stbbdev         }
145*51c0b2f7Stbbdev     };
146*51c0b2f7Stbbdev 
147*51c0b2f7Stbbdev MemoryBuffer *CreateBuffer(UINT_PTR addr)
148*51c0b2f7Stbbdev     {
149*51c0b2f7Stbbdev         // No more room in the pages database
150*51c0b2f7Stbbdev         if (m_lastBuffer - m_pages == MAX_NUM_BUFFERS)
151*51c0b2f7Stbbdev             return 0;
152*51c0b2f7Stbbdev 
153*51c0b2f7Stbbdev         void *newAddr = Addrint2Ptr(addr);
154*51c0b2f7Stbbdev         // Get information for the region which the given address belongs to
155*51c0b2f7Stbbdev         MEMORY_BASIC_INFORMATION memInfo;
156*51c0b2f7Stbbdev         if (VirtualQuery(newAddr, &memInfo, sizeof(memInfo)) != sizeof(memInfo))
157*51c0b2f7Stbbdev             return 0;
158*51c0b2f7Stbbdev 
159*51c0b2f7Stbbdev         for(;;) {
160*51c0b2f7Stbbdev             // The new address to check is beyond the current region and aligned to allocation size
161*51c0b2f7Stbbdev             newAddr = Addrint2Ptr( (Ptr2Addrint(memInfo.BaseAddress) + memInfo.RegionSize + m_allocSize) & ~(UINT_PTR)(m_allocSize-1) );
162*51c0b2f7Stbbdev 
163*51c0b2f7Stbbdev             // Check that the address is in the right distance.
164*51c0b2f7Stbbdev             // VirtualAlloc can only round the address down; so it will remain in the right distance
165*51c0b2f7Stbbdev             if (!IsInDistance(addr, Ptr2Addrint(newAddr), MAX_DISTANCE))
166*51c0b2f7Stbbdev                 break;
167*51c0b2f7Stbbdev 
168*51c0b2f7Stbbdev             if (VirtualQuery(newAddr, &memInfo, sizeof(memInfo)) != sizeof(memInfo))
169*51c0b2f7Stbbdev                 break;
170*51c0b2f7Stbbdev 
171*51c0b2f7Stbbdev             if (memInfo.State == MEM_FREE && memInfo.RegionSize >= m_allocSize)
172*51c0b2f7Stbbdev             {
173*51c0b2f7Stbbdev                 // Found a free region, try to allocate a page in this region
174*51c0b2f7Stbbdev                 void *newPage = VirtualAlloc(newAddr, m_allocSize, MEM_COMMIT|MEM_RESERVE, PAGE_READWRITE);
175*51c0b2f7Stbbdev                 if (!newPage)
176*51c0b2f7Stbbdev                     break;
177*51c0b2f7Stbbdev 
178*51c0b2f7Stbbdev                 // Add the new page to the pages database
179*51c0b2f7Stbbdev                 MemoryBuffer *pBuff = new (m_lastBuffer) MemoryBuffer(newPage, m_allocSize);
180*51c0b2f7Stbbdev                 ++m_lastBuffer;
181*51c0b2f7Stbbdev                 return pBuff;
182*51c0b2f7Stbbdev             }
183*51c0b2f7Stbbdev         }
184*51c0b2f7Stbbdev 
185*51c0b2f7Stbbdev         // Failed to find a buffer in the distance
186*51c0b2f7Stbbdev         return 0;
187*51c0b2f7Stbbdev     }
188*51c0b2f7Stbbdev 
189*51c0b2f7Stbbdev public:
190*51c0b2f7Stbbdev     MemoryProvider()
191*51c0b2f7Stbbdev     {
192*51c0b2f7Stbbdev         SYSTEM_INFO sysInfo;
193*51c0b2f7Stbbdev         GetSystemInfo(&sysInfo);
194*51c0b2f7Stbbdev         m_allocSize = sysInfo.dwAllocationGranularity;
195*51c0b2f7Stbbdev         m_lastBuffer = &m_pages[0];
196*51c0b2f7Stbbdev     }
197*51c0b2f7Stbbdev 
198*51c0b2f7Stbbdev     // We can't free the pages in the destructor because the trampolines
199*51c0b2f7Stbbdev     // are using these memory locations and a replaced function might be called
200*51c0b2f7Stbbdev     // after the destructor was called.
201*51c0b2f7Stbbdev     ~MemoryProvider()
202*51c0b2f7Stbbdev     {
203*51c0b2f7Stbbdev     }
204*51c0b2f7Stbbdev 
205*51c0b2f7Stbbdev     // Return a memory location in distance less than 2^31 from input address
206*51c0b2f7Stbbdev     UINT_PTR GetLocation(UINT_PTR addr)
207*51c0b2f7Stbbdev     {
208*51c0b2f7Stbbdev         MemoryBuffer *pBuff = m_pages;
209*51c0b2f7Stbbdev         for (; pBuff<m_lastBuffer && IsInDistance(pBuff->m_next, addr, MAX_DISTANCE); ++pBuff)
210*51c0b2f7Stbbdev         {
211*51c0b2f7Stbbdev             if (pBuff->m_next < pBuff->m_base + pBuff->m_size)
212*51c0b2f7Stbbdev             {
213*51c0b2f7Stbbdev                 UINT_PTR loc = pBuff->m_next;
214*51c0b2f7Stbbdev                 pBuff->m_next += MAX_PROBE_SIZE;
215*51c0b2f7Stbbdev                 return loc;
216*51c0b2f7Stbbdev             }
217*51c0b2f7Stbbdev         }
218*51c0b2f7Stbbdev 
219*51c0b2f7Stbbdev         pBuff = CreateBuffer(addr);
220*51c0b2f7Stbbdev         if(!pBuff)
221*51c0b2f7Stbbdev             return 0;
222*51c0b2f7Stbbdev 
223*51c0b2f7Stbbdev         UINT_PTR loc = pBuff->m_next;
224*51c0b2f7Stbbdev         pBuff->m_next += MAX_PROBE_SIZE;
225*51c0b2f7Stbbdev         return loc;
226*51c0b2f7Stbbdev     }
227*51c0b2f7Stbbdev 
228*51c0b2f7Stbbdev private:
229*51c0b2f7Stbbdev     MemoryBuffer m_pages[MAX_NUM_BUFFERS];
230*51c0b2f7Stbbdev     MemoryBuffer *m_lastBuffer;
231*51c0b2f7Stbbdev     DWORD m_allocSize;
232*51c0b2f7Stbbdev };
233*51c0b2f7Stbbdev 
234*51c0b2f7Stbbdev static MemoryProvider memProvider;
235*51c0b2f7Stbbdev 
236*51c0b2f7Stbbdev // Compare opcodes from dictionary (str1) and opcodes from code (str2)
237*51c0b2f7Stbbdev // str1 might contain '*' to mask addresses
238*51c0b2f7Stbbdev // RETURN: 0 if opcodes did not match, 1 on success
239*51c0b2f7Stbbdev size_t compareStrings( const char *str1, const char *str2 )
240*51c0b2f7Stbbdev {
241*51c0b2f7Stbbdev    for (size_t i=0; str1[i]!=0; i++){
242*51c0b2f7Stbbdev        if( str1[i]!='*' && str1[i]!='#' && str1[i]!=str2[i] ) return 0;
243*51c0b2f7Stbbdev    }
244*51c0b2f7Stbbdev    return 1;
245*51c0b2f7Stbbdev }
246*51c0b2f7Stbbdev 
247*51c0b2f7Stbbdev // Check function prologue with known prologues from the dictionary
248*51c0b2f7Stbbdev // opcodes - dictionary
249*51c0b2f7Stbbdev // inpAddr - pointer to function prologue
250*51c0b2f7Stbbdev // Dictionary contains opcodes for several full asm instructions
251*51c0b2f7Stbbdev // + one opcode byte for the next asm instruction for safe address processing
252*51c0b2f7Stbbdev // RETURN: 1 + the index of the matched pattern, or 0 if no match found.
253*51c0b2f7Stbbdev static UINT CheckOpcodes( const char ** opcodes, void *inpAddr, bool abortOnError, const FunctionInfo* functionInfo = NULL)
254*51c0b2f7Stbbdev {
255*51c0b2f7Stbbdev     static size_t opcodesStringsCount = 0;
256*51c0b2f7Stbbdev     static size_t maxOpcodesLength = 0;
257*51c0b2f7Stbbdev     static size_t opcodes_pointer = (size_t)opcodes;
258*51c0b2f7Stbbdev     char opcodeString[2*MAX_PATTERN_SIZE+1];
259*51c0b2f7Stbbdev     size_t i;
260*51c0b2f7Stbbdev     size_t result = 0;
261*51c0b2f7Stbbdev 
262*51c0b2f7Stbbdev     // Get the values for static variables
263*51c0b2f7Stbbdev     // max length and number of patterns
264*51c0b2f7Stbbdev     if( !opcodesStringsCount || opcodes_pointer != (size_t)opcodes ){
265*51c0b2f7Stbbdev         while( *(opcodes + opcodesStringsCount)!= NULL ){
266*51c0b2f7Stbbdev             if( (i=strlen(*(opcodes + opcodesStringsCount))) > maxOpcodesLength )
267*51c0b2f7Stbbdev                 maxOpcodesLength = i;
268*51c0b2f7Stbbdev             opcodesStringsCount++;
269*51c0b2f7Stbbdev         }
270*51c0b2f7Stbbdev         opcodes_pointer = (size_t)opcodes;
271*51c0b2f7Stbbdev         __TBB_ASSERT( maxOpcodesLength/2 <= MAX_PATTERN_SIZE, "Pattern exceeded the limit of 28 opcodes/56 symbols" );
272*51c0b2f7Stbbdev     }
273*51c0b2f7Stbbdev 
274*51c0b2f7Stbbdev     // Translate prologue opcodes to string format to compare
275*51c0b2f7Stbbdev     for( i=0; i<maxOpcodesLength/2 && i<MAX_PATTERN_SIZE; ++i ){
276*51c0b2f7Stbbdev         sprintf( opcodeString + 2*i, "%.2X", *((unsigned char*)inpAddr+i) );
277*51c0b2f7Stbbdev     }
278*51c0b2f7Stbbdev     opcodeString[2*i] = 0;
279*51c0b2f7Stbbdev 
280*51c0b2f7Stbbdev     // Compare translated opcodes with patterns
281*51c0b2f7Stbbdev     for( UINT idx=0; idx<opcodesStringsCount; ++idx ){
282*51c0b2f7Stbbdev         result = compareStrings( opcodes[idx],opcodeString );
283*51c0b2f7Stbbdev         if( result ) {
284*51c0b2f7Stbbdev             if (functionInfo) {
285*51c0b2f7Stbbdev                 Log::record(*functionInfo, opcodeString, /*status*/ true);
286*51c0b2f7Stbbdev             }
287*51c0b2f7Stbbdev             return idx + 1; // avoid 0 which indicates a failure
288*51c0b2f7Stbbdev         }
289*51c0b2f7Stbbdev     }
290*51c0b2f7Stbbdev     if (functionInfo) {
291*51c0b2f7Stbbdev         Log::record(*functionInfo, opcodeString, /*status*/ false);
292*51c0b2f7Stbbdev     }
293*51c0b2f7Stbbdev     if (abortOnError) {
294*51c0b2f7Stbbdev         // Impossibility to find opcodes in the dictionary is a serious issue,
295*51c0b2f7Stbbdev         // as if we unable to call original function, leak or crash is expected result.
296*51c0b2f7Stbbdev         __TBB_ASSERT_RELEASE( false, "CheckOpcodes failed" );
297*51c0b2f7Stbbdev     }
298*51c0b2f7Stbbdev     return 0;
299*51c0b2f7Stbbdev }
300*51c0b2f7Stbbdev 
301*51c0b2f7Stbbdev // Modify offsets in original code after moving it to a trampoline.
302*51c0b2f7Stbbdev // We do not have more than one offset to correct in existing opcode patterns.
303*51c0b2f7Stbbdev static void CorrectOffset( UINT_PTR address, const char* pattern, UINT distance )
304*51c0b2f7Stbbdev {
305*51c0b2f7Stbbdev     const char* pos = strstr(pattern, "#*******");
306*51c0b2f7Stbbdev     if( pos ) {
307*51c0b2f7Stbbdev         address += (pos - pattern)/2; // compute the offset position
308*51c0b2f7Stbbdev         UINT value;
309*51c0b2f7Stbbdev         // UINT assignment is not used to avoid potential alignment issues
310*51c0b2f7Stbbdev         memcpy(&value, Addrint2Ptr(address), sizeof(value));
311*51c0b2f7Stbbdev         value += distance;
312*51c0b2f7Stbbdev         memcpy(Addrint2Ptr(address), &value, sizeof(value));
313*51c0b2f7Stbbdev     }
314*51c0b2f7Stbbdev }
315*51c0b2f7Stbbdev 
316*51c0b2f7Stbbdev // Insert jump relative instruction to the input address
317*51c0b2f7Stbbdev // RETURN: the size of the trampoline or 0 on failure
318*51c0b2f7Stbbdev static DWORD InsertTrampoline32(void *inpAddr, void *targetAddr, const char* pattern, void** storedAddr)
319*51c0b2f7Stbbdev {
320*51c0b2f7Stbbdev     size_t bytesToMove = SIZE_OF_RELJUMP;
321*51c0b2f7Stbbdev     UINT_PTR srcAddr = Ptr2Addrint(inpAddr);
322*51c0b2f7Stbbdev     UINT_PTR tgtAddr = Ptr2Addrint(targetAddr);
323*51c0b2f7Stbbdev     // Check that the target fits in 32 bits
324*51c0b2f7Stbbdev     if (!IsInDistance(srcAddr, tgtAddr, MAX_DISTANCE))
325*51c0b2f7Stbbdev         return 0;
326*51c0b2f7Stbbdev 
327*51c0b2f7Stbbdev     UINT_PTR offset;
328*51c0b2f7Stbbdev     UINT offset32;
329*51c0b2f7Stbbdev     UCHAR *codePtr = (UCHAR *)inpAddr;
330*51c0b2f7Stbbdev 
331*51c0b2f7Stbbdev     if ( storedAddr ){ // If requested, store original function code
332*51c0b2f7Stbbdev         bytesToMove = strlen(pattern)/2-1; // The last byte matching the pattern must not be copied
333*51c0b2f7Stbbdev         __TBB_ASSERT_RELEASE( bytesToMove >= SIZE_OF_RELJUMP, "Incorrect bytecode pattern?" );
334*51c0b2f7Stbbdev         UINT_PTR trampAddr = memProvider.GetLocation(srcAddr);
335*51c0b2f7Stbbdev         if (!trampAddr)
336*51c0b2f7Stbbdev             return 0;
337*51c0b2f7Stbbdev         *storedAddr = Addrint2Ptr(trampAddr);
338*51c0b2f7Stbbdev         // Set 'executable' flag for original instructions in the new place
339*51c0b2f7Stbbdev         DWORD pageFlags = PAGE_EXECUTE_READWRITE;
340*51c0b2f7Stbbdev         if (!VirtualProtect(*storedAddr, MAX_PROBE_SIZE, pageFlags, &pageFlags)) return 0;
341*51c0b2f7Stbbdev         // Copy original instructions to the new place
342*51c0b2f7Stbbdev         memcpy(*storedAddr, codePtr, bytesToMove);
343*51c0b2f7Stbbdev         offset = srcAddr - trampAddr;
344*51c0b2f7Stbbdev         offset32 = (UINT)(offset & 0xFFFFFFFF);
345*51c0b2f7Stbbdev         CorrectOffset( trampAddr, pattern, offset32 );
346*51c0b2f7Stbbdev         // Set jump to the code after replacement
347*51c0b2f7Stbbdev         offset32 -= SIZE_OF_RELJUMP;
348*51c0b2f7Stbbdev         *(UCHAR*)(trampAddr+bytesToMove) = 0xE9;
349*51c0b2f7Stbbdev         memcpy((UCHAR*)(trampAddr+bytesToMove+1), &offset32, sizeof(offset32));
350*51c0b2f7Stbbdev     }
351*51c0b2f7Stbbdev 
352*51c0b2f7Stbbdev     // The following will work correctly even if srcAddr>tgtAddr, as long as
353*51c0b2f7Stbbdev     // address difference is less than 2^31, which is guaranteed by IsInDistance.
354*51c0b2f7Stbbdev     offset = tgtAddr - srcAddr - SIZE_OF_RELJUMP;
355*51c0b2f7Stbbdev     offset32 = (UINT)(offset & 0xFFFFFFFF);
356*51c0b2f7Stbbdev     // Insert the jump to the new code
357*51c0b2f7Stbbdev     *codePtr = 0xE9;
358*51c0b2f7Stbbdev     memcpy(codePtr+1, &offset32, sizeof(offset32));
359*51c0b2f7Stbbdev 
360*51c0b2f7Stbbdev     // Fill the rest with NOPs to correctly see disassembler of old code in debugger.
361*51c0b2f7Stbbdev     for( unsigned i=SIZE_OF_RELJUMP; i<bytesToMove; i++ ){
362*51c0b2f7Stbbdev         *(codePtr+i) = 0x90;
363*51c0b2f7Stbbdev     }
364*51c0b2f7Stbbdev 
365*51c0b2f7Stbbdev     return SIZE_OF_RELJUMP;
366*51c0b2f7Stbbdev }
367*51c0b2f7Stbbdev 
368*51c0b2f7Stbbdev // This function is called when the offset doesn't fit in 32 bits
369*51c0b2f7Stbbdev // 1  Find and allocate a page in the small distance (<2^31) from input address
370*51c0b2f7Stbbdev // 2  Put jump RIP relative indirect through the address in the close page
371*51c0b2f7Stbbdev // 3  Put the absolute address of the target in the allocated location
372*51c0b2f7Stbbdev // RETURN: the size of the trampoline or 0 on failure
373*51c0b2f7Stbbdev static DWORD InsertTrampoline64(void *inpAddr, void *targetAddr, const char* pattern, void** storedAddr)
374*51c0b2f7Stbbdev {
375*51c0b2f7Stbbdev     size_t bytesToMove = SIZE_OF_INDJUMP;
376*51c0b2f7Stbbdev 
377*51c0b2f7Stbbdev     UINT_PTR srcAddr = Ptr2Addrint(inpAddr);
378*51c0b2f7Stbbdev     UINT_PTR tgtAddr = Ptr2Addrint(targetAddr);
379*51c0b2f7Stbbdev 
380*51c0b2f7Stbbdev     // Get a location close to the source address
381*51c0b2f7Stbbdev     UINT_PTR location = memProvider.GetLocation(srcAddr);
382*51c0b2f7Stbbdev     if (!location)
383*51c0b2f7Stbbdev         return 0;
384*51c0b2f7Stbbdev 
385*51c0b2f7Stbbdev     UINT_PTR offset;
386*51c0b2f7Stbbdev     UINT offset32;
387*51c0b2f7Stbbdev     UCHAR *codePtr = (UCHAR *)inpAddr;
388*51c0b2f7Stbbdev 
389*51c0b2f7Stbbdev     // Fill the location
390*51c0b2f7Stbbdev     UINT_PTR *locPtr = (UINT_PTR *)Addrint2Ptr(location);
391*51c0b2f7Stbbdev     *locPtr = tgtAddr;
392*51c0b2f7Stbbdev 
393*51c0b2f7Stbbdev     if ( storedAddr ){ // If requested, store original function code
394*51c0b2f7Stbbdev         bytesToMove = strlen(pattern)/2-1; // The last byte matching the pattern must not be copied
395*51c0b2f7Stbbdev         __TBB_ASSERT_RELEASE( bytesToMove >= SIZE_OF_INDJUMP, "Incorrect bytecode pattern?" );
396*51c0b2f7Stbbdev         UINT_PTR trampAddr = memProvider.GetLocation(srcAddr);
397*51c0b2f7Stbbdev         if (!trampAddr)
398*51c0b2f7Stbbdev             return 0;
399*51c0b2f7Stbbdev         *storedAddr = Addrint2Ptr(trampAddr);
400*51c0b2f7Stbbdev         // Set 'executable' flag for original instructions in the new place
401*51c0b2f7Stbbdev         DWORD pageFlags = PAGE_EXECUTE_READWRITE;
402*51c0b2f7Stbbdev         if (!VirtualProtect(*storedAddr, MAX_PROBE_SIZE, pageFlags, &pageFlags)) return 0;
403*51c0b2f7Stbbdev         // Copy original instructions to the new place
404*51c0b2f7Stbbdev         memcpy(*storedAddr, codePtr, bytesToMove);
405*51c0b2f7Stbbdev         offset = srcAddr - trampAddr;
406*51c0b2f7Stbbdev         offset32 = (UINT)(offset & 0xFFFFFFFF);
407*51c0b2f7Stbbdev         CorrectOffset( trampAddr, pattern, offset32 );
408*51c0b2f7Stbbdev         // Set jump to the code after replacement. It is within the distance of relative jump!
409*51c0b2f7Stbbdev         offset32 -= SIZE_OF_RELJUMP;
410*51c0b2f7Stbbdev         *(UCHAR*)(trampAddr+bytesToMove) = 0xE9;
411*51c0b2f7Stbbdev         memcpy((UCHAR*)(trampAddr+bytesToMove+1), &offset32, sizeof(offset32));
412*51c0b2f7Stbbdev     }
413*51c0b2f7Stbbdev 
414*51c0b2f7Stbbdev     // Fill the buffer
415*51c0b2f7Stbbdev     offset = location - srcAddr - SIZE_OF_INDJUMP;
416*51c0b2f7Stbbdev     offset32 = (UINT)(offset & 0xFFFFFFFF);
417*51c0b2f7Stbbdev     *(codePtr) = 0xFF;
418*51c0b2f7Stbbdev     *(codePtr+1) = 0x25;
419*51c0b2f7Stbbdev     memcpy(codePtr+2, &offset32, sizeof(offset32));
420*51c0b2f7Stbbdev 
421*51c0b2f7Stbbdev     // Fill the rest with NOPs to correctly see disassembler of old code in debugger.
422*51c0b2f7Stbbdev     for( unsigned i=SIZE_OF_INDJUMP; i<bytesToMove; i++ ){
423*51c0b2f7Stbbdev         *(codePtr+i) = 0x90;
424*51c0b2f7Stbbdev     }
425*51c0b2f7Stbbdev 
426*51c0b2f7Stbbdev     return SIZE_OF_INDJUMP;
427*51c0b2f7Stbbdev }
428*51c0b2f7Stbbdev 
429*51c0b2f7Stbbdev // Insert a jump instruction in the inpAddr to the targetAddr
430*51c0b2f7Stbbdev // 1. Get the memory protection of the page containing the input address
431*51c0b2f7Stbbdev // 2. Change the memory protection to writable
432*51c0b2f7Stbbdev // 3. Call InsertTrampoline32 or InsertTrampoline64
433*51c0b2f7Stbbdev // 4. Restore memory protection
434*51c0b2f7Stbbdev // RETURN: FALSE on failure, TRUE on success
435*51c0b2f7Stbbdev static bool InsertTrampoline(void *inpAddr, void *targetAddr, const char ** opcodes, void** origFunc)
436*51c0b2f7Stbbdev {
437*51c0b2f7Stbbdev     DWORD probeSize;
438*51c0b2f7Stbbdev     // Change page protection to EXECUTE+WRITE
439*51c0b2f7Stbbdev     DWORD origProt = 0;
440*51c0b2f7Stbbdev     if (!VirtualProtect(inpAddr, MAX_PROBE_SIZE, PAGE_EXECUTE_WRITECOPY, &origProt))
441*51c0b2f7Stbbdev         return FALSE;
442*51c0b2f7Stbbdev 
443*51c0b2f7Stbbdev     const char* pattern = NULL;
444*51c0b2f7Stbbdev     if ( origFunc ){ // Need to store original function code
445*51c0b2f7Stbbdev         UCHAR * const codePtr = (UCHAR *)inpAddr;
446*51c0b2f7Stbbdev         if ( *codePtr == 0xE9 ){ // JMP relative instruction
447*51c0b2f7Stbbdev             // For the special case when a system function consists of a single near jump,
448*51c0b2f7Stbbdev             // instead of moving it somewhere we use the target of the jump as the original function.
449*51c0b2f7Stbbdev             unsigned offsetInJmp = *(unsigned*)(codePtr + 1);
450*51c0b2f7Stbbdev             *origFunc = (void*)(Ptr2Addrint(inpAddr) + offsetInJmp + SIZE_OF_RELJUMP);
451*51c0b2f7Stbbdev             origFunc = NULL; // now it must be ignored by InsertTrampoline32/64
452*51c0b2f7Stbbdev         } else {
453*51c0b2f7Stbbdev             // find the right opcode pattern
454*51c0b2f7Stbbdev             UINT opcodeIdx = CheckOpcodes( opcodes, inpAddr, /*abortOnError=*/true );
455*51c0b2f7Stbbdev             __TBB_ASSERT( opcodeIdx > 0, "abortOnError ignored in CheckOpcodes?" );
456*51c0b2f7Stbbdev             pattern = opcodes[opcodeIdx-1];  // -1 compensates for +1 in CheckOpcodes
457*51c0b2f7Stbbdev         }
458*51c0b2f7Stbbdev     }
459*51c0b2f7Stbbdev 
460*51c0b2f7Stbbdev     probeSize = InsertTrampoline32(inpAddr, targetAddr, pattern, origFunc);
461*51c0b2f7Stbbdev     if (!probeSize)
462*51c0b2f7Stbbdev         probeSize = InsertTrampoline64(inpAddr, targetAddr, pattern, origFunc);
463*51c0b2f7Stbbdev 
464*51c0b2f7Stbbdev     // Restore original protection
465*51c0b2f7Stbbdev     VirtualProtect(inpAddr, MAX_PROBE_SIZE, origProt, &origProt);
466*51c0b2f7Stbbdev 
467*51c0b2f7Stbbdev     if (!probeSize)
468*51c0b2f7Stbbdev         return FALSE;
469*51c0b2f7Stbbdev 
470*51c0b2f7Stbbdev     FlushInstructionCache(GetCurrentProcess(), inpAddr, probeSize);
471*51c0b2f7Stbbdev     FlushInstructionCache(GetCurrentProcess(), origFunc, probeSize);
472*51c0b2f7Stbbdev 
473*51c0b2f7Stbbdev     return TRUE;
474*51c0b2f7Stbbdev }
475*51c0b2f7Stbbdev 
476*51c0b2f7Stbbdev // Routine to replace the functions
477*51c0b2f7Stbbdev // TODO: replace opcodesNumber with opcodes and opcodes number to check if we replace right code.
478*51c0b2f7Stbbdev FRR_TYPE ReplaceFunctionA(const char *dllName, const char *funcName, FUNCPTR newFunc, const char ** opcodes, FUNCPTR* origFunc)
479*51c0b2f7Stbbdev {
480*51c0b2f7Stbbdev     // Cache the results of the last search for the module
481*51c0b2f7Stbbdev     // Assume that there was no DLL unload between
482*51c0b2f7Stbbdev     static char cachedName[MAX_PATH+1];
483*51c0b2f7Stbbdev     static HMODULE cachedHM = 0;
484*51c0b2f7Stbbdev 
485*51c0b2f7Stbbdev     if (!dllName || !*dllName)
486*51c0b2f7Stbbdev         return FRR_NODLL;
487*51c0b2f7Stbbdev 
488*51c0b2f7Stbbdev     if (!cachedHM || strncmp(dllName, cachedName, MAX_PATH) != 0)
489*51c0b2f7Stbbdev     {
490*51c0b2f7Stbbdev         // Find the module handle for the input dll
491*51c0b2f7Stbbdev         HMODULE hModule = GetModuleHandleA(dllName);
492*51c0b2f7Stbbdev         if (hModule == 0)
493*51c0b2f7Stbbdev         {
494*51c0b2f7Stbbdev             // Couldn't find the module with the input name
495*51c0b2f7Stbbdev             cachedHM = 0;
496*51c0b2f7Stbbdev             return FRR_NODLL;
497*51c0b2f7Stbbdev         }
498*51c0b2f7Stbbdev 
499*51c0b2f7Stbbdev         cachedHM = hModule;
500*51c0b2f7Stbbdev         strncpy(cachedName, dllName, MAX_PATH);
501*51c0b2f7Stbbdev     }
502*51c0b2f7Stbbdev 
503*51c0b2f7Stbbdev     FARPROC inpFunc = GetProcAddress(cachedHM, funcName);
504*51c0b2f7Stbbdev     if (inpFunc == 0)
505*51c0b2f7Stbbdev     {
506*51c0b2f7Stbbdev         // Function was not found
507*51c0b2f7Stbbdev         return FRR_NOFUNC;
508*51c0b2f7Stbbdev     }
509*51c0b2f7Stbbdev 
510*51c0b2f7Stbbdev     if (!InsertTrampoline((void*)inpFunc, (void*)newFunc, opcodes, (void**)origFunc)){
511*51c0b2f7Stbbdev         // Failed to insert the trampoline to the target address
512*51c0b2f7Stbbdev         return FRR_FAILED;
513*51c0b2f7Stbbdev     }
514*51c0b2f7Stbbdev 
515*51c0b2f7Stbbdev     return FRR_OK;
516*51c0b2f7Stbbdev }
517*51c0b2f7Stbbdev 
518*51c0b2f7Stbbdev FRR_TYPE ReplaceFunctionW(const wchar_t *dllName, const char *funcName, FUNCPTR newFunc, const char ** opcodes, FUNCPTR* origFunc)
519*51c0b2f7Stbbdev {
520*51c0b2f7Stbbdev     // Cache the results of the last search for the module
521*51c0b2f7Stbbdev     // Assume that there was no DLL unload between
522*51c0b2f7Stbbdev     static wchar_t cachedName[MAX_PATH+1];
523*51c0b2f7Stbbdev     static HMODULE cachedHM = 0;
524*51c0b2f7Stbbdev 
525*51c0b2f7Stbbdev     if (!dllName || !*dllName)
526*51c0b2f7Stbbdev         return FRR_NODLL;
527*51c0b2f7Stbbdev 
528*51c0b2f7Stbbdev     if (!cachedHM || wcsncmp(dllName, cachedName, MAX_PATH) != 0)
529*51c0b2f7Stbbdev     {
530*51c0b2f7Stbbdev         // Find the module handle for the input dll
531*51c0b2f7Stbbdev         HMODULE hModule = GetModuleHandleW(dllName);
532*51c0b2f7Stbbdev         if (hModule == 0)
533*51c0b2f7Stbbdev         {
534*51c0b2f7Stbbdev             // Couldn't find the module with the input name
535*51c0b2f7Stbbdev             cachedHM = 0;
536*51c0b2f7Stbbdev             return FRR_NODLL;
537*51c0b2f7Stbbdev         }
538*51c0b2f7Stbbdev 
539*51c0b2f7Stbbdev         cachedHM = hModule;
540*51c0b2f7Stbbdev         wcsncpy(cachedName, dllName, MAX_PATH);
541*51c0b2f7Stbbdev     }
542*51c0b2f7Stbbdev 
543*51c0b2f7Stbbdev     FARPROC inpFunc = GetProcAddress(cachedHM, funcName);
544*51c0b2f7Stbbdev     if (inpFunc == 0)
545*51c0b2f7Stbbdev     {
546*51c0b2f7Stbbdev         // Function was not found
547*51c0b2f7Stbbdev         return FRR_NOFUNC;
548*51c0b2f7Stbbdev     }
549*51c0b2f7Stbbdev 
550*51c0b2f7Stbbdev     if (!InsertTrampoline((void*)inpFunc, (void*)newFunc, opcodes, (void**)origFunc)){
551*51c0b2f7Stbbdev         // Failed to insert the trampoline to the target address
552*51c0b2f7Stbbdev         return FRR_FAILED;
553*51c0b2f7Stbbdev     }
554*51c0b2f7Stbbdev 
555*51c0b2f7Stbbdev     return FRR_OK;
556*51c0b2f7Stbbdev }
557*51c0b2f7Stbbdev 
558*51c0b2f7Stbbdev bool IsPrologueKnown(const char* dllName, const char *funcName, const char **opcodes, HMODULE module)
559*51c0b2f7Stbbdev {
560*51c0b2f7Stbbdev     FARPROC inpFunc = GetProcAddress(module, funcName);
561*51c0b2f7Stbbdev     FunctionInfo functionInfo = { funcName, dllName };
562*51c0b2f7Stbbdev 
563*51c0b2f7Stbbdev     if (!inpFunc) {
564*51c0b2f7Stbbdev         Log::record(functionInfo, "unknown", /*status*/ false);
565*51c0b2f7Stbbdev         return false;
566*51c0b2f7Stbbdev     }
567*51c0b2f7Stbbdev 
568*51c0b2f7Stbbdev     return CheckOpcodes( opcodes, (void*)inpFunc, /*abortOnError=*/false, &functionInfo) != 0;
569*51c0b2f7Stbbdev }
570*51c0b2f7Stbbdev 
571*51c0b2f7Stbbdev // Public Windows API
572*51c0b2f7Stbbdev extern "C" __declspec(dllexport) int TBB_malloc_replacement_log(char *** function_replacement_log_ptr)
573*51c0b2f7Stbbdev {
574*51c0b2f7Stbbdev     if (function_replacement_log_ptr != NULL) {
575*51c0b2f7Stbbdev         *function_replacement_log_ptr = Log::records;
576*51c0b2f7Stbbdev     }
577*51c0b2f7Stbbdev 
578*51c0b2f7Stbbdev     // If we have no logs -> return false status
579*51c0b2f7Stbbdev     return Log::replacement_status && Log::records[0] != NULL ? 0 : -1;
580*51c0b2f7Stbbdev }
581*51c0b2f7Stbbdev 
582*51c0b2f7Stbbdev #endif /* !__TBB_WIN8UI_SUPPORT && defined(_WIN32) */
583