10ba00878SRiver Riddle //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
20ba00878SRiver Riddle //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60ba00878SRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
80ba00878SRiver Riddle //
90ba00878SRiver Riddle // This file implements miscellaneous inlining utilities.
100ba00878SRiver Riddle //
110ba00878SRiver Riddle //===----------------------------------------------------------------------===//
120ba00878SRiver Riddle
130ba00878SRiver Riddle #include "mlir/Transforms/InliningUtils.h"
140ba00878SRiver Riddle
150ba00878SRiver Riddle #include "mlir/IR/BlockAndValueMapping.h"
165830f71aSRiver Riddle #include "mlir/IR/Builders.h"
170ba00878SRiver Riddle #include "mlir/IR/Operation.h"
18*36550692SRiver Riddle #include "mlir/Interfaces/CallInterfaces.h"
190ba00878SRiver Riddle #include "llvm/ADT/MapVector.h"
20553f794bSSean Silva #include "llvm/Support/Debug.h"
210ba00878SRiver Riddle #include "llvm/Support/raw_ostream.h"
220ba00878SRiver Riddle
230ba00878SRiver Riddle #define DEBUG_TYPE "inlining"
240ba00878SRiver Riddle
250ba00878SRiver Riddle using namespace mlir;
260ba00878SRiver Riddle
270ba00878SRiver Riddle /// Remap locations from the inlined blocks with CallSiteLoc locations with the
280ba00878SRiver Riddle /// provided caller location.
290ba00878SRiver Riddle static void
remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,Location callerLoc)304562e389SRiver Riddle remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
310ba00878SRiver Riddle Location callerLoc) {
320ba00878SRiver Riddle DenseMap<Location, Location> mappedLocations;
330ba00878SRiver Riddle auto remapOpLoc = [&](Operation *op) {
340ba00878SRiver Riddle auto it = mappedLocations.find(op->getLoc());
350ba00878SRiver Riddle if (it == mappedLocations.end()) {
360ba00878SRiver Riddle auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
370ba00878SRiver Riddle it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
380ba00878SRiver Riddle }
390ba00878SRiver Riddle op->setLoc(it->second);
400ba00878SRiver Riddle };
410ba00878SRiver Riddle for (auto &block : inlinedBlocks)
420ba00878SRiver Riddle block.walk(remapOpLoc);
430ba00878SRiver Riddle }
440ba00878SRiver Riddle
remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,BlockAndValueMapping & mapper)454562e389SRiver Riddle static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
460ba00878SRiver Riddle BlockAndValueMapping &mapper) {
470ba00878SRiver Riddle auto remapOperands = [&](Operation *op) {
480ba00878SRiver Riddle for (auto &operand : op->getOpOperands())
4935807bc4SRiver Riddle if (auto mappedOp = mapper.lookupOrNull(operand.get()))
500ba00878SRiver Riddle operand.set(mappedOp);
510ba00878SRiver Riddle };
520ba00878SRiver Riddle for (auto &block : inlinedBlocks)
530ba00878SRiver Riddle block.walk(remapOperands);
540ba00878SRiver Riddle }
550ba00878SRiver Riddle
560ba00878SRiver Riddle //===----------------------------------------------------------------------===//
570ba00878SRiver Riddle // InlinerInterface
580ba00878SRiver Riddle //===----------------------------------------------------------------------===//
590ba00878SRiver Riddle
isLegalToInline(Operation * call,Operation * callable,bool wouldBeCloned) const60fa417479SRiver Riddle bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable,
61fa417479SRiver Riddle bool wouldBeCloned) const {
62fa417479SRiver Riddle if (auto *handler = getInterfaceFor(call))
63fa417479SRiver Riddle return handler->isLegalToInline(call, callable, wouldBeCloned);
64fa417479SRiver Riddle return false;
65501fda01SRiver Riddle }
66501fda01SRiver Riddle
isLegalToInline(Region * dest,Region * src,bool wouldBeCloned,BlockAndValueMapping & valueMapping) const670ba00878SRiver Riddle bool InlinerInterface::isLegalToInline(
68fa417479SRiver Riddle Region *dest, Region *src, bool wouldBeCloned,
69fa417479SRiver Riddle BlockAndValueMapping &valueMapping) const {
70fa417479SRiver Riddle if (auto *handler = getInterfaceFor(dest->getParentOp()))
71fa417479SRiver Riddle return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
72fa417479SRiver Riddle return false;
730ba00878SRiver Riddle }
740ba00878SRiver Riddle
isLegalToInline(Operation * op,Region * dest,bool wouldBeCloned,BlockAndValueMapping & valueMapping) const750ba00878SRiver Riddle bool InlinerInterface::isLegalToInline(
76fa417479SRiver Riddle Operation *op, Region *dest, bool wouldBeCloned,
77fa417479SRiver Riddle BlockAndValueMapping &valueMapping) const {
78fa417479SRiver Riddle if (auto *handler = getInterfaceFor(op))
79fa417479SRiver Riddle return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
80fa417479SRiver Riddle return false;
810ba00878SRiver Riddle }
820ba00878SRiver Riddle
shouldAnalyzeRecursively(Operation * op) const830ba00878SRiver Riddle bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
840ba00878SRiver Riddle auto *handler = getInterfaceFor(op);
850ba00878SRiver Riddle return handler ? handler->shouldAnalyzeRecursively(op) : true;
860ba00878SRiver Riddle }
870ba00878SRiver Riddle
880ba00878SRiver Riddle /// Handle the given inlined terminator by replacing it with a new operation
890ba00878SRiver Riddle /// as necessary.
handleTerminator(Operation * op,Block * newDest) const900ba00878SRiver Riddle void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
910ba00878SRiver Riddle auto *handler = getInterfaceFor(op);
920ba00878SRiver Riddle assert(handler && "expected valid dialect handler");
930ba00878SRiver Riddle handler->handleTerminator(op, newDest);
940ba00878SRiver Riddle }
950ba00878SRiver Riddle
960ba00878SRiver Riddle /// Handle the given inlined terminator by replacing it with a new operation
970ba00878SRiver Riddle /// as necessary.
handleTerminator(Operation * op,ArrayRef<Value> valuesToRepl) const980ba00878SRiver Riddle void InlinerInterface::handleTerminator(Operation *op,
99e62a6956SRiver Riddle ArrayRef<Value> valuesToRepl) const {
1000ba00878SRiver Riddle auto *handler = getInterfaceFor(op);
1010ba00878SRiver Riddle assert(handler && "expected valid dialect handler");
1020ba00878SRiver Riddle handler->handleTerminator(op, valuesToRepl);
1030ba00878SRiver Riddle }
1040ba00878SRiver Riddle
processInlinedCallBlocks(Operation * call,iterator_range<Region::iterator> inlinedBlocks) const1050e760a08SJacques Pienaar void InlinerInterface::processInlinedCallBlocks(
1060e760a08SJacques Pienaar Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
1070e760a08SJacques Pienaar auto *handler = getInterfaceFor(call);
1080e760a08SJacques Pienaar assert(handler && "expected valid dialect handler");
1090e760a08SJacques Pienaar handler->processInlinedCallBlocks(call, inlinedBlocks);
1100e760a08SJacques Pienaar }
1110e760a08SJacques Pienaar
1120ba00878SRiver Riddle /// Utility to check that all of the operations within 'src' can be inlined.
isLegalToInline(InlinerInterface & interface,Region * src,Region * insertRegion,bool shouldCloneInlinedRegion,BlockAndValueMapping & valueMapping)1130ba00878SRiver Riddle static bool isLegalToInline(InlinerInterface &interface, Region *src,
114fa417479SRiver Riddle Region *insertRegion, bool shouldCloneInlinedRegion,
1150ba00878SRiver Riddle BlockAndValueMapping &valueMapping) {
1160ba00878SRiver Riddle for (auto &block : *src) {
1170ba00878SRiver Riddle for (auto &op : block) {
1180ba00878SRiver Riddle // Check this operation.
119fa417479SRiver Riddle if (!interface.isLegalToInline(&op, insertRegion,
120fa417479SRiver Riddle shouldCloneInlinedRegion, valueMapping)) {
121553f794bSSean Silva LLVM_DEBUG({
122553f794bSSean Silva llvm::dbgs() << "* Illegal to inline because of op: ";
123553f794bSSean Silva op.dump();
124553f794bSSean Silva });
1250ba00878SRiver Riddle return false;
126553f794bSSean Silva }
1270ba00878SRiver Riddle // Check any nested regions.
1280ba00878SRiver Riddle if (interface.shouldAnalyzeRecursively(&op) &&
1290ba00878SRiver Riddle llvm::any_of(op.getRegions(), [&](Region ®ion) {
1300ba00878SRiver Riddle return !isLegalToInline(interface, ®ion, insertRegion,
131fa417479SRiver Riddle shouldCloneInlinedRegion, valueMapping);
1320ba00878SRiver Riddle }))
1330ba00878SRiver Riddle return false;
1340ba00878SRiver Riddle }
1350ba00878SRiver Riddle }
1360ba00878SRiver Riddle return true;
1370ba00878SRiver Riddle }
1380ba00878SRiver Riddle
1390ba00878SRiver Riddle //===----------------------------------------------------------------------===//
1400ba00878SRiver Riddle // Inline Methods
1410ba00878SRiver Riddle //===----------------------------------------------------------------------===//
1420ba00878SRiver Riddle
1430e760a08SJacques Pienaar static LogicalResult
inlineRegionImpl(InlinerInterface & interface,Region * src,Block * inlineBlock,Block::iterator inlinePoint,BlockAndValueMapping & mapper,ValueRange resultsToReplace,TypeRange regionResultTypes,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion,Operation * call=nullptr)144da12d88bSRiver Riddle inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
145da12d88bSRiver Riddle Block::iterator inlinePoint, BlockAndValueMapping &mapper,
1460e760a08SJacques Pienaar ValueRange resultsToReplace, TypeRange regionResultTypes,
1470e760a08SJacques Pienaar Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
148da12d88bSRiver Riddle Operation *call = nullptr) {
14922219cfcSSean Silva assert(resultsToReplace.size() == regionResultTypes.size());
1500ba00878SRiver Riddle // We expect the region to have at least one block.
1510ba00878SRiver Riddle if (src->empty())
1520ba00878SRiver Riddle return failure();
1530ba00878SRiver Riddle
1540ba00878SRiver Riddle // Check that all of the region arguments have been mapped.
1550ba00878SRiver Riddle auto *srcEntryBlock = &src->front();
1560ba00878SRiver Riddle if (llvm::any_of(srcEntryBlock->getArguments(),
157e62a6956SRiver Riddle [&](BlockArgument arg) { return !mapper.contains(arg); }))
1580ba00878SRiver Riddle return failure();
1590ba00878SRiver Riddle
1600ba00878SRiver Riddle // Check that the operations within the source region are valid to inline.
161da12d88bSRiver Riddle Region *insertRegion = inlineBlock->getParent();
162fa417479SRiver Riddle if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
163fa417479SRiver Riddle mapper) ||
164fa417479SRiver Riddle !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
165fa417479SRiver Riddle mapper))
1660ba00878SRiver Riddle return failure();
1670ba00878SRiver Riddle
1680ba00878SRiver Riddle // Check to see if the region is being cloned, or moved inline. In either
1690ba00878SRiver Riddle // case, move the new blocks after the 'insertBlock' to improve IR
1700ba00878SRiver Riddle // readability.
171da12d88bSRiver Riddle Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
1720ba00878SRiver Riddle if (shouldCloneInlinedRegion)
1730ba00878SRiver Riddle src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
1740ba00878SRiver Riddle else
1750ba00878SRiver Riddle insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
1760ba00878SRiver Riddle src->getBlocks(), src->begin(),
1770ba00878SRiver Riddle src->end());
1780ba00878SRiver Riddle
1790ba00878SRiver Riddle // Get the range of newly inserted blocks.
180da12d88bSRiver Riddle auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
1810ba00878SRiver Riddle postInsertBlock->getIterator());
1820ba00878SRiver Riddle Block *firstNewBlock = &*newBlocks.begin();
1830ba00878SRiver Riddle
1840ba00878SRiver Riddle // Remap the locations of the inlined operations if a valid source location
1850ba00878SRiver Riddle // was provided.
1860ba00878SRiver Riddle if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
1870ba00878SRiver Riddle remapInlinedLocations(newBlocks, *inlineLoc);
1880ba00878SRiver Riddle
1890ba00878SRiver Riddle // If the blocks were moved in-place, make sure to remap any necessary
1900ba00878SRiver Riddle // operands.
1910ba00878SRiver Riddle if (!shouldCloneInlinedRegion)
1920ba00878SRiver Riddle remapInlinedOperands(newBlocks, mapper);
1930ba00878SRiver Riddle
194a20d96e4SRiver Riddle // Process the newly inlined blocks.
1950e760a08SJacques Pienaar if (call)
1960e760a08SJacques Pienaar interface.processInlinedCallBlocks(call, newBlocks);
197a20d96e4SRiver Riddle interface.processInlinedBlocks(newBlocks);
198a20d96e4SRiver Riddle
1990ba00878SRiver Riddle // Handle the case where only a single block was inlined.
2000ba00878SRiver Riddle if (std::next(newBlocks.begin()) == newBlocks.end()) {
2010ba00878SRiver Riddle // Have the interface handle the terminator of this block.
2020ba00878SRiver Riddle auto *firstBlockTerminator = firstNewBlock->getTerminator();
20322219cfcSSean Silva interface.handleTerminator(firstBlockTerminator,
20422219cfcSSean Silva llvm::to_vector<6>(resultsToReplace));
2050ba00878SRiver Riddle firstBlockTerminator->erase();
2060ba00878SRiver Riddle
2070ba00878SRiver Riddle // Merge the post insert block into the cloned entry block.
2080ba00878SRiver Riddle firstNewBlock->getOperations().splice(firstNewBlock->end(),
2090ba00878SRiver Riddle postInsertBlock->getOperations());
2100ba00878SRiver Riddle postInsertBlock->erase();
2110ba00878SRiver Riddle } else {
2120ba00878SRiver Riddle // Otherwise, there were multiple blocks inlined. Add arguments to the post
2130ba00878SRiver Riddle // insertion block to represent the results to replace.
214e4853be2SMehdi Amini for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) {
215e084679fSRiver Riddle resultToRepl.value().replaceAllUsesWith(
216e084679fSRiver Riddle postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()],
217e084679fSRiver Riddle resultToRepl.value().getLoc()));
2180ba00878SRiver Riddle }
2190ba00878SRiver Riddle
2200ba00878SRiver Riddle /// Handle the terminators for each of the new blocks.
2210ba00878SRiver Riddle for (auto &newBlock : newBlocks)
2220ba00878SRiver Riddle interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
2230ba00878SRiver Riddle }
2240ba00878SRiver Riddle
2250ba00878SRiver Riddle // Splice the instructions of the inlined entry block into the insert block.
226da12d88bSRiver Riddle inlineBlock->getOperations().splice(inlineBlock->end(),
2270ba00878SRiver Riddle firstNewBlock->getOperations());
2280ba00878SRiver Riddle firstNewBlock->erase();
2290ba00878SRiver Riddle return success();
2300ba00878SRiver Riddle }
2310ba00878SRiver Riddle
2320e760a08SJacques Pienaar static LogicalResult
inlineRegionImpl(InlinerInterface & interface,Region * src,Block * inlineBlock,Block::iterator inlinePoint,ValueRange inlinedOperands,ValueRange resultsToReplace,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion,Operation * call=nullptr)233da12d88bSRiver Riddle inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
234da12d88bSRiver Riddle Block::iterator inlinePoint, ValueRange inlinedOperands,
2350e760a08SJacques Pienaar ValueRange resultsToReplace, Optional<Location> inlineLoc,
236da12d88bSRiver Riddle bool shouldCloneInlinedRegion, Operation *call = nullptr) {
2370ba00878SRiver Riddle // We expect the region to have at least one block.
2380ba00878SRiver Riddle if (src->empty())
2390ba00878SRiver Riddle return failure();
2400ba00878SRiver Riddle
2410ba00878SRiver Riddle auto *entryBlock = &src->front();
2420ba00878SRiver Riddle if (inlinedOperands.size() != entryBlock->getNumArguments())
2430ba00878SRiver Riddle return failure();
2440ba00878SRiver Riddle
2450ba00878SRiver Riddle // Map the provided call operands to the arguments of the region.
2460ba00878SRiver Riddle BlockAndValueMapping mapper;
2470ba00878SRiver Riddle for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
2480ba00878SRiver Riddle // Verify that the types of the provided values match the function argument
2490ba00878SRiver Riddle // types.
250e62a6956SRiver Riddle BlockArgument regionArg = entryBlock->getArgument(i);
2512bdf33ccSRiver Riddle if (inlinedOperands[i].getType() != regionArg.getType())
2520ba00878SRiver Riddle return failure();
2530ba00878SRiver Riddle mapper.map(regionArg, inlinedOperands[i]);
2540ba00878SRiver Riddle }
2550ba00878SRiver Riddle
2560ba00878SRiver Riddle // Call into the main region inliner function.
257da12d88bSRiver Riddle return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
258da12d88bSRiver Riddle resultsToReplace, resultsToReplace.getTypes(),
259da12d88bSRiver Riddle inlineLoc, shouldCloneInlinedRegion, call);
2600e760a08SJacques Pienaar }
2610e760a08SJacques Pienaar
inlineRegion(InlinerInterface & interface,Region * src,Operation * inlinePoint,BlockAndValueMapping & mapper,ValueRange resultsToReplace,TypeRange regionResultTypes,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)2620e760a08SJacques Pienaar LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
2630e760a08SJacques Pienaar Operation *inlinePoint,
2640e760a08SJacques Pienaar BlockAndValueMapping &mapper,
2650e760a08SJacques Pienaar ValueRange resultsToReplace,
2660e760a08SJacques Pienaar TypeRange regionResultTypes,
2670e760a08SJacques Pienaar Optional<Location> inlineLoc,
2680e760a08SJacques Pienaar bool shouldCloneInlinedRegion) {
269da12d88bSRiver Riddle return inlineRegion(interface, src, inlinePoint->getBlock(),
270da12d88bSRiver Riddle ++inlinePoint->getIterator(), mapper, resultsToReplace,
271da12d88bSRiver Riddle regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
272da12d88bSRiver Riddle }
273da12d88bSRiver Riddle LogicalResult
inlineRegion(InlinerInterface & interface,Region * src,Block * inlineBlock,Block::iterator inlinePoint,BlockAndValueMapping & mapper,ValueRange resultsToReplace,TypeRange regionResultTypes,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)274da12d88bSRiver Riddle mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
275da12d88bSRiver Riddle Block::iterator inlinePoint, BlockAndValueMapping &mapper,
276da12d88bSRiver Riddle ValueRange resultsToReplace, TypeRange regionResultTypes,
277da12d88bSRiver Riddle Optional<Location> inlineLoc,
278da12d88bSRiver Riddle bool shouldCloneInlinedRegion) {
279da12d88bSRiver Riddle return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
280da12d88bSRiver Riddle resultsToReplace, regionResultTypes, inlineLoc,
281da12d88bSRiver Riddle shouldCloneInlinedRegion);
2820e760a08SJacques Pienaar }
2830e760a08SJacques Pienaar
inlineRegion(InlinerInterface & interface,Region * src,Operation * inlinePoint,ValueRange inlinedOperands,ValueRange resultsToReplace,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)2840e760a08SJacques Pienaar LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
2850e760a08SJacques Pienaar Operation *inlinePoint,
2860e760a08SJacques Pienaar ValueRange inlinedOperands,
2870e760a08SJacques Pienaar ValueRange resultsToReplace,
2880e760a08SJacques Pienaar Optional<Location> inlineLoc,
2890e760a08SJacques Pienaar bool shouldCloneInlinedRegion) {
290da12d88bSRiver Riddle return inlineRegion(interface, src, inlinePoint->getBlock(),
291da12d88bSRiver Riddle ++inlinePoint->getIterator(), inlinedOperands,
292da12d88bSRiver Riddle resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
293da12d88bSRiver Riddle }
294da12d88bSRiver Riddle LogicalResult
inlineRegion(InlinerInterface & interface,Region * src,Block * inlineBlock,Block::iterator inlinePoint,ValueRange inlinedOperands,ValueRange resultsToReplace,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)295da12d88bSRiver Riddle mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
296da12d88bSRiver Riddle Block::iterator inlinePoint, ValueRange inlinedOperands,
297da12d88bSRiver Riddle ValueRange resultsToReplace, Optional<Location> inlineLoc,
298da12d88bSRiver Riddle bool shouldCloneInlinedRegion) {
299da12d88bSRiver Riddle return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
300da12d88bSRiver Riddle inlinedOperands, resultsToReplace, inlineLoc,
301da12d88bSRiver Riddle shouldCloneInlinedRegion);
3020ba00878SRiver Riddle }
3030ba00878SRiver Riddle
3045830f71aSRiver Riddle /// Utility function used to generate a cast operation from the given interface,
3055830f71aSRiver Riddle /// or return nullptr if a cast could not be generated.
materializeConversion(const DialectInlinerInterface * interface,SmallVectorImpl<Operation * > & castOps,OpBuilder & castBuilder,Value arg,Type type,Location conversionLoc)306e62a6956SRiver Riddle static Value materializeConversion(const DialectInlinerInterface *interface,
3075830f71aSRiver Riddle SmallVectorImpl<Operation *> &castOps,
308e62a6956SRiver Riddle OpBuilder &castBuilder, Value arg, Type type,
309e62a6956SRiver Riddle Location conversionLoc) {
3105830f71aSRiver Riddle if (!interface)
3115830f71aSRiver Riddle return nullptr;
3125830f71aSRiver Riddle
3135830f71aSRiver Riddle // Check to see if the interface for the call can materialize a conversion.
3145830f71aSRiver Riddle Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
3155830f71aSRiver Riddle type, conversionLoc);
3165830f71aSRiver Riddle if (!castOp)
3175830f71aSRiver Riddle return nullptr;
3185830f71aSRiver Riddle castOps.push_back(castOp);
3195830f71aSRiver Riddle
3205830f71aSRiver Riddle // Ensure that the generated cast is correct.
3215830f71aSRiver Riddle assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
3225830f71aSRiver Riddle castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
3235830f71aSRiver Riddle return castOp->getResult(0);
3245830f71aSRiver Riddle }
3255830f71aSRiver Riddle
3265830f71aSRiver Riddle /// This function inlines a given region, 'src', of a callable operation,
3275830f71aSRiver Riddle /// 'callable', into the location defined by the given call operation. This
3285830f71aSRiver Riddle /// function returns failure if inlining is not possible, success otherwise. On
3295830f71aSRiver Riddle /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
3305830f71aSRiver Riddle /// corresponds to whether the source region should be cloned into the 'call' or
3315830f71aSRiver Riddle /// spliced directly.
inlineCall(InlinerInterface & interface,CallOpInterface call,CallableOpInterface callable,Region * src,bool shouldCloneInlinedRegion)3325830f71aSRiver Riddle LogicalResult mlir::inlineCall(InlinerInterface &interface,
3335830f71aSRiver Riddle CallOpInterface call,
3345830f71aSRiver Riddle CallableOpInterface callable, Region *src,
3355830f71aSRiver Riddle bool shouldCloneInlinedRegion) {
3365830f71aSRiver Riddle // We expect the region to have at least one block.
3375830f71aSRiver Riddle if (src->empty())
3385830f71aSRiver Riddle return failure();
3395830f71aSRiver Riddle auto *entryBlock = &src->front();
340c7748404SRiver Riddle ArrayRef<Type> callableResultTypes = callable.getCallableResults();
3415830f71aSRiver Riddle
3425830f71aSRiver Riddle // Make sure that the number of arguments and results matchup between the call
3435830f71aSRiver Riddle // and the region.
344e62a6956SRiver Riddle SmallVector<Value, 8> callOperands(call.getArgOperands());
345c4a04059SChristian Sigg SmallVector<Value, 8> callResults(call->getResults());
3465830f71aSRiver Riddle if (callOperands.size() != entryBlock->getNumArguments() ||
3475830f71aSRiver Riddle callResults.size() != callableResultTypes.size())
3480ba00878SRiver Riddle return failure();
3490ba00878SRiver Riddle
3505830f71aSRiver Riddle // A set of cast operations generated to matchup the signature of the region
3515830f71aSRiver Riddle // with the signature of the call.
3525830f71aSRiver Riddle SmallVector<Operation *, 4> castOps;
3535830f71aSRiver Riddle castOps.reserve(callOperands.size() + callResults.size());
3540ba00878SRiver Riddle
3555830f71aSRiver Riddle // Functor used to cleanup generated state on failure.
3565830f71aSRiver Riddle auto cleanupState = [&] {
3575830f71aSRiver Riddle for (auto *op : castOps) {
3582bdf33ccSRiver Riddle op->getResult(0).replaceAllUsesWith(op->getOperand(0));
3595830f71aSRiver Riddle op->erase();
3605830f71aSRiver Riddle }
3610ba00878SRiver Riddle return failure();
3625830f71aSRiver Riddle };
3630ba00878SRiver Riddle
3645830f71aSRiver Riddle // Builder used for any conversion operations that need to be materialized.
3655830f71aSRiver Riddle OpBuilder castBuilder(call);
3665830f71aSRiver Riddle Location castLoc = call.getLoc();
3670bf4a82aSChristian Sigg const auto *callInterface = interface.getInterfaceFor(call->getDialect());
3685830f71aSRiver Riddle
3695830f71aSRiver Riddle // Map the provided call operands to the arguments of the region.
3705830f71aSRiver Riddle BlockAndValueMapping mapper;
3715830f71aSRiver Riddle for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
372e62a6956SRiver Riddle BlockArgument regionArg = entryBlock->getArgument(i);
373e62a6956SRiver Riddle Value operand = callOperands[i];
3745830f71aSRiver Riddle
3755830f71aSRiver Riddle // If the call operand doesn't match the expected region argument, try to
3765830f71aSRiver Riddle // generate a cast.
3772bdf33ccSRiver Riddle Type regionArgType = regionArg.getType();
3782bdf33ccSRiver Riddle if (operand.getType() != regionArgType) {
3795830f71aSRiver Riddle if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
3805830f71aSRiver Riddle operand, regionArgType, castLoc)))
3815830f71aSRiver Riddle return cleanupState();
3825830f71aSRiver Riddle }
3835830f71aSRiver Riddle mapper.map(regionArg, operand);
3845830f71aSRiver Riddle }
3855830f71aSRiver Riddle
386706d992cSRahul Joshi // Ensure that the resultant values of the call match the callable.
3875830f71aSRiver Riddle castBuilder.setInsertionPointAfter(call);
3885830f71aSRiver Riddle for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
389e62a6956SRiver Riddle Value callResult = callResults[i];
3902bdf33ccSRiver Riddle if (callResult.getType() == callableResultTypes[i])
3915830f71aSRiver Riddle continue;
3925830f71aSRiver Riddle
3935830f71aSRiver Riddle // Generate a conversion that will produce the original type, so that the IR
3945830f71aSRiver Riddle // is still valid after the original call gets replaced.
395e62a6956SRiver Riddle Value castResult =
3965830f71aSRiver Riddle materializeConversion(callInterface, castOps, castBuilder, callResult,
3972bdf33ccSRiver Riddle callResult.getType(), castLoc);
3985830f71aSRiver Riddle if (!castResult)
3995830f71aSRiver Riddle return cleanupState();
4002bdf33ccSRiver Riddle callResult.replaceAllUsesWith(castResult);
4012bdf33ccSRiver Riddle castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
4025830f71aSRiver Riddle }
4035830f71aSRiver Riddle
404501fda01SRiver Riddle // Check that it is legal to inline the callable into the call.
405fa417479SRiver Riddle if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
406501fda01SRiver Riddle return cleanupState();
407501fda01SRiver Riddle
4085830f71aSRiver Riddle // Attempt to inline the call.
409da12d88bSRiver Riddle if (failed(inlineRegionImpl(interface, src, call->getBlock(),
410da12d88bSRiver Riddle ++call->getIterator(), mapper, callResults,
41122219cfcSSean Silva callableResultTypes, call.getLoc(),
4120e760a08SJacques Pienaar shouldCloneInlinedRegion, call)))
4135830f71aSRiver Riddle return cleanupState();
4145830f71aSRiver Riddle return success();
4150ba00878SRiver Riddle }
416