1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous inlining utilities.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Transforms/InliningUtils.h"
14
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Interfaces/CallInterfaces.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22
23 #define DEBUG_TYPE "inlining"
24
25 using namespace mlir;
26
27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the
28 /// provided caller location.
29 static void
remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,Location callerLoc)30 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
31 Location callerLoc) {
32 DenseMap<Location, Location> mappedLocations;
33 auto remapOpLoc = [&](Operation *op) {
34 auto it = mappedLocations.find(op->getLoc());
35 if (it == mappedLocations.end()) {
36 auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
37 it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
38 }
39 op->setLoc(it->second);
40 };
41 for (auto &block : inlinedBlocks)
42 block.walk(remapOpLoc);
43 }
44
remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,BlockAndValueMapping & mapper)45 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
46 BlockAndValueMapping &mapper) {
47 auto remapOperands = [&](Operation *op) {
48 for (auto &operand : op->getOpOperands())
49 if (auto mappedOp = mapper.lookupOrNull(operand.get()))
50 operand.set(mappedOp);
51 };
52 for (auto &block : inlinedBlocks)
53 block.walk(remapOperands);
54 }
55
56 //===----------------------------------------------------------------------===//
57 // InlinerInterface
58 //===----------------------------------------------------------------------===//
59
isLegalToInline(Operation * call,Operation * callable,bool wouldBeCloned) const60 bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable,
61 bool wouldBeCloned) const {
62 if (auto *handler = getInterfaceFor(call))
63 return handler->isLegalToInline(call, callable, wouldBeCloned);
64 return false;
65 }
66
isLegalToInline(Region * dest,Region * src,bool wouldBeCloned,BlockAndValueMapping & valueMapping) const67 bool InlinerInterface::isLegalToInline(
68 Region *dest, Region *src, bool wouldBeCloned,
69 BlockAndValueMapping &valueMapping) const {
70 if (auto *handler = getInterfaceFor(dest->getParentOp()))
71 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
72 return false;
73 }
74
isLegalToInline(Operation * op,Region * dest,bool wouldBeCloned,BlockAndValueMapping & valueMapping) const75 bool InlinerInterface::isLegalToInline(
76 Operation *op, Region *dest, bool wouldBeCloned,
77 BlockAndValueMapping &valueMapping) const {
78 if (auto *handler = getInterfaceFor(op))
79 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
80 return false;
81 }
82
shouldAnalyzeRecursively(Operation * op) const83 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
84 auto *handler = getInterfaceFor(op);
85 return handler ? handler->shouldAnalyzeRecursively(op) : true;
86 }
87
88 /// Handle the given inlined terminator by replacing it with a new operation
89 /// as necessary.
handleTerminator(Operation * op,Block * newDest) const90 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
91 auto *handler = getInterfaceFor(op);
92 assert(handler && "expected valid dialect handler");
93 handler->handleTerminator(op, newDest);
94 }
95
96 /// Handle the given inlined terminator by replacing it with a new operation
97 /// as necessary.
handleTerminator(Operation * op,ArrayRef<Value> valuesToRepl) const98 void InlinerInterface::handleTerminator(Operation *op,
99 ArrayRef<Value> valuesToRepl) const {
100 auto *handler = getInterfaceFor(op);
101 assert(handler && "expected valid dialect handler");
102 handler->handleTerminator(op, valuesToRepl);
103 }
104
processInlinedCallBlocks(Operation * call,iterator_range<Region::iterator> inlinedBlocks) const105 void InlinerInterface::processInlinedCallBlocks(
106 Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
107 auto *handler = getInterfaceFor(call);
108 assert(handler && "expected valid dialect handler");
109 handler->processInlinedCallBlocks(call, inlinedBlocks);
110 }
111
112 /// Utility to check that all of the operations within 'src' can be inlined.
isLegalToInline(InlinerInterface & interface,Region * src,Region * insertRegion,bool shouldCloneInlinedRegion,BlockAndValueMapping & valueMapping)113 static bool isLegalToInline(InlinerInterface &interface, Region *src,
114 Region *insertRegion, bool shouldCloneInlinedRegion,
115 BlockAndValueMapping &valueMapping) {
116 for (auto &block : *src) {
117 for (auto &op : block) {
118 // Check this operation.
119 if (!interface.isLegalToInline(&op, insertRegion,
120 shouldCloneInlinedRegion, valueMapping)) {
121 LLVM_DEBUG({
122 llvm::dbgs() << "* Illegal to inline because of op: ";
123 op.dump();
124 });
125 return false;
126 }
127 // Check any nested regions.
128 if (interface.shouldAnalyzeRecursively(&op) &&
129 llvm::any_of(op.getRegions(), [&](Region ®ion) {
130 return !isLegalToInline(interface, ®ion, insertRegion,
131 shouldCloneInlinedRegion, valueMapping);
132 }))
133 return false;
134 }
135 }
136 return true;
137 }
138
139 //===----------------------------------------------------------------------===//
140 // Inline Methods
141 //===----------------------------------------------------------------------===//
142
143 static LogicalResult
inlineRegionImpl(InlinerInterface & interface,Region * src,Block * inlineBlock,Block::iterator inlinePoint,BlockAndValueMapping & mapper,ValueRange resultsToReplace,TypeRange regionResultTypes,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion,Operation * call=nullptr)144 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
145 Block::iterator inlinePoint, BlockAndValueMapping &mapper,
146 ValueRange resultsToReplace, TypeRange regionResultTypes,
147 Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
148 Operation *call = nullptr) {
149 assert(resultsToReplace.size() == regionResultTypes.size());
150 // We expect the region to have at least one block.
151 if (src->empty())
152 return failure();
153
154 // Check that all of the region arguments have been mapped.
155 auto *srcEntryBlock = &src->front();
156 if (llvm::any_of(srcEntryBlock->getArguments(),
157 [&](BlockArgument arg) { return !mapper.contains(arg); }))
158 return failure();
159
160 // Check that the operations within the source region are valid to inline.
161 Region *insertRegion = inlineBlock->getParent();
162 if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
163 mapper) ||
164 !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
165 mapper))
166 return failure();
167
168 // Check to see if the region is being cloned, or moved inline. In either
169 // case, move the new blocks after the 'insertBlock' to improve IR
170 // readability.
171 Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
172 if (shouldCloneInlinedRegion)
173 src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
174 else
175 insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
176 src->getBlocks(), src->begin(),
177 src->end());
178
179 // Get the range of newly inserted blocks.
180 auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
181 postInsertBlock->getIterator());
182 Block *firstNewBlock = &*newBlocks.begin();
183
184 // Remap the locations of the inlined operations if a valid source location
185 // was provided.
186 if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
187 remapInlinedLocations(newBlocks, *inlineLoc);
188
189 // If the blocks were moved in-place, make sure to remap any necessary
190 // operands.
191 if (!shouldCloneInlinedRegion)
192 remapInlinedOperands(newBlocks, mapper);
193
194 // Process the newly inlined blocks.
195 if (call)
196 interface.processInlinedCallBlocks(call, newBlocks);
197 interface.processInlinedBlocks(newBlocks);
198
199 // Handle the case where only a single block was inlined.
200 if (std::next(newBlocks.begin()) == newBlocks.end()) {
201 // Have the interface handle the terminator of this block.
202 auto *firstBlockTerminator = firstNewBlock->getTerminator();
203 interface.handleTerminator(firstBlockTerminator,
204 llvm::to_vector<6>(resultsToReplace));
205 firstBlockTerminator->erase();
206
207 // Merge the post insert block into the cloned entry block.
208 firstNewBlock->getOperations().splice(firstNewBlock->end(),
209 postInsertBlock->getOperations());
210 postInsertBlock->erase();
211 } else {
212 // Otherwise, there were multiple blocks inlined. Add arguments to the post
213 // insertion block to represent the results to replace.
214 for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) {
215 resultToRepl.value().replaceAllUsesWith(
216 postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()],
217 resultToRepl.value().getLoc()));
218 }
219
220 /// Handle the terminators for each of the new blocks.
221 for (auto &newBlock : newBlocks)
222 interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
223 }
224
225 // Splice the instructions of the inlined entry block into the insert block.
226 inlineBlock->getOperations().splice(inlineBlock->end(),
227 firstNewBlock->getOperations());
228 firstNewBlock->erase();
229 return success();
230 }
231
232 static LogicalResult
inlineRegionImpl(InlinerInterface & interface,Region * src,Block * inlineBlock,Block::iterator inlinePoint,ValueRange inlinedOperands,ValueRange resultsToReplace,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion,Operation * call=nullptr)233 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
234 Block::iterator inlinePoint, ValueRange inlinedOperands,
235 ValueRange resultsToReplace, Optional<Location> inlineLoc,
236 bool shouldCloneInlinedRegion, Operation *call = nullptr) {
237 // We expect the region to have at least one block.
238 if (src->empty())
239 return failure();
240
241 auto *entryBlock = &src->front();
242 if (inlinedOperands.size() != entryBlock->getNumArguments())
243 return failure();
244
245 // Map the provided call operands to the arguments of the region.
246 BlockAndValueMapping mapper;
247 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
248 // Verify that the types of the provided values match the function argument
249 // types.
250 BlockArgument regionArg = entryBlock->getArgument(i);
251 if (inlinedOperands[i].getType() != regionArg.getType())
252 return failure();
253 mapper.map(regionArg, inlinedOperands[i]);
254 }
255
256 // Call into the main region inliner function.
257 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
258 resultsToReplace, resultsToReplace.getTypes(),
259 inlineLoc, shouldCloneInlinedRegion, call);
260 }
261
inlineRegion(InlinerInterface & interface,Region * src,Operation * inlinePoint,BlockAndValueMapping & mapper,ValueRange resultsToReplace,TypeRange regionResultTypes,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)262 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
263 Operation *inlinePoint,
264 BlockAndValueMapping &mapper,
265 ValueRange resultsToReplace,
266 TypeRange regionResultTypes,
267 Optional<Location> inlineLoc,
268 bool shouldCloneInlinedRegion) {
269 return inlineRegion(interface, src, inlinePoint->getBlock(),
270 ++inlinePoint->getIterator(), mapper, resultsToReplace,
271 regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
272 }
273 LogicalResult
inlineRegion(InlinerInterface & interface,Region * src,Block * inlineBlock,Block::iterator inlinePoint,BlockAndValueMapping & mapper,ValueRange resultsToReplace,TypeRange regionResultTypes,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)274 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
275 Block::iterator inlinePoint, BlockAndValueMapping &mapper,
276 ValueRange resultsToReplace, TypeRange regionResultTypes,
277 Optional<Location> inlineLoc,
278 bool shouldCloneInlinedRegion) {
279 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
280 resultsToReplace, regionResultTypes, inlineLoc,
281 shouldCloneInlinedRegion);
282 }
283
inlineRegion(InlinerInterface & interface,Region * src,Operation * inlinePoint,ValueRange inlinedOperands,ValueRange resultsToReplace,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)284 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
285 Operation *inlinePoint,
286 ValueRange inlinedOperands,
287 ValueRange resultsToReplace,
288 Optional<Location> inlineLoc,
289 bool shouldCloneInlinedRegion) {
290 return inlineRegion(interface, src, inlinePoint->getBlock(),
291 ++inlinePoint->getIterator(), inlinedOperands,
292 resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
293 }
294 LogicalResult
inlineRegion(InlinerInterface & interface,Region * src,Block * inlineBlock,Block::iterator inlinePoint,ValueRange inlinedOperands,ValueRange resultsToReplace,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)295 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock,
296 Block::iterator inlinePoint, ValueRange inlinedOperands,
297 ValueRange resultsToReplace, Optional<Location> inlineLoc,
298 bool shouldCloneInlinedRegion) {
299 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
300 inlinedOperands, resultsToReplace, inlineLoc,
301 shouldCloneInlinedRegion);
302 }
303
304 /// Utility function used to generate a cast operation from the given interface,
305 /// or return nullptr if a cast could not be generated.
materializeConversion(const DialectInlinerInterface * interface,SmallVectorImpl<Operation * > & castOps,OpBuilder & castBuilder,Value arg,Type type,Location conversionLoc)306 static Value materializeConversion(const DialectInlinerInterface *interface,
307 SmallVectorImpl<Operation *> &castOps,
308 OpBuilder &castBuilder, Value arg, Type type,
309 Location conversionLoc) {
310 if (!interface)
311 return nullptr;
312
313 // Check to see if the interface for the call can materialize a conversion.
314 Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
315 type, conversionLoc);
316 if (!castOp)
317 return nullptr;
318 castOps.push_back(castOp);
319
320 // Ensure that the generated cast is correct.
321 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
322 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
323 return castOp->getResult(0);
324 }
325
326 /// This function inlines a given region, 'src', of a callable operation,
327 /// 'callable', into the location defined by the given call operation. This
328 /// function returns failure if inlining is not possible, success otherwise. On
329 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
330 /// corresponds to whether the source region should be cloned into the 'call' or
331 /// spliced directly.
inlineCall(InlinerInterface & interface,CallOpInterface call,CallableOpInterface callable,Region * src,bool shouldCloneInlinedRegion)332 LogicalResult mlir::inlineCall(InlinerInterface &interface,
333 CallOpInterface call,
334 CallableOpInterface callable, Region *src,
335 bool shouldCloneInlinedRegion) {
336 // We expect the region to have at least one block.
337 if (src->empty())
338 return failure();
339 auto *entryBlock = &src->front();
340 ArrayRef<Type> callableResultTypes = callable.getCallableResults();
341
342 // Make sure that the number of arguments and results matchup between the call
343 // and the region.
344 SmallVector<Value, 8> callOperands(call.getArgOperands());
345 SmallVector<Value, 8> callResults(call->getResults());
346 if (callOperands.size() != entryBlock->getNumArguments() ||
347 callResults.size() != callableResultTypes.size())
348 return failure();
349
350 // A set of cast operations generated to matchup the signature of the region
351 // with the signature of the call.
352 SmallVector<Operation *, 4> castOps;
353 castOps.reserve(callOperands.size() + callResults.size());
354
355 // Functor used to cleanup generated state on failure.
356 auto cleanupState = [&] {
357 for (auto *op : castOps) {
358 op->getResult(0).replaceAllUsesWith(op->getOperand(0));
359 op->erase();
360 }
361 return failure();
362 };
363
364 // Builder used for any conversion operations that need to be materialized.
365 OpBuilder castBuilder(call);
366 Location castLoc = call.getLoc();
367 const auto *callInterface = interface.getInterfaceFor(call->getDialect());
368
369 // Map the provided call operands to the arguments of the region.
370 BlockAndValueMapping mapper;
371 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
372 BlockArgument regionArg = entryBlock->getArgument(i);
373 Value operand = callOperands[i];
374
375 // If the call operand doesn't match the expected region argument, try to
376 // generate a cast.
377 Type regionArgType = regionArg.getType();
378 if (operand.getType() != regionArgType) {
379 if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
380 operand, regionArgType, castLoc)))
381 return cleanupState();
382 }
383 mapper.map(regionArg, operand);
384 }
385
386 // Ensure that the resultant values of the call match the callable.
387 castBuilder.setInsertionPointAfter(call);
388 for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
389 Value callResult = callResults[i];
390 if (callResult.getType() == callableResultTypes[i])
391 continue;
392
393 // Generate a conversion that will produce the original type, so that the IR
394 // is still valid after the original call gets replaced.
395 Value castResult =
396 materializeConversion(callInterface, castOps, castBuilder, callResult,
397 callResult.getType(), castLoc);
398 if (!castResult)
399 return cleanupState();
400 callResult.replaceAllUsesWith(castResult);
401 castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
402 }
403
404 // Check that it is legal to inline the callable into the call.
405 if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
406 return cleanupState();
407
408 // Attempt to inline the call.
409 if (failed(inlineRegionImpl(interface, src, call->getBlock(),
410 ++call->getIterator(), mapper, callResults,
411 callableResultTypes, call.getLoc(),
412 shouldCloneInlinedRegion, call)))
413 return cleanupState();
414 return success();
415 }
416