1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements miscellaneous inlining utilities.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Transforms/InliningUtils.h"
23 
24 #include "mlir/IR/BlockAndValueMapping.h"
25 #include "mlir/IR/Function.h"
26 #include "mlir/IR/Operation.h"
27 #include "llvm/ADT/MapVector.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 #define DEBUG_TYPE "inlining"
31 
32 using namespace mlir;
33 
34 /// Remap locations from the inlined blocks with CallSiteLoc locations with the
35 /// provided caller location.
36 static void
37 remapInlinedLocations(llvm::iterator_range<Region::iterator> inlinedBlocks,
38                       Location callerLoc) {
39   DenseMap<Location, Location> mappedLocations;
40   auto remapOpLoc = [&](Operation *op) {
41     auto it = mappedLocations.find(op->getLoc());
42     if (it == mappedLocations.end()) {
43       auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
44       it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
45     }
46     op->setLoc(it->second);
47   };
48   for (auto &block : inlinedBlocks)
49     block.walk(remapOpLoc);
50 }
51 
52 static void
53 remapInlinedOperands(llvm::iterator_range<Region::iterator> inlinedBlocks,
54                      BlockAndValueMapping &mapper) {
55   auto remapOperands = [&](Operation *op) {
56     for (auto &operand : op->getOpOperands())
57       if (auto *mappedOp = mapper.lookupOrNull(operand.get()))
58         operand.set(mappedOp);
59   };
60   for (auto &block : inlinedBlocks)
61     block.walk(remapOperands);
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // InlinerInterface
66 //===----------------------------------------------------------------------===//
67 
68 InlinerInterface::~InlinerInterface() {}
69 
70 bool InlinerInterface::isLegalToInline(
71     Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
72   // Regions can always be inlined into functions.
73   if (isa<FuncOp>(dest->getParentOp()))
74     return true;
75 
76   auto *handler = getInterfaceFor(dest->getParentOp());
77   return handler ? handler->isLegalToInline(src, dest, valueMapping) : false;
78 }
79 
80 bool InlinerInterface::isLegalToInline(
81     Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const {
82   auto *handler = getInterfaceFor(op);
83   return handler ? handler->isLegalToInline(op, dest, valueMapping) : false;
84 }
85 
86 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
87   auto *handler = getInterfaceFor(op);
88   return handler ? handler->shouldAnalyzeRecursively(op) : true;
89 }
90 
91 /// Handle the given inlined terminator by replacing it with a new operation
92 /// as necessary.
93 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
94   auto *handler = getInterfaceFor(op);
95   assert(handler && "expected valid dialect handler");
96   handler->handleTerminator(op, newDest);
97 }
98 
99 /// Handle the given inlined terminator by replacing it with a new operation
100 /// as necessary.
101 void InlinerInterface::handleTerminator(Operation *op,
102                                         ArrayRef<Value *> valuesToRepl) const {
103   auto *handler = getInterfaceFor(op);
104   assert(handler && "expected valid dialect handler");
105   handler->handleTerminator(op, valuesToRepl);
106 }
107 
108 /// Utility to check that all of the operations within 'src' can be inlined.
109 static bool isLegalToInline(InlinerInterface &interface, Region *src,
110                             Region *insertRegion,
111                             BlockAndValueMapping &valueMapping) {
112   for (auto &block : *src) {
113     for (auto &op : block) {
114       // Check this operation.
115       if (!interface.isLegalToInline(&op, insertRegion, valueMapping))
116         return false;
117       // Check any nested regions.
118       if (interface.shouldAnalyzeRecursively(&op) &&
119           llvm::any_of(op.getRegions(), [&](Region &region) {
120             return !isLegalToInline(interface, &region, insertRegion,
121                                     valueMapping);
122           }))
123         return false;
124     }
125   }
126   return true;
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // Inline Methods
131 //===----------------------------------------------------------------------===//
132 
133 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
134                                  Operation *inlinePoint,
135                                  BlockAndValueMapping &mapper,
136                                  ArrayRef<Value *> resultsToReplace,
137                                  llvm::Optional<Location> inlineLoc,
138                                  bool shouldCloneInlinedRegion) {
139   // We expect the region to have at least one block.
140   if (src->empty())
141     return failure();
142 
143   // Check that all of the region arguments have been mapped.
144   auto *srcEntryBlock = &src->front();
145   if (llvm::any_of(srcEntryBlock->getArguments(),
146                    [&](BlockArgument *arg) { return !mapper.contains(arg); }))
147     return failure();
148 
149   // The insertion point must be within a block.
150   Block *insertBlock = inlinePoint->getBlock();
151   if (!insertBlock)
152     return failure();
153   Region *insertRegion = insertBlock->getParent();
154 
155   // Check that the operations within the source region are valid to inline.
156   if (!interface.isLegalToInline(insertRegion, src, mapper) ||
157       !isLegalToInline(interface, src, insertRegion, mapper))
158     return failure();
159 
160   // Split the insertion block.
161   Block *postInsertBlock =
162       insertBlock->splitBlock(++inlinePoint->getIterator());
163 
164   // Check to see if the region is being cloned, or moved inline. In either
165   // case, move the new blocks after the 'insertBlock' to improve IR
166   // readability.
167   if (shouldCloneInlinedRegion)
168     src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
169   else
170     insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
171                                      src->getBlocks(), src->begin(),
172                                      src->end());
173 
174   // Get the range of newly inserted blocks.
175   auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()),
176                                     postInsertBlock->getIterator());
177   Block *firstNewBlock = &*newBlocks.begin();
178 
179   // Remap the locations of the inlined operations if a valid source location
180   // was provided.
181   if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
182     remapInlinedLocations(newBlocks, *inlineLoc);
183 
184   // If the blocks were moved in-place, make sure to remap any necessary
185   // operands.
186   if (!shouldCloneInlinedRegion)
187     remapInlinedOperands(newBlocks, mapper);
188 
189   // Process the newly inlined blocks.
190   interface.processInlinedBlocks(newBlocks);
191 
192   // Handle the case where only a single block was inlined.
193   if (std::next(newBlocks.begin()) == newBlocks.end()) {
194     // Have the interface handle the terminator of this block.
195     auto *firstBlockTerminator = firstNewBlock->getTerminator();
196     interface.handleTerminator(firstBlockTerminator, resultsToReplace);
197     firstBlockTerminator->erase();
198 
199     // Merge the post insert block into the cloned entry block.
200     firstNewBlock->getOperations().splice(firstNewBlock->end(),
201                                           postInsertBlock->getOperations());
202     postInsertBlock->erase();
203   } else {
204     // Otherwise, there were multiple blocks inlined. Add arguments to the post
205     // insertion block to represent the results to replace.
206     for (Value *resultToRepl : resultsToReplace) {
207       resultToRepl->replaceAllUsesWith(
208           postInsertBlock->addArgument(resultToRepl->getType()));
209     }
210 
211     /// Handle the terminators for each of the new blocks.
212     for (auto &newBlock : newBlocks)
213       interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
214   }
215 
216   // Splice the instructions of the inlined entry block into the insert block.
217   insertBlock->getOperations().splice(insertBlock->end(),
218                                       firstNewBlock->getOperations());
219   firstNewBlock->erase();
220   return success();
221 }
222 
223 /// This function is an overload of the above 'inlineRegion' that allows for
224 /// providing the set of operands ('inlinedOperands') that should be used
225 /// in-favor of the region arguments when inlining.
226 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
227                                  Operation *inlinePoint,
228                                  ArrayRef<Value *> inlinedOperands,
229                                  ArrayRef<Value *> resultsToReplace,
230                                  llvm::Optional<Location> inlineLoc,
231                                  bool shouldCloneInlinedRegion) {
232   // We expect the region to have at least one block.
233   if (src->empty())
234     return failure();
235 
236   auto *entryBlock = &src->front();
237   if (inlinedOperands.size() != entryBlock->getNumArguments())
238     return failure();
239 
240   // Map the provided call operands to the arguments of the region.
241   BlockAndValueMapping mapper;
242   for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
243     // Verify that the types of the provided values match the function argument
244     // types.
245     BlockArgument *regionArg = entryBlock->getArgument(i);
246     if (inlinedOperands[i]->getType() != regionArg->getType())
247       return failure();
248     mapper.map(regionArg, inlinedOperands[i]);
249   }
250 
251   // Call into the main region inliner function.
252   return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace,
253                       inlineLoc, shouldCloneInlinedRegion);
254 }
255 
256 /// This function inlines a FuncOp into another. This function returns failure
257 /// if it is not possible to inline this FuncOp. If the function returned
258 /// failure, then no changes to the module have been made.
259 ///
260 /// Note that this only does one level of inlining.  For example, if the
261 /// instruction 'call B' is inlined, and 'B' calls 'C', then the call to 'C' now
262 /// exists in the instruction stream.  Similarly this will inline a recursive
263 /// FuncOp by one level.
264 ///
265 LogicalResult mlir::inlineFunction(InlinerInterface &interface, FuncOp callee,
266                                    Operation *inlinePoint,
267                                    ArrayRef<Value *> callOperands,
268                                    ArrayRef<Value *> callResults,
269                                    Location inlineLoc) {
270   // We don't inline if the provided callee function is a declaration.
271   assert(callee && "expected valid function to inline");
272   if (callee.isExternal())
273     return failure();
274 
275   // Verify that the provided arguments match the function arguments.
276   if (callOperands.size() != callee.getNumArguments())
277     return failure();
278 
279   // Verify that the provided values to replace match the function results.
280   auto funcResultTypes = callee.getType().getResults();
281   if (callResults.size() != funcResultTypes.size())
282     return failure();
283   for (unsigned i = 0, e = callResults.size(); i != e; ++i)
284     if (callResults[i]->getType() != funcResultTypes[i])
285       return failure();
286 
287   // Call into the main region inliner function.
288   return inlineRegion(interface, &callee.getBody(), inlinePoint, callOperands,
289                       callResults, inlineLoc);
290 }
291