1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements miscellaneous inlining utilities.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Transforms/InliningUtils.h"
23 
24 #include "mlir/IR/BlockAndValueMapping.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/Function.h"
27 #include "mlir/IR/Operation.h"
28 #include "llvm/ADT/MapVector.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/raw_ostream.h"
31 
32 #define DEBUG_TYPE "inlining"
33 
34 using namespace mlir;
35 
36 /// Remap locations from the inlined blocks with CallSiteLoc locations with the
37 /// provided caller location.
38 static void
39 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
40                       Location callerLoc) {
41   DenseMap<Location, Location> mappedLocations;
42   auto remapOpLoc = [&](Operation *op) {
43     auto it = mappedLocations.find(op->getLoc());
44     if (it == mappedLocations.end()) {
45       auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
46       it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
47     }
48     op->setLoc(it->second);
49   };
50   for (auto &block : inlinedBlocks)
51     block.walk(remapOpLoc);
52 }
53 
54 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
55                                  BlockAndValueMapping &mapper) {
56   auto remapOperands = [&](Operation *op) {
57     for (auto &operand : op->getOpOperands())
58       if (auto mappedOp = mapper.lookupOrNull(operand.get()))
59         operand.set(mappedOp);
60   };
61   for (auto &block : inlinedBlocks)
62     block.walk(remapOperands);
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // InlinerInterface
67 //===----------------------------------------------------------------------===//
68 
69 bool InlinerInterface::isLegalToInline(
70     Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
71   // Regions can always be inlined into functions.
72   if (isa<FuncOp>(dest->getParentOp()))
73     return true;
74 
75   auto *handler = getInterfaceFor(dest->getParentOp());
76   return handler ? handler->isLegalToInline(dest, src, valueMapping) : false;
77 }
78 
79 bool InlinerInterface::isLegalToInline(
80     Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const {
81   auto *handler = getInterfaceFor(op);
82   return handler ? handler->isLegalToInline(op, dest, valueMapping) : false;
83 }
84 
85 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
86   auto *handler = getInterfaceFor(op);
87   return handler ? handler->shouldAnalyzeRecursively(op) : true;
88 }
89 
90 /// Handle the given inlined terminator by replacing it with a new operation
91 /// as necessary.
92 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
93   auto *handler = getInterfaceFor(op);
94   assert(handler && "expected valid dialect handler");
95   handler->handleTerminator(op, newDest);
96 }
97 
98 /// Handle the given inlined terminator by replacing it with a new operation
99 /// as necessary.
100 void InlinerInterface::handleTerminator(Operation *op,
101                                         ArrayRef<ValuePtr> valuesToRepl) const {
102   auto *handler = getInterfaceFor(op);
103   assert(handler && "expected valid dialect handler");
104   handler->handleTerminator(op, valuesToRepl);
105 }
106 
107 /// Utility to check that all of the operations within 'src' can be inlined.
108 static bool isLegalToInline(InlinerInterface &interface, Region *src,
109                             Region *insertRegion,
110                             BlockAndValueMapping &valueMapping) {
111   for (auto &block : *src) {
112     for (auto &op : block) {
113       // Check this operation.
114       if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) {
115         LLVM_DEBUG({
116           llvm::dbgs() << "* Illegal to inline because of op: ";
117           op.dump();
118         });
119         return false;
120       }
121       // Check any nested regions.
122       if (interface.shouldAnalyzeRecursively(&op) &&
123           llvm::any_of(op.getRegions(), [&](Region &region) {
124             return !isLegalToInline(interface, &region, insertRegion,
125                                     valueMapping);
126           }))
127         return false;
128     }
129   }
130   return true;
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // Inline Methods
135 //===----------------------------------------------------------------------===//
136 
137 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
138                                  Operation *inlinePoint,
139                                  BlockAndValueMapping &mapper,
140                                  ArrayRef<ValuePtr> resultsToReplace,
141                                  Optional<Location> inlineLoc,
142                                  bool shouldCloneInlinedRegion) {
143   // We expect the region to have at least one block.
144   if (src->empty())
145     return failure();
146 
147   // Check that all of the region arguments have been mapped.
148   auto *srcEntryBlock = &src->front();
149   if (llvm::any_of(srcEntryBlock->getArguments(),
150                    [&](BlockArgumentPtr arg) { return !mapper.contains(arg); }))
151     return failure();
152 
153   // The insertion point must be within a block.
154   Block *insertBlock = inlinePoint->getBlock();
155   if (!insertBlock)
156     return failure();
157   Region *insertRegion = insertBlock->getParent();
158 
159   // Check that the operations within the source region are valid to inline.
160   if (!interface.isLegalToInline(insertRegion, src, mapper) ||
161       !isLegalToInline(interface, src, insertRegion, mapper))
162     return failure();
163 
164   // Split the insertion block.
165   Block *postInsertBlock =
166       insertBlock->splitBlock(++inlinePoint->getIterator());
167 
168   // Check to see if the region is being cloned, or moved inline. In either
169   // case, move the new blocks after the 'insertBlock' to improve IR
170   // readability.
171   if (shouldCloneInlinedRegion)
172     src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
173   else
174     insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
175                                      src->getBlocks(), src->begin(),
176                                      src->end());
177 
178   // Get the range of newly inserted blocks.
179   auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()),
180                                     postInsertBlock->getIterator());
181   Block *firstNewBlock = &*newBlocks.begin();
182 
183   // Remap the locations of the inlined operations if a valid source location
184   // was provided.
185   if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
186     remapInlinedLocations(newBlocks, *inlineLoc);
187 
188   // If the blocks were moved in-place, make sure to remap any necessary
189   // operands.
190   if (!shouldCloneInlinedRegion)
191     remapInlinedOperands(newBlocks, mapper);
192 
193   // Process the newly inlined blocks.
194   interface.processInlinedBlocks(newBlocks);
195 
196   // Handle the case where only a single block was inlined.
197   if (std::next(newBlocks.begin()) == newBlocks.end()) {
198     // Have the interface handle the terminator of this block.
199     auto *firstBlockTerminator = firstNewBlock->getTerminator();
200     interface.handleTerminator(firstBlockTerminator, resultsToReplace);
201     firstBlockTerminator->erase();
202 
203     // Merge the post insert block into the cloned entry block.
204     firstNewBlock->getOperations().splice(firstNewBlock->end(),
205                                           postInsertBlock->getOperations());
206     postInsertBlock->erase();
207   } else {
208     // Otherwise, there were multiple blocks inlined. Add arguments to the post
209     // insertion block to represent the results to replace.
210     for (ValuePtr resultToRepl : resultsToReplace) {
211       resultToRepl->replaceAllUsesWith(
212           postInsertBlock->addArgument(resultToRepl->getType()));
213     }
214 
215     /// Handle the terminators for each of the new blocks.
216     for (auto &newBlock : newBlocks)
217       interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
218   }
219 
220   // Splice the instructions of the inlined entry block into the insert block.
221   insertBlock->getOperations().splice(insertBlock->end(),
222                                       firstNewBlock->getOperations());
223   firstNewBlock->erase();
224   return success();
225 }
226 
227 /// This function is an overload of the above 'inlineRegion' that allows for
228 /// providing the set of operands ('inlinedOperands') that should be used
229 /// in-favor of the region arguments when inlining.
230 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
231                                  Operation *inlinePoint,
232                                  ArrayRef<ValuePtr> inlinedOperands,
233                                  ArrayRef<ValuePtr> resultsToReplace,
234                                  Optional<Location> inlineLoc,
235                                  bool shouldCloneInlinedRegion) {
236   // We expect the region to have at least one block.
237   if (src->empty())
238     return failure();
239 
240   auto *entryBlock = &src->front();
241   if (inlinedOperands.size() != entryBlock->getNumArguments())
242     return failure();
243 
244   // Map the provided call operands to the arguments of the region.
245   BlockAndValueMapping mapper;
246   for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
247     // Verify that the types of the provided values match the function argument
248     // types.
249     BlockArgumentPtr regionArg = entryBlock->getArgument(i);
250     if (inlinedOperands[i]->getType() != regionArg->getType())
251       return failure();
252     mapper.map(regionArg, inlinedOperands[i]);
253   }
254 
255   // Call into the main region inliner function.
256   return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace,
257                       inlineLoc, shouldCloneInlinedRegion);
258 }
259 
260 /// Utility function used to generate a cast operation from the given interface,
261 /// or return nullptr if a cast could not be generated.
262 static ValuePtr materializeConversion(const DialectInlinerInterface *interface,
263                                       SmallVectorImpl<Operation *> &castOps,
264                                       OpBuilder &castBuilder, ValuePtr arg,
265                                       Type type, Location conversionLoc) {
266   if (!interface)
267     return nullptr;
268 
269   // Check to see if the interface for the call can materialize a conversion.
270   Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
271                                                            type, conversionLoc);
272   if (!castOp)
273     return nullptr;
274   castOps.push_back(castOp);
275 
276   // Ensure that the generated cast is correct.
277   assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
278          castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
279   return castOp->getResult(0);
280 }
281 
282 /// This function inlines a given region, 'src', of a callable operation,
283 /// 'callable', into the location defined by the given call operation. This
284 /// function returns failure if inlining is not possible, success otherwise. On
285 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
286 /// corresponds to whether the source region should be cloned into the 'call' or
287 /// spliced directly.
288 LogicalResult mlir::inlineCall(InlinerInterface &interface,
289                                CallOpInterface call,
290                                CallableOpInterface callable, Region *src,
291                                bool shouldCloneInlinedRegion) {
292   // We expect the region to have at least one block.
293   if (src->empty())
294     return failure();
295   auto *entryBlock = &src->front();
296   ArrayRef<Type> callableResultTypes = callable.getCallableResults(src);
297 
298   // Make sure that the number of arguments and results matchup between the call
299   // and the region.
300   SmallVector<ValuePtr, 8> callOperands(call.getArgOperands());
301   SmallVector<ValuePtr, 8> callResults(call.getOperation()->getResults());
302   if (callOperands.size() != entryBlock->getNumArguments() ||
303       callResults.size() != callableResultTypes.size())
304     return failure();
305 
306   // A set of cast operations generated to matchup the signature of the region
307   // with the signature of the call.
308   SmallVector<Operation *, 4> castOps;
309   castOps.reserve(callOperands.size() + callResults.size());
310 
311   // Functor used to cleanup generated state on failure.
312   auto cleanupState = [&] {
313     for (auto *op : castOps) {
314       op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
315       op->erase();
316     }
317     return failure();
318   };
319 
320   // Builder used for any conversion operations that need to be materialized.
321   OpBuilder castBuilder(call);
322   Location castLoc = call.getLoc();
323   auto *callInterface = interface.getInterfaceFor(call.getDialect());
324 
325   // Map the provided call operands to the arguments of the region.
326   BlockAndValueMapping mapper;
327   for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
328     BlockArgumentPtr regionArg = entryBlock->getArgument(i);
329     ValuePtr operand = callOperands[i];
330 
331     // If the call operand doesn't match the expected region argument, try to
332     // generate a cast.
333     Type regionArgType = regionArg->getType();
334     if (operand->getType() != regionArgType) {
335       if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
336                                             operand, regionArgType, castLoc)))
337         return cleanupState();
338     }
339     mapper.map(regionArg, operand);
340   }
341 
342   // Ensure that the resultant values of the call, match the callable.
343   castBuilder.setInsertionPointAfter(call);
344   for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
345     ValuePtr callResult = callResults[i];
346     if (callResult->getType() == callableResultTypes[i])
347       continue;
348 
349     // Generate a conversion that will produce the original type, so that the IR
350     // is still valid after the original call gets replaced.
351     ValuePtr castResult =
352         materializeConversion(callInterface, castOps, castBuilder, callResult,
353                               callResult->getType(), castLoc);
354     if (!castResult)
355       return cleanupState();
356     callResult->replaceAllUsesWith(castResult);
357     castResult->getDefiningOp()->replaceUsesOfWith(castResult, callResult);
358   }
359 
360   // Attempt to inline the call.
361   if (failed(inlineRegion(interface, src, call, mapper, callResults,
362                           call.getLoc(), shouldCloneInlinedRegion)))
363     return cleanupState();
364   return success();
365 }
366