1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous inlining utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/InliningUtils.h"
14 
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Interfaces/CallInterfaces.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #define DEBUG_TYPE "inlining"
24 
25 using namespace mlir;
26 
27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the
28 /// provided caller location.
29 static void
30 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
31                       Location callerLoc) {
32   DenseMap<Location, Location> mappedLocations;
33   auto remapOpLoc = [&](Operation *op) {
34     auto it = mappedLocations.find(op->getLoc());
35     if (it == mappedLocations.end()) {
36       auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
37       it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
38     }
39     op->setLoc(it->second);
40   };
41   for (auto &block : inlinedBlocks)
42     block.walk(remapOpLoc);
43 }
44 
45 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
46                                  BlockAndValueMapping &mapper) {
47   auto remapOperands = [&](Operation *op) {
48     for (auto &operand : op->getOpOperands())
49       if (auto mappedOp = mapper.lookupOrNull(operand.get()))
50         operand.set(mappedOp);
51   };
52   for (auto &block : inlinedBlocks)
53     block.walk(remapOperands);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // InlinerInterface
58 //===----------------------------------------------------------------------===//
59 
60 bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable,
61                                        bool wouldBeCloned) const {
62   if (auto *handler = getInterfaceFor(call))
63     return handler->isLegalToInline(call, callable, wouldBeCloned);
64   return false;
65 }
66 
67 bool InlinerInterface::isLegalToInline(
68     Region *dest, Region *src, bool wouldBeCloned,
69     BlockAndValueMapping &valueMapping) const {
70   if (auto *handler = getInterfaceFor(dest->getParentOp()))
71     return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
72   return false;
73 }
74 
75 bool InlinerInterface::isLegalToInline(
76     Operation *op, Region *dest, bool wouldBeCloned,
77     BlockAndValueMapping &valueMapping) const {
78   if (auto *handler = getInterfaceFor(op))
79     return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
80   return false;
81 }
82 
83 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
84   auto *handler = getInterfaceFor(op);
85   return handler ? handler->shouldAnalyzeRecursively(op) : true;
86 }
87 
88 /// Handle the given inlined terminator by replacing it with a new operation
89 /// as necessary.
90 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
91   auto *handler = getInterfaceFor(op);
92   assert(handler && "expected valid dialect handler");
93   handler->handleTerminator(op, newDest);
94 }
95 
96 /// Handle the given inlined terminator by replacing it with a new operation
97 /// as necessary.
98 void InlinerInterface::handleTerminator(Operation *op,
99                                         ArrayRef<Value> valuesToRepl) const {
100   auto *handler = getInterfaceFor(op);
101   assert(handler && "expected valid dialect handler");
102   handler->handleTerminator(op, valuesToRepl);
103 }
104 
105 void InlinerInterface::processInlinedCallBlocks(
106     Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
107   auto *handler = getInterfaceFor(call);
108   assert(handler && "expected valid dialect handler");
109   handler->processInlinedCallBlocks(call, inlinedBlocks);
110 }
111 
112 /// Utility to check that all of the operations within 'src' can be inlined.
113 static bool isLegalToInline(InlinerInterface &interface, Region *src,
114                             Region *insertRegion, bool shouldCloneInlinedRegion,
115                             BlockAndValueMapping &valueMapping) {
116   for (auto &block : *src) {
117     for (auto &op : block) {
118       // Check this operation.
119       if (!interface.isLegalToInline(&op, insertRegion,
120                                      shouldCloneInlinedRegion, valueMapping)) {
121         LLVM_DEBUG({
122           llvm::dbgs() << "* Illegal to inline because of op: ";
123           op.dump();
124         });
125         return false;
126       }
127       // Check any nested regions.
128       if (interface.shouldAnalyzeRecursively(&op) &&
129           llvm::any_of(op.getRegions(), [&](Region &region) {
130             return !isLegalToInline(interface, &region, insertRegion,
131                                     shouldCloneInlinedRegion, valueMapping);
132           }))
133         return false;
134     }
135   }
136   return true;
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Inline Methods
141 //===----------------------------------------------------------------------===//
142 
143 static LogicalResult
144 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
145                  Block::iterator inlinePoint, BlockAndValueMapping &mapper,
146                  ValueRange resultsToReplace, TypeRange regionResultTypes,
147                  Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
148                  Operation *call = nullptr) {
149   assert(resultsToReplace.size() == regionResultTypes.size());
150   // We expect the region to have at least one block.
151   if (src->empty())
152     return failure();
153 
154   // Check that all of the region arguments have been mapped.
155   auto *srcEntryBlock = &src->front();
156   if (llvm::any_of(srcEntryBlock->getArguments(),
157                    [&](BlockArgument arg) { return !mapper.contains(arg); }))
158     return failure();
159 
160   // Check that the operations within the source region are valid to inline.
161   Region *insertRegion = inlineBlock->getParent();
162   if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
163                                  mapper) ||
164       !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
165                        mapper))
166     return failure();
167 
168   // Check to see if the region is being cloned, or moved inline. In either
169   // case, move the new blocks after the 'insertBlock' to improve IR
170   // readability.
171   Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
172   if (shouldCloneInlinedRegion)
173     src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
174   else
175     insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
176                                      src->getBlocks(), src->begin(),
177                                      src->end());
178 
179   // Get the range of newly inserted blocks.
180   auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
181                                     postInsertBlock->getIterator());
182   Block *firstNewBlock = &*newBlocks.begin();
183 
184   // Remap the locations of the inlined operations if a valid source location
185   // was provided.
186   if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
187     remapInlinedLocations(newBlocks, *inlineLoc);
188 
189   // If the blocks were moved in-place, make sure to remap any necessary
190   // operands.
191   if (!shouldCloneInlinedRegion)
192     remapInlinedOperands(newBlocks, mapper);
193 
194   // Process the newly inlined blocks.
195   if (call)
196     interface.processInlinedCallBlocks(call, newBlocks);
197   interface.processInlinedBlocks(newBlocks);
198 
199   // Handle the case where only a single block was inlined.
200   if (std::next(newBlocks.begin()) == newBlocks.end()) {
201     // Have the interface handle the terminator of this block.
202     auto *firstBlockTerminator = firstNewBlock->getTerminator();
203     interface.handleTerminator(firstBlockTerminator,
204                                llvm::to_vector<6>(resultsToReplace));
205     firstBlockTerminator->erase();
206 
207     // Merge the post insert block into the cloned entry block.
208     firstNewBlock->getOperations().splice(firstNewBlock->end(),
209                                           postInsertBlock->getOperations());
210     postInsertBlock->erase();
211   } else {
212     // Otherwise, there were multiple blocks inlined. Add arguments to the post
213     // insertion block to represent the results to replace.
214     for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) {
215       resultToRepl.value().replaceAllUsesWith(
216           postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()],
217                                        resultToRepl.value().getLoc()));
218     }
219 
220     /// Handle the terminators for each of the new blocks.
221     for (auto &newBlock : newBlocks)
222       interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
223   }
224 
225   // Splice the instructions of the inlined entry block into the insert block.
226   inlineBlock->getOperations().splice(inlineBlock->end(),
227                                       firstNewBlock->getOperations());
228   firstNewBlock->erase();
229   return success();
230 }
231 
232 static LogicalResult
233 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
234                  Block::iterator inlinePoint, ValueRange inlinedOperands,
235                  ValueRange resultsToReplace, Optional<Location> inlineLoc,
236                  bool shouldCloneInlinedRegion, Operation *call = nullptr) {
237   // We expect the region to have at least one block.
238   if (src->empty())
239     return failure();
240 
241   auto *entryBlock = &src->front();
242   if (inlinedOperands.size() != entryBlock->getNumArguments())
243     return failure();
244 
245   // Map the provided call operands to the arguments of the region.
246   BlockAndValueMapping mapper;
247   for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
248     // Verify that the types of the provided values match the function argument
249     // types.
250     BlockArgument regionArg = entryBlock->getArgument(i);
251     if (inlinedOperands[i].getType() != regionArg.getType())
252       return failure();
253     mapper.map(regionArg, inlinedOperands[i]);
254   }
255 
256   // Call into the main region inliner function.
257   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
258                           resultsToReplace, resultsToReplace.getTypes(),
259                           inlineLoc, shouldCloneInlinedRegion, call);
260 }
261 
262 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
263                                  Operation *inlinePoint,
264                                  BlockAndValueMapping &mapper,
265                                  ValueRange resultsToReplace,
266                                  TypeRange regionResultTypes,
267                                  Optional<Location> inlineLoc,
268                                  bool shouldCloneInlinedRegion) {
269   return inlineRegion(interface, src, inlinePoint->getBlock(),
270                       ++inlinePoint->getIterator(), mapper, resultsToReplace,
271                       regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
272 }
273 LogicalResult
274 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
275                    Block::iterator inlinePoint, BlockAndValueMapping &mapper,
276                    ValueRange resultsToReplace, TypeRange regionResultTypes,
277                    Optional<Location> inlineLoc,
278                    bool shouldCloneInlinedRegion) {
279   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
280                           resultsToReplace, regionResultTypes, inlineLoc,
281                           shouldCloneInlinedRegion);
282 }
283 
284 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
285                                  Operation *inlinePoint,
286                                  ValueRange inlinedOperands,
287                                  ValueRange resultsToReplace,
288                                  Optional<Location> inlineLoc,
289                                  bool shouldCloneInlinedRegion) {
290   return inlineRegion(interface, src, inlinePoint->getBlock(),
291                       ++inlinePoint->getIterator(), inlinedOperands,
292                       resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
293 }
294 LogicalResult
295 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
296                    Block::iterator inlinePoint, ValueRange inlinedOperands,
297                    ValueRange resultsToReplace, Optional<Location> inlineLoc,
298                    bool shouldCloneInlinedRegion) {
299   return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
300                           inlinedOperands, resultsToReplace, inlineLoc,
301                           shouldCloneInlinedRegion);
302 }
303 
304 /// Utility function used to generate a cast operation from the given interface,
305 /// or return nullptr if a cast could not be generated.
306 static Value materializeConversion(const DialectInlinerInterface *interface,
307                                    SmallVectorImpl<Operation *> &castOps,
308                                    OpBuilder &castBuilder, Value arg, Type type,
309                                    Location conversionLoc) {
310   if (!interface)
311     return nullptr;
312 
313   // Check to see if the interface for the call can materialize a conversion.
314   Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
315                                                            type, conversionLoc);
316   if (!castOp)
317     return nullptr;
318   castOps.push_back(castOp);
319 
320   // Ensure that the generated cast is correct.
321   assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
322          castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
323   return castOp->getResult(0);
324 }
325 
326 /// This function inlines a given region, 'src', of a callable operation,
327 /// 'callable', into the location defined by the given call operation. This
328 /// function returns failure if inlining is not possible, success otherwise. On
329 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
330 /// corresponds to whether the source region should be cloned into the 'call' or
331 /// spliced directly.
332 LogicalResult mlir::inlineCall(InlinerInterface &interface,
333                                CallOpInterface call,
334                                CallableOpInterface callable, Region *src,
335                                bool shouldCloneInlinedRegion) {
336   // We expect the region to have at least one block.
337   if (src->empty())
338     return failure();
339   auto *entryBlock = &src->front();
340   ArrayRef<Type> callableResultTypes = callable.getCallableResults();
341 
342   // Make sure that the number of arguments and results matchup between the call
343   // and the region.
344   SmallVector<Value, 8> callOperands(call.getArgOperands());
345   SmallVector<Value, 8> callResults(call->getResults());
346   if (callOperands.size() != entryBlock->getNumArguments() ||
347       callResults.size() != callableResultTypes.size())
348     return failure();
349 
350   // A set of cast operations generated to matchup the signature of the region
351   // with the signature of the call.
352   SmallVector<Operation *, 4> castOps;
353   castOps.reserve(callOperands.size() + callResults.size());
354 
355   // Functor used to cleanup generated state on failure.
356   auto cleanupState = [&] {
357     for (auto *op : castOps) {
358       op->getResult(0).replaceAllUsesWith(op->getOperand(0));
359       op->erase();
360     }
361     return failure();
362   };
363 
364   // Builder used for any conversion operations that need to be materialized.
365   OpBuilder castBuilder(call);
366   Location castLoc = call.getLoc();
367   const auto *callInterface = interface.getInterfaceFor(call->getDialect());
368 
369   // Map the provided call operands to the arguments of the region.
370   BlockAndValueMapping mapper;
371   for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
372     BlockArgument regionArg = entryBlock->getArgument(i);
373     Value operand = callOperands[i];
374 
375     // If the call operand doesn't match the expected region argument, try to
376     // generate a cast.
377     Type regionArgType = regionArg.getType();
378     if (operand.getType() != regionArgType) {
379       if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
380                                             operand, regionArgType, castLoc)))
381         return cleanupState();
382     }
383     mapper.map(regionArg, operand);
384   }
385 
386   // Ensure that the resultant values of the call match the callable.
387   castBuilder.setInsertionPointAfter(call);
388   for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
389     Value callResult = callResults[i];
390     if (callResult.getType() == callableResultTypes[i])
391       continue;
392 
393     // Generate a conversion that will produce the original type, so that the IR
394     // is still valid after the original call gets replaced.
395     Value castResult =
396         materializeConversion(callInterface, castOps, castBuilder, callResult,
397                               callResult.getType(), castLoc);
398     if (!castResult)
399       return cleanupState();
400     callResult.replaceAllUsesWith(castResult);
401     castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
402   }
403 
404   // Check that it is legal to inline the callable into the call.
405   if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
406     return cleanupState();
407 
408   // Attempt to inline the call.
409   if (failed(inlineRegionImpl(interface, src, call->getBlock(),
410                               ++call->getIterator(), mapper, callResults,
411                               callableResultTypes, call.getLoc(),
412                               shouldCloneInlinedRegion, call)))
413     return cleanupState();
414   return success();
415 }
416