1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous inlining utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/InliningUtils.h"
14 
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Function.h"
18 #include "mlir/IR/Operation.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #define DEBUG_TYPE "inlining"
24 
25 using namespace mlir;
26 
27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the
28 /// provided caller location.
29 static void
30 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
31                       Location callerLoc) {
32   DenseMap<Location, Location> mappedLocations;
33   auto remapOpLoc = [&](Operation *op) {
34     auto it = mappedLocations.find(op->getLoc());
35     if (it == mappedLocations.end()) {
36       auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
37       it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
38     }
39     op->setLoc(it->second);
40   };
41   for (auto &block : inlinedBlocks)
42     block.walk(remapOpLoc);
43 }
44 
45 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
46                                  BlockAndValueMapping &mapper) {
47   auto remapOperands = [&](Operation *op) {
48     for (auto &operand : op->getOpOperands())
49       if (auto mappedOp = mapper.lookupOrNull(operand.get()))
50         operand.set(mappedOp);
51   };
52   for (auto &block : inlinedBlocks)
53     block.walk(remapOperands);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // InlinerInterface
58 //===----------------------------------------------------------------------===//
59 
60 bool InlinerInterface::isLegalToInline(
61     Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
62   // Regions can always be inlined into functions.
63   if (isa<FuncOp>(dest->getParentOp()))
64     return true;
65 
66   auto *handler = getInterfaceFor(dest->getParentOp());
67   return handler ? handler->isLegalToInline(dest, src, valueMapping) : false;
68 }
69 
70 bool InlinerInterface::isLegalToInline(
71     Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const {
72   auto *handler = getInterfaceFor(op);
73   return handler ? handler->isLegalToInline(op, dest, valueMapping) : false;
74 }
75 
76 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
77   auto *handler = getInterfaceFor(op);
78   return handler ? handler->shouldAnalyzeRecursively(op) : true;
79 }
80 
81 /// Handle the given inlined terminator by replacing it with a new operation
82 /// as necessary.
83 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
84   auto *handler = getInterfaceFor(op);
85   assert(handler && "expected valid dialect handler");
86   handler->handleTerminator(op, newDest);
87 }
88 
89 /// Handle the given inlined terminator by replacing it with a new operation
90 /// as necessary.
91 void InlinerInterface::handleTerminator(Operation *op,
92                                         ArrayRef<Value> valuesToRepl) const {
93   auto *handler = getInterfaceFor(op);
94   assert(handler && "expected valid dialect handler");
95   handler->handleTerminator(op, valuesToRepl);
96 }
97 
98 /// Utility to check that all of the operations within 'src' can be inlined.
99 static bool isLegalToInline(InlinerInterface &interface, Region *src,
100                             Region *insertRegion,
101                             BlockAndValueMapping &valueMapping) {
102   for (auto &block : *src) {
103     for (auto &op : block) {
104       // Check this operation.
105       if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) {
106         LLVM_DEBUG({
107           llvm::dbgs() << "* Illegal to inline because of op: ";
108           op.dump();
109         });
110         return false;
111       }
112       // Check any nested regions.
113       if (interface.shouldAnalyzeRecursively(&op) &&
114           llvm::any_of(op.getRegions(), [&](Region &region) {
115             return !isLegalToInline(interface, &region, insertRegion,
116                                     valueMapping);
117           }))
118         return false;
119     }
120   }
121   return true;
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // Inline Methods
126 //===----------------------------------------------------------------------===//
127 
128 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
129                                  Operation *inlinePoint,
130                                  BlockAndValueMapping &mapper,
131                                  ArrayRef<Value> resultsToReplace,
132                                  Optional<Location> inlineLoc,
133                                  bool shouldCloneInlinedRegion) {
134   // We expect the region to have at least one block.
135   if (src->empty())
136     return failure();
137 
138   // Check that all of the region arguments have been mapped.
139   auto *srcEntryBlock = &src->front();
140   if (llvm::any_of(srcEntryBlock->getArguments(),
141                    [&](BlockArgument arg) { return !mapper.contains(arg); }))
142     return failure();
143 
144   // The insertion point must be within a block.
145   Block *insertBlock = inlinePoint->getBlock();
146   if (!insertBlock)
147     return failure();
148   Region *insertRegion = insertBlock->getParent();
149 
150   // Check that the operations within the source region are valid to inline.
151   if (!interface.isLegalToInline(insertRegion, src, mapper) ||
152       !isLegalToInline(interface, src, insertRegion, mapper))
153     return failure();
154 
155   // Split the insertion block.
156   Block *postInsertBlock =
157       insertBlock->splitBlock(++inlinePoint->getIterator());
158 
159   // Check to see if the region is being cloned, or moved inline. In either
160   // case, move the new blocks after the 'insertBlock' to improve IR
161   // readability.
162   if (shouldCloneInlinedRegion)
163     src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
164   else
165     insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
166                                      src->getBlocks(), src->begin(),
167                                      src->end());
168 
169   // Get the range of newly inserted blocks.
170   auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()),
171                                     postInsertBlock->getIterator());
172   Block *firstNewBlock = &*newBlocks.begin();
173 
174   // Remap the locations of the inlined operations if a valid source location
175   // was provided.
176   if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
177     remapInlinedLocations(newBlocks, *inlineLoc);
178 
179   // If the blocks were moved in-place, make sure to remap any necessary
180   // operands.
181   if (!shouldCloneInlinedRegion)
182     remapInlinedOperands(newBlocks, mapper);
183 
184   // Process the newly inlined blocks.
185   interface.processInlinedBlocks(newBlocks);
186 
187   // Handle the case where only a single block was inlined.
188   if (std::next(newBlocks.begin()) == newBlocks.end()) {
189     // Have the interface handle the terminator of this block.
190     auto *firstBlockTerminator = firstNewBlock->getTerminator();
191     interface.handleTerminator(firstBlockTerminator, resultsToReplace);
192     firstBlockTerminator->erase();
193 
194     // Merge the post insert block into the cloned entry block.
195     firstNewBlock->getOperations().splice(firstNewBlock->end(),
196                                           postInsertBlock->getOperations());
197     postInsertBlock->erase();
198   } else {
199     // Otherwise, there were multiple blocks inlined. Add arguments to the post
200     // insertion block to represent the results to replace.
201     for (Value resultToRepl : resultsToReplace) {
202       resultToRepl.replaceAllUsesWith(
203           postInsertBlock->addArgument(resultToRepl.getType()));
204     }
205 
206     /// Handle the terminators for each of the new blocks.
207     for (auto &newBlock : newBlocks)
208       interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
209   }
210 
211   // Splice the instructions of the inlined entry block into the insert block.
212   insertBlock->getOperations().splice(insertBlock->end(),
213                                       firstNewBlock->getOperations());
214   firstNewBlock->erase();
215   return success();
216 }
217 
218 /// This function is an overload of the above 'inlineRegion' that allows for
219 /// providing the set of operands ('inlinedOperands') that should be used
220 /// in-favor of the region arguments when inlining.
221 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
222                                  Operation *inlinePoint,
223                                  ArrayRef<Value> inlinedOperands,
224                                  ArrayRef<Value> resultsToReplace,
225                                  Optional<Location> inlineLoc,
226                                  bool shouldCloneInlinedRegion) {
227   // We expect the region to have at least one block.
228   if (src->empty())
229     return failure();
230 
231   auto *entryBlock = &src->front();
232   if (inlinedOperands.size() != entryBlock->getNumArguments())
233     return failure();
234 
235   // Map the provided call operands to the arguments of the region.
236   BlockAndValueMapping mapper;
237   for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
238     // Verify that the types of the provided values match the function argument
239     // types.
240     BlockArgument regionArg = entryBlock->getArgument(i);
241     if (inlinedOperands[i].getType() != regionArg.getType())
242       return failure();
243     mapper.map(regionArg, inlinedOperands[i]);
244   }
245 
246   // Call into the main region inliner function.
247   return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace,
248                       inlineLoc, shouldCloneInlinedRegion);
249 }
250 
251 /// Utility function used to generate a cast operation from the given interface,
252 /// or return nullptr if a cast could not be generated.
253 static Value materializeConversion(const DialectInlinerInterface *interface,
254                                    SmallVectorImpl<Operation *> &castOps,
255                                    OpBuilder &castBuilder, Value arg, Type type,
256                                    Location conversionLoc) {
257   if (!interface)
258     return nullptr;
259 
260   // Check to see if the interface for the call can materialize a conversion.
261   Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
262                                                            type, conversionLoc);
263   if (!castOp)
264     return nullptr;
265   castOps.push_back(castOp);
266 
267   // Ensure that the generated cast is correct.
268   assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
269          castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
270   return castOp->getResult(0);
271 }
272 
273 /// This function inlines a given region, 'src', of a callable operation,
274 /// 'callable', into the location defined by the given call operation. This
275 /// function returns failure if inlining is not possible, success otherwise. On
276 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
277 /// corresponds to whether the source region should be cloned into the 'call' or
278 /// spliced directly.
279 LogicalResult mlir::inlineCall(InlinerInterface &interface,
280                                CallOpInterface call,
281                                CallableOpInterface callable, Region *src,
282                                bool shouldCloneInlinedRegion) {
283   // We expect the region to have at least one block.
284   if (src->empty())
285     return failure();
286   auto *entryBlock = &src->front();
287   ArrayRef<Type> callableResultTypes = callable.getCallableResults(src);
288 
289   // Make sure that the number of arguments and results matchup between the call
290   // and the region.
291   SmallVector<Value, 8> callOperands(call.getArgOperands());
292   SmallVector<Value, 8> callResults(call.getOperation()->getResults());
293   if (callOperands.size() != entryBlock->getNumArguments() ||
294       callResults.size() != callableResultTypes.size())
295     return failure();
296 
297   // A set of cast operations generated to matchup the signature of the region
298   // with the signature of the call.
299   SmallVector<Operation *, 4> castOps;
300   castOps.reserve(callOperands.size() + callResults.size());
301 
302   // Functor used to cleanup generated state on failure.
303   auto cleanupState = [&] {
304     for (auto *op : castOps) {
305       op->getResult(0).replaceAllUsesWith(op->getOperand(0));
306       op->erase();
307     }
308     return failure();
309   };
310 
311   // Builder used for any conversion operations that need to be materialized.
312   OpBuilder castBuilder(call);
313   Location castLoc = call.getLoc();
314   auto *callInterface = interface.getInterfaceFor(call.getDialect());
315 
316   // Map the provided call operands to the arguments of the region.
317   BlockAndValueMapping mapper;
318   for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
319     BlockArgument regionArg = entryBlock->getArgument(i);
320     Value operand = callOperands[i];
321 
322     // If the call operand doesn't match the expected region argument, try to
323     // generate a cast.
324     Type regionArgType = regionArg.getType();
325     if (operand.getType() != regionArgType) {
326       if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
327                                             operand, regionArgType, castLoc)))
328         return cleanupState();
329     }
330     mapper.map(regionArg, operand);
331   }
332 
333   // Ensure that the resultant values of the call, match the callable.
334   castBuilder.setInsertionPointAfter(call);
335   for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
336     Value callResult = callResults[i];
337     if (callResult.getType() == callableResultTypes[i])
338       continue;
339 
340     // Generate a conversion that will produce the original type, so that the IR
341     // is still valid after the original call gets replaced.
342     Value castResult =
343         materializeConversion(callInterface, castOps, castBuilder, callResult,
344                               callResult.getType(), castLoc);
345     if (!castResult)
346       return cleanupState();
347     callResult.replaceAllUsesWith(castResult);
348     castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
349   }
350 
351   // Attempt to inline the call.
352   if (failed(inlineRegion(interface, src, call, mapper, callResults,
353                           call.getLoc(), shouldCloneInlinedRegion)))
354     return cleanupState();
355   return success();
356 }
357