1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements miscellaneous inlining utilities. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Transforms/InliningUtils.h" 23 24 #include "mlir/IR/BlockAndValueMapping.h" 25 #include "mlir/IR/Function.h" 26 #include "mlir/IR/Operation.h" 27 #include "llvm/ADT/MapVector.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 #define DEBUG_TYPE "inlining" 31 32 using namespace mlir; 33 34 /// Remap locations from the inlined blocks with CallSiteLoc locations with the 35 /// provided caller location. 36 static void 37 remapInlinedLocations(llvm::iterator_range<Region::iterator> inlinedBlocks, 38 Location callerLoc) { 39 DenseMap<Location, Location> mappedLocations; 40 auto remapOpLoc = [&](Operation *op) { 41 auto it = mappedLocations.find(op->getLoc()); 42 if (it == mappedLocations.end()) { 43 auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc); 44 it = mappedLocations.try_emplace(op->getLoc(), newLoc).first; 45 } 46 op->setLoc(it->second); 47 }; 48 for (auto &block : inlinedBlocks) 49 block.walk(remapOpLoc); 50 } 51 52 static void 53 remapInlinedOperands(llvm::iterator_range<Region::iterator> inlinedBlocks, 54 BlockAndValueMapping &mapper) { 55 auto remapOperands = [&](Operation *op) { 56 for (auto &operand : op->getOpOperands()) 57 if (auto *mappedOp = mapper.lookupOrNull(operand.get())) 58 operand.set(mappedOp); 59 }; 60 for (auto &block : inlinedBlocks) 61 block.walk(remapOperands); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // InlinerInterface 66 //===----------------------------------------------------------------------===// 67 68 InlinerInterface::~InlinerInterface() {} 69 70 bool InlinerInterface::isLegalToInline( 71 Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { 72 // Regions can always be inlined into functions. 73 if (isa<FuncOp>(dest->getParentOp())) 74 return true; 75 76 auto *handler = getInterfaceFor(dest->getParentOp()); 77 return handler ? handler->isLegalToInline(src, dest, valueMapping) : false; 78 } 79 80 bool InlinerInterface::isLegalToInline( 81 Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const { 82 auto *handler = getInterfaceFor(op); 83 return handler ? handler->isLegalToInline(op, dest, valueMapping) : false; 84 } 85 86 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { 87 auto *handler = getInterfaceFor(op); 88 return handler ? handler->shouldAnalyzeRecursively(op) : true; 89 } 90 91 /// Handle the given inlined terminator by replacing it with a new operation 92 /// as necessary. 93 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { 94 auto *handler = getInterfaceFor(op); 95 assert(handler && "expected valid dialect handler"); 96 handler->handleTerminator(op, newDest); 97 } 98 99 /// Handle the given inlined terminator by replacing it with a new operation 100 /// as necessary. 101 void InlinerInterface::handleTerminator(Operation *op, 102 ArrayRef<Value *> valuesToRepl) const { 103 auto *handler = getInterfaceFor(op); 104 assert(handler && "expected valid dialect handler"); 105 handler->handleTerminator(op, valuesToRepl); 106 } 107 108 /// Utility to check that all of the operations within 'src' can be inlined. 109 static bool isLegalToInline(InlinerInterface &interface, Region *src, 110 Region *insertRegion, 111 BlockAndValueMapping &valueMapping) { 112 for (auto &block : *src) { 113 for (auto &op : block) { 114 // Check this operation. 115 if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) 116 return false; 117 // Check any nested regions. 118 if (interface.shouldAnalyzeRecursively(&op) && 119 llvm::any_of(op.getRegions(), [&](Region ®ion) { 120 return !isLegalToInline(interface, ®ion, insertRegion, 121 valueMapping); 122 })) 123 return false; 124 } 125 } 126 return true; 127 } 128 129 //===----------------------------------------------------------------------===// 130 // Inline Methods 131 //===----------------------------------------------------------------------===// 132 133 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 134 Operation *inlinePoint, 135 BlockAndValueMapping &mapper, 136 ArrayRef<Value *> resultsToReplace, 137 llvm::Optional<Location> inlineLoc, 138 bool shouldCloneInlinedRegion) { 139 // We expect the region to have at least one block. 140 if (src->empty()) 141 return failure(); 142 143 // Check that all of the region arguments have been mapped. 144 auto *srcEntryBlock = &src->front(); 145 if (llvm::any_of(srcEntryBlock->getArguments(), 146 [&](BlockArgument *arg) { return !mapper.contains(arg); })) 147 return failure(); 148 149 // The insertion point must be within a block. 150 Block *insertBlock = inlinePoint->getBlock(); 151 if (!insertBlock) 152 return failure(); 153 Region *insertRegion = insertBlock->getParent(); 154 155 // Check that the operations within the source region are valid to inline. 156 if (!interface.isLegalToInline(insertRegion, src, mapper) || 157 !isLegalToInline(interface, src, insertRegion, mapper)) 158 return failure(); 159 160 // Split the insertion block. 161 Block *postInsertBlock = 162 insertBlock->splitBlock(++inlinePoint->getIterator()); 163 164 // Check to see if the region is being cloned, or moved inline. In either 165 // case, move the new blocks after the 'insertBlock' to improve IR 166 // readability. 167 if (shouldCloneInlinedRegion) 168 src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); 169 else 170 insertRegion->getBlocks().splice(postInsertBlock->getIterator(), 171 src->getBlocks(), src->begin(), 172 src->end()); 173 174 // Get the range of newly inserted blocks. 175 auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()), 176 postInsertBlock->getIterator()); 177 Block *firstNewBlock = &*newBlocks.begin(); 178 179 // Remap the locations of the inlined operations if a valid source location 180 // was provided. 181 if (inlineLoc && !inlineLoc->isa<UnknownLoc>()) 182 remapInlinedLocations(newBlocks, *inlineLoc); 183 184 // If the blocks were moved in-place, make sure to remap any necessary 185 // operands. 186 if (!shouldCloneInlinedRegion) 187 remapInlinedOperands(newBlocks, mapper); 188 189 // Process the newly inlined blocks. 190 interface.processInlinedBlocks(newBlocks); 191 192 // Handle the case where only a single block was inlined. 193 if (std::next(newBlocks.begin()) == newBlocks.end()) { 194 // Have the interface handle the terminator of this block. 195 auto *firstBlockTerminator = firstNewBlock->getTerminator(); 196 interface.handleTerminator(firstBlockTerminator, resultsToReplace); 197 firstBlockTerminator->erase(); 198 199 // Merge the post insert block into the cloned entry block. 200 firstNewBlock->getOperations().splice(firstNewBlock->end(), 201 postInsertBlock->getOperations()); 202 postInsertBlock->erase(); 203 } else { 204 // Otherwise, there were multiple blocks inlined. Add arguments to the post 205 // insertion block to represent the results to replace. 206 for (Value *resultToRepl : resultsToReplace) { 207 resultToRepl->replaceAllUsesWith( 208 postInsertBlock->addArgument(resultToRepl->getType())); 209 } 210 211 /// Handle the terminators for each of the new blocks. 212 for (auto &newBlock : newBlocks) 213 interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); 214 } 215 216 // Splice the instructions of the inlined entry block into the insert block. 217 insertBlock->getOperations().splice(insertBlock->end(), 218 firstNewBlock->getOperations()); 219 firstNewBlock->erase(); 220 return success(); 221 } 222 223 /// This function is an overload of the above 'inlineRegion' that allows for 224 /// providing the set of operands ('inlinedOperands') that should be used 225 /// in-favor of the region arguments when inlining. 226 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 227 Operation *inlinePoint, 228 ArrayRef<Value *> inlinedOperands, 229 ArrayRef<Value *> resultsToReplace, 230 llvm::Optional<Location> inlineLoc, 231 bool shouldCloneInlinedRegion) { 232 // We expect the region to have at least one block. 233 if (src->empty()) 234 return failure(); 235 236 auto *entryBlock = &src->front(); 237 if (inlinedOperands.size() != entryBlock->getNumArguments()) 238 return failure(); 239 240 // Map the provided call operands to the arguments of the region. 241 BlockAndValueMapping mapper; 242 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { 243 // Verify that the types of the provided values match the function argument 244 // types. 245 BlockArgument *regionArg = entryBlock->getArgument(i); 246 if (inlinedOperands[i]->getType() != regionArg->getType()) 247 return failure(); 248 mapper.map(regionArg, inlinedOperands[i]); 249 } 250 251 // Call into the main region inliner function. 252 return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace, 253 inlineLoc, shouldCloneInlinedRegion); 254 } 255 256 /// This function inlines a FuncOp into another. This function returns failure 257 /// if it is not possible to inline this FuncOp. If the function returned 258 /// failure, then no changes to the module have been made. 259 /// 260 /// Note that this only does one level of inlining. For example, if the 261 /// instruction 'call B' is inlined, and 'B' calls 'C', then the call to 'C' now 262 /// exists in the instruction stream. Similarly this will inline a recursive 263 /// FuncOp by one level. 264 /// 265 LogicalResult mlir::inlineFunction(InlinerInterface &interface, FuncOp callee, 266 Operation *inlinePoint, 267 ArrayRef<Value *> callOperands, 268 ArrayRef<Value *> callResults, 269 Location inlineLoc) { 270 // We don't inline if the provided callee function is a declaration. 271 assert(callee && "expected valid function to inline"); 272 if (callee.isExternal()) 273 return failure(); 274 275 // Verify that the provided arguments match the function arguments. 276 if (callOperands.size() != callee.getNumArguments()) 277 return failure(); 278 279 // Verify that the provided values to replace match the function results. 280 auto funcResultTypes = callee.getType().getResults(); 281 if (callResults.size() != funcResultTypes.size()) 282 return failure(); 283 for (unsigned i = 0, e = callResults.size(); i != e; ++i) 284 if (callResults[i]->getType() != funcResultTypes[i]) 285 return failure(); 286 287 // Call into the main region inliner function. 288 return inlineRegion(interface, &callee.getBody(), inlinePoint, callOperands, 289 callResults, inlineLoc); 290 } 291