1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements miscellaneous inlining utilities. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Transforms/InliningUtils.h" 23 24 #include "mlir/IR/BlockAndValueMapping.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/Function.h" 27 #include "mlir/IR/Operation.h" 28 #include "llvm/ADT/MapVector.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/raw_ostream.h" 31 32 #define DEBUG_TYPE "inlining" 33 34 using namespace mlir; 35 36 /// Remap locations from the inlined blocks with CallSiteLoc locations with the 37 /// provided caller location. 38 static void 39 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks, 40 Location callerLoc) { 41 DenseMap<Location, Location> mappedLocations; 42 auto remapOpLoc = [&](Operation *op) { 43 auto it = mappedLocations.find(op->getLoc()); 44 if (it == mappedLocations.end()) { 45 auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc); 46 it = mappedLocations.try_emplace(op->getLoc(), newLoc).first; 47 } 48 op->setLoc(it->second); 49 }; 50 for (auto &block : inlinedBlocks) 51 block.walk(remapOpLoc); 52 } 53 54 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks, 55 BlockAndValueMapping &mapper) { 56 auto remapOperands = [&](Operation *op) { 57 for (auto &operand : op->getOpOperands()) 58 if (auto mappedOp = mapper.lookupOrNull(operand.get())) 59 operand.set(mappedOp); 60 }; 61 for (auto &block : inlinedBlocks) 62 block.walk(remapOperands); 63 } 64 65 //===----------------------------------------------------------------------===// 66 // InlinerInterface 67 //===----------------------------------------------------------------------===// 68 69 bool InlinerInterface::isLegalToInline( 70 Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { 71 // Regions can always be inlined into functions. 72 if (isa<FuncOp>(dest->getParentOp())) 73 return true; 74 75 auto *handler = getInterfaceFor(dest->getParentOp()); 76 return handler ? handler->isLegalToInline(dest, src, valueMapping) : false; 77 } 78 79 bool InlinerInterface::isLegalToInline( 80 Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const { 81 auto *handler = getInterfaceFor(op); 82 return handler ? handler->isLegalToInline(op, dest, valueMapping) : false; 83 } 84 85 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { 86 auto *handler = getInterfaceFor(op); 87 return handler ? handler->shouldAnalyzeRecursively(op) : true; 88 } 89 90 /// Handle the given inlined terminator by replacing it with a new operation 91 /// as necessary. 92 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { 93 auto *handler = getInterfaceFor(op); 94 assert(handler && "expected valid dialect handler"); 95 handler->handleTerminator(op, newDest); 96 } 97 98 /// Handle the given inlined terminator by replacing it with a new operation 99 /// as necessary. 100 void InlinerInterface::handleTerminator(Operation *op, 101 ArrayRef<ValuePtr> valuesToRepl) const { 102 auto *handler = getInterfaceFor(op); 103 assert(handler && "expected valid dialect handler"); 104 handler->handleTerminator(op, valuesToRepl); 105 } 106 107 /// Utility to check that all of the operations within 'src' can be inlined. 108 static bool isLegalToInline(InlinerInterface &interface, Region *src, 109 Region *insertRegion, 110 BlockAndValueMapping &valueMapping) { 111 for (auto &block : *src) { 112 for (auto &op : block) { 113 // Check this operation. 114 if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) { 115 LLVM_DEBUG({ 116 llvm::dbgs() << "* Illegal to inline because of op: "; 117 op.dump(); 118 }); 119 return false; 120 } 121 // Check any nested regions. 122 if (interface.shouldAnalyzeRecursively(&op) && 123 llvm::any_of(op.getRegions(), [&](Region ®ion) { 124 return !isLegalToInline(interface, ®ion, insertRegion, 125 valueMapping); 126 })) 127 return false; 128 } 129 } 130 return true; 131 } 132 133 //===----------------------------------------------------------------------===// 134 // Inline Methods 135 //===----------------------------------------------------------------------===// 136 137 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 138 Operation *inlinePoint, 139 BlockAndValueMapping &mapper, 140 ArrayRef<ValuePtr> resultsToReplace, 141 Optional<Location> inlineLoc, 142 bool shouldCloneInlinedRegion) { 143 // We expect the region to have at least one block. 144 if (src->empty()) 145 return failure(); 146 147 // Check that all of the region arguments have been mapped. 148 auto *srcEntryBlock = &src->front(); 149 if (llvm::any_of(srcEntryBlock->getArguments(), 150 [&](BlockArgumentPtr arg) { return !mapper.contains(arg); })) 151 return failure(); 152 153 // The insertion point must be within a block. 154 Block *insertBlock = inlinePoint->getBlock(); 155 if (!insertBlock) 156 return failure(); 157 Region *insertRegion = insertBlock->getParent(); 158 159 // Check that the operations within the source region are valid to inline. 160 if (!interface.isLegalToInline(insertRegion, src, mapper) || 161 !isLegalToInline(interface, src, insertRegion, mapper)) 162 return failure(); 163 164 // Split the insertion block. 165 Block *postInsertBlock = 166 insertBlock->splitBlock(++inlinePoint->getIterator()); 167 168 // Check to see if the region is being cloned, or moved inline. In either 169 // case, move the new blocks after the 'insertBlock' to improve IR 170 // readability. 171 if (shouldCloneInlinedRegion) 172 src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); 173 else 174 insertRegion->getBlocks().splice(postInsertBlock->getIterator(), 175 src->getBlocks(), src->begin(), 176 src->end()); 177 178 // Get the range of newly inserted blocks. 179 auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()), 180 postInsertBlock->getIterator()); 181 Block *firstNewBlock = &*newBlocks.begin(); 182 183 // Remap the locations of the inlined operations if a valid source location 184 // was provided. 185 if (inlineLoc && !inlineLoc->isa<UnknownLoc>()) 186 remapInlinedLocations(newBlocks, *inlineLoc); 187 188 // If the blocks were moved in-place, make sure to remap any necessary 189 // operands. 190 if (!shouldCloneInlinedRegion) 191 remapInlinedOperands(newBlocks, mapper); 192 193 // Process the newly inlined blocks. 194 interface.processInlinedBlocks(newBlocks); 195 196 // Handle the case where only a single block was inlined. 197 if (std::next(newBlocks.begin()) == newBlocks.end()) { 198 // Have the interface handle the terminator of this block. 199 auto *firstBlockTerminator = firstNewBlock->getTerminator(); 200 interface.handleTerminator(firstBlockTerminator, resultsToReplace); 201 firstBlockTerminator->erase(); 202 203 // Merge the post insert block into the cloned entry block. 204 firstNewBlock->getOperations().splice(firstNewBlock->end(), 205 postInsertBlock->getOperations()); 206 postInsertBlock->erase(); 207 } else { 208 // Otherwise, there were multiple blocks inlined. Add arguments to the post 209 // insertion block to represent the results to replace. 210 for (ValuePtr resultToRepl : resultsToReplace) { 211 resultToRepl->replaceAllUsesWith( 212 postInsertBlock->addArgument(resultToRepl->getType())); 213 } 214 215 /// Handle the terminators for each of the new blocks. 216 for (auto &newBlock : newBlocks) 217 interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); 218 } 219 220 // Splice the instructions of the inlined entry block into the insert block. 221 insertBlock->getOperations().splice(insertBlock->end(), 222 firstNewBlock->getOperations()); 223 firstNewBlock->erase(); 224 return success(); 225 } 226 227 /// This function is an overload of the above 'inlineRegion' that allows for 228 /// providing the set of operands ('inlinedOperands') that should be used 229 /// in-favor of the region arguments when inlining. 230 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 231 Operation *inlinePoint, 232 ArrayRef<ValuePtr> inlinedOperands, 233 ArrayRef<ValuePtr> resultsToReplace, 234 Optional<Location> inlineLoc, 235 bool shouldCloneInlinedRegion) { 236 // We expect the region to have at least one block. 237 if (src->empty()) 238 return failure(); 239 240 auto *entryBlock = &src->front(); 241 if (inlinedOperands.size() != entryBlock->getNumArguments()) 242 return failure(); 243 244 // Map the provided call operands to the arguments of the region. 245 BlockAndValueMapping mapper; 246 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { 247 // Verify that the types of the provided values match the function argument 248 // types. 249 BlockArgumentPtr regionArg = entryBlock->getArgument(i); 250 if (inlinedOperands[i]->getType() != regionArg->getType()) 251 return failure(); 252 mapper.map(regionArg, inlinedOperands[i]); 253 } 254 255 // Call into the main region inliner function. 256 return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace, 257 inlineLoc, shouldCloneInlinedRegion); 258 } 259 260 /// Utility function used to generate a cast operation from the given interface, 261 /// or return nullptr if a cast could not be generated. 262 static ValuePtr materializeConversion(const DialectInlinerInterface *interface, 263 SmallVectorImpl<Operation *> &castOps, 264 OpBuilder &castBuilder, ValuePtr arg, 265 Type type, Location conversionLoc) { 266 if (!interface) 267 return nullptr; 268 269 // Check to see if the interface for the call can materialize a conversion. 270 Operation *castOp = interface->materializeCallConversion(castBuilder, arg, 271 type, conversionLoc); 272 if (!castOp) 273 return nullptr; 274 castOps.push_back(castOp); 275 276 // Ensure that the generated cast is correct. 277 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && 278 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); 279 return castOp->getResult(0); 280 } 281 282 /// This function inlines a given region, 'src', of a callable operation, 283 /// 'callable', into the location defined by the given call operation. This 284 /// function returns failure if inlining is not possible, success otherwise. On 285 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' 286 /// corresponds to whether the source region should be cloned into the 'call' or 287 /// spliced directly. 288 LogicalResult mlir::inlineCall(InlinerInterface &interface, 289 CallOpInterface call, 290 CallableOpInterface callable, Region *src, 291 bool shouldCloneInlinedRegion) { 292 // We expect the region to have at least one block. 293 if (src->empty()) 294 return failure(); 295 auto *entryBlock = &src->front(); 296 ArrayRef<Type> callableResultTypes = callable.getCallableResults(src); 297 298 // Make sure that the number of arguments and results matchup between the call 299 // and the region. 300 SmallVector<ValuePtr, 8> callOperands(call.getArgOperands()); 301 SmallVector<ValuePtr, 8> callResults(call.getOperation()->getResults()); 302 if (callOperands.size() != entryBlock->getNumArguments() || 303 callResults.size() != callableResultTypes.size()) 304 return failure(); 305 306 // A set of cast operations generated to matchup the signature of the region 307 // with the signature of the call. 308 SmallVector<Operation *, 4> castOps; 309 castOps.reserve(callOperands.size() + callResults.size()); 310 311 // Functor used to cleanup generated state on failure. 312 auto cleanupState = [&] { 313 for (auto *op : castOps) { 314 op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); 315 op->erase(); 316 } 317 return failure(); 318 }; 319 320 // Builder used for any conversion operations that need to be materialized. 321 OpBuilder castBuilder(call); 322 Location castLoc = call.getLoc(); 323 auto *callInterface = interface.getInterfaceFor(call.getDialect()); 324 325 // Map the provided call operands to the arguments of the region. 326 BlockAndValueMapping mapper; 327 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { 328 BlockArgumentPtr regionArg = entryBlock->getArgument(i); 329 ValuePtr operand = callOperands[i]; 330 331 // If the call operand doesn't match the expected region argument, try to 332 // generate a cast. 333 Type regionArgType = regionArg->getType(); 334 if (operand->getType() != regionArgType) { 335 if (!(operand = materializeConversion(callInterface, castOps, castBuilder, 336 operand, regionArgType, castLoc))) 337 return cleanupState(); 338 } 339 mapper.map(regionArg, operand); 340 } 341 342 // Ensure that the resultant values of the call, match the callable. 343 castBuilder.setInsertionPointAfter(call); 344 for (unsigned i = 0, e = callResults.size(); i != e; ++i) { 345 ValuePtr callResult = callResults[i]; 346 if (callResult->getType() == callableResultTypes[i]) 347 continue; 348 349 // Generate a conversion that will produce the original type, so that the IR 350 // is still valid after the original call gets replaced. 351 ValuePtr castResult = 352 materializeConversion(callInterface, castOps, castBuilder, callResult, 353 callResult->getType(), castLoc); 354 if (!castResult) 355 return cleanupState(); 356 callResult->replaceAllUsesWith(castResult); 357 castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult); 358 } 359 360 // Attempt to inline the call. 361 if (failed(inlineRegion(interface, src, call, mapper, callResults, 362 call.getLoc(), shouldCloneInlinedRegion))) 363 return cleanupState(); 364 return success(); 365 } 366