1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements miscellaneous inlining utilities. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Transforms/InliningUtils.h" 23 24 #include "mlir/IR/BlockAndValueMapping.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/Function.h" 27 #include "mlir/IR/Operation.h" 28 #include "llvm/ADT/MapVector.h" 29 #include "llvm/Support/raw_ostream.h" 30 31 #define DEBUG_TYPE "inlining" 32 33 using namespace mlir; 34 35 /// Remap locations from the inlined blocks with CallSiteLoc locations with the 36 /// provided caller location. 37 static void 38 remapInlinedLocations(llvm::iterator_range<Region::iterator> inlinedBlocks, 39 Location callerLoc) { 40 DenseMap<Location, Location> mappedLocations; 41 auto remapOpLoc = [&](Operation *op) { 42 auto it = mappedLocations.find(op->getLoc()); 43 if (it == mappedLocations.end()) { 44 auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc); 45 it = mappedLocations.try_emplace(op->getLoc(), newLoc).first; 46 } 47 op->setLoc(it->second); 48 }; 49 for (auto &block : inlinedBlocks) 50 block.walk(remapOpLoc); 51 } 52 53 static void 54 remapInlinedOperands(llvm::iterator_range<Region::iterator> inlinedBlocks, 55 BlockAndValueMapping &mapper) { 56 auto remapOperands = [&](Operation *op) { 57 for (auto &operand : op->getOpOperands()) 58 if (auto *mappedOp = mapper.lookupOrNull(operand.get())) 59 operand.set(mappedOp); 60 }; 61 for (auto &block : inlinedBlocks) 62 block.walk(remapOperands); 63 } 64 65 //===----------------------------------------------------------------------===// 66 // InlinerInterface 67 //===----------------------------------------------------------------------===// 68 69 bool InlinerInterface::isLegalToInline( 70 Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { 71 // Regions can always be inlined into functions. 72 if (isa<FuncOp>(dest->getParentOp())) 73 return true; 74 75 auto *handler = getInterfaceFor(dest->getParentOp()); 76 return handler ? handler->isLegalToInline(dest, src, valueMapping) : false; 77 } 78 79 bool InlinerInterface::isLegalToInline( 80 Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const { 81 auto *handler = getInterfaceFor(op); 82 return handler ? handler->isLegalToInline(op, dest, valueMapping) : false; 83 } 84 85 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { 86 auto *handler = getInterfaceFor(op); 87 return handler ? handler->shouldAnalyzeRecursively(op) : true; 88 } 89 90 /// Handle the given inlined terminator by replacing it with a new operation 91 /// as necessary. 92 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { 93 auto *handler = getInterfaceFor(op); 94 assert(handler && "expected valid dialect handler"); 95 handler->handleTerminator(op, newDest); 96 } 97 98 /// Handle the given inlined terminator by replacing it with a new operation 99 /// as necessary. 100 void InlinerInterface::handleTerminator(Operation *op, 101 ArrayRef<Value *> valuesToRepl) const { 102 auto *handler = getInterfaceFor(op); 103 assert(handler && "expected valid dialect handler"); 104 handler->handleTerminator(op, valuesToRepl); 105 } 106 107 /// Utility to check that all of the operations within 'src' can be inlined. 108 static bool isLegalToInline(InlinerInterface &interface, Region *src, 109 Region *insertRegion, 110 BlockAndValueMapping &valueMapping) { 111 for (auto &block : *src) { 112 for (auto &op : block) { 113 // Check this operation. 114 if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) 115 return false; 116 // Check any nested regions. 117 if (interface.shouldAnalyzeRecursively(&op) && 118 llvm::any_of(op.getRegions(), [&](Region ®ion) { 119 return !isLegalToInline(interface, ®ion, insertRegion, 120 valueMapping); 121 })) 122 return false; 123 } 124 } 125 return true; 126 } 127 128 //===----------------------------------------------------------------------===// 129 // Inline Methods 130 //===----------------------------------------------------------------------===// 131 132 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 133 Operation *inlinePoint, 134 BlockAndValueMapping &mapper, 135 ArrayRef<Value *> resultsToReplace, 136 llvm::Optional<Location> inlineLoc, 137 bool shouldCloneInlinedRegion) { 138 // We expect the region to have at least one block. 139 if (src->empty()) 140 return failure(); 141 142 // Check that all of the region arguments have been mapped. 143 auto *srcEntryBlock = &src->front(); 144 if (llvm::any_of(srcEntryBlock->getArguments(), 145 [&](BlockArgument *arg) { return !mapper.contains(arg); })) 146 return failure(); 147 148 // The insertion point must be within a block. 149 Block *insertBlock = inlinePoint->getBlock(); 150 if (!insertBlock) 151 return failure(); 152 Region *insertRegion = insertBlock->getParent(); 153 154 // Check that the operations within the source region are valid to inline. 155 if (!interface.isLegalToInline(insertRegion, src, mapper) || 156 !isLegalToInline(interface, src, insertRegion, mapper)) 157 return failure(); 158 159 // Split the insertion block. 160 Block *postInsertBlock = 161 insertBlock->splitBlock(++inlinePoint->getIterator()); 162 163 // Check to see if the region is being cloned, or moved inline. In either 164 // case, move the new blocks after the 'insertBlock' to improve IR 165 // readability. 166 if (shouldCloneInlinedRegion) 167 src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); 168 else 169 insertRegion->getBlocks().splice(postInsertBlock->getIterator(), 170 src->getBlocks(), src->begin(), 171 src->end()); 172 173 // Get the range of newly inserted blocks. 174 auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()), 175 postInsertBlock->getIterator()); 176 Block *firstNewBlock = &*newBlocks.begin(); 177 178 // Remap the locations of the inlined operations if a valid source location 179 // was provided. 180 if (inlineLoc && !inlineLoc->isa<UnknownLoc>()) 181 remapInlinedLocations(newBlocks, *inlineLoc); 182 183 // If the blocks were moved in-place, make sure to remap any necessary 184 // operands. 185 if (!shouldCloneInlinedRegion) 186 remapInlinedOperands(newBlocks, mapper); 187 188 // Process the newly inlined blocks. 189 interface.processInlinedBlocks(newBlocks); 190 191 // Handle the case where only a single block was inlined. 192 if (std::next(newBlocks.begin()) == newBlocks.end()) { 193 // Have the interface handle the terminator of this block. 194 auto *firstBlockTerminator = firstNewBlock->getTerminator(); 195 interface.handleTerminator(firstBlockTerminator, resultsToReplace); 196 firstBlockTerminator->erase(); 197 198 // Merge the post insert block into the cloned entry block. 199 firstNewBlock->getOperations().splice(firstNewBlock->end(), 200 postInsertBlock->getOperations()); 201 postInsertBlock->erase(); 202 } else { 203 // Otherwise, there were multiple blocks inlined. Add arguments to the post 204 // insertion block to represent the results to replace. 205 for (Value *resultToRepl : resultsToReplace) { 206 resultToRepl->replaceAllUsesWith( 207 postInsertBlock->addArgument(resultToRepl->getType())); 208 } 209 210 /// Handle the terminators for each of the new blocks. 211 for (auto &newBlock : newBlocks) 212 interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); 213 } 214 215 // Splice the instructions of the inlined entry block into the insert block. 216 insertBlock->getOperations().splice(insertBlock->end(), 217 firstNewBlock->getOperations()); 218 firstNewBlock->erase(); 219 return success(); 220 } 221 222 /// This function is an overload of the above 'inlineRegion' that allows for 223 /// providing the set of operands ('inlinedOperands') that should be used 224 /// in-favor of the region arguments when inlining. 225 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 226 Operation *inlinePoint, 227 ArrayRef<Value *> inlinedOperands, 228 ArrayRef<Value *> resultsToReplace, 229 llvm::Optional<Location> inlineLoc, 230 bool shouldCloneInlinedRegion) { 231 // We expect the region to have at least one block. 232 if (src->empty()) 233 return failure(); 234 235 auto *entryBlock = &src->front(); 236 if (inlinedOperands.size() != entryBlock->getNumArguments()) 237 return failure(); 238 239 // Map the provided call operands to the arguments of the region. 240 BlockAndValueMapping mapper; 241 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { 242 // Verify that the types of the provided values match the function argument 243 // types. 244 BlockArgument *regionArg = entryBlock->getArgument(i); 245 if (inlinedOperands[i]->getType() != regionArg->getType()) 246 return failure(); 247 mapper.map(regionArg, inlinedOperands[i]); 248 } 249 250 // Call into the main region inliner function. 251 return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace, 252 inlineLoc, shouldCloneInlinedRegion); 253 } 254 255 /// Utility function used to generate a cast operation from the given interface, 256 /// or return nullptr if a cast could not be generated. 257 static Value *materializeConversion(const DialectInlinerInterface *interface, 258 SmallVectorImpl<Operation *> &castOps, 259 OpBuilder &castBuilder, Value *arg, 260 Type type, Location conversionLoc) { 261 if (!interface) 262 return nullptr; 263 264 // Check to see if the interface for the call can materialize a conversion. 265 Operation *castOp = interface->materializeCallConversion(castBuilder, arg, 266 type, conversionLoc); 267 if (!castOp) 268 return nullptr; 269 castOps.push_back(castOp); 270 271 // Ensure that the generated cast is correct. 272 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && 273 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); 274 return castOp->getResult(0); 275 } 276 277 /// This function inlines a given region, 'src', of a callable operation, 278 /// 'callable', into the location defined by the given call operation. This 279 /// function returns failure if inlining is not possible, success otherwise. On 280 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' 281 /// corresponds to whether the source region should be cloned into the 'call' or 282 /// spliced directly. 283 LogicalResult mlir::inlineCall(InlinerInterface &interface, 284 CallOpInterface call, 285 CallableOpInterface callable, Region *src, 286 bool shouldCloneInlinedRegion) { 287 // We expect the region to have at least one block. 288 if (src->empty()) 289 return failure(); 290 auto *entryBlock = &src->front(); 291 ArrayRef<Type> callableResultTypes = callable.getCallableResults(src); 292 293 // Make sure that the number of arguments and results matchup between the call 294 // and the region. 295 SmallVector<Value *, 8> callOperands(call.getArgOperands()); 296 SmallVector<Value *, 8> callResults(call.getOperation()->getResults()); 297 if (callOperands.size() != entryBlock->getNumArguments() || 298 callResults.size() != callableResultTypes.size()) 299 return failure(); 300 301 // A set of cast operations generated to matchup the signature of the region 302 // with the signature of the call. 303 SmallVector<Operation *, 4> castOps; 304 castOps.reserve(callOperands.size() + callResults.size()); 305 306 // Functor used to cleanup generated state on failure. 307 auto cleanupState = [&] { 308 for (auto *op : castOps) { 309 op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); 310 op->erase(); 311 } 312 return failure(); 313 }; 314 315 // Builder used for any conversion operations that need to be materialized. 316 OpBuilder castBuilder(call); 317 Location castLoc = call.getLoc(); 318 auto *callInterface = interface.getInterfaceFor(call.getDialect()); 319 320 // Map the provided call operands to the arguments of the region. 321 BlockAndValueMapping mapper; 322 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { 323 BlockArgument *regionArg = entryBlock->getArgument(i); 324 Value *operand = callOperands[i]; 325 326 // If the call operand doesn't match the expected region argument, try to 327 // generate a cast. 328 Type regionArgType = regionArg->getType(); 329 if (operand->getType() != regionArgType) { 330 if (!(operand = materializeConversion(callInterface, castOps, castBuilder, 331 operand, regionArgType, castLoc))) 332 return cleanupState(); 333 } 334 mapper.map(regionArg, operand); 335 } 336 337 // Ensure that the resultant values of the call, match the callable. 338 castBuilder.setInsertionPointAfter(call); 339 for (unsigned i = 0, e = callResults.size(); i != e; ++i) { 340 Value *callResult = callResults[i]; 341 if (callResult->getType() == callableResultTypes[i]) 342 continue; 343 344 // Generate a conversion that will produce the original type, so that the IR 345 // is still valid after the original call gets replaced. 346 Value *castResult = 347 materializeConversion(callInterface, castOps, castBuilder, callResult, 348 callResult->getType(), castLoc); 349 if (!castResult) 350 return cleanupState(); 351 callResult->replaceAllUsesWith(castResult); 352 castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult); 353 } 354 355 // Attempt to inline the call. 356 if (failed(inlineRegion(interface, src, call, mapper, callResults, 357 call.getLoc(), shouldCloneInlinedRegion))) 358 return cleanupState(); 359 return success(); 360 } 361