1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous inlining utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/InliningUtils.h"
14 
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Function.h"
18 #include "mlir/IR/Operation.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #define DEBUG_TYPE "inlining"
24 
25 using namespace mlir;
26 
27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the
28 /// provided caller location.
29 static void
30 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
31                       Location callerLoc) {
32   DenseMap<Location, Location> mappedLocations;
33   auto remapOpLoc = [&](Operation *op) {
34     auto it = mappedLocations.find(op->getLoc());
35     if (it == mappedLocations.end()) {
36       auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
37       it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
38     }
39     op->setLoc(it->second);
40   };
41   for (auto &block : inlinedBlocks)
42     block.walk(remapOpLoc);
43 }
44 
45 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
46                                  BlockAndValueMapping &mapper) {
47   auto remapOperands = [&](Operation *op) {
48     for (auto &operand : op->getOpOperands())
49       if (auto mappedOp = mapper.lookupOrNull(operand.get()))
50         operand.set(mappedOp);
51   };
52   for (auto &block : inlinedBlocks)
53     block.walk(remapOperands);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // InlinerInterface
58 //===----------------------------------------------------------------------===//
59 
60 bool InlinerInterface::isLegalToInline(Operation *call,
61                                        Operation *callable) const {
62   auto *handler = getInterfaceFor(call);
63   return handler ? handler->isLegalToInline(call, callable) : false;
64 }
65 
66 bool InlinerInterface::isLegalToInline(
67     Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
68   // Regions can always be inlined into functions.
69   if (isa<FuncOp>(dest->getParentOp()))
70     return true;
71 
72   auto *handler = getInterfaceFor(dest->getParentOp());
73   return handler ? handler->isLegalToInline(dest, src, valueMapping) : false;
74 }
75 
76 bool InlinerInterface::isLegalToInline(
77     Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const {
78   auto *handler = getInterfaceFor(op);
79   return handler ? handler->isLegalToInline(op, dest, valueMapping) : false;
80 }
81 
82 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
83   auto *handler = getInterfaceFor(op);
84   return handler ? handler->shouldAnalyzeRecursively(op) : true;
85 }
86 
87 /// Handle the given inlined terminator by replacing it with a new operation
88 /// as necessary.
89 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
90   auto *handler = getInterfaceFor(op);
91   assert(handler && "expected valid dialect handler");
92   handler->handleTerminator(op, newDest);
93 }
94 
95 /// Handle the given inlined terminator by replacing it with a new operation
96 /// as necessary.
97 void InlinerInterface::handleTerminator(Operation *op,
98                                         ArrayRef<Value> valuesToRepl) const {
99   auto *handler = getInterfaceFor(op);
100   assert(handler && "expected valid dialect handler");
101   handler->handleTerminator(op, valuesToRepl);
102 }
103 
104 /// Utility to check that all of the operations within 'src' can be inlined.
105 static bool isLegalToInline(InlinerInterface &interface, Region *src,
106                             Region *insertRegion,
107                             BlockAndValueMapping &valueMapping) {
108   for (auto &block : *src) {
109     for (auto &op : block) {
110       // Check this operation.
111       if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) {
112         LLVM_DEBUG({
113           llvm::dbgs() << "* Illegal to inline because of op: ";
114           op.dump();
115         });
116         return false;
117       }
118       // Check any nested regions.
119       if (interface.shouldAnalyzeRecursively(&op) &&
120           llvm::any_of(op.getRegions(), [&](Region &region) {
121             return !isLegalToInline(interface, &region, insertRegion,
122                                     valueMapping);
123           }))
124         return false;
125     }
126   }
127   return true;
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // Inline Methods
132 //===----------------------------------------------------------------------===//
133 
134 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
135                                  Operation *inlinePoint,
136                                  BlockAndValueMapping &mapper,
137                                  ValueRange resultsToReplace,
138                                  TypeRange regionResultTypes,
139                                  Optional<Location> inlineLoc,
140                                  bool shouldCloneInlinedRegion) {
141   assert(resultsToReplace.size() == regionResultTypes.size());
142   // We expect the region to have at least one block.
143   if (src->empty())
144     return failure();
145 
146   // Check that all of the region arguments have been mapped.
147   auto *srcEntryBlock = &src->front();
148   if (llvm::any_of(srcEntryBlock->getArguments(),
149                    [&](BlockArgument arg) { return !mapper.contains(arg); }))
150     return failure();
151 
152   // The insertion point must be within a block.
153   Block *insertBlock = inlinePoint->getBlock();
154   if (!insertBlock)
155     return failure();
156   Region *insertRegion = insertBlock->getParent();
157 
158   // Check that the operations within the source region are valid to inline.
159   if (!interface.isLegalToInline(insertRegion, src, mapper) ||
160       !isLegalToInline(interface, src, insertRegion, mapper))
161     return failure();
162 
163   // Split the insertion block.
164   Block *postInsertBlock =
165       insertBlock->splitBlock(++inlinePoint->getIterator());
166 
167   // Check to see if the region is being cloned, or moved inline. In either
168   // case, move the new blocks after the 'insertBlock' to improve IR
169   // readability.
170   if (shouldCloneInlinedRegion)
171     src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
172   else
173     insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
174                                      src->getBlocks(), src->begin(),
175                                      src->end());
176 
177   // Get the range of newly inserted blocks.
178   auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()),
179                                     postInsertBlock->getIterator());
180   Block *firstNewBlock = &*newBlocks.begin();
181 
182   // Remap the locations of the inlined operations if a valid source location
183   // was provided.
184   if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
185     remapInlinedLocations(newBlocks, *inlineLoc);
186 
187   // If the blocks were moved in-place, make sure to remap any necessary
188   // operands.
189   if (!shouldCloneInlinedRegion)
190     remapInlinedOperands(newBlocks, mapper);
191 
192   // Process the newly inlined blocks.
193   interface.processInlinedBlocks(newBlocks);
194 
195   // Handle the case where only a single block was inlined.
196   if (std::next(newBlocks.begin()) == newBlocks.end()) {
197     // Have the interface handle the terminator of this block.
198     auto *firstBlockTerminator = firstNewBlock->getTerminator();
199     interface.handleTerminator(firstBlockTerminator,
200                                llvm::to_vector<6>(resultsToReplace));
201     firstBlockTerminator->erase();
202 
203     // Merge the post insert block into the cloned entry block.
204     firstNewBlock->getOperations().splice(firstNewBlock->end(),
205                                           postInsertBlock->getOperations());
206     postInsertBlock->erase();
207   } else {
208     // Otherwise, there were multiple blocks inlined. Add arguments to the post
209     // insertion block to represent the results to replace.
210     for (auto resultToRepl : llvm::enumerate(resultsToReplace)) {
211       resultToRepl.value().replaceAllUsesWith(postInsertBlock->addArgument(
212           regionResultTypes[resultToRepl.index()]));
213     }
214 
215     /// Handle the terminators for each of the new blocks.
216     for (auto &newBlock : newBlocks)
217       interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
218   }
219 
220   // Splice the instructions of the inlined entry block into the insert block.
221   insertBlock->getOperations().splice(insertBlock->end(),
222                                       firstNewBlock->getOperations());
223   firstNewBlock->erase();
224   return success();
225 }
226 
227 /// This function is an overload of the above 'inlineRegion' that allows for
228 /// providing the set of operands ('inlinedOperands') that should be used
229 /// in-favor of the region arguments when inlining.
230 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
231                                  Operation *inlinePoint,
232                                  ValueRange inlinedOperands,
233                                  ValueRange resultsToReplace,
234                                  Optional<Location> inlineLoc,
235                                  bool shouldCloneInlinedRegion) {
236   // We expect the region to have at least one block.
237   if (src->empty())
238     return failure();
239 
240   auto *entryBlock = &src->front();
241   if (inlinedOperands.size() != entryBlock->getNumArguments())
242     return failure();
243 
244   // Map the provided call operands to the arguments of the region.
245   BlockAndValueMapping mapper;
246   for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
247     // Verify that the types of the provided values match the function argument
248     // types.
249     BlockArgument regionArg = entryBlock->getArgument(i);
250     if (inlinedOperands[i].getType() != regionArg.getType())
251       return failure();
252     mapper.map(regionArg, inlinedOperands[i]);
253   }
254 
255   // Call into the main region inliner function.
256   return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace,
257                       resultsToReplace.getTypes(), inlineLoc,
258                       shouldCloneInlinedRegion);
259 }
260 
261 /// Utility function used to generate a cast operation from the given interface,
262 /// or return nullptr if a cast could not be generated.
263 static Value materializeConversion(const DialectInlinerInterface *interface,
264                                    SmallVectorImpl<Operation *> &castOps,
265                                    OpBuilder &castBuilder, Value arg, Type type,
266                                    Location conversionLoc) {
267   if (!interface)
268     return nullptr;
269 
270   // Check to see if the interface for the call can materialize a conversion.
271   Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
272                                                            type, conversionLoc);
273   if (!castOp)
274     return nullptr;
275   castOps.push_back(castOp);
276 
277   // Ensure that the generated cast is correct.
278   assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
279          castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
280   return castOp->getResult(0);
281 }
282 
283 /// This function inlines a given region, 'src', of a callable operation,
284 /// 'callable', into the location defined by the given call operation. This
285 /// function returns failure if inlining is not possible, success otherwise. On
286 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
287 /// corresponds to whether the source region should be cloned into the 'call' or
288 /// spliced directly.
289 LogicalResult mlir::inlineCall(InlinerInterface &interface,
290                                CallOpInterface call,
291                                CallableOpInterface callable, Region *src,
292                                bool shouldCloneInlinedRegion) {
293   // We expect the region to have at least one block.
294   if (src->empty())
295     return failure();
296   auto *entryBlock = &src->front();
297   ArrayRef<Type> callableResultTypes = callable.getCallableResults();
298 
299   // Make sure that the number of arguments and results matchup between the call
300   // and the region.
301   SmallVector<Value, 8> callOperands(call.getArgOperands());
302   SmallVector<Value, 8> callResults(call.getOperation()->getResults());
303   if (callOperands.size() != entryBlock->getNumArguments() ||
304       callResults.size() != callableResultTypes.size())
305     return failure();
306 
307   // A set of cast operations generated to matchup the signature of the region
308   // with the signature of the call.
309   SmallVector<Operation *, 4> castOps;
310   castOps.reserve(callOperands.size() + callResults.size());
311 
312   // Functor used to cleanup generated state on failure.
313   auto cleanupState = [&] {
314     for (auto *op : castOps) {
315       op->getResult(0).replaceAllUsesWith(op->getOperand(0));
316       op->erase();
317     }
318     return failure();
319   };
320 
321   // Builder used for any conversion operations that need to be materialized.
322   OpBuilder castBuilder(call);
323   Location castLoc = call.getLoc();
324   auto *callInterface = interface.getInterfaceFor(call.getDialect());
325 
326   // Map the provided call operands to the arguments of the region.
327   BlockAndValueMapping mapper;
328   for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
329     BlockArgument regionArg = entryBlock->getArgument(i);
330     Value operand = callOperands[i];
331 
332     // If the call operand doesn't match the expected region argument, try to
333     // generate a cast.
334     Type regionArgType = regionArg.getType();
335     if (operand.getType() != regionArgType) {
336       if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
337                                             operand, regionArgType, castLoc)))
338         return cleanupState();
339     }
340     mapper.map(regionArg, operand);
341   }
342 
343   // Ensure that the resultant values of the call match the callable.
344   castBuilder.setInsertionPointAfter(call);
345   for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
346     Value callResult = callResults[i];
347     if (callResult.getType() == callableResultTypes[i])
348       continue;
349 
350     // Generate a conversion that will produce the original type, so that the IR
351     // is still valid after the original call gets replaced.
352     Value castResult =
353         materializeConversion(callInterface, castOps, castBuilder, callResult,
354                               callResult.getType(), castLoc);
355     if (!castResult)
356       return cleanupState();
357     callResult.replaceAllUsesWith(castResult);
358     castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
359   }
360 
361   // Check that it is legal to inline the callable into the call.
362   if (!interface.isLegalToInline(call, callable))
363     return cleanupState();
364 
365   // Attempt to inline the call.
366   if (failed(inlineRegion(interface, src, call, mapper, callResults,
367                           callableResultTypes, call.getLoc(),
368                           shouldCloneInlinedRegion)))
369     return cleanupState();
370   return success();
371 }
372