10ba00878SRiver Riddle //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
20ba00878SRiver Riddle //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60ba00878SRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
80ba00878SRiver Riddle //
90ba00878SRiver Riddle // This file implements miscellaneous inlining utilities.
100ba00878SRiver Riddle //
110ba00878SRiver Riddle //===----------------------------------------------------------------------===//
120ba00878SRiver Riddle 
130ba00878SRiver Riddle #include "mlir/Transforms/InliningUtils.h"
140ba00878SRiver Riddle 
150ba00878SRiver Riddle #include "mlir/IR/BlockAndValueMapping.h"
165830f71aSRiver Riddle #include "mlir/IR/Builders.h"
170ba00878SRiver Riddle #include "mlir/IR/Operation.h"
18*36550692SRiver Riddle #include "mlir/Interfaces/CallInterfaces.h"
190ba00878SRiver Riddle #include "llvm/ADT/MapVector.h"
20553f794bSSean Silva #include "llvm/Support/Debug.h"
210ba00878SRiver Riddle #include "llvm/Support/raw_ostream.h"
220ba00878SRiver Riddle 
230ba00878SRiver Riddle #define DEBUG_TYPE "inlining"
240ba00878SRiver Riddle 
250ba00878SRiver Riddle using namespace mlir;
260ba00878SRiver Riddle 
270ba00878SRiver Riddle /// Remap locations from the inlined blocks with CallSiteLoc locations with the
280ba00878SRiver Riddle /// provided caller location.
290ba00878SRiver Riddle static void
remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,Location callerLoc)304562e389SRiver Riddle remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
310ba00878SRiver Riddle                       Location callerLoc) {
320ba00878SRiver Riddle   DenseMap<Location, Location> mappedLocations;
330ba00878SRiver Riddle   auto remapOpLoc = [&](Operation *op) {
340ba00878SRiver Riddle     auto it = mappedLocations.find(op->getLoc());
350ba00878SRiver Riddle     if (it == mappedLocations.end()) {
360ba00878SRiver Riddle       auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
370ba00878SRiver Riddle       it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
380ba00878SRiver Riddle     }
390ba00878SRiver Riddle     op->setLoc(it->second);
400ba00878SRiver Riddle   };
410ba00878SRiver Riddle   for (auto &block : inlinedBlocks)
420ba00878SRiver Riddle     block.walk(remapOpLoc);
430ba00878SRiver Riddle }
440ba00878SRiver Riddle 
remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,BlockAndValueMapping & mapper)454562e389SRiver Riddle static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
460ba00878SRiver Riddle                                  BlockAndValueMapping &mapper) {
470ba00878SRiver Riddle   auto remapOperands = [&](Operation *op) {
480ba00878SRiver Riddle     for (auto &operand : op->getOpOperands())
4935807bc4SRiver Riddle       if (auto mappedOp = mapper.lookupOrNull(operand.get()))
500ba00878SRiver Riddle         operand.set(mappedOp);
510ba00878SRiver Riddle   };
520ba00878SRiver Riddle   for (auto &block : inlinedBlocks)
530ba00878SRiver Riddle     block.walk(remapOperands);
540ba00878SRiver Riddle }
550ba00878SRiver Riddle 
560ba00878SRiver Riddle //===----------------------------------------------------------------------===//
570ba00878SRiver Riddle // InlinerInterface
580ba00878SRiver Riddle //===----------------------------------------------------------------------===//
590ba00878SRiver Riddle 
isLegalToInline(Operation * call,Operation * callable,bool wouldBeCloned) const60fa417479SRiver Riddle bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable,
61fa417479SRiver Riddle                                        bool wouldBeCloned) const {
62fa417479SRiver Riddle   if (auto *handler = getInterfaceFor(call))
63fa417479SRiver Riddle     return handler->isLegalToInline(call, callable, wouldBeCloned);
64fa417479SRiver Riddle   return false;
65501fda01SRiver Riddle }
66501fda01SRiver Riddle 
isLegalToInline(Region * dest,Region * src,bool wouldBeCloned,BlockAndValueMapping & valueMapping) const670ba00878SRiver Riddle bool InlinerInterface::isLegalToInline(
68fa417479SRiver Riddle     Region *dest, Region *src, bool wouldBeCloned,
69fa417479SRiver Riddle     BlockAndValueMapping &valueMapping) const {
70fa417479SRiver Riddle   if (auto *handler = getInterfaceFor(dest->getParentOp()))
71fa417479SRiver Riddle     return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
72fa417479SRiver Riddle   return false;
730ba00878SRiver Riddle }
740ba00878SRiver Riddle 
isLegalToInline(Operation * op,Region * dest,bool wouldBeCloned,BlockAndValueMapping & valueMapping) const750ba00878SRiver Riddle bool InlinerInterface::isLegalToInline(
76fa417479SRiver Riddle     Operation *op, Region *dest, bool wouldBeCloned,
77fa417479SRiver Riddle     BlockAndValueMapping &valueMapping) const {
78fa417479SRiver Riddle   if (auto *handler = getInterfaceFor(op))
79fa417479SRiver Riddle     return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
80fa417479SRiver Riddle   return false;
810ba00878SRiver Riddle }
820ba00878SRiver Riddle 
shouldAnalyzeRecursively(Operation * op) const830ba00878SRiver Riddle bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
840ba00878SRiver Riddle   auto *handler = getInterfaceFor(op);
850ba00878SRiver Riddle   return handler ? handler->shouldAnalyzeRecursively(op) : true;
860ba00878SRiver Riddle }
870ba00878SRiver Riddle 
880ba00878SRiver Riddle /// Handle the given inlined terminator by replacing it with a new operation
890ba00878SRiver Riddle /// as necessary.
handleTerminator(Operation * op,Block * newDest) const900ba00878SRiver Riddle void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
910ba00878SRiver Riddle   auto *handler = getInterfaceFor(op);
920ba00878SRiver Riddle   assert(handler && "expected valid dialect handler");
930ba00878SRiver Riddle   handler->handleTerminator(op, newDest);
940ba00878SRiver Riddle }
950ba00878SRiver Riddle 
960ba00878SRiver Riddle /// Handle the given inlined terminator by replacing it with a new operation
970ba00878SRiver Riddle /// as necessary.
handleTerminator(Operation * op,ArrayRef<Value> valuesToRepl) const980ba00878SRiver Riddle void InlinerInterface::handleTerminator(Operation *op,
99e62a6956SRiver Riddle                                         ArrayRef<Value> valuesToRepl) const {
1000ba00878SRiver Riddle   auto *handler = getInterfaceFor(op);
1010ba00878SRiver Riddle   assert(handler && "expected valid dialect handler");
1020ba00878SRiver Riddle   handler->handleTerminator(op, valuesToRepl);
1030ba00878SRiver Riddle }
1040ba00878SRiver Riddle 
processInlinedCallBlocks(Operation * call,iterator_range<Region::iterator> inlinedBlocks) const1050e760a08SJacques Pienaar void InlinerInterface::processInlinedCallBlocks(
1060e760a08SJacques Pienaar     Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
1070e760a08SJacques Pienaar   auto *handler = getInterfaceFor(call);
1080e760a08SJacques Pienaar   assert(handler && "expected valid dialect handler");
1090e760a08SJacques Pienaar   handler->processInlinedCallBlocks(call, inlinedBlocks);
1100e760a08SJacques Pienaar }
1110e760a08SJacques Pienaar 
1120ba00878SRiver Riddle /// Utility to check that all of the operations within 'src' can be inlined.
isLegalToInline(InlinerInterface & interface,Region * src,Region * insertRegion,bool shouldCloneInlinedRegion,BlockAndValueMapping & valueMapping)1130ba00878SRiver Riddle static bool isLegalToInline(InlinerInterface &interface, Region *src,
114fa417479SRiver Riddle                             Region *insertRegion, bool shouldCloneInlinedRegion,
1150ba00878SRiver Riddle                             BlockAndValueMapping &valueMapping) {
1160ba00878SRiver Riddle   for (auto &block : *src) {
1170ba00878SRiver Riddle     for (auto &op : block) {
1180ba00878SRiver Riddle       // Check this operation.
119fa417479SRiver Riddle       if (!interface.isLegalToInline(&op, insertRegion,
120fa417479SRiver Riddle                                      shouldCloneInlinedRegion, valueMapping)) {
121553f794bSSean Silva         LLVM_DEBUG({
122553f794bSSean Silva           llvm::dbgs() << "* Illegal to inline because of op: ";
123553f794bSSean Silva           op.dump();
124553f794bSSean Silva         });
1250ba00878SRiver Riddle         return false;
126553f794bSSean Silva       }
1270ba00878SRiver Riddle       // Check any nested regions.
1280ba00878SRiver Riddle       if (interface.shouldAnalyzeRecursively(&op) &&
1290ba00878SRiver Riddle           llvm::any_of(op.getRegions(), [&](Region &region) {
1300ba00878SRiver Riddle             return !isLegalToInline(interface, &region, insertRegion,
131fa417479SRiver Riddle                                     shouldCloneInlinedRegion, valueMapping);
1320ba00878SRiver Riddle           }))
1330ba00878SRiver Riddle         return false;
1340ba00878SRiver Riddle     }
1350ba00878SRiver Riddle   }
1360ba00878SRiver Riddle   return true;
1370ba00878SRiver Riddle }
1380ba00878SRiver Riddle 
1390ba00878SRiver Riddle //===----------------------------------------------------------------------===//
1400ba00878SRiver Riddle // Inline Methods
1410ba00878SRiver Riddle //===----------------------------------------------------------------------===//
1420ba00878SRiver Riddle 
1430e760a08SJacques Pienaar static LogicalResult
inlineRegionImpl(InlinerInterface & interface,Region * src,Block * inlineBlock,Block::iterator inlinePoint,BlockAndValueMapping & mapper,ValueRange resultsToReplace,TypeRange regionResultTypes,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion,Operation * call=nullptr)144da12d88bSRiver Riddle inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
145da12d88bSRiver Riddle                  Block::iterator inlinePoint, BlockAndValueMapping &mapper,
1460e760a08SJacques Pienaar                  ValueRange resultsToReplace, TypeRange regionResultTypes,
1470e760a08SJacques Pienaar                  Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
148da12d88bSRiver Riddle                  Operation *call = nullptr) {
14922219cfcSSean Silva   assert(resultsToReplace.size() == regionResultTypes.size());
1500ba00878SRiver Riddle   // We expect the region to have at least one block.
1510ba00878SRiver Riddle   if (src->empty())
1520ba00878SRiver Riddle     return failure();
1530ba00878SRiver Riddle 
1540ba00878SRiver Riddle   // Check that all of the region arguments have been mapped.
1550ba00878SRiver Riddle   auto *srcEntryBlock = &src->front();
1560ba00878SRiver Riddle   if (llvm::any_of(srcEntryBlock->getArguments(),
157e62a6956SRiver Riddle                    [&](BlockArgument arg) { return !mapper.contains(arg); }))
1580ba00878SRiver Riddle     return failure();
1590ba00878SRiver Riddle 
1600ba00878SRiver Riddle   // Check that the operations within the source region are valid to inline.
161da12d88bSRiver Riddle   Region *insertRegion = inlineBlock->getParent();
162fa417479SRiver Riddle   if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
163fa417479SRiver Riddle                                  mapper) ||
164fa417479SRiver Riddle       !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
165fa417479SRiver Riddle                        mapper))
1660ba00878SRiver Riddle     return failure();
1670ba00878SRiver Riddle 
1680ba00878SRiver Riddle   // Check to see if the region is being cloned, or moved inline. In either
1690ba00878SRiver Riddle   // case, move the new blocks after the 'insertBlock' to improve IR
1700ba00878SRiver Riddle   // readability.
171da12d88bSRiver Riddle   Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
1720ba00878SRiver Riddle   if (shouldCloneInlinedRegion)
1730ba00878SRiver Riddle     src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
1740ba00878SRiver Riddle   else
1750ba00878SRiver Riddle     insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
1760ba00878SRiver Riddle                                      src->getBlocks(), src->begin(),
1770ba00878SRiver Riddle                                      src->end());
1780ba00878SRiver Riddle 
1790ba00878SRiver Riddle   // Get the range of newly inserted blocks.
180da12d88bSRiver Riddle   auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
1810ba00878SRiver Riddle                                     postInsertBlock->getIterator());
1820ba00878SRiver Riddle   Block *firstNewBlock = &*newBlocks.begin();
1830ba00878SRiver Riddle 
1840ba00878SRiver Riddle   // Remap the locations of the inlined operations if a valid source location
1850ba00878SRiver Riddle   // was provided.
1860ba00878SRiver Riddle   if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
1870ba00878SRiver Riddle     remapInlinedLocations(newBlocks, *inlineLoc);
1880ba00878SRiver Riddle 
1890ba00878SRiver Riddle   // If the blocks were moved in-place, make sure to remap any necessary
1900ba00878SRiver Riddle   // operands.
1910ba00878SRiver Riddle   if (!shouldCloneInlinedRegion)
1920ba00878SRiver Riddle     remapInlinedOperands(newBlocks, mapper);
1930ba00878SRiver Riddle 
194a20d96e4SRiver Riddle   // Process the newly inlined blocks.
1950e760a08SJacques Pienaar   if (call)
1960e760a08SJacques Pienaar     interface.processInlinedCallBlocks(call, newBlocks);
197a20d96e4SRiver Riddle   interface.processInlinedBlocks(newBlocks);
198a20d96e4SRiver Riddle 
1990ba00878SRiver Riddle   // Handle the case where only a single block was inlined.
2000ba00878SRiver Riddle   if (std::next(newBlocks.begin()) == newBlocks.end()) {
2010ba00878SRiver Riddle     // Have the interface handle the terminator of this block.
2020ba00878SRiver Riddle     auto *firstBlockTerminator = firstNewBlock->getTerminator();
20322219cfcSSean Silva     interface.handleTerminator(firstBlockTerminator,
20422219cfcSSean Silva                                llvm::to_vector<6>(resultsToReplace));
2050ba00878SRiver Riddle     firstBlockTerminator->erase();
2060ba00878SRiver Riddle 
2070ba00878SRiver Riddle     // Merge the post insert block into the cloned entry block.
2080ba00878SRiver Riddle     firstNewBlock->getOperations().splice(firstNewBlock->end(),
2090ba00878SRiver Riddle                                           postInsertBlock->getOperations());
2100ba00878SRiver Riddle     postInsertBlock->erase();
2110ba00878SRiver Riddle   } else {
2120ba00878SRiver Riddle     // Otherwise, there were multiple blocks inlined. Add arguments to the post
2130ba00878SRiver Riddle     // insertion block to represent the results to replace.
214e4853be2SMehdi Amini     for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) {
215e084679fSRiver Riddle       resultToRepl.value().replaceAllUsesWith(
216e084679fSRiver Riddle           postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()],
217e084679fSRiver Riddle                                        resultToRepl.value().getLoc()));
2180ba00878SRiver Riddle     }
2190ba00878SRiver Riddle 
2200ba00878SRiver Riddle     /// Handle the terminators for each of the new blocks.
2210ba00878SRiver Riddle     for (auto &newBlock : newBlocks)
2220ba00878SRiver Riddle       interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
2230ba00878SRiver Riddle   }
2240ba00878SRiver Riddle 
2250ba00878SRiver Riddle   // Splice the instructions of the inlined entry block into the insert block.
226da12d88bSRiver Riddle   inlineBlock->getOperations().splice(inlineBlock->end(),
2270ba00878SRiver Riddle                                       firstNewBlock->getOperations());
2280ba00878SRiver Riddle   firstNewBlock->erase();
2290ba00878SRiver Riddle   return success();
2300ba00878SRiver Riddle }
2310ba00878SRiver Riddle 
2320e760a08SJacques Pienaar static LogicalResult
inlineRegionImpl(InlinerInterface & interface,Region * src,Block * inlineBlock,Block::iterator inlinePoint,ValueRange inlinedOperands,ValueRange resultsToReplace,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion,Operation * call=nullptr)233da12d88bSRiver Riddle inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
234da12d88bSRiver Riddle                  Block::iterator inlinePoint, ValueRange inlinedOperands,
2350e760a08SJacques Pienaar                  ValueRange resultsToReplace, Optional<Location> inlineLoc,
236da12d88bSRiver Riddle                  bool shouldCloneInlinedRegion, Operation *call = nullptr) {
2370ba00878SRiver Riddle   // We expect the region to have at least one block.
2380ba00878SRiver Riddle   if (src->empty())
2390ba00878SRiver Riddle     return failure();
2400ba00878SRiver Riddle 
2410ba00878SRiver Riddle   auto *entryBlock = &src->front();
2420ba00878SRiver Riddle   if (inlinedOperands.size() != entryBlock->getNumArguments())
2430ba00878SRiver Riddle     return failure();
2440ba00878SRiver Riddle 
2450ba00878SRiver Riddle   // Map the provided call operands to the arguments of the region.
2460ba00878SRiver Riddle   BlockAndValueMapping mapper;
2470ba00878SRiver Riddle   for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
2480ba00878SRiver Riddle     // Verify that the types of the provided values match the function argument
2490ba00878SRiver Riddle     // types.
250e62a6956SRiver Riddle     BlockArgument regionArg = entryBlock->getArgument(i);
2512bdf33ccSRiver Riddle     if (inlinedOperands[i].getType() != regionArg.getType())
2520ba00878SRiver Riddle       return failure();
2530ba00878SRiver Riddle     mapper.map(regionArg, inlinedOperands[i]);
2540ba00878SRiver Riddle   }
2550ba00878SRiver Riddle 
2560ba00878SRiver Riddle   // Call into the main region inliner function.
257da12d88bSRiver Riddle   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
258da12d88bSRiver Riddle                           resultsToReplace, resultsToReplace.getTypes(),
259da12d88bSRiver Riddle                           inlineLoc, shouldCloneInlinedRegion, call);
2600e760a08SJacques Pienaar }
2610e760a08SJacques Pienaar 
inlineRegion(InlinerInterface & interface,Region * src,Operation * inlinePoint,BlockAndValueMapping & mapper,ValueRange resultsToReplace,TypeRange regionResultTypes,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)2620e760a08SJacques Pienaar LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
2630e760a08SJacques Pienaar                                  Operation *inlinePoint,
2640e760a08SJacques Pienaar                                  BlockAndValueMapping &mapper,
2650e760a08SJacques Pienaar                                  ValueRange resultsToReplace,
2660e760a08SJacques Pienaar                                  TypeRange regionResultTypes,
2670e760a08SJacques Pienaar                                  Optional<Location> inlineLoc,
2680e760a08SJacques Pienaar                                  bool shouldCloneInlinedRegion) {
269da12d88bSRiver Riddle   return inlineRegion(interface, src, inlinePoint->getBlock(),
270da12d88bSRiver Riddle                       ++inlinePoint->getIterator(), mapper, resultsToReplace,
271da12d88bSRiver Riddle                       regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
272da12d88bSRiver Riddle }
273da12d88bSRiver Riddle LogicalResult
inlineRegion(InlinerInterface & interface,Region * src,Block * inlineBlock,Block::iterator inlinePoint,BlockAndValueMapping & mapper,ValueRange resultsToReplace,TypeRange regionResultTypes,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)274da12d88bSRiver Riddle mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
275da12d88bSRiver Riddle                    Block::iterator inlinePoint, BlockAndValueMapping &mapper,
276da12d88bSRiver Riddle                    ValueRange resultsToReplace, TypeRange regionResultTypes,
277da12d88bSRiver Riddle                    Optional<Location> inlineLoc,
278da12d88bSRiver Riddle                    bool shouldCloneInlinedRegion) {
279da12d88bSRiver Riddle   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
280da12d88bSRiver Riddle                           resultsToReplace, regionResultTypes, inlineLoc,
281da12d88bSRiver Riddle                           shouldCloneInlinedRegion);
2820e760a08SJacques Pienaar }
2830e760a08SJacques Pienaar 
inlineRegion(InlinerInterface & interface,Region * src,Operation * inlinePoint,ValueRange inlinedOperands,ValueRange resultsToReplace,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)2840e760a08SJacques Pienaar LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
2850e760a08SJacques Pienaar                                  Operation *inlinePoint,
2860e760a08SJacques Pienaar                                  ValueRange inlinedOperands,
2870e760a08SJacques Pienaar                                  ValueRange resultsToReplace,
2880e760a08SJacques Pienaar                                  Optional<Location> inlineLoc,
2890e760a08SJacques Pienaar                                  bool shouldCloneInlinedRegion) {
290da12d88bSRiver Riddle   return inlineRegion(interface, src, inlinePoint->getBlock(),
291da12d88bSRiver Riddle                       ++inlinePoint->getIterator(), inlinedOperands,
292da12d88bSRiver Riddle                       resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
293da12d88bSRiver Riddle }
294da12d88bSRiver Riddle LogicalResult
inlineRegion(InlinerInterface & interface,Region * src,Block * inlineBlock,Block::iterator inlinePoint,ValueRange inlinedOperands,ValueRange resultsToReplace,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)295da12d88bSRiver Riddle mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
296da12d88bSRiver Riddle                    Block::iterator inlinePoint, ValueRange inlinedOperands,
297da12d88bSRiver Riddle                    ValueRange resultsToReplace, Optional<Location> inlineLoc,
298da12d88bSRiver Riddle                    bool shouldCloneInlinedRegion) {
299da12d88bSRiver Riddle   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
300da12d88bSRiver Riddle                           inlinedOperands, resultsToReplace, inlineLoc,
301da12d88bSRiver Riddle                           shouldCloneInlinedRegion);
3020ba00878SRiver Riddle }
3030ba00878SRiver Riddle 
3045830f71aSRiver Riddle /// Utility function used to generate a cast operation from the given interface,
3055830f71aSRiver Riddle /// or return nullptr if a cast could not be generated.
materializeConversion(const DialectInlinerInterface * interface,SmallVectorImpl<Operation * > & castOps,OpBuilder & castBuilder,Value arg,Type type,Location conversionLoc)306e62a6956SRiver Riddle static Value materializeConversion(const DialectInlinerInterface *interface,
3075830f71aSRiver Riddle                                    SmallVectorImpl<Operation *> &castOps,
308e62a6956SRiver Riddle                                    OpBuilder &castBuilder, Value arg, Type type,
309e62a6956SRiver Riddle                                    Location conversionLoc) {
3105830f71aSRiver Riddle   if (!interface)
3115830f71aSRiver Riddle     return nullptr;
3125830f71aSRiver Riddle 
3135830f71aSRiver Riddle   // Check to see if the interface for the call can materialize a conversion.
3145830f71aSRiver Riddle   Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
3155830f71aSRiver Riddle                                                            type, conversionLoc);
3165830f71aSRiver Riddle   if (!castOp)
3175830f71aSRiver Riddle     return nullptr;
3185830f71aSRiver Riddle   castOps.push_back(castOp);
3195830f71aSRiver Riddle 
3205830f71aSRiver Riddle   // Ensure that the generated cast is correct.
3215830f71aSRiver Riddle   assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
3225830f71aSRiver Riddle          castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
3235830f71aSRiver Riddle   return castOp->getResult(0);
3245830f71aSRiver Riddle }
3255830f71aSRiver Riddle 
3265830f71aSRiver Riddle /// This function inlines a given region, 'src', of a callable operation,
3275830f71aSRiver Riddle /// 'callable', into the location defined by the given call operation. This
3285830f71aSRiver Riddle /// function returns failure if inlining is not possible, success otherwise. On
3295830f71aSRiver Riddle /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
3305830f71aSRiver Riddle /// corresponds to whether the source region should be cloned into the 'call' or
3315830f71aSRiver Riddle /// spliced directly.
inlineCall(InlinerInterface & interface,CallOpInterface call,CallableOpInterface callable,Region * src,bool shouldCloneInlinedRegion)3325830f71aSRiver Riddle LogicalResult mlir::inlineCall(InlinerInterface &interface,
3335830f71aSRiver Riddle                                CallOpInterface call,
3345830f71aSRiver Riddle                                CallableOpInterface callable, Region *src,
3355830f71aSRiver Riddle                                bool shouldCloneInlinedRegion) {
3365830f71aSRiver Riddle   // We expect the region to have at least one block.
3375830f71aSRiver Riddle   if (src->empty())
3385830f71aSRiver Riddle     return failure();
3395830f71aSRiver Riddle   auto *entryBlock = &src->front();
340c7748404SRiver Riddle   ArrayRef<Type> callableResultTypes = callable.getCallableResults();
3415830f71aSRiver Riddle 
3425830f71aSRiver Riddle   // Make sure that the number of arguments and results matchup between the call
3435830f71aSRiver Riddle   // and the region.
344e62a6956SRiver Riddle   SmallVector<Value, 8> callOperands(call.getArgOperands());
345c4a04059SChristian Sigg   SmallVector<Value, 8> callResults(call->getResults());
3465830f71aSRiver Riddle   if (callOperands.size() != entryBlock->getNumArguments() ||
3475830f71aSRiver Riddle       callResults.size() != callableResultTypes.size())
3480ba00878SRiver Riddle     return failure();
3490ba00878SRiver Riddle 
3505830f71aSRiver Riddle   // A set of cast operations generated to matchup the signature of the region
3515830f71aSRiver Riddle   // with the signature of the call.
3525830f71aSRiver Riddle   SmallVector<Operation *, 4> castOps;
3535830f71aSRiver Riddle   castOps.reserve(callOperands.size() + callResults.size());
3540ba00878SRiver Riddle 
3555830f71aSRiver Riddle   // Functor used to cleanup generated state on failure.
3565830f71aSRiver Riddle   auto cleanupState = [&] {
3575830f71aSRiver Riddle     for (auto *op : castOps) {
3582bdf33ccSRiver Riddle       op->getResult(0).replaceAllUsesWith(op->getOperand(0));
3595830f71aSRiver Riddle       op->erase();
3605830f71aSRiver Riddle     }
3610ba00878SRiver Riddle     return failure();
3625830f71aSRiver Riddle   };
3630ba00878SRiver Riddle 
3645830f71aSRiver Riddle   // Builder used for any conversion operations that need to be materialized.
3655830f71aSRiver Riddle   OpBuilder castBuilder(call);
3665830f71aSRiver Riddle   Location castLoc = call.getLoc();
3670bf4a82aSChristian Sigg   const auto *callInterface = interface.getInterfaceFor(call->getDialect());
3685830f71aSRiver Riddle 
3695830f71aSRiver Riddle   // Map the provided call operands to the arguments of the region.
3705830f71aSRiver Riddle   BlockAndValueMapping mapper;
3715830f71aSRiver Riddle   for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
372e62a6956SRiver Riddle     BlockArgument regionArg = entryBlock->getArgument(i);
373e62a6956SRiver Riddle     Value operand = callOperands[i];
3745830f71aSRiver Riddle 
3755830f71aSRiver Riddle     // If the call operand doesn't match the expected region argument, try to
3765830f71aSRiver Riddle     // generate a cast.
3772bdf33ccSRiver Riddle     Type regionArgType = regionArg.getType();
3782bdf33ccSRiver Riddle     if (operand.getType() != regionArgType) {
3795830f71aSRiver Riddle       if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
3805830f71aSRiver Riddle                                             operand, regionArgType, castLoc)))
3815830f71aSRiver Riddle         return cleanupState();
3825830f71aSRiver Riddle     }
3835830f71aSRiver Riddle     mapper.map(regionArg, operand);
3845830f71aSRiver Riddle   }
3855830f71aSRiver Riddle 
386706d992cSRahul Joshi   // Ensure that the resultant values of the call match the callable.
3875830f71aSRiver Riddle   castBuilder.setInsertionPointAfter(call);
3885830f71aSRiver Riddle   for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
389e62a6956SRiver Riddle     Value callResult = callResults[i];
3902bdf33ccSRiver Riddle     if (callResult.getType() == callableResultTypes[i])
3915830f71aSRiver Riddle       continue;
3925830f71aSRiver Riddle 
3935830f71aSRiver Riddle     // Generate a conversion that will produce the original type, so that the IR
3945830f71aSRiver Riddle     // is still valid after the original call gets replaced.
395e62a6956SRiver Riddle     Value castResult =
3965830f71aSRiver Riddle         materializeConversion(callInterface, castOps, castBuilder, callResult,
3972bdf33ccSRiver Riddle                               callResult.getType(), castLoc);
3985830f71aSRiver Riddle     if (!castResult)
3995830f71aSRiver Riddle       return cleanupState();
4002bdf33ccSRiver Riddle     callResult.replaceAllUsesWith(castResult);
4012bdf33ccSRiver Riddle     castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
4025830f71aSRiver Riddle   }
4035830f71aSRiver Riddle 
404501fda01SRiver Riddle   // Check that it is legal to inline the callable into the call.
405fa417479SRiver Riddle   if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
406501fda01SRiver Riddle     return cleanupState();
407501fda01SRiver Riddle 
4085830f71aSRiver Riddle   // Attempt to inline the call.
409da12d88bSRiver Riddle   if (failed(inlineRegionImpl(interface, src, call->getBlock(),
410da12d88bSRiver Riddle                               ++call->getIterator(), mapper, callResults,
41122219cfcSSean Silva                               callableResultTypes, call.getLoc(),
4120e760a08SJacques Pienaar                               shouldCloneInlinedRegion, call)))
4135830f71aSRiver Riddle     return cleanupState();
4145830f71aSRiver Riddle   return success();
4150ba00878SRiver Riddle }
416