1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous inlining utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/InliningUtils.h"
14 
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Operation.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #define DEBUG_TYPE "inlining"
24 
25 using namespace mlir;
26 
27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the
28 /// provided caller location.
29 static void
30 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
31                       Location callerLoc) {
32   DenseMap<Location, Location> mappedLocations;
33   auto remapOpLoc = [&](Operation *op) {
34     auto it = mappedLocations.find(op->getLoc());
35     if (it == mappedLocations.end()) {
36       auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
37       it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
38     }
39     op->setLoc(it->second);
40   };
41   for (auto &block : inlinedBlocks)
42     block.walk(remapOpLoc);
43 }
44 
45 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
46                                  BlockAndValueMapping &mapper) {
47   auto remapOperands = [&](Operation *op) {
48     for (auto &operand : op->getOpOperands())
49       if (auto mappedOp = mapper.lookupOrNull(operand.get()))
50         operand.set(mappedOp);
51   };
52   for (auto &block : inlinedBlocks)
53     block.walk(remapOperands);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // InlinerInterface
58 //===----------------------------------------------------------------------===//
59 
60 bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable,
61                                        bool wouldBeCloned) const {
62   if (auto *handler = getInterfaceFor(call))
63     return handler->isLegalToInline(call, callable, wouldBeCloned);
64   return false;
65 }
66 
67 bool InlinerInterface::isLegalToInline(
68     Region *dest, Region *src, bool wouldBeCloned,
69     BlockAndValueMapping &valueMapping) const {
70   // Regions can always be inlined into functions.
71   if (isa<FuncOp>(dest->getParentOp()))
72     return true;
73 
74   if (auto *handler = getInterfaceFor(dest->getParentOp()))
75     return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
76   return false;
77 }
78 
79 bool InlinerInterface::isLegalToInline(
80     Operation *op, Region *dest, bool wouldBeCloned,
81     BlockAndValueMapping &valueMapping) const {
82   if (auto *handler = getInterfaceFor(op))
83     return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
84   return false;
85 }
86 
87 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
88   auto *handler = getInterfaceFor(op);
89   return handler ? handler->shouldAnalyzeRecursively(op) : true;
90 }
91 
92 /// Handle the given inlined terminator by replacing it with a new operation
93 /// as necessary.
94 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
95   auto *handler = getInterfaceFor(op);
96   assert(handler && "expected valid dialect handler");
97   handler->handleTerminator(op, newDest);
98 }
99 
100 /// Handle the given inlined terminator by replacing it with a new operation
101 /// as necessary.
102 void InlinerInterface::handleTerminator(Operation *op,
103                                         ArrayRef<Value> valuesToRepl) const {
104   auto *handler = getInterfaceFor(op);
105   assert(handler && "expected valid dialect handler");
106   handler->handleTerminator(op, valuesToRepl);
107 }
108 
109 void InlinerInterface::processInlinedCallBlocks(
110     Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
111   auto *handler = getInterfaceFor(call);
112   assert(handler && "expected valid dialect handler");
113   handler->processInlinedCallBlocks(call, inlinedBlocks);
114 }
115 
116 /// Utility to check that all of the operations within 'src' can be inlined.
117 static bool isLegalToInline(InlinerInterface &interface, Region *src,
118                             Region *insertRegion, bool shouldCloneInlinedRegion,
119                             BlockAndValueMapping &valueMapping) {
120   for (auto &block : *src) {
121     for (auto &op : block) {
122       // Check this operation.
123       if (!interface.isLegalToInline(&op, insertRegion,
124                                      shouldCloneInlinedRegion, valueMapping)) {
125         LLVM_DEBUG({
126           llvm::dbgs() << "* Illegal to inline because of op: ";
127           op.dump();
128         });
129         return false;
130       }
131       // Check any nested regions.
132       if (interface.shouldAnalyzeRecursively(&op) &&
133           llvm::any_of(op.getRegions(), [&](Region &region) {
134             return !isLegalToInline(interface, &region, insertRegion,
135                                     shouldCloneInlinedRegion, valueMapping);
136           }))
137         return false;
138     }
139   }
140   return true;
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // Inline Methods
145 //===----------------------------------------------------------------------===//
146 
147 static LogicalResult
148 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
149                  Block::iterator inlinePoint, BlockAndValueMapping &mapper,
150                  ValueRange resultsToReplace, TypeRange regionResultTypes,
151                  Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
152                  Operation *call = nullptr) {
153   assert(resultsToReplace.size() == regionResultTypes.size());
154   // We expect the region to have at least one block.
155   if (src->empty())
156     return failure();
157 
158   // Check that all of the region arguments have been mapped.
159   auto *srcEntryBlock = &src->front();
160   if (llvm::any_of(srcEntryBlock->getArguments(),
161                    [&](BlockArgument arg) { return !mapper.contains(arg); }))
162     return failure();
163 
164   // Check that the operations within the source region are valid to inline.
165   Region *insertRegion = inlineBlock->getParent();
166   if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
167                                  mapper) ||
168       !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
169                        mapper))
170     return failure();
171 
172   // Check to see if the region is being cloned, or moved inline. In either
173   // case, move the new blocks after the 'insertBlock' to improve IR
174   // readability.
175   Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
176   if (shouldCloneInlinedRegion)
177     src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
178   else
179     insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
180                                      src->getBlocks(), src->begin(),
181                                      src->end());
182 
183   // Get the range of newly inserted blocks.
184   auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
185                                     postInsertBlock->getIterator());
186   Block *firstNewBlock = &*newBlocks.begin();
187 
188   // Remap the locations of the inlined operations if a valid source location
189   // was provided.
190   if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
191     remapInlinedLocations(newBlocks, *inlineLoc);
192 
193   // If the blocks were moved in-place, make sure to remap any necessary
194   // operands.
195   if (!shouldCloneInlinedRegion)
196     remapInlinedOperands(newBlocks, mapper);
197 
198   // Process the newly inlined blocks.
199   if (call)
200     interface.processInlinedCallBlocks(call, newBlocks);
201   interface.processInlinedBlocks(newBlocks);
202 
203   // Handle the case where only a single block was inlined.
204   if (std::next(newBlocks.begin()) == newBlocks.end()) {
205     // Have the interface handle the terminator of this block.
206     auto *firstBlockTerminator = firstNewBlock->getTerminator();
207     interface.handleTerminator(firstBlockTerminator,
208                                llvm::to_vector<6>(resultsToReplace));
209     firstBlockTerminator->erase();
210 
211     // Merge the post insert block into the cloned entry block.
212     firstNewBlock->getOperations().splice(firstNewBlock->end(),
213                                           postInsertBlock->getOperations());
214     postInsertBlock->erase();
215   } else {
216     // Otherwise, there were multiple blocks inlined. Add arguments to the post
217     // insertion block to represent the results to replace.
218     for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) {
219       resultToRepl.value().replaceAllUsesWith(
220           postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()],
221                                        resultToRepl.value().getLoc()));
222     }
223 
224     /// Handle the terminators for each of the new blocks.
225     for (auto &newBlock : newBlocks)
226       interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
227   }
228 
229   // Splice the instructions of the inlined entry block into the insert block.
230   inlineBlock->getOperations().splice(inlineBlock->end(),
231                                       firstNewBlock->getOperations());
232   firstNewBlock->erase();
233   return success();
234 }
235 
236 static LogicalResult
237 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
238                  Block::iterator inlinePoint, ValueRange inlinedOperands,
239                  ValueRange resultsToReplace, Optional<Location> inlineLoc,
240                  bool shouldCloneInlinedRegion, Operation *call = nullptr) {
241   // We expect the region to have at least one block.
242   if (src->empty())
243     return failure();
244 
245   auto *entryBlock = &src->front();
246   if (inlinedOperands.size() != entryBlock->getNumArguments())
247     return failure();
248 
249   // Map the provided call operands to the arguments of the region.
250   BlockAndValueMapping mapper;
251   for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
252     // Verify that the types of the provided values match the function argument
253     // types.
254     BlockArgument regionArg = entryBlock->getArgument(i);
255     if (inlinedOperands[i].getType() != regionArg.getType())
256       return failure();
257     mapper.map(regionArg, inlinedOperands[i]);
258   }
259 
260   // Call into the main region inliner function.
261   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
262                           resultsToReplace, resultsToReplace.getTypes(),
263                           inlineLoc, shouldCloneInlinedRegion, call);
264 }
265 
266 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
267                                  Operation *inlinePoint,
268                                  BlockAndValueMapping &mapper,
269                                  ValueRange resultsToReplace,
270                                  TypeRange regionResultTypes,
271                                  Optional<Location> inlineLoc,
272                                  bool shouldCloneInlinedRegion) {
273   return inlineRegion(interface, src, inlinePoint->getBlock(),
274                       ++inlinePoint->getIterator(), mapper, resultsToReplace,
275                       regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
276 }
277 LogicalResult
278 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
279                    Block::iterator inlinePoint, BlockAndValueMapping &mapper,
280                    ValueRange resultsToReplace, TypeRange regionResultTypes,
281                    Optional<Location> inlineLoc,
282                    bool shouldCloneInlinedRegion) {
283   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
284                           resultsToReplace, regionResultTypes, inlineLoc,
285                           shouldCloneInlinedRegion);
286 }
287 
288 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
289                                  Operation *inlinePoint,
290                                  ValueRange inlinedOperands,
291                                  ValueRange resultsToReplace,
292                                  Optional<Location> inlineLoc,
293                                  bool shouldCloneInlinedRegion) {
294   return inlineRegion(interface, src, inlinePoint->getBlock(),
295                       ++inlinePoint->getIterator(), inlinedOperands,
296                       resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
297 }
298 LogicalResult
299 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
300                    Block::iterator inlinePoint, ValueRange inlinedOperands,
301                    ValueRange resultsToReplace, Optional<Location> inlineLoc,
302                    bool shouldCloneInlinedRegion) {
303   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
304                           inlinedOperands, resultsToReplace, inlineLoc,
305                           shouldCloneInlinedRegion);
306 }
307 
308 /// Utility function used to generate a cast operation from the given interface,
309 /// or return nullptr if a cast could not be generated.
310 static Value materializeConversion(const DialectInlinerInterface *interface,
311                                    SmallVectorImpl<Operation *> &castOps,
312                                    OpBuilder &castBuilder, Value arg, Type type,
313                                    Location conversionLoc) {
314   if (!interface)
315     return nullptr;
316 
317   // Check to see if the interface for the call can materialize a conversion.
318   Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
319                                                            type, conversionLoc);
320   if (!castOp)
321     return nullptr;
322   castOps.push_back(castOp);
323 
324   // Ensure that the generated cast is correct.
325   assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
326          castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
327   return castOp->getResult(0);
328 }
329 
330 /// This function inlines a given region, 'src', of a callable operation,
331 /// 'callable', into the location defined by the given call operation. This
332 /// function returns failure if inlining is not possible, success otherwise. On
333 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
334 /// corresponds to whether the source region should be cloned into the 'call' or
335 /// spliced directly.
336 LogicalResult mlir::inlineCall(InlinerInterface &interface,
337                                CallOpInterface call,
338                                CallableOpInterface callable, Region *src,
339                                bool shouldCloneInlinedRegion) {
340   // We expect the region to have at least one block.
341   if (src->empty())
342     return failure();
343   auto *entryBlock = &src->front();
344   ArrayRef<Type> callableResultTypes = callable.getCallableResults();
345 
346   // Make sure that the number of arguments and results matchup between the call
347   // and the region.
348   SmallVector<Value, 8> callOperands(call.getArgOperands());
349   SmallVector<Value, 8> callResults(call->getResults());
350   if (callOperands.size() != entryBlock->getNumArguments() ||
351       callResults.size() != callableResultTypes.size())
352     return failure();
353 
354   // A set of cast operations generated to matchup the signature of the region
355   // with the signature of the call.
356   SmallVector<Operation *, 4> castOps;
357   castOps.reserve(callOperands.size() + callResults.size());
358 
359   // Functor used to cleanup generated state on failure.
360   auto cleanupState = [&] {
361     for (auto *op : castOps) {
362       op->getResult(0).replaceAllUsesWith(op->getOperand(0));
363       op->erase();
364     }
365     return failure();
366   };
367 
368   // Builder used for any conversion operations that need to be materialized.
369   OpBuilder castBuilder(call);
370   Location castLoc = call.getLoc();
371   const auto *callInterface = interface.getInterfaceFor(call->getDialect());
372 
373   // Map the provided call operands to the arguments of the region.
374   BlockAndValueMapping mapper;
375   for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
376     BlockArgument regionArg = entryBlock->getArgument(i);
377     Value operand = callOperands[i];
378 
379     // If the call operand doesn't match the expected region argument, try to
380     // generate a cast.
381     Type regionArgType = regionArg.getType();
382     if (operand.getType() != regionArgType) {
383       if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
384                                             operand, regionArgType, castLoc)))
385         return cleanupState();
386     }
387     mapper.map(regionArg, operand);
388   }
389 
390   // Ensure that the resultant values of the call match the callable.
391   castBuilder.setInsertionPointAfter(call);
392   for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
393     Value callResult = callResults[i];
394     if (callResult.getType() == callableResultTypes[i])
395       continue;
396 
397     // Generate a conversion that will produce the original type, so that the IR
398     // is still valid after the original call gets replaced.
399     Value castResult =
400         materializeConversion(callInterface, castOps, castBuilder, callResult,
401                               callResult.getType(), castLoc);
402     if (!castResult)
403       return cleanupState();
404     callResult.replaceAllUsesWith(castResult);
405     castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
406   }
407 
408   // Check that it is legal to inline the callable into the call.
409   if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
410     return cleanupState();
411 
412   // Attempt to inline the call.
413   if (failed(inlineRegionImpl(interface, src, call->getBlock(),
414                               ++call->getIterator(), mapper, callResults,
415                               callableResultTypes, call.getLoc(),
416                               shouldCloneInlinedRegion, call)))
417     return cleanupState();
418   return success();
419 }
420