1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements miscellaneous inlining utilities.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Transforms/InliningUtils.h"
23 
24 #include "mlir/IR/BlockAndValueMapping.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/Function.h"
27 #include "mlir/IR/Operation.h"
28 #include "llvm/ADT/MapVector.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 #define DEBUG_TYPE "inlining"
32 
33 using namespace mlir;
34 
35 /// Remap locations from the inlined blocks with CallSiteLoc locations with the
36 /// provided caller location.
37 static void
38 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
39                       Location callerLoc) {
40   DenseMap<Location, Location> mappedLocations;
41   auto remapOpLoc = [&](Operation *op) {
42     auto it = mappedLocations.find(op->getLoc());
43     if (it == mappedLocations.end()) {
44       auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
45       it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
46     }
47     op->setLoc(it->second);
48   };
49   for (auto &block : inlinedBlocks)
50     block.walk(remapOpLoc);
51 }
52 
53 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
54                                  BlockAndValueMapping &mapper) {
55   auto remapOperands = [&](Operation *op) {
56     for (auto &operand : op->getOpOperands())
57       if (auto *mappedOp = mapper.lookupOrNull(operand.get()))
58         operand.set(mappedOp);
59   };
60   for (auto &block : inlinedBlocks)
61     block.walk(remapOperands);
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // InlinerInterface
66 //===----------------------------------------------------------------------===//
67 
68 bool InlinerInterface::isLegalToInline(
69     Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
70   // Regions can always be inlined into functions.
71   if (isa<FuncOp>(dest->getParentOp()))
72     return true;
73 
74   auto *handler = getInterfaceFor(dest->getParentOp());
75   return handler ? handler->isLegalToInline(dest, src, valueMapping) : false;
76 }
77 
78 bool InlinerInterface::isLegalToInline(
79     Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const {
80   auto *handler = getInterfaceFor(op);
81   return handler ? handler->isLegalToInline(op, dest, valueMapping) : false;
82 }
83 
84 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
85   auto *handler = getInterfaceFor(op);
86   return handler ? handler->shouldAnalyzeRecursively(op) : true;
87 }
88 
89 /// Handle the given inlined terminator by replacing it with a new operation
90 /// as necessary.
91 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
92   auto *handler = getInterfaceFor(op);
93   assert(handler && "expected valid dialect handler");
94   handler->handleTerminator(op, newDest);
95 }
96 
97 /// Handle the given inlined terminator by replacing it with a new operation
98 /// as necessary.
99 void InlinerInterface::handleTerminator(Operation *op,
100                                         ArrayRef<Value *> valuesToRepl) const {
101   auto *handler = getInterfaceFor(op);
102   assert(handler && "expected valid dialect handler");
103   handler->handleTerminator(op, valuesToRepl);
104 }
105 
106 /// Utility to check that all of the operations within 'src' can be inlined.
107 static bool isLegalToInline(InlinerInterface &interface, Region *src,
108                             Region *insertRegion,
109                             BlockAndValueMapping &valueMapping) {
110   for (auto &block : *src) {
111     for (auto &op : block) {
112       // Check this operation.
113       if (!interface.isLegalToInline(&op, insertRegion, valueMapping))
114         return false;
115       // Check any nested regions.
116       if (interface.shouldAnalyzeRecursively(&op) &&
117           llvm::any_of(op.getRegions(), [&](Region &region) {
118             return !isLegalToInline(interface, &region, insertRegion,
119                                     valueMapping);
120           }))
121         return false;
122     }
123   }
124   return true;
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // Inline Methods
129 //===----------------------------------------------------------------------===//
130 
131 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
132                                  Operation *inlinePoint,
133                                  BlockAndValueMapping &mapper,
134                                  ArrayRef<Value *> resultsToReplace,
135                                  Optional<Location> inlineLoc,
136                                  bool shouldCloneInlinedRegion) {
137   // We expect the region to have at least one block.
138   if (src->empty())
139     return failure();
140 
141   // Check that all of the region arguments have been mapped.
142   auto *srcEntryBlock = &src->front();
143   if (llvm::any_of(srcEntryBlock->getArguments(),
144                    [&](BlockArgument *arg) { return !mapper.contains(arg); }))
145     return failure();
146 
147   // The insertion point must be within a block.
148   Block *insertBlock = inlinePoint->getBlock();
149   if (!insertBlock)
150     return failure();
151   Region *insertRegion = insertBlock->getParent();
152 
153   // Check that the operations within the source region are valid to inline.
154   if (!interface.isLegalToInline(insertRegion, src, mapper) ||
155       !isLegalToInline(interface, src, insertRegion, mapper))
156     return failure();
157 
158   // Split the insertion block.
159   Block *postInsertBlock =
160       insertBlock->splitBlock(++inlinePoint->getIterator());
161 
162   // Check to see if the region is being cloned, or moved inline. In either
163   // case, move the new blocks after the 'insertBlock' to improve IR
164   // readability.
165   if (shouldCloneInlinedRegion)
166     src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
167   else
168     insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
169                                      src->getBlocks(), src->begin(),
170                                      src->end());
171 
172   // Get the range of newly inserted blocks.
173   auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()),
174                                     postInsertBlock->getIterator());
175   Block *firstNewBlock = &*newBlocks.begin();
176 
177   // Remap the locations of the inlined operations if a valid source location
178   // was provided.
179   if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
180     remapInlinedLocations(newBlocks, *inlineLoc);
181 
182   // If the blocks were moved in-place, make sure to remap any necessary
183   // operands.
184   if (!shouldCloneInlinedRegion)
185     remapInlinedOperands(newBlocks, mapper);
186 
187   // Process the newly inlined blocks.
188   interface.processInlinedBlocks(newBlocks);
189 
190   // Handle the case where only a single block was inlined.
191   if (std::next(newBlocks.begin()) == newBlocks.end()) {
192     // Have the interface handle the terminator of this block.
193     auto *firstBlockTerminator = firstNewBlock->getTerminator();
194     interface.handleTerminator(firstBlockTerminator, resultsToReplace);
195     firstBlockTerminator->erase();
196 
197     // Merge the post insert block into the cloned entry block.
198     firstNewBlock->getOperations().splice(firstNewBlock->end(),
199                                           postInsertBlock->getOperations());
200     postInsertBlock->erase();
201   } else {
202     // Otherwise, there were multiple blocks inlined. Add arguments to the post
203     // insertion block to represent the results to replace.
204     for (Value *resultToRepl : resultsToReplace) {
205       resultToRepl->replaceAllUsesWith(
206           postInsertBlock->addArgument(resultToRepl->getType()));
207     }
208 
209     /// Handle the terminators for each of the new blocks.
210     for (auto &newBlock : newBlocks)
211       interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
212   }
213 
214   // Splice the instructions of the inlined entry block into the insert block.
215   insertBlock->getOperations().splice(insertBlock->end(),
216                                       firstNewBlock->getOperations());
217   firstNewBlock->erase();
218   return success();
219 }
220 
221 /// This function is an overload of the above 'inlineRegion' that allows for
222 /// providing the set of operands ('inlinedOperands') that should be used
223 /// in-favor of the region arguments when inlining.
224 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
225                                  Operation *inlinePoint,
226                                  ArrayRef<Value *> inlinedOperands,
227                                  ArrayRef<Value *> resultsToReplace,
228                                  Optional<Location> inlineLoc,
229                                  bool shouldCloneInlinedRegion) {
230   // We expect the region to have at least one block.
231   if (src->empty())
232     return failure();
233 
234   auto *entryBlock = &src->front();
235   if (inlinedOperands.size() != entryBlock->getNumArguments())
236     return failure();
237 
238   // Map the provided call operands to the arguments of the region.
239   BlockAndValueMapping mapper;
240   for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
241     // Verify that the types of the provided values match the function argument
242     // types.
243     BlockArgument *regionArg = entryBlock->getArgument(i);
244     if (inlinedOperands[i]->getType() != regionArg->getType())
245       return failure();
246     mapper.map(regionArg, inlinedOperands[i]);
247   }
248 
249   // Call into the main region inliner function.
250   return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace,
251                       inlineLoc, shouldCloneInlinedRegion);
252 }
253 
254 /// Utility function used to generate a cast operation from the given interface,
255 /// or return nullptr if a cast could not be generated.
256 static Value *materializeConversion(const DialectInlinerInterface *interface,
257                                     SmallVectorImpl<Operation *> &castOps,
258                                     OpBuilder &castBuilder, Value *arg,
259                                     Type type, Location conversionLoc) {
260   if (!interface)
261     return nullptr;
262 
263   // Check to see if the interface for the call can materialize a conversion.
264   Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
265                                                            type, conversionLoc);
266   if (!castOp)
267     return nullptr;
268   castOps.push_back(castOp);
269 
270   // Ensure that the generated cast is correct.
271   assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
272          castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
273   return castOp->getResult(0);
274 }
275 
276 /// This function inlines a given region, 'src', of a callable operation,
277 /// 'callable', into the location defined by the given call operation. This
278 /// function returns failure if inlining is not possible, success otherwise. On
279 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
280 /// corresponds to whether the source region should be cloned into the 'call' or
281 /// spliced directly.
282 LogicalResult mlir::inlineCall(InlinerInterface &interface,
283                                CallOpInterface call,
284                                CallableOpInterface callable, Region *src,
285                                bool shouldCloneInlinedRegion) {
286   // We expect the region to have at least one block.
287   if (src->empty())
288     return failure();
289   auto *entryBlock = &src->front();
290   ArrayRef<Type> callableResultTypes = callable.getCallableResults(src);
291 
292   // Make sure that the number of arguments and results matchup between the call
293   // and the region.
294   SmallVector<Value *, 8> callOperands(call.getArgOperands());
295   SmallVector<Value *, 8> callResults(call.getOperation()->getResults());
296   if (callOperands.size() != entryBlock->getNumArguments() ||
297       callResults.size() != callableResultTypes.size())
298     return failure();
299 
300   // A set of cast operations generated to matchup the signature of the region
301   // with the signature of the call.
302   SmallVector<Operation *, 4> castOps;
303   castOps.reserve(callOperands.size() + callResults.size());
304 
305   // Functor used to cleanup generated state on failure.
306   auto cleanupState = [&] {
307     for (auto *op : castOps) {
308       op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
309       op->erase();
310     }
311     return failure();
312   };
313 
314   // Builder used for any conversion operations that need to be materialized.
315   OpBuilder castBuilder(call);
316   Location castLoc = call.getLoc();
317   auto *callInterface = interface.getInterfaceFor(call.getDialect());
318 
319   // Map the provided call operands to the arguments of the region.
320   BlockAndValueMapping mapper;
321   for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
322     BlockArgument *regionArg = entryBlock->getArgument(i);
323     Value *operand = callOperands[i];
324 
325     // If the call operand doesn't match the expected region argument, try to
326     // generate a cast.
327     Type regionArgType = regionArg->getType();
328     if (operand->getType() != regionArgType) {
329       if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
330                                             operand, regionArgType, castLoc)))
331         return cleanupState();
332     }
333     mapper.map(regionArg, operand);
334   }
335 
336   // Ensure that the resultant values of the call, match the callable.
337   castBuilder.setInsertionPointAfter(call);
338   for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
339     Value *callResult = callResults[i];
340     if (callResult->getType() == callableResultTypes[i])
341       continue;
342 
343     // Generate a conversion that will produce the original type, so that the IR
344     // is still valid after the original call gets replaced.
345     Value *castResult =
346         materializeConversion(callInterface, castOps, castBuilder, callResult,
347                               callResult->getType(), castLoc);
348     if (!castResult)
349       return cleanupState();
350     callResult->replaceAllUsesWith(castResult);
351     castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult);
352   }
353 
354   // Attempt to inline the call.
355   if (failed(inlineRegion(interface, src, call, mapper, callResults,
356                           call.getLoc(), shouldCloneInlinedRegion)))
357     return cleanupState();
358   return success();
359 }
360