1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements miscellaneous inlining utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/InliningUtils.h" 14 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/Function.h" 18 #include "mlir/IR/Operation.h" 19 #include "llvm/ADT/MapVector.h" 20 #include "llvm/Support/Debug.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 #define DEBUG_TYPE "inlining" 24 25 using namespace mlir; 26 27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the 28 /// provided caller location. 29 static void 30 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks, 31 Location callerLoc) { 32 DenseMap<Location, Location> mappedLocations; 33 auto remapOpLoc = [&](Operation *op) { 34 auto it = mappedLocations.find(op->getLoc()); 35 if (it == mappedLocations.end()) { 36 auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc); 37 it = mappedLocations.try_emplace(op->getLoc(), newLoc).first; 38 } 39 op->setLoc(it->second); 40 }; 41 for (auto &block : inlinedBlocks) 42 block.walk(remapOpLoc); 43 } 44 45 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks, 46 BlockAndValueMapping &mapper) { 47 auto remapOperands = [&](Operation *op) { 48 for (auto &operand : op->getOpOperands()) 49 if (auto mappedOp = mapper.lookupOrNull(operand.get())) 50 operand.set(mappedOp); 51 }; 52 for (auto &block : inlinedBlocks) 53 block.walk(remapOperands); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // InlinerInterface 58 //===----------------------------------------------------------------------===// 59 60 bool InlinerInterface::isLegalToInline(Operation *call, 61 Operation *callable) const { 62 auto *handler = getInterfaceFor(call); 63 return handler ? handler->isLegalToInline(call, callable) : false; 64 } 65 66 bool InlinerInterface::isLegalToInline( 67 Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { 68 // Regions can always be inlined into functions. 69 if (isa<FuncOp>(dest->getParentOp())) 70 return true; 71 72 auto *handler = getInterfaceFor(dest->getParentOp()); 73 return handler ? handler->isLegalToInline(dest, src, valueMapping) : false; 74 } 75 76 bool InlinerInterface::isLegalToInline( 77 Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const { 78 auto *handler = getInterfaceFor(op); 79 return handler ? handler->isLegalToInline(op, dest, valueMapping) : false; 80 } 81 82 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { 83 auto *handler = getInterfaceFor(op); 84 return handler ? handler->shouldAnalyzeRecursively(op) : true; 85 } 86 87 /// Handle the given inlined terminator by replacing it with a new operation 88 /// as necessary. 89 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { 90 auto *handler = getInterfaceFor(op); 91 assert(handler && "expected valid dialect handler"); 92 handler->handleTerminator(op, newDest); 93 } 94 95 /// Handle the given inlined terminator by replacing it with a new operation 96 /// as necessary. 97 void InlinerInterface::handleTerminator(Operation *op, 98 ArrayRef<Value> valuesToRepl) const { 99 auto *handler = getInterfaceFor(op); 100 assert(handler && "expected valid dialect handler"); 101 handler->handleTerminator(op, valuesToRepl); 102 } 103 104 /// Utility to check that all of the operations within 'src' can be inlined. 105 static bool isLegalToInline(InlinerInterface &interface, Region *src, 106 Region *insertRegion, 107 BlockAndValueMapping &valueMapping) { 108 for (auto &block : *src) { 109 for (auto &op : block) { 110 // Check this operation. 111 if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) { 112 LLVM_DEBUG({ 113 llvm::dbgs() << "* Illegal to inline because of op: "; 114 op.dump(); 115 }); 116 return false; 117 } 118 // Check any nested regions. 119 if (interface.shouldAnalyzeRecursively(&op) && 120 llvm::any_of(op.getRegions(), [&](Region ®ion) { 121 return !isLegalToInline(interface, ®ion, insertRegion, 122 valueMapping); 123 })) 124 return false; 125 } 126 } 127 return true; 128 } 129 130 //===----------------------------------------------------------------------===// 131 // Inline Methods 132 //===----------------------------------------------------------------------===// 133 134 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 135 Operation *inlinePoint, 136 BlockAndValueMapping &mapper, 137 ValueRange resultsToReplace, 138 TypeRange regionResultTypes, 139 Optional<Location> inlineLoc, 140 bool shouldCloneInlinedRegion) { 141 assert(resultsToReplace.size() == regionResultTypes.size()); 142 // We expect the region to have at least one block. 143 if (src->empty()) 144 return failure(); 145 146 // Check that all of the region arguments have been mapped. 147 auto *srcEntryBlock = &src->front(); 148 if (llvm::any_of(srcEntryBlock->getArguments(), 149 [&](BlockArgument arg) { return !mapper.contains(arg); })) 150 return failure(); 151 152 // The insertion point must be within a block. 153 Block *insertBlock = inlinePoint->getBlock(); 154 if (!insertBlock) 155 return failure(); 156 Region *insertRegion = insertBlock->getParent(); 157 158 // Check that the operations within the source region are valid to inline. 159 if (!interface.isLegalToInline(insertRegion, src, mapper) || 160 !isLegalToInline(interface, src, insertRegion, mapper)) 161 return failure(); 162 163 // Split the insertion block. 164 Block *postInsertBlock = 165 insertBlock->splitBlock(++inlinePoint->getIterator()); 166 167 // Check to see if the region is being cloned, or moved inline. In either 168 // case, move the new blocks after the 'insertBlock' to improve IR 169 // readability. 170 if (shouldCloneInlinedRegion) 171 src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); 172 else 173 insertRegion->getBlocks().splice(postInsertBlock->getIterator(), 174 src->getBlocks(), src->begin(), 175 src->end()); 176 177 // Get the range of newly inserted blocks. 178 auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()), 179 postInsertBlock->getIterator()); 180 Block *firstNewBlock = &*newBlocks.begin(); 181 182 // Remap the locations of the inlined operations if a valid source location 183 // was provided. 184 if (inlineLoc && !inlineLoc->isa<UnknownLoc>()) 185 remapInlinedLocations(newBlocks, *inlineLoc); 186 187 // If the blocks were moved in-place, make sure to remap any necessary 188 // operands. 189 if (!shouldCloneInlinedRegion) 190 remapInlinedOperands(newBlocks, mapper); 191 192 // Process the newly inlined blocks. 193 interface.processInlinedBlocks(newBlocks); 194 195 // Handle the case where only a single block was inlined. 196 if (std::next(newBlocks.begin()) == newBlocks.end()) { 197 // Have the interface handle the terminator of this block. 198 auto *firstBlockTerminator = firstNewBlock->getTerminator(); 199 interface.handleTerminator(firstBlockTerminator, 200 llvm::to_vector<6>(resultsToReplace)); 201 firstBlockTerminator->erase(); 202 203 // Merge the post insert block into the cloned entry block. 204 firstNewBlock->getOperations().splice(firstNewBlock->end(), 205 postInsertBlock->getOperations()); 206 postInsertBlock->erase(); 207 } else { 208 // Otherwise, there were multiple blocks inlined. Add arguments to the post 209 // insertion block to represent the results to replace. 210 for (auto resultToRepl : llvm::enumerate(resultsToReplace)) { 211 resultToRepl.value().replaceAllUsesWith(postInsertBlock->addArgument( 212 regionResultTypes[resultToRepl.index()])); 213 } 214 215 /// Handle the terminators for each of the new blocks. 216 for (auto &newBlock : newBlocks) 217 interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); 218 } 219 220 // Splice the instructions of the inlined entry block into the insert block. 221 insertBlock->getOperations().splice(insertBlock->end(), 222 firstNewBlock->getOperations()); 223 firstNewBlock->erase(); 224 return success(); 225 } 226 227 /// This function is an overload of the above 'inlineRegion' that allows for 228 /// providing the set of operands ('inlinedOperands') that should be used 229 /// in-favor of the region arguments when inlining. 230 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 231 Operation *inlinePoint, 232 ValueRange inlinedOperands, 233 ValueRange resultsToReplace, 234 Optional<Location> inlineLoc, 235 bool shouldCloneInlinedRegion) { 236 // We expect the region to have at least one block. 237 if (src->empty()) 238 return failure(); 239 240 auto *entryBlock = &src->front(); 241 if (inlinedOperands.size() != entryBlock->getNumArguments()) 242 return failure(); 243 244 // Map the provided call operands to the arguments of the region. 245 BlockAndValueMapping mapper; 246 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { 247 // Verify that the types of the provided values match the function argument 248 // types. 249 BlockArgument regionArg = entryBlock->getArgument(i); 250 if (inlinedOperands[i].getType() != regionArg.getType()) 251 return failure(); 252 mapper.map(regionArg, inlinedOperands[i]); 253 } 254 255 // Call into the main region inliner function. 256 return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace, 257 resultsToReplace.getTypes(), inlineLoc, 258 shouldCloneInlinedRegion); 259 } 260 261 /// Utility function used to generate a cast operation from the given interface, 262 /// or return nullptr if a cast could not be generated. 263 static Value materializeConversion(const DialectInlinerInterface *interface, 264 SmallVectorImpl<Operation *> &castOps, 265 OpBuilder &castBuilder, Value arg, Type type, 266 Location conversionLoc) { 267 if (!interface) 268 return nullptr; 269 270 // Check to see if the interface for the call can materialize a conversion. 271 Operation *castOp = interface->materializeCallConversion(castBuilder, arg, 272 type, conversionLoc); 273 if (!castOp) 274 return nullptr; 275 castOps.push_back(castOp); 276 277 // Ensure that the generated cast is correct. 278 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && 279 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); 280 return castOp->getResult(0); 281 } 282 283 /// This function inlines a given region, 'src', of a callable operation, 284 /// 'callable', into the location defined by the given call operation. This 285 /// function returns failure if inlining is not possible, success otherwise. On 286 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' 287 /// corresponds to whether the source region should be cloned into the 'call' or 288 /// spliced directly. 289 LogicalResult mlir::inlineCall(InlinerInterface &interface, 290 CallOpInterface call, 291 CallableOpInterface callable, Region *src, 292 bool shouldCloneInlinedRegion) { 293 // We expect the region to have at least one block. 294 if (src->empty()) 295 return failure(); 296 auto *entryBlock = &src->front(); 297 ArrayRef<Type> callableResultTypes = callable.getCallableResults(); 298 299 // Make sure that the number of arguments and results matchup between the call 300 // and the region. 301 SmallVector<Value, 8> callOperands(call.getArgOperands()); 302 SmallVector<Value, 8> callResults(call.getOperation()->getResults()); 303 if (callOperands.size() != entryBlock->getNumArguments() || 304 callResults.size() != callableResultTypes.size()) 305 return failure(); 306 307 // A set of cast operations generated to matchup the signature of the region 308 // with the signature of the call. 309 SmallVector<Operation *, 4> castOps; 310 castOps.reserve(callOperands.size() + callResults.size()); 311 312 // Functor used to cleanup generated state on failure. 313 auto cleanupState = [&] { 314 for (auto *op : castOps) { 315 op->getResult(0).replaceAllUsesWith(op->getOperand(0)); 316 op->erase(); 317 } 318 return failure(); 319 }; 320 321 // Builder used for any conversion operations that need to be materialized. 322 OpBuilder castBuilder(call); 323 Location castLoc = call.getLoc(); 324 auto *callInterface = interface.getInterfaceFor(call.getDialect()); 325 326 // Map the provided call operands to the arguments of the region. 327 BlockAndValueMapping mapper; 328 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { 329 BlockArgument regionArg = entryBlock->getArgument(i); 330 Value operand = callOperands[i]; 331 332 // If the call operand doesn't match the expected region argument, try to 333 // generate a cast. 334 Type regionArgType = regionArg.getType(); 335 if (operand.getType() != regionArgType) { 336 if (!(operand = materializeConversion(callInterface, castOps, castBuilder, 337 operand, regionArgType, castLoc))) 338 return cleanupState(); 339 } 340 mapper.map(regionArg, operand); 341 } 342 343 // Ensure that the resultant values of the call match the callable. 344 castBuilder.setInsertionPointAfter(call); 345 for (unsigned i = 0, e = callResults.size(); i != e; ++i) { 346 Value callResult = callResults[i]; 347 if (callResult.getType() == callableResultTypes[i]) 348 continue; 349 350 // Generate a conversion that will produce the original type, so that the IR 351 // is still valid after the original call gets replaced. 352 Value castResult = 353 materializeConversion(callInterface, castOps, castBuilder, callResult, 354 callResult.getType(), castLoc); 355 if (!castResult) 356 return cleanupState(); 357 callResult.replaceAllUsesWith(castResult); 358 castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult); 359 } 360 361 // Check that it is legal to inline the callable into the call. 362 if (!interface.isLegalToInline(call, callable)) 363 return cleanupState(); 364 365 // Attempt to inline the call. 366 if (failed(inlineRegion(interface, src, call, mapper, callResults, 367 callableResultTypes, call.getLoc(), 368 shouldCloneInlinedRegion))) 369 return cleanupState(); 370 return success(); 371 } 372