1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements miscellaneous inlining utilities. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Transforms/InliningUtils.h" 23 24 #include "mlir/IR/BlockAndValueMapping.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/Function.h" 27 #include "mlir/IR/Operation.h" 28 #include "llvm/ADT/MapVector.h" 29 #include "llvm/Support/raw_ostream.h" 30 31 #define DEBUG_TYPE "inlining" 32 33 using namespace mlir; 34 35 /// Remap locations from the inlined blocks with CallSiteLoc locations with the 36 /// provided caller location. 37 static void 38 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks, 39 Location callerLoc) { 40 DenseMap<Location, Location> mappedLocations; 41 auto remapOpLoc = [&](Operation *op) { 42 auto it = mappedLocations.find(op->getLoc()); 43 if (it == mappedLocations.end()) { 44 auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc); 45 it = mappedLocations.try_emplace(op->getLoc(), newLoc).first; 46 } 47 op->setLoc(it->second); 48 }; 49 for (auto &block : inlinedBlocks) 50 block.walk(remapOpLoc); 51 } 52 53 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks, 54 BlockAndValueMapping &mapper) { 55 auto remapOperands = [&](Operation *op) { 56 for (auto &operand : op->getOpOperands()) 57 if (auto *mappedOp = mapper.lookupOrNull(operand.get())) 58 operand.set(mappedOp); 59 }; 60 for (auto &block : inlinedBlocks) 61 block.walk(remapOperands); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // InlinerInterface 66 //===----------------------------------------------------------------------===// 67 68 bool InlinerInterface::isLegalToInline( 69 Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { 70 // Regions can always be inlined into functions. 71 if (isa<FuncOp>(dest->getParentOp())) 72 return true; 73 74 auto *handler = getInterfaceFor(dest->getParentOp()); 75 return handler ? handler->isLegalToInline(dest, src, valueMapping) : false; 76 } 77 78 bool InlinerInterface::isLegalToInline( 79 Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const { 80 auto *handler = getInterfaceFor(op); 81 return handler ? handler->isLegalToInline(op, dest, valueMapping) : false; 82 } 83 84 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { 85 auto *handler = getInterfaceFor(op); 86 return handler ? handler->shouldAnalyzeRecursively(op) : true; 87 } 88 89 /// Handle the given inlined terminator by replacing it with a new operation 90 /// as necessary. 91 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { 92 auto *handler = getInterfaceFor(op); 93 assert(handler && "expected valid dialect handler"); 94 handler->handleTerminator(op, newDest); 95 } 96 97 /// Handle the given inlined terminator by replacing it with a new operation 98 /// as necessary. 99 void InlinerInterface::handleTerminator(Operation *op, 100 ArrayRef<Value *> valuesToRepl) const { 101 auto *handler = getInterfaceFor(op); 102 assert(handler && "expected valid dialect handler"); 103 handler->handleTerminator(op, valuesToRepl); 104 } 105 106 /// Utility to check that all of the operations within 'src' can be inlined. 107 static bool isLegalToInline(InlinerInterface &interface, Region *src, 108 Region *insertRegion, 109 BlockAndValueMapping &valueMapping) { 110 for (auto &block : *src) { 111 for (auto &op : block) { 112 // Check this operation. 113 if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) 114 return false; 115 // Check any nested regions. 116 if (interface.shouldAnalyzeRecursively(&op) && 117 llvm::any_of(op.getRegions(), [&](Region ®ion) { 118 return !isLegalToInline(interface, ®ion, insertRegion, 119 valueMapping); 120 })) 121 return false; 122 } 123 } 124 return true; 125 } 126 127 //===----------------------------------------------------------------------===// 128 // Inline Methods 129 //===----------------------------------------------------------------------===// 130 131 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 132 Operation *inlinePoint, 133 BlockAndValueMapping &mapper, 134 ArrayRef<Value *> resultsToReplace, 135 Optional<Location> inlineLoc, 136 bool shouldCloneInlinedRegion) { 137 // We expect the region to have at least one block. 138 if (src->empty()) 139 return failure(); 140 141 // Check that all of the region arguments have been mapped. 142 auto *srcEntryBlock = &src->front(); 143 if (llvm::any_of(srcEntryBlock->getArguments(), 144 [&](BlockArgument *arg) { return !mapper.contains(arg); })) 145 return failure(); 146 147 // The insertion point must be within a block. 148 Block *insertBlock = inlinePoint->getBlock(); 149 if (!insertBlock) 150 return failure(); 151 Region *insertRegion = insertBlock->getParent(); 152 153 // Check that the operations within the source region are valid to inline. 154 if (!interface.isLegalToInline(insertRegion, src, mapper) || 155 !isLegalToInline(interface, src, insertRegion, mapper)) 156 return failure(); 157 158 // Split the insertion block. 159 Block *postInsertBlock = 160 insertBlock->splitBlock(++inlinePoint->getIterator()); 161 162 // Check to see if the region is being cloned, or moved inline. In either 163 // case, move the new blocks after the 'insertBlock' to improve IR 164 // readability. 165 if (shouldCloneInlinedRegion) 166 src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); 167 else 168 insertRegion->getBlocks().splice(postInsertBlock->getIterator(), 169 src->getBlocks(), src->begin(), 170 src->end()); 171 172 // Get the range of newly inserted blocks. 173 auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()), 174 postInsertBlock->getIterator()); 175 Block *firstNewBlock = &*newBlocks.begin(); 176 177 // Remap the locations of the inlined operations if a valid source location 178 // was provided. 179 if (inlineLoc && !inlineLoc->isa<UnknownLoc>()) 180 remapInlinedLocations(newBlocks, *inlineLoc); 181 182 // If the blocks were moved in-place, make sure to remap any necessary 183 // operands. 184 if (!shouldCloneInlinedRegion) 185 remapInlinedOperands(newBlocks, mapper); 186 187 // Process the newly inlined blocks. 188 interface.processInlinedBlocks(newBlocks); 189 190 // Handle the case where only a single block was inlined. 191 if (std::next(newBlocks.begin()) == newBlocks.end()) { 192 // Have the interface handle the terminator of this block. 193 auto *firstBlockTerminator = firstNewBlock->getTerminator(); 194 interface.handleTerminator(firstBlockTerminator, resultsToReplace); 195 firstBlockTerminator->erase(); 196 197 // Merge the post insert block into the cloned entry block. 198 firstNewBlock->getOperations().splice(firstNewBlock->end(), 199 postInsertBlock->getOperations()); 200 postInsertBlock->erase(); 201 } else { 202 // Otherwise, there were multiple blocks inlined. Add arguments to the post 203 // insertion block to represent the results to replace. 204 for (Value *resultToRepl : resultsToReplace) { 205 resultToRepl->replaceAllUsesWith( 206 postInsertBlock->addArgument(resultToRepl->getType())); 207 } 208 209 /// Handle the terminators for each of the new blocks. 210 for (auto &newBlock : newBlocks) 211 interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); 212 } 213 214 // Splice the instructions of the inlined entry block into the insert block. 215 insertBlock->getOperations().splice(insertBlock->end(), 216 firstNewBlock->getOperations()); 217 firstNewBlock->erase(); 218 return success(); 219 } 220 221 /// This function is an overload of the above 'inlineRegion' that allows for 222 /// providing the set of operands ('inlinedOperands') that should be used 223 /// in-favor of the region arguments when inlining. 224 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 225 Operation *inlinePoint, 226 ArrayRef<Value *> inlinedOperands, 227 ArrayRef<Value *> resultsToReplace, 228 Optional<Location> inlineLoc, 229 bool shouldCloneInlinedRegion) { 230 // We expect the region to have at least one block. 231 if (src->empty()) 232 return failure(); 233 234 auto *entryBlock = &src->front(); 235 if (inlinedOperands.size() != entryBlock->getNumArguments()) 236 return failure(); 237 238 // Map the provided call operands to the arguments of the region. 239 BlockAndValueMapping mapper; 240 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { 241 // Verify that the types of the provided values match the function argument 242 // types. 243 BlockArgument *regionArg = entryBlock->getArgument(i); 244 if (inlinedOperands[i]->getType() != regionArg->getType()) 245 return failure(); 246 mapper.map(regionArg, inlinedOperands[i]); 247 } 248 249 // Call into the main region inliner function. 250 return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace, 251 inlineLoc, shouldCloneInlinedRegion); 252 } 253 254 /// Utility function used to generate a cast operation from the given interface, 255 /// or return nullptr if a cast could not be generated. 256 static Value *materializeConversion(const DialectInlinerInterface *interface, 257 SmallVectorImpl<Operation *> &castOps, 258 OpBuilder &castBuilder, Value *arg, 259 Type type, Location conversionLoc) { 260 if (!interface) 261 return nullptr; 262 263 // Check to see if the interface for the call can materialize a conversion. 264 Operation *castOp = interface->materializeCallConversion(castBuilder, arg, 265 type, conversionLoc); 266 if (!castOp) 267 return nullptr; 268 castOps.push_back(castOp); 269 270 // Ensure that the generated cast is correct. 271 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && 272 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); 273 return castOp->getResult(0); 274 } 275 276 /// This function inlines a given region, 'src', of a callable operation, 277 /// 'callable', into the location defined by the given call operation. This 278 /// function returns failure if inlining is not possible, success otherwise. On 279 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' 280 /// corresponds to whether the source region should be cloned into the 'call' or 281 /// spliced directly. 282 LogicalResult mlir::inlineCall(InlinerInterface &interface, 283 CallOpInterface call, 284 CallableOpInterface callable, Region *src, 285 bool shouldCloneInlinedRegion) { 286 // We expect the region to have at least one block. 287 if (src->empty()) 288 return failure(); 289 auto *entryBlock = &src->front(); 290 ArrayRef<Type> callableResultTypes = callable.getCallableResults(src); 291 292 // Make sure that the number of arguments and results matchup between the call 293 // and the region. 294 SmallVector<Value *, 8> callOperands(call.getArgOperands()); 295 SmallVector<Value *, 8> callResults(call.getOperation()->getResults()); 296 if (callOperands.size() != entryBlock->getNumArguments() || 297 callResults.size() != callableResultTypes.size()) 298 return failure(); 299 300 // A set of cast operations generated to matchup the signature of the region 301 // with the signature of the call. 302 SmallVector<Operation *, 4> castOps; 303 castOps.reserve(callOperands.size() + callResults.size()); 304 305 // Functor used to cleanup generated state on failure. 306 auto cleanupState = [&] { 307 for (auto *op : castOps) { 308 op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); 309 op->erase(); 310 } 311 return failure(); 312 }; 313 314 // Builder used for any conversion operations that need to be materialized. 315 OpBuilder castBuilder(call); 316 Location castLoc = call.getLoc(); 317 auto *callInterface = interface.getInterfaceFor(call.getDialect()); 318 319 // Map the provided call operands to the arguments of the region. 320 BlockAndValueMapping mapper; 321 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { 322 BlockArgument *regionArg = entryBlock->getArgument(i); 323 Value *operand = callOperands[i]; 324 325 // If the call operand doesn't match the expected region argument, try to 326 // generate a cast. 327 Type regionArgType = regionArg->getType(); 328 if (operand->getType() != regionArgType) { 329 if (!(operand = materializeConversion(callInterface, castOps, castBuilder, 330 operand, regionArgType, castLoc))) 331 return cleanupState(); 332 } 333 mapper.map(regionArg, operand); 334 } 335 336 // Ensure that the resultant values of the call, match the callable. 337 castBuilder.setInsertionPointAfter(call); 338 for (unsigned i = 0, e = callResults.size(); i != e; ++i) { 339 Value *callResult = callResults[i]; 340 if (callResult->getType() == callableResultTypes[i]) 341 continue; 342 343 // Generate a conversion that will produce the original type, so that the IR 344 // is still valid after the original call gets replaced. 345 Value *castResult = 346 materializeConversion(callInterface, castOps, castBuilder, callResult, 347 callResult->getType(), castLoc); 348 if (!castResult) 349 return cleanupState(); 350 callResult->replaceAllUsesWith(castResult); 351 castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult); 352 } 353 354 // Attempt to inline the call. 355 if (failed(inlineRegion(interface, src, call, mapper, callResults, 356 call.getLoc(), shouldCloneInlinedRegion))) 357 return cleanupState(); 358 return success(); 359 } 360