1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/BlockAndValueMapping.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Function.h"
14 #include "mlir/IR/Module.h"
15 #include "mlir/Rewrite/PatternApplicator.h"
16 #include "mlir/Transforms/Utils.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/Support/SaveAndRestore.h"
22 #include "llvm/Support/ScopedPrinter.h"
23 
24 using namespace mlir;
25 using namespace mlir::detail;
26 
27 #define DEBUG_TYPE "dialect-conversion"
28 
29 /// Recursively collect all of the operations to convert from within 'region'.
30 /// If 'target' is nonnull, operations that are recursively legal have their
31 /// regions pre-filtered to avoid considering them for legalization.
32 static LogicalResult
33 computeConversionSet(iterator_range<Region::iterator> region,
34                      Location regionLoc, std::vector<Operation *> &toConvert,
35                      ConversionTarget *target = nullptr) {
36   if (llvm::empty(region))
37     return success();
38 
39   // Traverse starting from the entry block.
40   SmallVector<Block *, 16> worklist(1, &*region.begin());
41   DenseSet<Block *> visitedBlocks;
42   visitedBlocks.insert(worklist.front());
43   while (!worklist.empty()) {
44     Block *block = worklist.pop_back_val();
45 
46     // Compute the conversion set of each of the nested operations.
47     for (Operation &op : *block) {
48       toConvert.emplace_back(&op);
49 
50       // Don't check this operation's children for conversion if the operation
51       // is recursively legal.
52       auto legalityInfo = target ? target->isLegal(&op)
53                                  : Optional<ConversionTarget::LegalOpDetails>();
54       if (legalityInfo && legalityInfo->isRecursivelyLegal)
55         continue;
56       for (auto &region : op.getRegions()) {
57         if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
58                                         toConvert, target)))
59           return failure();
60       }
61     }
62 
63     // Recurse to children that haven't been visited.
64     for (Block *succ : block->getSuccessors())
65       if (visitedBlocks.insert(succ).second)
66         worklist.push_back(succ);
67   }
68 
69   // Check that all blocks in the region were visited.
70   if (llvm::any_of(llvm::drop_begin(region, 1),
71                    [&](Block &block) { return !visitedBlocks.count(&block); }))
72     return emitError(regionLoc, "unreachable blocks were not converted");
73   return success();
74 }
75 
76 /// A utility function to log a successful result for the given reason.
77 template <typename... Args>
78 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
79   LLVM_DEBUG({
80     os.unindent();
81     os.startLine() << "} -> SUCCESS";
82     if (!fmt.empty())
83       os.getOStream() << " : "
84                       << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
85     os.getOStream() << "\n";
86   });
87 }
88 
89 /// A utility function to log a failure result for the given reason.
90 template <typename... Args>
91 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
92   LLVM_DEBUG({
93     os.unindent();
94     os.startLine() << "} -> FAILURE : "
95                    << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
96                    << "\n";
97   });
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // ConversionValueMapping
102 //===----------------------------------------------------------------------===//
103 
104 namespace {
105 /// This class wraps a BlockAndValueMapping to provide recursive lookup
106 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
107 struct ConversionValueMapping {
108   /// Lookup a mapped value within the map. If a mapping for the provided value
109   /// does not exist then return the provided value. If `desiredType` is
110   /// non-null, returns the most recently mapped value with that type. If an
111   /// operand of that type does not exist, defaults to normal behavior.
112   Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
113 
114   /// Lookup a mapped value within the map, or return null if a mapping does not
115   /// exist. If a mapping exists, this follows the same behavior of
116   /// `lookupOrDefault`.
117   Value lookupOrNull(Value from) const;
118 
119   /// Map a value to the one provided.
120   void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); }
121 
122   /// Drop the last mapping for the given value.
123   void erase(Value value) { mapping.erase(value); }
124 
125 private:
126   /// Current value mappings.
127   BlockAndValueMapping mapping;
128 };
129 } // end anonymous namespace
130 
131 Value ConversionValueMapping::lookupOrDefault(Value from,
132                                               Type desiredType) const {
133   // If there was no desired type, simply find the leaf value.
134   if (!desiredType) {
135     // If this value had a valid mapping, unmap that value as well in the case
136     // that it was also replaced.
137     while (auto mappedValue = mapping.lookupOrNull(from))
138       from = mappedValue;
139     return from;
140   }
141 
142   // Otherwise, try to find the deepest value that has the desired type.
143   Value desiredValue;
144   do {
145     if (from.getType() == desiredType)
146       desiredValue = from;
147 
148     Value mappedValue = mapping.lookupOrNull(from);
149     if (!mappedValue)
150       break;
151     from = mappedValue;
152   } while (true);
153 
154   // If the desired value was found use it, otherwise default to the leaf value.
155   return desiredValue ? desiredValue : from;
156 }
157 
158 Value ConversionValueMapping::lookupOrNull(Value from) const {
159   Value result = lookupOrDefault(from);
160   return result == from ? nullptr : result;
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // ArgConverter
165 //===----------------------------------------------------------------------===//
166 namespace {
167 /// This class provides a simple interface for converting the types of block
168 /// arguments. This is done by creating a new block that contains the new legal
169 /// types and extracting the block that contains the old illegal types to allow
170 /// for undoing pending rewrites in the case of failure.
171 struct ArgConverter {
172   ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {}
173 
174   /// This structure contains the information pertaining to an argument that has
175   /// been converted.
176   struct ConvertedArgInfo {
177     ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
178                      Value castValue = nullptr)
179         : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
180 
181     /// The start index of in the new argument list that contains arguments that
182     /// replace the original.
183     unsigned newArgIdx;
184 
185     /// The number of arguments that replaced the original argument.
186     unsigned newArgSize;
187 
188     /// The cast value that was created to cast from the new arguments to the
189     /// old. This only used if 'newArgSize' > 1.
190     Value castValue;
191   };
192 
193   /// This structure contains information pertaining to a block that has had its
194   /// signature converted.
195   struct ConvertedBlockInfo {
196     ConvertedBlockInfo(Block *origBlock, TypeConverter &converter)
197         : origBlock(origBlock), converter(&converter) {}
198 
199     /// The original block that was requested to have its signature converted.
200     Block *origBlock;
201 
202     /// The conversion information for each of the arguments. The information is
203     /// None if the argument was dropped during conversion.
204     SmallVector<Optional<ConvertedArgInfo>, 1> argInfo;
205 
206     /// The type converter used to convert the arguments.
207     TypeConverter *converter;
208   };
209 
210   /// Return if the signature of the given block has already been converted.
211   bool hasBeenConverted(Block *block) const {
212     return conversionInfo.count(block) || convertedBlocks.count(block);
213   }
214 
215   /// Set the type converter to use for the given region.
216   void setConverter(Region *region, TypeConverter *typeConverter) {
217     assert(typeConverter && "expected valid type converter");
218     regionToConverter[region] = typeConverter;
219   }
220 
221   /// Return the type converter to use for the given region, or null if there
222   /// isn't one.
223   TypeConverter *getConverter(Region *region) {
224     return regionToConverter.lookup(region);
225   }
226 
227   //===--------------------------------------------------------------------===//
228   // Rewrite Application
229   //===--------------------------------------------------------------------===//
230 
231   /// Erase any rewrites registered for the blocks within the given operation
232   /// which is about to be removed. This merely drops the rewrites without
233   /// undoing them.
234   void notifyOpRemoved(Operation *op);
235 
236   /// Cleanup and undo any generated conversions for the arguments of block.
237   /// This method replaces the new block with the original, reverting the IR to
238   /// its original state.
239   void discardRewrites(Block *block);
240 
241   /// Fully replace uses of the old arguments with the new.
242   void applyRewrites(ConversionValueMapping &mapping);
243 
244   /// Materialize any necessary conversions for converted arguments that have
245   /// live users, using the provided `findLiveUser` to search for a user that
246   /// survives the conversion process.
247   LogicalResult
248   materializeLiveConversions(ConversionValueMapping &mapping,
249                              OpBuilder &builder,
250                              function_ref<Operation *(Value)> findLiveUser);
251 
252   //===--------------------------------------------------------------------===//
253   // Conversion
254   //===--------------------------------------------------------------------===//
255 
256   /// Attempt to convert the signature of the given block, if successful a new
257   /// block is returned containing the new arguments. Returns `block` if it did
258   /// not require conversion.
259   FailureOr<Block *> convertSignature(Block *block, TypeConverter &converter,
260                                       ConversionValueMapping &mapping);
261 
262   /// Apply the given signature conversion on the given block. The new block
263   /// containing the updated signature is returned. If no conversions were
264   /// necessary, e.g. if the block has no arguments, `block` is returned.
265   /// `converter` is used to generate any necessary cast operations that
266   /// translate between the origin argument types and those specified in the
267   /// signature conversion.
268   Block *applySignatureConversion(
269       Block *block, TypeConverter &converter,
270       TypeConverter::SignatureConversion &signatureConversion,
271       ConversionValueMapping &mapping);
272 
273   /// Insert a new conversion into the cache.
274   void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
275 
276   /// A collection of blocks that have had their arguments converted. This is a
277   /// map from the new replacement block, back to the original block.
278   llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
279 
280   /// The set of original blocks that were converted.
281   DenseSet<Block *> convertedBlocks;
282 
283   /// A mapping from valid regions, to those containing the original blocks of a
284   /// conversion.
285   DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
286 
287   /// A mapping of regions to type converters that should be used when
288   /// converting the arguments of blocks within that region.
289   DenseMap<Region *, TypeConverter *> regionToConverter;
290 
291   /// The pattern rewriter to use when materializing conversions.
292   PatternRewriter &rewriter;
293 };
294 } // end anonymous namespace
295 
296 //===----------------------------------------------------------------------===//
297 // Rewrite Application
298 
299 void ArgConverter::notifyOpRemoved(Operation *op) {
300   if (conversionInfo.empty())
301     return;
302 
303   for (Region &region : op->getRegions()) {
304     for (Block &block : region) {
305       // Drop any rewrites from within.
306       for (Operation &nestedOp : block)
307         if (nestedOp.getNumRegions())
308           notifyOpRemoved(&nestedOp);
309 
310       // Check if this block was converted.
311       auto it = conversionInfo.find(&block);
312       if (it == conversionInfo.end())
313         continue;
314 
315       // Drop all uses of the original arguments and delete the original block.
316       Block *origBlock = it->second.origBlock;
317       for (BlockArgument arg : origBlock->getArguments())
318         arg.dropAllUses();
319       conversionInfo.erase(it);
320     }
321   }
322 }
323 
324 void ArgConverter::discardRewrites(Block *block) {
325   auto it = conversionInfo.find(block);
326   if (it == conversionInfo.end())
327     return;
328   Block *origBlock = it->second.origBlock;
329 
330   // Drop all uses of the new block arguments and replace uses of the new block.
331   for (int i = block->getNumArguments() - 1; i >= 0; --i)
332     block->getArgument(i).dropAllUses();
333   block->replaceAllUsesWith(origBlock);
334 
335   // Move the operations back the original block and the delete the new block.
336   origBlock->getOperations().splice(origBlock->end(), block->getOperations());
337   origBlock->moveBefore(block);
338   block->erase();
339 
340   convertedBlocks.erase(origBlock);
341   conversionInfo.erase(it);
342 }
343 
344 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
345   for (auto &info : conversionInfo) {
346     ConvertedBlockInfo &blockInfo = info.second;
347     Block *origBlock = blockInfo.origBlock;
348 
349     // Process the remapping for each of the original arguments.
350     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
351       Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
352       BlockArgument origArg = origBlock->getArgument(i);
353 
354       // Handle the case of a 1->0 value mapping.
355       if (!argInfo) {
356         if (Value newArg = mapping.lookupOrNull(origArg))
357           origArg.replaceAllUsesWith(newArg);
358         continue;
359       }
360 
361       // Otherwise this is a 1->1+ value mapping.
362       Value castValue = argInfo->castValue;
363       assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
364 
365       // If the argument is still used, replace it with the generated cast.
366       if (!origArg.use_empty())
367         origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue));
368 
369       // If all users of the cast were removed, we can drop it. Otherwise, keep
370       // the operation alive and let the user handle any remaining usages.
371       if (castValue.use_empty() && castValue.getDefiningOp())
372         castValue.getDefiningOp()->erase();
373     }
374   }
375 }
376 
377 LogicalResult ArgConverter::materializeLiveConversions(
378     ConversionValueMapping &mapping, OpBuilder &builder,
379     function_ref<Operation *(Value)> findLiveUser) {
380   for (auto &info : conversionInfo) {
381     Block *newBlock = info.first;
382     ConvertedBlockInfo &blockInfo = info.second;
383     Block *origBlock = blockInfo.origBlock;
384 
385     // Process the remapping for each of the original arguments.
386     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
387       // FIXME: We should run the below checks even if the type conversion was
388       // 1->N, but a lot of existing lowering rely on the block argument being
389       // blindly replaced. Those usages should be updated, and this if should be
390       // removed.
391       if (blockInfo.argInfo[i])
392         continue;
393 
394       // If the type of this argument changed and the argument is still live, we
395       // need to materialize a conversion.
396       BlockArgument origArg = origBlock->getArgument(i);
397       auto argReplacementValue = mapping.lookupOrDefault(origArg);
398       bool isDroppedArg = argReplacementValue == origArg;
399       if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg)
400         continue;
401       Operation *liveUser = findLiveUser(origArg);
402       if (!liveUser)
403         continue;
404 
405       if (OpResult result = argReplacementValue.dyn_cast<OpResult>())
406         rewriter.setInsertionPointAfter(result.getOwner());
407       else
408         rewriter.setInsertionPointToStart(newBlock);
409       Value newArg = blockInfo.converter->materializeSourceConversion(
410           rewriter, origArg.getLoc(), origArg.getType(),
411           isDroppedArg ? ValueRange() : ValueRange(argReplacementValue));
412       if (!newArg) {
413         InFlightDiagnostic diag =
414             emitError(origArg.getLoc())
415             << "failed to materialize conversion for block argument #" << i
416             << " that remained live after conversion, type was "
417             << origArg.getType();
418         if (!isDroppedArg)
419           diag << ", with target type " << argReplacementValue.getType();
420         diag.attachNote(liveUser->getLoc())
421             << "see existing live user here: " << *liveUser;
422         return failure();
423       }
424       mapping.map(origArg, newArg);
425     }
426   }
427   return success();
428 }
429 
430 //===----------------------------------------------------------------------===//
431 // Conversion
432 
433 FailureOr<Block *>
434 ArgConverter::convertSignature(Block *block, TypeConverter &converter,
435                                ConversionValueMapping &mapping) {
436   // Check if the block was already converted. If the block is detached,
437   // conservatively assume it is going to be deleted.
438   if (hasBeenConverted(block) || !block->getParent())
439     return block;
440 
441   // Try to convert the signature for the block with the provided converter.
442   if (auto conversion = converter.convertBlockSignature(block))
443     return applySignatureConversion(block, converter, *conversion, mapping);
444   return failure();
445 }
446 
447 Block *ArgConverter::applySignatureConversion(
448     Block *block, TypeConverter &converter,
449     TypeConverter::SignatureConversion &signatureConversion,
450     ConversionValueMapping &mapping) {
451   // If no arguments are being changed or added, there is nothing to do.
452   unsigned origArgCount = block->getNumArguments();
453   auto convertedTypes = signatureConversion.getConvertedTypes();
454   if (origArgCount == 0 && convertedTypes.empty())
455     return block;
456 
457   // Split the block at the beginning to get a new block to use for the updated
458   // signature.
459   Block *newBlock = block->splitBlock(block->begin());
460   block->replaceAllUsesWith(newBlock);
461 
462   SmallVector<Value, 4> newArgRange(newBlock->addArguments(convertedTypes));
463   ArrayRef<Value> newArgs(newArgRange);
464 
465   // Remap each of the original arguments as determined by the signature
466   // conversion.
467   ConvertedBlockInfo info(block, converter);
468   info.argInfo.resize(origArgCount);
469 
470   OpBuilder::InsertionGuard guard(rewriter);
471   rewriter.setInsertionPointToStart(newBlock);
472   for (unsigned i = 0; i != origArgCount; ++i) {
473     auto inputMap = signatureConversion.getInputMapping(i);
474     if (!inputMap)
475       continue;
476     BlockArgument origArg = block->getArgument(i);
477 
478     // If inputMap->replacementValue is not nullptr, then the argument is
479     // dropped and a replacement value is provided to be the remappedValue.
480     if (inputMap->replacementValue) {
481       assert(inputMap->size == 0 &&
482              "invalid to provide a replacement value when the argument isn't "
483              "dropped");
484       mapping.map(origArg, inputMap->replacementValue);
485       continue;
486     }
487 
488     // Otherwise, this is a 1->1+ mapping. Call into the provided type converter
489     // to pack the new values. For 1->1 mappings, if there is no materialization
490     // provided, use the argument directly instead.
491     auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
492     Value newArg = converter.materializeArgumentConversion(
493         rewriter, origArg.getLoc(), origArg.getType(), replArgs);
494     if (!newArg) {
495       assert(replArgs.size() == 1 &&
496              "couldn't materialize the result of 1->N conversion");
497       newArg = replArgs.front();
498     }
499     mapping.map(origArg, newArg);
500     info.argInfo[i] =
501         ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
502   }
503 
504   // Remove the original block from the region and return the new one.
505   insertConversion(newBlock, std::move(info));
506   return newBlock;
507 }
508 
509 void ArgConverter::insertConversion(Block *newBlock,
510                                     ConvertedBlockInfo &&info) {
511   // Get a region to insert the old block.
512   Region *region = newBlock->getParent();
513   std::unique_ptr<Region> &mappedRegion = regionMapping[region];
514   if (!mappedRegion)
515     mappedRegion = std::make_unique<Region>(region->getParentOp());
516 
517   // Move the original block to the mapped region and emplace the conversion.
518   mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
519                                    info.origBlock->getIterator());
520   convertedBlocks.insert(info.origBlock);
521   conversionInfo.insert({newBlock, std::move(info)});
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // Rewriter and Translation State
526 //===----------------------------------------------------------------------===//
527 namespace {
528 /// This class contains a snapshot of the current conversion rewriter state.
529 /// This is useful when saving and undoing a set of rewrites.
530 struct RewriterState {
531   RewriterState(unsigned numCreatedOps, unsigned numReplacements,
532                 unsigned numArgReplacements, unsigned numBlockActions,
533                 unsigned numIgnoredOperations, unsigned numRootUpdates)
534       : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
535         numArgReplacements(numArgReplacements),
536         numBlockActions(numBlockActions),
537         numIgnoredOperations(numIgnoredOperations),
538         numRootUpdates(numRootUpdates) {}
539 
540   /// The current number of created operations.
541   unsigned numCreatedOps;
542 
543   /// The current number of replacements queued.
544   unsigned numReplacements;
545 
546   /// The current number of argument replacements queued.
547   unsigned numArgReplacements;
548 
549   /// The current number of block actions performed.
550   unsigned numBlockActions;
551 
552   /// The current number of ignored operations.
553   unsigned numIgnoredOperations;
554 
555   /// The current number of operations that were updated in place.
556   unsigned numRootUpdates;
557 };
558 
559 /// The state of an operation that was updated by a pattern in-place. This
560 /// contains all of the necessary information to reconstruct an operation that
561 /// was updated in place.
562 class OperationTransactionState {
563 public:
564   OperationTransactionState() = default;
565   OperationTransactionState(Operation *op)
566       : op(op), loc(op->getLoc()), attrs(op->getMutableAttrDict()),
567         operands(op->operand_begin(), op->operand_end()),
568         successors(op->successor_begin(), op->successor_end()) {}
569 
570   /// Discard the transaction state and reset the state of the original
571   /// operation.
572   void resetOperation() const {
573     op->setLoc(loc);
574     op->setAttrs(attrs);
575     op->setOperands(operands);
576     for (auto it : llvm::enumerate(successors))
577       op->setSuccessor(it.value(), it.index());
578   }
579 
580   /// Return the original operation of this state.
581   Operation *getOperation() const { return op; }
582 
583 private:
584   Operation *op;
585   LocationAttr loc;
586   MutableDictionaryAttr attrs;
587   SmallVector<Value, 8> operands;
588   SmallVector<Block *, 2> successors;
589 };
590 
591 /// This class represents one requested operation replacement via 'replaceOp' or
592 /// 'eraseOp`.
593 struct OpReplacement {
594   OpReplacement() = default;
595   OpReplacement(TypeConverter *converter) : converter(converter) {}
596 
597   /// An optional type converter that can be used to materialize conversions
598   /// between the new and old values if necessary.
599   TypeConverter *converter = nullptr;
600 };
601 
602 /// The kind of the block action performed during the rewrite.  Actions can be
603 /// undone if the conversion fails.
604 enum class BlockActionKind {
605   Create,
606   Erase,
607   Merge,
608   Move,
609   Split,
610   TypeConversion
611 };
612 
613 /// Original position of the given block in its parent region. During undo
614 /// actions, the block needs to be placed after `insertAfterBlock`.
615 struct BlockPosition {
616   Region *region;
617   Block *insertAfterBlock;
618 };
619 
620 /// Information needed to undo the merge actions.
621 /// - the source block, and
622 /// - the Operation that was the last operation in the dest block before the
623 ///   merge (could be null if the dest block was empty).
624 struct MergeInfo {
625   Block *sourceBlock;
626   Operation *destBlockLastInst;
627 };
628 
629 /// The storage class for an undoable block action (one of BlockActionKind),
630 /// contains the information necessary to undo this action.
631 struct BlockAction {
632   static BlockAction getCreate(Block *block) {
633     return {BlockActionKind::Create, block, {}};
634   }
635   static BlockAction getErase(Block *block, BlockPosition originalPosition) {
636     return {BlockActionKind::Erase, block, {originalPosition}};
637   }
638   static BlockAction getMerge(Block *block, Block *sourceBlock) {
639     BlockAction action{BlockActionKind::Merge, block, {}};
640     action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
641     return action;
642   }
643   static BlockAction getMove(Block *block, BlockPosition originalPosition) {
644     return {BlockActionKind::Move, block, {originalPosition}};
645   }
646   static BlockAction getSplit(Block *block, Block *originalBlock) {
647     BlockAction action{BlockActionKind::Split, block, {}};
648     action.originalBlock = originalBlock;
649     return action;
650   }
651   static BlockAction getTypeConversion(Block *block) {
652     return BlockAction{BlockActionKind::TypeConversion, block, {}};
653   }
654 
655   // The action kind.
656   BlockActionKind kind;
657 
658   // A pointer to the block that was created by the action.
659   Block *block;
660 
661   union {
662     // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
663     // contains a pointer to the region that originally contained the block as
664     // well as the position of the block in that region.
665     BlockPosition originalPosition;
666     // In use if kind == BlockActionKind::Split and contains a pointer to the
667     // block that was split into two parts.
668     Block *originalBlock;
669     // In use if kind == BlockActionKind::Merge, and contains the information
670     // needed to undo the merge.
671     MergeInfo mergeInfo;
672   };
673 };
674 } // end anonymous namespace
675 
676 //===----------------------------------------------------------------------===//
677 // ConversionPatternRewriterImpl
678 //===----------------------------------------------------------------------===//
679 namespace mlir {
680 namespace detail {
681 struct ConversionPatternRewriterImpl {
682   ConversionPatternRewriterImpl(PatternRewriter &rewriter)
683       : argConverter(rewriter) {}
684 
685   /// Cleanup and destroy any generated rewrite operations. This method is
686   /// invoked when the conversion process fails.
687   void discardRewrites();
688 
689   /// Apply all requested operation rewrites. This method is invoked when the
690   /// conversion process succeeds.
691   void applyRewrites();
692 
693   //===--------------------------------------------------------------------===//
694   // State Management
695   //===--------------------------------------------------------------------===//
696 
697   /// Return the current state of the rewriter.
698   RewriterState getCurrentState();
699 
700   /// Reset the state of the rewriter to a previously saved point.
701   void resetState(RewriterState state);
702 
703   /// Erase any blocks that were unlinked from their regions and stored in block
704   /// actions.
705   void eraseDanglingBlocks();
706 
707   /// Undo the block actions (motions, splits) one by one in reverse order until
708   /// "numActionsToKeep" actions remains.
709   void undoBlockActions(unsigned numActionsToKeep = 0);
710 
711   /// Remap the given operands to those with potentially different types. The
712   /// provided type converter is used to ensure that the remapped types are
713   /// legal. Returns success if the operands could be remapped, failure
714   /// otherwise.
715   LogicalResult remapValues(Location loc, PatternRewriter &rewriter,
716                             TypeConverter *converter,
717                             Operation::operand_range operands,
718                             SmallVectorImpl<Value> &remapped);
719 
720   /// Returns true if the given operation is ignored, and does not need to be
721   /// converted.
722   bool isOpIgnored(Operation *op) const;
723 
724   /// Recursively marks the nested operations under 'op' as ignored. This
725   /// removes them from being considered for legalization.
726   void markNestedOpsIgnored(Operation *op);
727 
728   //===--------------------------------------------------------------------===//
729   // Type Conversion
730   //===--------------------------------------------------------------------===//
731 
732   /// Convert the signature of the given block.
733   FailureOr<Block *> convertBlockSignature(
734       Block *block, TypeConverter &converter,
735       TypeConverter::SignatureConversion *conversion = nullptr);
736 
737   /// Apply a signature conversion on the given region.
738   Block *
739   applySignatureConversion(Region *region,
740                            TypeConverter::SignatureConversion &conversion);
741 
742   /// Convert the types of block arguments within the given region.
743   FailureOr<Block *>
744   convertRegionTypes(Region *region, TypeConverter &converter,
745                      TypeConverter::SignatureConversion *entryConversion);
746 
747   //===--------------------------------------------------------------------===//
748   // Rewriter Notification Hooks
749   //===--------------------------------------------------------------------===//
750 
751   /// PatternRewriter hook for replacing the results of an operation.
752   void notifyOpReplaced(Operation *op, ValueRange newValues);
753 
754   /// Notifies that a block is about to be erased.
755   void notifyBlockIsBeingErased(Block *block);
756 
757   /// Notifies that a block was created.
758   void notifyCreatedBlock(Block *block);
759 
760   /// Notifies that a block was split.
761   void notifySplitBlock(Block *block, Block *continuation);
762 
763   /// Notifies that `block` is being merged with `srcBlock`.
764   void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
765 
766   /// Notifies that the blocks of a region are about to be moved.
767   void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
768                                         Region::iterator before);
769 
770   /// Notifies that the blocks of a region were cloned into another.
771   void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
772                                    Location origRegionLoc);
773 
774   /// Notifies that a pattern match failed for the given reason.
775   LogicalResult
776   notifyMatchFailure(Location loc,
777                      function_ref<void(Diagnostic &)> reasonCallback);
778 
779   //===--------------------------------------------------------------------===//
780   // State
781   //===--------------------------------------------------------------------===//
782 
783   // Mapping between replaced values that differ in type. This happens when
784   // replacing a value with one of a different type.
785   ConversionValueMapping mapping;
786 
787   /// Utility used to convert block arguments.
788   ArgConverter argConverter;
789 
790   /// Ordered vector of all of the newly created operations during conversion.
791   std::vector<Operation *> createdOps;
792 
793   /// Ordered map of requested operation replacements.
794   llvm::MapVector<Operation *, OpReplacement> replacements;
795 
796   /// Ordered vector of any requested block argument replacements.
797   SmallVector<BlockArgument, 4> argReplacements;
798 
799   /// Ordered list of block operations (creations, splits, motions).
800   SmallVector<BlockAction, 4> blockActions;
801 
802   /// A set of operations that should no longer be considered for legalization,
803   /// but were not directly replace/erased/etc. by a pattern. These are
804   /// generally child operations of other operations who were
805   /// replaced/erased/etc. This is not meant to be an exhaustive list of all
806   /// operations, but the minimal set that can be used to detect if a given
807   /// operation should be `ignored`. For example, we may add the operations that
808   /// define non-empty regions to the set, but not any of the others. This
809   /// simplifies the amount of memory needed as we can query if the parent
810   /// operation was ignored.
811   llvm::SetVector<Operation *> ignoredOps;
812 
813   /// A transaction state for each of operations that were updated in-place.
814   SmallVector<OperationTransactionState, 4> rootUpdates;
815 
816   /// A vector of indices into `replacements` of operations that were replaced
817   /// with values with different result types than the original operation, e.g.
818   /// 1->N conversion of some kind.
819   SmallVector<unsigned, 4> operationsWithChangedResults;
820 
821   /// A default type converter, used when block conversions do not have one
822   /// explicitly provided.
823   TypeConverter defaultTypeConverter;
824 
825   /// The current conversion pattern that is being rewritten, or nullptr if
826   /// called from outside of a conversion pattern rewrite.
827   const ConversionPattern *currentConversionPattern = nullptr;
828 
829 #ifndef NDEBUG
830   /// A set of operations that have pending updates. This tracking isn't
831   /// strictly necessary, and is thus only active during debug builds for extra
832   /// verification.
833   SmallPtrSet<Operation *, 1> pendingRootUpdates;
834 
835   /// A logger used to emit diagnostics during the conversion process.
836   llvm::ScopedPrinter logger{llvm::dbgs()};
837 #endif
838 };
839 } // end namespace detail
840 } // end namespace mlir
841 
842 /// Detach any operations nested in the given operation from their parent
843 /// blocks, and erase the given operation. This can be used when the nested
844 /// operations are scheduled for erasure themselves, so deleting the regions of
845 /// the given operation together with their content would result in double-free.
846 /// This happens, for example, when rolling back op creation in the reverse
847 /// order and if the nested ops were created before the parent op. This function
848 /// does not need to collect nested ops recursively because it is expected to
849 /// also be called for each nested op when it is about to be deleted.
850 static void detachNestedAndErase(Operation *op) {
851   for (Region &region : op->getRegions()) {
852     for (Block &block : region.getBlocks()) {
853       while (!block.getOperations().empty())
854         block.getOperations().remove(block.getOperations().begin());
855       block.dropAllDefinedValueUses();
856     }
857   }
858   op->erase();
859 }
860 
861 void ConversionPatternRewriterImpl::discardRewrites() {
862   // Reset any operations that were updated in place.
863   for (auto &state : rootUpdates)
864     state.resetOperation();
865 
866   undoBlockActions();
867 
868   // Remove any newly created ops.
869   for (auto *op : llvm::reverse(createdOps))
870     detachNestedAndErase(op);
871 }
872 
873 void ConversionPatternRewriterImpl::applyRewrites() {
874   // Apply all of the rewrites replacements requested during conversion.
875   for (auto &repl : replacements) {
876     for (OpResult result : repl.first->getResults())
877       if (Value newValue = mapping.lookupOrNull(result))
878         result.replaceAllUsesWith(newValue);
879 
880     // If this operation defines any regions, drop any pending argument
881     // rewrites.
882     if (repl.first->getNumRegions())
883       argConverter.notifyOpRemoved(repl.first);
884   }
885 
886   // Apply all of the requested argument replacements.
887   for (BlockArgument arg : argReplacements) {
888     Value repl = mapping.lookupOrDefault(arg);
889     if (repl.isa<BlockArgument>()) {
890       arg.replaceAllUsesWith(repl);
891       continue;
892     }
893 
894     // If the replacement value is an operation, we check to make sure that we
895     // don't replace uses that are within the parent operation of the
896     // replacement value.
897     Operation *replOp = repl.cast<OpResult>().getOwner();
898     Block *replBlock = replOp->getBlock();
899     arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
900       Operation *user = operand.getOwner();
901       return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
902     });
903   }
904 
905   // In a second pass, erase all of the replaced operations in reverse. This
906   // allows processing nested operations before their parent region is
907   // destroyed.
908   for (auto &repl : llvm::reverse(replacements))
909     repl.first->erase();
910 
911   argConverter.applyRewrites(mapping);
912 
913   // Now that the ops have been erased, also erase dangling blocks.
914   eraseDanglingBlocks();
915 }
916 
917 //===----------------------------------------------------------------------===//
918 // State Management
919 
920 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
921   return RewriterState(createdOps.size(), replacements.size(),
922                        argReplacements.size(), blockActions.size(),
923                        ignoredOps.size(), rootUpdates.size());
924 }
925 
926 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
927   // Reset any operations that were updated in place.
928   for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
929     rootUpdates[i].resetOperation();
930   rootUpdates.resize(state.numRootUpdates);
931 
932   // Reset any replaced arguments.
933   for (BlockArgument replacedArg :
934        llvm::drop_begin(argReplacements, state.numArgReplacements))
935     mapping.erase(replacedArg);
936   argReplacements.resize(state.numArgReplacements);
937 
938   // Undo any block actions.
939   undoBlockActions(state.numBlockActions);
940 
941   // Reset any replaced operations and undo any saved mappings.
942   for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
943     for (auto result : repl.first->getResults())
944       mapping.erase(result);
945   while (replacements.size() != state.numReplacements)
946     replacements.pop_back();
947 
948   // Pop all of the newly created operations.
949   while (createdOps.size() != state.numCreatedOps) {
950     detachNestedAndErase(createdOps.back());
951     createdOps.pop_back();
952   }
953 
954   // Pop all of the recorded ignored operations that are no longer valid.
955   while (ignoredOps.size() != state.numIgnoredOperations)
956     ignoredOps.pop_back();
957 
958   // Reset operations with changed results.
959   while (!operationsWithChangedResults.empty() &&
960          operationsWithChangedResults.back() >= state.numReplacements)
961     operationsWithChangedResults.pop_back();
962 }
963 
964 void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
965   for (auto &action : blockActions)
966     if (action.kind == BlockActionKind::Erase)
967       delete action.block;
968 }
969 
970 void ConversionPatternRewriterImpl::undoBlockActions(
971     unsigned numActionsToKeep) {
972   for (auto &action :
973        llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
974     switch (action.kind) {
975     // Delete the created block.
976     case BlockActionKind::Create: {
977       // Unlink all of the operations within this block, they will be deleted
978       // separately.
979       auto &blockOps = action.block->getOperations();
980       while (!blockOps.empty())
981         blockOps.remove(blockOps.begin());
982       action.block->dropAllDefinedValueUses();
983       action.block->erase();
984       break;
985     }
986     // Put the block (owned by action) back into its original position.
987     case BlockActionKind::Erase: {
988       auto &blockList = action.originalPosition.region->getBlocks();
989       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
990       blockList.insert((insertAfterBlock
991                             ? std::next(Region::iterator(insertAfterBlock))
992                             : blockList.end()),
993                        action.block);
994       break;
995     }
996     // Split the block at the position which was originally the end of the
997     // destination block (owned by action), and put the instructions back into
998     // the block used before the merge.
999     case BlockActionKind::Merge: {
1000       Block *sourceBlock = action.mergeInfo.sourceBlock;
1001       Block::iterator splitPoint =
1002           (action.mergeInfo.destBlockLastInst
1003                ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
1004                : action.block->begin());
1005       sourceBlock->getOperations().splice(sourceBlock->begin(),
1006                                           action.block->getOperations(),
1007                                           splitPoint, action.block->end());
1008       break;
1009     }
1010     // Move the block back to its original position.
1011     case BlockActionKind::Move: {
1012       Region *originalRegion = action.originalPosition.region;
1013       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1014       originalRegion->getBlocks().splice(
1015           (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
1016                             : originalRegion->end()),
1017           action.block->getParent()->getBlocks(), action.block);
1018       break;
1019     }
1020     // Merge back the block that was split out.
1021     case BlockActionKind::Split: {
1022       action.originalBlock->getOperations().splice(
1023           action.originalBlock->end(), action.block->getOperations());
1024       action.block->dropAllDefinedValueUses();
1025       action.block->erase();
1026       break;
1027     }
1028     // Undo the type conversion.
1029     case BlockActionKind::TypeConversion: {
1030       argConverter.discardRewrites(action.block);
1031       break;
1032     }
1033     }
1034   }
1035   blockActions.resize(numActionsToKeep);
1036 }
1037 
1038 LogicalResult ConversionPatternRewriterImpl::remapValues(
1039     Location loc, PatternRewriter &rewriter, TypeConverter *converter,
1040     Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
1041   remapped.reserve(llvm::size(operands));
1042 
1043   SmallVector<Type, 1> legalTypes;
1044   for (auto it : llvm::enumerate(operands)) {
1045     Value operand = it.value();
1046     Type origType = operand.getType();
1047 
1048     // If a converter was provided, get the desired legal types for this
1049     // operand.
1050     Type desiredType;
1051     if (converter) {
1052       // If there is no legal conversion, fail to match this pattern.
1053       legalTypes.clear();
1054       if (failed(converter->convertType(origType, legalTypes))) {
1055         return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1056           diag << "unable to convert type for operand #" << it.index()
1057                << ", type was " << origType;
1058         });
1059       }
1060       // TODO: There currently isn't any mechanism to do 1->N type conversion
1061       // via the PatternRewriter replacement API, so for now we just ignore it.
1062       if (legalTypes.size() == 1)
1063         desiredType = legalTypes.front();
1064     } else {
1065       // TODO: What we should do here is just set `desiredType` to `origType`
1066       // and then handle the necessary type conversions after the conversion
1067       // process has finished. Unfortunately a lot of patterns currently rely on
1068       // receiving the new operands even if the types change, so we keep the
1069       // original behavior here for now until all of the patterns relying on
1070       // this get updated.
1071     }
1072     Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1073 
1074     // Handle the case where the conversion was 1->1 and the new operand type
1075     // isn't legal.
1076     Type newOperandType = newOperand.getType();
1077     if (converter && desiredType && newOperandType != desiredType) {
1078       // Attempt to materialize a conversion for this new value.
1079       newOperand = converter->materializeTargetConversion(
1080           rewriter, loc, desiredType, newOperand);
1081       if (!newOperand) {
1082         return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1083           diag << "unable to materialize a conversion for "
1084                   "operand #"
1085                << it.index() << ", from " << newOperandType << " to "
1086                << desiredType;
1087         });
1088       }
1089     }
1090     remapped.push_back(newOperand);
1091   }
1092   return success();
1093 }
1094 
1095 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1096   // Check to see if this operation was replaced or its parent ignored.
1097   return replacements.count(op) || ignoredOps.count(op->getParentOp());
1098 }
1099 
1100 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
1101   // Walk this operation and collect nested operations that define non-empty
1102   // regions. We mark such operations as 'ignored' so that we know we don't have
1103   // to convert them, or their nested ops.
1104   if (op->getNumRegions() == 0)
1105     return;
1106   op->walk([&](Operation *op) {
1107     if (llvm::any_of(op->getRegions(),
1108                      [](Region &region) { return !region.empty(); }))
1109       ignoredOps.insert(op);
1110   });
1111 }
1112 
1113 //===----------------------------------------------------------------------===//
1114 // Type Conversion
1115 
1116 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1117     Block *block, TypeConverter &converter,
1118     TypeConverter::SignatureConversion *conversion) {
1119   FailureOr<Block *> result =
1120       conversion ? argConverter.applySignatureConversion(block, converter,
1121                                                          *conversion, mapping)
1122                  : argConverter.convertSignature(block, converter, mapping);
1123   if (Block *newBlock = result.getValue()) {
1124     if (newBlock != block)
1125       blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1126   }
1127   return result;
1128 }
1129 
1130 Block *ConversionPatternRewriterImpl::applySignatureConversion(
1131     Region *region, TypeConverter::SignatureConversion &conversion) {
1132   if (!region->empty()) {
1133     return *convertBlockSignature(&region->front(), defaultTypeConverter,
1134                                   &conversion);
1135   }
1136   return nullptr;
1137 }
1138 
1139 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1140     Region *region, TypeConverter &converter,
1141     TypeConverter::SignatureConversion *entryConversion) {
1142   argConverter.setConverter(region, &converter);
1143   if (region->empty())
1144     return nullptr;
1145 
1146   // Convert the arguments of each block within the region.
1147   FailureOr<Block *> newEntry =
1148       convertBlockSignature(&region->front(), converter, entryConversion);
1149   for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
1150     if (failed(convertBlockSignature(&block, converter)))
1151       return failure();
1152   return newEntry;
1153 }
1154 
1155 //===----------------------------------------------------------------------===//
1156 // Rewriter Notification Hooks
1157 
1158 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1159                                                      ValueRange newValues) {
1160   assert(newValues.size() == op->getNumResults());
1161   assert(!replacements.count(op) && "operation was already replaced");
1162 
1163   // Track if any of the results changed, e.g. erased and replaced with null.
1164   bool resultChanged = false;
1165 
1166   // Create mappings for each of the new result values.
1167   Value newValue, result;
1168   for (auto it : llvm::zip(newValues, op->getResults())) {
1169     std::tie(newValue, result) = it;
1170     if (!newValue) {
1171       resultChanged = true;
1172       continue;
1173     }
1174     // Remap, and check for any result type changes.
1175     mapping.map(result, newValue);
1176     resultChanged |= (newValue.getType() != result.getType());
1177   }
1178   if (resultChanged)
1179     operationsWithChangedResults.push_back(replacements.size());
1180 
1181   // Record the requested operation replacement.
1182   TypeConverter *converter = nullptr;
1183   if (currentConversionPattern)
1184     converter = currentConversionPattern->getTypeConverter();
1185   replacements.insert(std::make_pair(op, OpReplacement(converter)));
1186 
1187   // Mark this operation as recursively ignored so that we don't need to
1188   // convert any nested operations.
1189   markNestedOpsIgnored(op);
1190 }
1191 
1192 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
1193   Region *region = block->getParent();
1194   Block *origPrevBlock = block->getPrevNode();
1195   blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1196 }
1197 
1198 void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
1199   blockActions.push_back(BlockAction::getCreate(block));
1200 }
1201 
1202 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
1203                                                      Block *continuation) {
1204   blockActions.push_back(BlockAction::getSplit(continuation, block));
1205 }
1206 
1207 void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block,
1208                                                             Block *srcBlock) {
1209   blockActions.push_back(BlockAction::getMerge(block, srcBlock));
1210 }
1211 
1212 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
1213     Region &region, Region &parent, Region::iterator before) {
1214   Block *origPrevBlock = nullptr;
1215   for (auto &pair : llvm::enumerate(region)) {
1216     Block &block = pair.value();
1217     blockActions.push_back(
1218         BlockAction::getMove(&block, {&region, origPrevBlock}));
1219     origPrevBlock = &block;
1220   }
1221 }
1222 
1223 void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
1224     iterator_range<Region::iterator> &blocks, Location origRegionLoc) {
1225   for (Block &block : blocks)
1226     blockActions.push_back(BlockAction::getCreate(&block));
1227 
1228   // Compute the conversion set for the inlined region.
1229   auto result = computeConversionSet(blocks, origRegionLoc, createdOps);
1230 
1231   // This original region has already had its conversion set computed, so there
1232   // shouldn't be any new failures.
1233   (void)result;
1234   assert(succeeded(result) && "expected region to have no unreachable blocks");
1235 }
1236 
1237 LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
1238     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1239   LLVM_DEBUG({
1240     Diagnostic diag(loc, DiagnosticSeverity::Remark);
1241     reasonCallback(diag);
1242     logger.startLine() << "** Failure : " << diag.str() << "\n";
1243   });
1244   return failure();
1245 }
1246 
1247 //===----------------------------------------------------------------------===//
1248 // ConversionPatternRewriter
1249 //===----------------------------------------------------------------------===//
1250 
1251 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
1252     : PatternRewriter(ctx),
1253       impl(new detail::ConversionPatternRewriterImpl(*this)) {}
1254 ConversionPatternRewriter::~ConversionPatternRewriter() {}
1255 
1256 /// PatternRewriter hook for replacing the results of an operation.
1257 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
1258   LLVM_DEBUG({
1259     impl->logger.startLine()
1260         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1261   });
1262   impl->notifyOpReplaced(op, newValues);
1263 }
1264 
1265 /// PatternRewriter hook for erasing a dead operation. The uses of this
1266 /// operation *must* be made dead by the end of the conversion process,
1267 /// otherwise an assert will be issued.
1268 void ConversionPatternRewriter::eraseOp(Operation *op) {
1269   LLVM_DEBUG({
1270     impl->logger.startLine()
1271         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
1272   });
1273   SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1274   impl->notifyOpReplaced(op, nullRepls);
1275 }
1276 
1277 void ConversionPatternRewriter::eraseBlock(Block *block) {
1278   impl->notifyBlockIsBeingErased(block);
1279 
1280   // Mark all ops for erasure.
1281   for (Operation &op : *block)
1282     eraseOp(&op);
1283 
1284   // Unlink the block from its parent region. The block is kept in the block
1285   // action and will be actually destroyed when rewrites are applied. This
1286   // allows us to keep the operations in the block live and undo the removal by
1287   // re-inserting the block.
1288   block->getParent()->getBlocks().remove(block);
1289 }
1290 
1291 Block *ConversionPatternRewriter::applySignatureConversion(
1292     Region *region, TypeConverter::SignatureConversion &conversion) {
1293   return impl->applySignatureConversion(region, conversion);
1294 }
1295 
1296 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1297     Region *region, TypeConverter &converter,
1298     TypeConverter::SignatureConversion *entryConversion) {
1299   return impl->convertRegionTypes(region, converter, entryConversion);
1300 }
1301 
1302 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1303                                                            Value to) {
1304   LLVM_DEBUG({
1305     Operation *parentOp = from.getOwner()->getParentOp();
1306     impl->logger.startLine() << "** Replace Argument : '" << from
1307                              << "'(in region of '" << parentOp->getName()
1308                              << "'(" << from.getOwner()->getParentOp() << ")\n";
1309   });
1310   impl->argReplacements.push_back(from);
1311   impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1312 }
1313 
1314 /// Return the converted value that replaces 'key'. Return 'key' if there is
1315 /// no such a converted value.
1316 Value ConversionPatternRewriter::getRemappedValue(Value key) {
1317   return impl->mapping.lookupOrDefault(key);
1318 }
1319 
1320 /// PatternRewriter hook for creating a new block with the given arguments.
1321 void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
1322   impl->notifyCreatedBlock(block);
1323 }
1324 
1325 /// PatternRewriter hook for splitting a block into two parts.
1326 Block *ConversionPatternRewriter::splitBlock(Block *block,
1327                                              Block::iterator before) {
1328   auto *continuation = PatternRewriter::splitBlock(block, before);
1329   impl->notifySplitBlock(block, continuation);
1330   return continuation;
1331 }
1332 
1333 /// PatternRewriter hook for merging a block into another.
1334 void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
1335                                             ValueRange argValues) {
1336   impl->notifyBlocksBeingMerged(dest, source);
1337   assert(llvm::all_of(source->getPredecessors(),
1338                       [dest](Block *succ) { return succ == dest; }) &&
1339          "expected 'source' to have no predecessors or only 'dest'");
1340   assert(argValues.size() == source->getNumArguments() &&
1341          "incorrect # of argument replacement values");
1342   for (auto it : llvm::zip(source->getArguments(), argValues))
1343     replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1344   dest->getOperations().splice(dest->end(), source->getOperations());
1345   eraseBlock(source);
1346 }
1347 
1348 /// PatternRewriter hook for moving blocks out of a region.
1349 void ConversionPatternRewriter::inlineRegionBefore(Region &region,
1350                                                    Region &parent,
1351                                                    Region::iterator before) {
1352   impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1353   PatternRewriter::inlineRegionBefore(region, parent, before);
1354 }
1355 
1356 /// PatternRewriter hook for cloning blocks of one region into another.
1357 void ConversionPatternRewriter::cloneRegionBefore(
1358     Region &region, Region &parent, Region::iterator before,
1359     BlockAndValueMapping &mapping) {
1360   if (region.empty())
1361     return;
1362   PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1363 
1364   // Collect the range of the cloned blocks.
1365   auto clonedBeginIt = mapping.lookup(&region.front())->getIterator();
1366   auto clonedBlocks = llvm::make_range(clonedBeginIt, before);
1367   impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
1368 }
1369 
1370 /// PatternRewriter hook for creating a new operation.
1371 void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
1372   LLVM_DEBUG({
1373     impl->logger.startLine()
1374         << "** Insert  : '" << op->getName() << "'(" << op << ")\n";
1375   });
1376   impl->createdOps.push_back(op);
1377 }
1378 
1379 /// PatternRewriter hook for updating the root operation in-place.
1380 void ConversionPatternRewriter::startRootUpdate(Operation *op) {
1381 #ifndef NDEBUG
1382   impl->pendingRootUpdates.insert(op);
1383 #endif
1384   impl->rootUpdates.emplace_back(op);
1385 }
1386 
1387 /// PatternRewriter hook for updating the root operation in-place.
1388 void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
1389   // There is nothing to do here, we only need to track the operation at the
1390   // start of the update.
1391 #ifndef NDEBUG
1392   assert(impl->pendingRootUpdates.erase(op) &&
1393          "operation did not have a pending in-place update");
1394 #endif
1395 }
1396 
1397 /// PatternRewriter hook for updating the root operation in-place.
1398 void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
1399 #ifndef NDEBUG
1400   assert(impl->pendingRootUpdates.erase(op) &&
1401          "operation did not have a pending in-place update");
1402 #endif
1403   // Erase the last update for this operation.
1404   auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
1405   auto &rootUpdates = impl->rootUpdates;
1406   auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1407   rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
1408 }
1409 
1410 /// PatternRewriter hook for notifying match failure reasons.
1411 LogicalResult ConversionPatternRewriter::notifyMatchFailure(
1412     Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
1413   return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
1414 }
1415 
1416 /// Return a reference to the internal implementation.
1417 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
1418   return *impl;
1419 }
1420 
1421 //===----------------------------------------------------------------------===//
1422 // ConversionPattern
1423 //===----------------------------------------------------------------------===//
1424 
1425 /// Attempt to match and rewrite the IR root at the specified operation.
1426 LogicalResult
1427 ConversionPattern::matchAndRewrite(Operation *op,
1428                                    PatternRewriter &rewriter) const {
1429   auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1430   auto &rewriterImpl = dialectRewriter.getImpl();
1431 
1432   // Track the current conversion pattern in the rewriter.
1433   assert(!rewriterImpl.currentConversionPattern &&
1434          "already inside of a pattern rewrite");
1435   llvm::SaveAndRestore<const ConversionPattern *> currentPatternGuard(
1436       rewriterImpl.currentConversionPattern, this);
1437 
1438   // Remap the operands of the operation.
1439   SmallVector<Value, 4> operands;
1440   if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter,
1441                                       getTypeConverter(), op->getOperands(),
1442                                       operands))) {
1443     return failure();
1444   }
1445   return matchAndRewrite(op, operands, dialectRewriter);
1446 }
1447 
1448 //===----------------------------------------------------------------------===//
1449 // OperationLegalizer
1450 //===----------------------------------------------------------------------===//
1451 
1452 namespace {
1453 /// A set of rewrite patterns that can be used to legalize a given operation.
1454 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1455 
1456 /// This class defines a recursive operation legalizer.
1457 class OperationLegalizer {
1458 public:
1459   using LegalizationAction = ConversionTarget::LegalizationAction;
1460 
1461   OperationLegalizer(ConversionTarget &targetInfo,
1462                      const FrozenRewritePatternList &patterns);
1463 
1464   /// Returns true if the given operation is known to be illegal on the target.
1465   bool isIllegal(Operation *op) const;
1466 
1467   /// Attempt to legalize the given operation. Returns success if the operation
1468   /// was legalized, failure otherwise.
1469   LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1470 
1471   /// Returns the conversion target in use by the legalizer.
1472   ConversionTarget &getTarget() { return target; }
1473 
1474 private:
1475   /// Attempt to legalize the given operation by folding it.
1476   LogicalResult legalizeWithFold(Operation *op,
1477                                  ConversionPatternRewriter &rewriter);
1478 
1479   /// Attempt to legalize the given operation by applying a pattern. Returns
1480   /// success if the operation was legalized, failure otherwise.
1481   LogicalResult legalizeWithPattern(Operation *op,
1482                                     ConversionPatternRewriter &rewriter);
1483 
1484   /// Return true if the given pattern may be applied to the given operation,
1485   /// false otherwise.
1486   bool canApplyPattern(Operation *op, const Pattern &pattern,
1487                        ConversionPatternRewriter &rewriter);
1488 
1489   /// Legalize the resultant IR after successfully applying the given pattern.
1490   LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1491                                       ConversionPatternRewriter &rewriter,
1492                                       RewriterState &curState);
1493 
1494   /// Legalizes the actions registered during the execution of a pattern.
1495   LogicalResult legalizePatternBlockActions(Operation *op,
1496                                             ConversionPatternRewriter &rewriter,
1497                                             ConversionPatternRewriterImpl &impl,
1498                                             RewriterState &state,
1499                                             RewriterState &newState);
1500   LogicalResult legalizePatternCreatedOperations(
1501       ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1502       RewriterState &state, RewriterState &newState);
1503   LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1504                                            ConversionPatternRewriterImpl &impl,
1505                                            RewriterState &state,
1506                                            RewriterState &newState);
1507 
1508   //===--------------------------------------------------------------------===//
1509   // Cost Model
1510   //===--------------------------------------------------------------------===//
1511 
1512   /// Build an optimistic legalization graph given the provided patterns. This
1513   /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1514   /// patterns for operations that are not directly legal, but may be
1515   /// transitively legal for the current target given the provided patterns.
1516   void buildLegalizationGraph(
1517       LegalizationPatterns &anyOpLegalizerPatterns,
1518       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1519 
1520   /// Compute the benefit of each node within the computed legalization graph.
1521   /// This orders the patterns within 'legalizerPatterns' based upon two
1522   /// criteria:
1523   ///  1) Prefer patterns that have the lowest legalization depth, i.e.
1524   ///     represent the more direct mapping to the target.
1525   ///  2) When comparing patterns with the same legalization depth, prefer the
1526   ///     pattern with the highest PatternBenefit. This allows for users to
1527   ///     prefer specific legalizations over others.
1528   void computeLegalizationGraphBenefit(
1529       LegalizationPatterns &anyOpLegalizerPatterns,
1530       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1531 
1532   /// Compute the legalization depth when legalizing an operation of the given
1533   /// type.
1534   unsigned computeOpLegalizationDepth(
1535       OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1536       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1537 
1538   /// Apply the conversion cost model to the given set of patterns, and return
1539   /// the smallest legalization depth of any of the patterns. See
1540   /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1541   unsigned applyCostModelToPatterns(
1542       LegalizationPatterns &patterns,
1543       DenseMap<OperationName, unsigned> &minOpPatternDepth,
1544       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1545 
1546   /// The current set of patterns that have been applied.
1547   SmallPtrSet<const Pattern *, 8> appliedPatterns;
1548 
1549   /// The legalization information provided by the target.
1550   ConversionTarget &target;
1551 
1552   /// The pattern applicator to use for conversions.
1553   PatternApplicator applicator;
1554 };
1555 } // namespace
1556 
1557 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
1558                                        const FrozenRewritePatternList &patterns)
1559     : target(targetInfo), applicator(patterns) {
1560   // The set of patterns that can be applied to illegal operations to transform
1561   // them into legal ones.
1562   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
1563   LegalizationPatterns anyOpLegalizerPatterns;
1564 
1565   buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1566   computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1567 }
1568 
1569 bool OperationLegalizer::isIllegal(Operation *op) const {
1570   // Check if the target explicitly marked this operation as illegal.
1571   return target.getOpAction(op->getName()) == LegalizationAction::Illegal;
1572 }
1573 
1574 LogicalResult
1575 OperationLegalizer::legalize(Operation *op,
1576                              ConversionPatternRewriter &rewriter) {
1577 #ifndef NDEBUG
1578   const char *logLineComment =
1579       "//===-------------------------------------------===//\n";
1580 
1581   auto &rewriterImpl = rewriter.getImpl();
1582 #endif
1583   LLVM_DEBUG({
1584     auto &os = rewriterImpl.logger;
1585     os.getOStream() << "\n";
1586     os.startLine() << logLineComment;
1587     os.startLine() << "Legalizing operation : '" << op->getName() << "'(" << op
1588                    << ") {\n";
1589     os.indent();
1590 
1591     // If the operation has no regions, just print it here.
1592     if (op->getNumRegions() == 0) {
1593       op->print(os.startLine(), OpPrintingFlags().printGenericOpForm());
1594       os.getOStream() << "\n\n";
1595     }
1596   });
1597 
1598   // Check if this operation is legal on the target.
1599   if (auto legalityInfo = target.isLegal(op)) {
1600     LLVM_DEBUG({
1601       logSuccess(
1602           rewriterImpl.logger, "operation marked legal by the target{0}",
1603           legalityInfo->isRecursivelyLegal
1604               ? "; NOTE: operation is recursively legal; skipping internals"
1605               : "");
1606       rewriterImpl.logger.startLine() << logLineComment;
1607     });
1608 
1609     // If this operation is recursively legal, mark its children as ignored so
1610     // that we don't consider them for legalization.
1611     if (legalityInfo->isRecursivelyLegal)
1612       rewriter.getImpl().markNestedOpsIgnored(op);
1613     return success();
1614   }
1615 
1616   // Check to see if the operation is ignored and doesn't need to be converted.
1617   if (rewriter.getImpl().isOpIgnored(op)) {
1618     LLVM_DEBUG({
1619       logSuccess(rewriterImpl.logger,
1620                  "operation marked 'ignored' during conversion");
1621       rewriterImpl.logger.startLine() << logLineComment;
1622     });
1623     return success();
1624   }
1625 
1626   // If the operation isn't legal, try to fold it in-place.
1627   // TODO: Should we always try to do this, even if the op is
1628   // already legal?
1629   if (succeeded(legalizeWithFold(op, rewriter))) {
1630     LLVM_DEBUG({
1631       logSuccess(rewriterImpl.logger, "operation was folded");
1632       rewriterImpl.logger.startLine() << logLineComment;
1633     });
1634     return success();
1635   }
1636 
1637   // Otherwise, we need to apply a legalization pattern to this operation.
1638   if (succeeded(legalizeWithPattern(op, rewriter))) {
1639     LLVM_DEBUG({
1640       logSuccess(rewriterImpl.logger, "");
1641       rewriterImpl.logger.startLine() << logLineComment;
1642     });
1643     return success();
1644   }
1645 
1646   LLVM_DEBUG({
1647     logFailure(rewriterImpl.logger, "no matched legalization pattern");
1648     rewriterImpl.logger.startLine() << logLineComment;
1649   });
1650   return failure();
1651 }
1652 
1653 LogicalResult
1654 OperationLegalizer::legalizeWithFold(Operation *op,
1655                                      ConversionPatternRewriter &rewriter) {
1656   auto &rewriterImpl = rewriter.getImpl();
1657   RewriterState curState = rewriterImpl.getCurrentState();
1658 
1659   LLVM_DEBUG({
1660     rewriterImpl.logger.startLine() << "* Fold {\n";
1661     rewriterImpl.logger.indent();
1662   });
1663 
1664   // Try to fold the operation.
1665   SmallVector<Value, 2> replacementValues;
1666   rewriter.setInsertionPoint(op);
1667   if (failed(rewriter.tryFold(op, replacementValues))) {
1668     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1669     return failure();
1670   }
1671 
1672   // Insert a replacement for 'op' with the folded replacement values.
1673   rewriter.replaceOp(op, replacementValues);
1674 
1675   // Recursively legalize any new constant operations.
1676   for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1677        i != e; ++i) {
1678     Operation *cstOp = rewriterImpl.createdOps[i];
1679     if (failed(legalize(cstOp, rewriter))) {
1680       LLVM_DEBUG(logFailure(rewriterImpl.logger,
1681                             "generated constant '{0}' was illegal",
1682                             cstOp->getName()));
1683       rewriterImpl.resetState(curState);
1684       return failure();
1685     }
1686   }
1687 
1688   LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1689   return success();
1690 }
1691 
1692 LogicalResult
1693 OperationLegalizer::legalizeWithPattern(Operation *op,
1694                                         ConversionPatternRewriter &rewriter) {
1695   auto &rewriterImpl = rewriter.getImpl();
1696 
1697   // Functor that returns if the given pattern may be applied.
1698   auto canApply = [&](const Pattern &pattern) {
1699     return canApplyPattern(op, pattern, rewriter);
1700   };
1701 
1702   // Functor that cleans up the rewriter state after a pattern failed to match.
1703   RewriterState curState = rewriterImpl.getCurrentState();
1704   auto onFailure = [&](const Pattern &pattern) {
1705     LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
1706     rewriterImpl.resetState(curState);
1707     appliedPatterns.erase(&pattern);
1708   };
1709 
1710   // Functor that performs additional legalization when a pattern is
1711   // successfully applied.
1712   auto onSuccess = [&](const Pattern &pattern) {
1713     auto result = legalizePatternResult(op, pattern, rewriter, curState);
1714     appliedPatterns.erase(&pattern);
1715     if (failed(result))
1716       rewriterImpl.resetState(curState);
1717     return result;
1718   };
1719 
1720   // Try to match and rewrite a pattern on this operation.
1721   return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1722                                     onSuccess);
1723 }
1724 
1725 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1726                                          ConversionPatternRewriter &rewriter) {
1727   LLVM_DEBUG({
1728     auto &os = rewriter.getImpl().logger;
1729     os.getOStream() << "\n";
1730     os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1731     llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs());
1732     os.getOStream() << ")' {\n";
1733     os.indent();
1734   });
1735 
1736   // Ensure that we don't cycle by not allowing the same pattern to be
1737   // applied twice in the same recursion stack if it is not known to be safe.
1738   if (!pattern.hasBoundedRewriteRecursion() &&
1739       !appliedPatterns.insert(&pattern).second) {
1740     LLVM_DEBUG(
1741         logFailure(rewriter.getImpl().logger, "pattern was already applied"));
1742     return false;
1743   }
1744   return true;
1745 }
1746 
1747 LogicalResult
1748 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
1749                                           ConversionPatternRewriter &rewriter,
1750                                           RewriterState &curState) {
1751   auto &impl = rewriter.getImpl();
1752 
1753 #ifndef NDEBUG
1754   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
1755 #endif
1756 
1757   // Check that the root was either replaced or updated in place.
1758   auto replacedRoot = [&] {
1759     return llvm::any_of(
1760         llvm::drop_begin(impl.replacements, curState.numReplacements),
1761         [op](auto &it) { return it.first == op; });
1762   };
1763   auto updatedRootInPlace = [&] {
1764     return llvm::any_of(
1765         llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
1766         [op](auto &state) { return state.getOperation() == op; });
1767   };
1768   (void)replacedRoot;
1769   (void)updatedRootInPlace;
1770   assert((replacedRoot() || updatedRootInPlace()) &&
1771          "expected pattern to replace the root operation");
1772 
1773   // Legalize each of the actions registered during application.
1774   RewriterState newState = impl.getCurrentState();
1775   if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
1776                                          newState)) ||
1777       failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
1778       failed(legalizePatternCreatedOperations(rewriter, impl, curState,
1779                                               newState))) {
1780     return failure();
1781   }
1782 
1783   LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
1784   return success();
1785 }
1786 
1787 LogicalResult OperationLegalizer::legalizePatternBlockActions(
1788     Operation *op, ConversionPatternRewriter &rewriter,
1789     ConversionPatternRewriterImpl &impl, RewriterState &state,
1790     RewriterState &newState) {
1791   SmallPtrSet<Operation *, 16> operationsToIgnore;
1792 
1793   // If the pattern moved or created any blocks, make sure the types of block
1794   // arguments get legalized.
1795   for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
1796        ++i) {
1797     auto &action = impl.blockActions[i];
1798     if (action.kind == BlockActionKind::TypeConversion ||
1799         action.kind == BlockActionKind::Erase)
1800       continue;
1801     // Only check blocks outside of the current operation.
1802     Operation *parentOp = action.block->getParentOp();
1803     if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
1804       continue;
1805 
1806     // If the region of the block has a type converter, try to convert the block
1807     // directly.
1808     if (auto *converter =
1809             impl.argConverter.getConverter(action.block->getParent())) {
1810       if (failed(impl.convertBlockSignature(action.block, *converter))) {
1811         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
1812                                            "block"));
1813         return failure();
1814       }
1815       continue;
1816     }
1817 
1818     // Otherwise, check that this operation isn't one generated by this pattern.
1819     // This is because we will attempt to legalize the parent operation, and
1820     // blocks in regions created by this pattern will already be legalized later
1821     // on. If we haven't built the set yet, build it now.
1822     if (operationsToIgnore.empty()) {
1823       auto createdOps = ArrayRef<Operation *>(impl.createdOps)
1824                             .drop_front(state.numCreatedOps);
1825       operationsToIgnore.insert(createdOps.begin(), createdOps.end());
1826     }
1827 
1828     // If this operation should be considered for re-legalization, try it.
1829     if (operationsToIgnore.insert(parentOp).second &&
1830         failed(legalize(parentOp, rewriter))) {
1831       LLVM_DEBUG(logFailure(
1832           impl.logger, "operation '{0}'({1}) became illegal after block action",
1833           parentOp->getName(), parentOp));
1834       return failure();
1835     }
1836   }
1837   return success();
1838 }
1839 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
1840     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1841     RewriterState &state, RewriterState &newState) {
1842   for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
1843     Operation *op = impl.createdOps[i];
1844     if (failed(legalize(op, rewriter))) {
1845       LLVM_DEBUG(logFailure(impl.logger,
1846                             "generated operation '{0}'({1}) was illegal",
1847                             op->getName(), op));
1848       return failure();
1849     }
1850   }
1851   return success();
1852 }
1853 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
1854     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1855     RewriterState &state, RewriterState &newState) {
1856   for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
1857     Operation *op = impl.rootUpdates[i].getOperation();
1858     if (failed(legalize(op, rewriter))) {
1859       LLVM_DEBUG(logFailure(impl.logger,
1860                             "operation updated in-place '{0}' was illegal",
1861                             op->getName()));
1862       return failure();
1863     }
1864   }
1865   return success();
1866 }
1867 
1868 //===----------------------------------------------------------------------===//
1869 // Cost Model
1870 
1871 void OperationLegalizer::buildLegalizationGraph(
1872     LegalizationPatterns &anyOpLegalizerPatterns,
1873     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1874   // A mapping between an operation and a set of operations that can be used to
1875   // generate it.
1876   DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
1877   // A mapping between an operation and any currently invalid patterns it has.
1878   DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
1879   // A worklist of patterns to consider for legality.
1880   llvm::SetVector<const Pattern *> patternWorklist;
1881 
1882   // Build the mapping from operations to the parent ops that may generate them.
1883   applicator.walkAllPatterns([&](const Pattern &pattern) {
1884     Optional<OperationName> root = pattern.getRootKind();
1885 
1886     // If the pattern has no specific root, we can't analyze the relationship
1887     // between the root op and generated operations. Given that, add all such
1888     // patterns to the legalization set.
1889     if (!root) {
1890       anyOpLegalizerPatterns.push_back(&pattern);
1891       return;
1892     }
1893 
1894     // Skip operations that are always known to be legal.
1895     if (target.getOpAction(*root) == LegalizationAction::Legal)
1896       return;
1897 
1898     // Add this pattern to the invalid set for the root op and record this root
1899     // as a parent for any generated operations.
1900     invalidPatterns[*root].insert(&pattern);
1901     for (auto op : pattern.getGeneratedOps())
1902       parentOps[op].insert(*root);
1903 
1904     // Add this pattern to the worklist.
1905     patternWorklist.insert(&pattern);
1906   });
1907 
1908   // If there are any patterns that don't have a specific root kind, we can't
1909   // make direct assumptions about what operations will never be legalized.
1910   // Note: Technically we could, but it would require an analysis that may
1911   // recurse into itself. It would be better to perform this kind of filtering
1912   // at a higher level than here anyways.
1913   if (!anyOpLegalizerPatterns.empty()) {
1914     for (const Pattern *pattern : patternWorklist)
1915       legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1916     return;
1917   }
1918 
1919   while (!patternWorklist.empty()) {
1920     auto *pattern = patternWorklist.pop_back_val();
1921 
1922     // Check to see if any of the generated operations are invalid.
1923     if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
1924           Optional<LegalizationAction> action = target.getOpAction(op);
1925           return !legalizerPatterns.count(op) &&
1926                  (!action || action == LegalizationAction::Illegal);
1927         }))
1928       continue;
1929 
1930     // Otherwise, if all of the generated operation are valid, this op is now
1931     // legal so add all of the child patterns to the worklist.
1932     legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1933     invalidPatterns[*pattern->getRootKind()].erase(pattern);
1934 
1935     // Add any invalid patterns of the parent operations to see if they have now
1936     // become legal.
1937     for (auto op : parentOps[*pattern->getRootKind()])
1938       patternWorklist.set_union(invalidPatterns[op]);
1939   }
1940 }
1941 
1942 void OperationLegalizer::computeLegalizationGraphBenefit(
1943     LegalizationPatterns &anyOpLegalizerPatterns,
1944     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1945   // The smallest pattern depth, when legalizing an operation.
1946   DenseMap<OperationName, unsigned> minOpPatternDepth;
1947 
1948   // For each operation that is transitively legal, compute a cost for it.
1949   for (auto &opIt : legalizerPatterns)
1950     if (!minOpPatternDepth.count(opIt.first))
1951       computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
1952                                  legalizerPatterns);
1953 
1954   // Apply the cost model to the patterns that can match any operation. Those
1955   // with a specific operation type are already resolved when computing the op
1956   // legalization depth.
1957   if (!anyOpLegalizerPatterns.empty())
1958     applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
1959                              legalizerPatterns);
1960 
1961   // Apply a cost model to the pattern applicator. We order patterns first by
1962   // depth then benefit. `legalizerPatterns` contains per-op patterns by
1963   // decreasing benefit.
1964   applicator.applyCostModel([&](const Pattern &pattern) {
1965     ArrayRef<const Pattern *> orderedPatternList;
1966     if (Optional<OperationName> rootName = pattern.getRootKind())
1967       orderedPatternList = legalizerPatterns[*rootName];
1968     else
1969       orderedPatternList = anyOpLegalizerPatterns;
1970 
1971     // If the pattern is not found, then it was removed and cannot be matched.
1972     auto it = llvm::find(orderedPatternList, &pattern);
1973     if (it == orderedPatternList.end())
1974       return PatternBenefit::impossibleToMatch();
1975 
1976     // Patterns found earlier in the list have higher benefit.
1977     return PatternBenefit(std::distance(it, orderedPatternList.end()));
1978   });
1979 }
1980 
1981 unsigned OperationLegalizer::computeOpLegalizationDepth(
1982     OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1983     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1984   // Check for existing depth.
1985   auto depthIt = minOpPatternDepth.find(op);
1986   if (depthIt != minOpPatternDepth.end())
1987     return depthIt->second;
1988 
1989   // If a mapping for this operation does not exist, then this operation
1990   // is always legal. Return 0 as the depth for a directly legal operation.
1991   auto opPatternsIt = legalizerPatterns.find(op);
1992   if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
1993     return 0u;
1994 
1995   // Record this initial depth in case we encounter this op again when
1996   // recursively computing the depth.
1997   minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
1998 
1999   // Apply the cost model to the operation patterns, and update the minimum
2000   // depth.
2001   unsigned minDepth = applyCostModelToPatterns(
2002       opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2003   minOpPatternDepth[op] = minDepth;
2004   return minDepth;
2005 }
2006 
2007 unsigned OperationLegalizer::applyCostModelToPatterns(
2008     LegalizationPatterns &patterns,
2009     DenseMap<OperationName, unsigned> &minOpPatternDepth,
2010     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2011   unsigned minDepth = std::numeric_limits<unsigned>::max();
2012 
2013   // Compute the depth for each pattern within the set.
2014   SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2015   patternsByDepth.reserve(patterns.size());
2016   for (const Pattern *pattern : patterns) {
2017     unsigned depth = 0;
2018     for (auto generatedOp : pattern->getGeneratedOps()) {
2019       unsigned generatedOpDepth = computeOpLegalizationDepth(
2020           generatedOp, minOpPatternDepth, legalizerPatterns);
2021       depth = std::max(depth, generatedOpDepth + 1);
2022     }
2023     patternsByDepth.emplace_back(pattern, depth);
2024 
2025     // Update the minimum depth of the pattern list.
2026     minDepth = std::min(minDepth, depth);
2027   }
2028 
2029   // If the operation only has one legalization pattern, there is no need to
2030   // sort them.
2031   if (patternsByDepth.size() == 1)
2032     return minDepth;
2033 
2034   // Sort the patterns by those likely to be the most beneficial.
2035   llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
2036                        [](const std::pair<const Pattern *, unsigned> *lhs,
2037                           const std::pair<const Pattern *, unsigned> *rhs) {
2038                          // First sort by the smaller pattern legalization
2039                          // depth.
2040                          if (lhs->second != rhs->second)
2041                            return llvm::array_pod_sort_comparator<unsigned>(
2042                                &lhs->second, &rhs->second);
2043 
2044                          // Then sort by the larger pattern benefit.
2045                          auto lhsBenefit = lhs->first->getBenefit();
2046                          auto rhsBenefit = rhs->first->getBenefit();
2047                          return llvm::array_pod_sort_comparator<PatternBenefit>(
2048                              &rhsBenefit, &lhsBenefit);
2049                        });
2050 
2051   // Update the legalization pattern to use the new sorted list.
2052   patterns.clear();
2053   for (auto &patternIt : patternsByDepth)
2054     patterns.push_back(patternIt.first);
2055   return minDepth;
2056 }
2057 
2058 //===----------------------------------------------------------------------===//
2059 // OperationConverter
2060 //===----------------------------------------------------------------------===//
2061 namespace {
2062 enum OpConversionMode {
2063   // In this mode, the conversion will ignore failed conversions to allow
2064   // illegal operations to co-exist in the IR.
2065   Partial,
2066 
2067   // In this mode, all operations must be legal for the given target for the
2068   // conversion to succeed.
2069   Full,
2070 
2071   // In this mode, operations are analyzed for legality. No actual rewrites are
2072   // applied to the operations on success.
2073   Analysis,
2074 };
2075 
2076 // This class converts operations to a given conversion target via a set of
2077 // rewrite patterns. The conversion behaves differently depending on the
2078 // conversion mode.
2079 struct OperationConverter {
2080   explicit OperationConverter(ConversionTarget &target,
2081                               const FrozenRewritePatternList &patterns,
2082                               OpConversionMode mode,
2083                               DenseSet<Operation *> *trackedOps = nullptr)
2084       : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2085 
2086   /// Converts the given operations to the conversion target.
2087   LogicalResult convertOperations(ArrayRef<Operation *> ops);
2088 
2089 private:
2090   /// Converts an operation with the given rewriter.
2091   LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2092 
2093   /// This method is called after the conversion process to legalize any
2094   /// remaining artifacts and complete the conversion.
2095   LogicalResult finalize(ConversionPatternRewriter &rewriter);
2096 
2097   /// Legalize the types of converted block arguments.
2098   LogicalResult
2099   legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2100                                  ConversionPatternRewriterImpl &rewriterImpl);
2101 
2102   /// Legalize an operation result that was marked as "erased".
2103   LogicalResult
2104   legalizeErasedResult(Operation *op, OpResult result,
2105                        ConversionPatternRewriterImpl &rewriterImpl);
2106 
2107   /// Legalize an operation result that was replaced with a value of a different
2108   /// type.
2109   LogicalResult
2110   legalizeChangedResultType(Operation *op, OpResult result, Value newValue,
2111                             TypeConverter *replConverter,
2112                             ConversionPatternRewriter &rewriter,
2113                             ConversionPatternRewriterImpl &rewriterImpl);
2114 
2115   /// The legalizer to use when converting operations.
2116   OperationLegalizer opLegalizer;
2117 
2118   /// The conversion mode to use when legalizing operations.
2119   OpConversionMode mode;
2120 
2121   /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2122   /// this is populated with ops found to be legalizable to the target.
2123   /// When mode == OpConversionMode::Partial, this is populated with ops found
2124   /// *not* to be legalizable to the target.
2125   DenseSet<Operation *> *trackedOps;
2126 };
2127 } // end anonymous namespace
2128 
2129 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2130                                           Operation *op) {
2131   // Legalize the given operation.
2132   if (failed(opLegalizer.legalize(op, rewriter))) {
2133     // Handle the case of a failed conversion for each of the different modes.
2134     // Full conversions expect all operations to be converted.
2135     if (mode == OpConversionMode::Full)
2136       return op->emitError()
2137              << "failed to legalize operation '" << op->getName() << "'";
2138     // Partial conversions allow conversions to fail iff the operation was not
2139     // explicitly marked as illegal. If the user provided a nonlegalizableOps
2140     // set, non-legalizable ops are included.
2141     if (mode == OpConversionMode::Partial) {
2142       if (opLegalizer.isIllegal(op))
2143         return op->emitError()
2144                << "failed to legalize operation '" << op->getName()
2145                << "' that was explicitly marked illegal";
2146       if (trackedOps)
2147         trackedOps->insert(op);
2148     }
2149   } else if (mode == OpConversionMode::Analysis) {
2150     // Analysis conversions don't fail if any operations fail to legalize,
2151     // they are only interested in the operations that were successfully
2152     // legalized.
2153     trackedOps->insert(op);
2154   }
2155   return success();
2156 }
2157 
2158 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2159   if (ops.empty())
2160     return success();
2161   ConversionTarget &target = opLegalizer.getTarget();
2162 
2163   // Compute the set of operations and blocks to convert.
2164   std::vector<Operation *> toConvert;
2165   for (auto *op : ops) {
2166     toConvert.emplace_back(op);
2167     for (auto &region : op->getRegions())
2168       if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
2169                                       toConvert, &target)))
2170         return failure();
2171   }
2172 
2173   // Convert each operation and discard rewrites on failure.
2174   ConversionPatternRewriter rewriter(ops.front()->getContext());
2175   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2176   for (auto *op : toConvert)
2177     if (failed(convert(rewriter, op)))
2178       return rewriterImpl.discardRewrites(), failure();
2179 
2180   // Now that all of the operations have been converted, finalize the conversion
2181   // process to ensure any lingering conversion artifacts are cleaned up and
2182   // legalized.
2183   if (failed(finalize(rewriter)))
2184     return rewriterImpl.discardRewrites(), failure();
2185 
2186   // After a successful conversion, apply rewrites if this is not an analysis
2187   // conversion.
2188   if (mode == OpConversionMode::Analysis)
2189     rewriterImpl.discardRewrites();
2190   else
2191     rewriterImpl.applyRewrites();
2192   return success();
2193 }
2194 
2195 LogicalResult
2196 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2197   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2198 
2199   // Legalize converted block arguments.
2200   if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2201     return failure();
2202 
2203   // Process requested operation replacements.
2204   for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
2205        i != e; ++i) {
2206     unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
2207     auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
2208     for (OpResult result : repl.first->getResults()) {
2209       Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2210 
2211       // If the operation result was replaced with null, all of the uses of this
2212       // value should be replaced.
2213       if (!newValue) {
2214         if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2215           return failure();
2216         continue;
2217       }
2218 
2219       // Otherwise, check to see if the type of the result changed.
2220       if (result.getType() == newValue.getType())
2221         continue;
2222 
2223       // Legalize this result.
2224       rewriter.setInsertionPoint(repl.first);
2225       if (failed(legalizeChangedResultType(repl.first, result, newValue,
2226                                            repl.second.converter, rewriter,
2227                                            rewriterImpl)))
2228         return failure();
2229 
2230       // Update the end iterator for this loop in the case it was updated
2231       // when legalizing generated conversion operations.
2232       e = rewriterImpl.operationsWithChangedResults.size();
2233     }
2234   }
2235   return success();
2236 }
2237 
2238 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2239     ConversionPatternRewriter &rewriter,
2240     ConversionPatternRewriterImpl &rewriterImpl) {
2241   // Functor used to check if all users of a value will be dead after
2242   // conversion.
2243   auto findLiveUser = [&](Value val) {
2244     auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2245       return rewriterImpl.isOpIgnored(user);
2246     });
2247     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2248   };
2249 
2250   // Materialize any necessary conversions for converted block arguments that
2251   // are still live.
2252   size_t numCreatedOps = rewriterImpl.createdOps.size();
2253   if (failed(rewriterImpl.argConverter.materializeLiveConversions(
2254           rewriterImpl.mapping, rewriter, findLiveUser)))
2255     return failure();
2256 
2257   // Legalize any newly created operations during argument materialization.
2258   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2259     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2260       return rewriterImpl.createdOps[i]->emitError()
2261              << "failed to legalize conversion operation generated for block "
2262                 "argument that remained live after conversion";
2263     }
2264   }
2265   return success();
2266 }
2267 
2268 LogicalResult OperationConverter::legalizeErasedResult(
2269     Operation *op, OpResult result,
2270     ConversionPatternRewriterImpl &rewriterImpl) {
2271   // If the operation result was replaced with null, all of the uses of this
2272   // value should be replaced.
2273   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2274     return rewriterImpl.isOpIgnored(user);
2275   });
2276   if (liveUserIt != result.user_end()) {
2277     InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2278                               << op->getName() << "' marked as erased";
2279     diag.attachNote(liveUserIt->getLoc())
2280         << "found live user of result #" << result.getResultNumber() << ": "
2281         << *liveUserIt;
2282     return failure();
2283   }
2284   return success();
2285 }
2286 
2287 LogicalResult OperationConverter::legalizeChangedResultType(
2288     Operation *op, OpResult result, Value newValue,
2289     TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2290     ConversionPatternRewriterImpl &rewriterImpl) {
2291   // Walk the users of this value to see if there are any live users that
2292   // weren't replaced during conversion.
2293   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2294     return rewriterImpl.isOpIgnored(user);
2295   });
2296   if (liveUserIt == result.user_end())
2297     return success();
2298 
2299   // If the replacement has a type converter, attempt to materialize a
2300   // conversion back to the original type.
2301   if (!replConverter) {
2302     // TODO: We should emit an error here, similarly to the case where the
2303     // result is replaced with null. Unfortunately a lot of existing
2304     // patterns rely on this behavior, so until those patterns are updated
2305     // we keep the legacy behavior here of just forwarding the new value.
2306     return success();
2307   }
2308 
2309   // Track the number of created operations so that new ones can be legalized.
2310   size_t numCreatedOps = rewriterImpl.createdOps.size();
2311 
2312   // Materialize a conversion for this live result value.
2313   Type resultType = result.getType();
2314   Value convertedValue = replConverter->materializeSourceConversion(
2315       rewriter, op->getLoc(), resultType, newValue);
2316   if (!convertedValue) {
2317     InFlightDiagnostic diag = op->emitError()
2318                               << "failed to materialize conversion for result #"
2319                               << result.getResultNumber() << " of operation '"
2320                               << op->getName()
2321                               << "' that remained live after conversion";
2322     diag.attachNote(liveUserIt->getLoc())
2323         << "see existing live user here: " << *liveUserIt;
2324     return failure();
2325   }
2326 
2327   // Legalize all of the newly created conversion operations.
2328   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2329     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2330       return op->emitError("failed to legalize conversion operation generated ")
2331              << "for result #" << result.getResultNumber() << " of operation '"
2332              << op->getName() << "' that remained live after conversion";
2333     }
2334   }
2335 
2336   rewriterImpl.mapping.map(result, convertedValue);
2337   return success();
2338 }
2339 
2340 //===----------------------------------------------------------------------===//
2341 // Type Conversion
2342 //===----------------------------------------------------------------------===//
2343 
2344 /// Remap an input of the original signature with a new set of types. The
2345 /// new types are appended to the new signature conversion.
2346 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
2347                                                    ArrayRef<Type> types) {
2348   assert(!types.empty() && "expected valid types");
2349   remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2350   addInputs(types);
2351 }
2352 
2353 /// Append new input types to the signature conversion, this should only be
2354 /// used if the new types are not intended to remap an existing input.
2355 void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
2356   assert(!types.empty() &&
2357          "1->0 type remappings don't need to be added explicitly");
2358   argTypes.append(types.begin(), types.end());
2359 }
2360 
2361 /// Remap an input of the original signature with a range of types in the
2362 /// new signature.
2363 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2364                                                     unsigned newInputNo,
2365                                                     unsigned newInputCount) {
2366   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2367   assert(newInputCount != 0 && "expected valid input count");
2368   remappedInputs[origInputNo] =
2369       InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2370 }
2371 
2372 /// Remap an input of the original signature to another `replacementValue`
2373 /// value. This would make the signature converter drop this argument.
2374 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2375                                                     Value replacementValue) {
2376   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2377   remappedInputs[origInputNo] =
2378       InputMapping{origInputNo, /*size=*/0, replacementValue};
2379 }
2380 
2381 /// This hooks allows for converting a type.
2382 LogicalResult TypeConverter::convertType(Type t,
2383                                          SmallVectorImpl<Type> &results) {
2384   auto existingIt = cachedDirectConversions.find(t);
2385   if (existingIt != cachedDirectConversions.end()) {
2386     if (existingIt->second)
2387       results.push_back(existingIt->second);
2388     return success(existingIt->second != nullptr);
2389   }
2390   auto multiIt = cachedMultiConversions.find(t);
2391   if (multiIt != cachedMultiConversions.end()) {
2392     results.append(multiIt->second.begin(), multiIt->second.end());
2393     return success();
2394   }
2395 
2396   // Walk the added converters in reverse order to apply the most recently
2397   // registered first.
2398   size_t currentCount = results.size();
2399   for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2400     if (Optional<LogicalResult> result = converter(t, results)) {
2401       if (!succeeded(*result)) {
2402         cachedDirectConversions.try_emplace(t, nullptr);
2403         return failure();
2404       }
2405       auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2406       if (newTypes.size() == 1)
2407         cachedDirectConversions.try_emplace(t, newTypes.front());
2408       else
2409         cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2410       return success();
2411     }
2412   }
2413   return failure();
2414 }
2415 
2416 /// This hook simplifies defining 1-1 type conversions. This function returns
2417 /// the type to convert to on success, and a null type on failure.
2418 Type TypeConverter::convertType(Type t) {
2419   // Use the multi-type result version to convert the type.
2420   SmallVector<Type, 1> results;
2421   if (failed(convertType(t, results)))
2422     return nullptr;
2423 
2424   // Check to ensure that only one type was produced.
2425   return results.size() == 1 ? results.front() : nullptr;
2426 }
2427 
2428 /// Convert the given set of types, filling 'results' as necessary. This
2429 /// returns failure if the conversion of any of the types fails, success
2430 /// otherwise.
2431 LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
2432                                           SmallVectorImpl<Type> &results) {
2433   for (auto type : types)
2434     if (failed(convertType(type, results)))
2435       return failure();
2436   return success();
2437 }
2438 
2439 /// Return true if the given type is legal for this type converter, i.e. the
2440 /// type converts to itself.
2441 bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
2442 /// Return true if the given operation has legal operand and result types.
2443 bool TypeConverter::isLegal(Operation *op) {
2444   return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2445 }
2446 
2447 /// Return true if the types of block arguments within the region are legal.
2448 bool TypeConverter::isLegal(Region *region) {
2449   return llvm::all_of(*region, [this](Block &block) {
2450     return isLegal(block.getArgumentTypes());
2451   });
2452 }
2453 
2454 /// Return true if the inputs and outputs of the given function type are
2455 /// legal.
2456 bool TypeConverter::isSignatureLegal(FunctionType ty) {
2457   return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2458 }
2459 
2460 /// This hook allows for converting a specific argument of a signature.
2461 LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
2462                                                  SignatureConversion &result) {
2463   // Try to convert the given input type.
2464   SmallVector<Type, 1> convertedTypes;
2465   if (failed(convertType(type, convertedTypes)))
2466     return failure();
2467 
2468   // If this argument is being dropped, there is nothing left to do.
2469   if (convertedTypes.empty())
2470     return success();
2471 
2472   // Otherwise, add the new inputs.
2473   result.addInputs(inputNo, convertedTypes);
2474   return success();
2475 }
2476 LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
2477                                                   SignatureConversion &result,
2478                                                   unsigned origInputOffset) {
2479   for (unsigned i = 0, e = types.size(); i != e; ++i)
2480     if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2481       return failure();
2482   return success();
2483 }
2484 
2485 Value TypeConverter::materializeConversion(
2486     MutableArrayRef<MaterializationCallbackFn> materializations,
2487     OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
2488   for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
2489     if (Optional<Value> result = fn(builder, resultType, inputs, loc))
2490       return result.getValue();
2491   return nullptr;
2492 }
2493 
2494 /// This function converts the type signature of the given block, by invoking
2495 /// 'convertSignatureArg' for each argument. This function should return a valid
2496 /// conversion for the signature on success, None otherwise.
2497 auto TypeConverter::convertBlockSignature(Block *block)
2498     -> Optional<SignatureConversion> {
2499   SignatureConversion conversion(block->getNumArguments());
2500   if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2501     return llvm::None;
2502   return conversion;
2503 }
2504 
2505 /// Create a default conversion pattern that rewrites the type signature of a
2506 /// FuncOp.
2507 namespace {
2508 struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
2509   FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
2510       : OpConversionPattern(converter, ctx) {}
2511 
2512   /// Hook for derived classes to implement combined matching and rewriting.
2513   LogicalResult
2514   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
2515                   ConversionPatternRewriter &rewriter) const override {
2516     FunctionType type = funcOp.getType();
2517 
2518     // Convert the original function types.
2519     TypeConverter::SignatureConversion result(type.getNumInputs());
2520     SmallVector<Type, 1> newResults;
2521     if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
2522         failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
2523         failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter,
2524                                            &result)))
2525       return failure();
2526 
2527     // Update the function signature in-place.
2528     rewriter.updateRootInPlace(funcOp, [&] {
2529       funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults,
2530                                        funcOp.getContext()));
2531     });
2532     return success();
2533   }
2534 };
2535 } // end anonymous namespace
2536 
2537 void mlir::populateFuncOpTypeConversionPattern(
2538     OwningRewritePatternList &patterns, MLIRContext *ctx,
2539     TypeConverter &converter) {
2540   patterns.insert<FuncOpSignatureConversion>(ctx, converter);
2541 }
2542 
2543 //===----------------------------------------------------------------------===//
2544 // ConversionTarget
2545 //===----------------------------------------------------------------------===//
2546 
2547 /// Register a legality action for the given operation.
2548 void ConversionTarget::setOpAction(OperationName op,
2549                                    LegalizationAction action) {
2550   legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None};
2551 }
2552 
2553 /// Register a legality action for the given dialects.
2554 void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
2555                                         LegalizationAction action) {
2556   for (StringRef dialect : dialectNames)
2557     legalDialects[dialect] = action;
2558 }
2559 
2560 /// Get the legality action for the given operation.
2561 auto ConversionTarget::getOpAction(OperationName op) const
2562     -> Optional<LegalizationAction> {
2563   Optional<LegalizationInfo> info = getOpInfo(op);
2564   return info ? info->action : Optional<LegalizationAction>();
2565 }
2566 
2567 /// If the given operation instance is legal on this target, a structure
2568 /// containing legality information is returned. If the operation is not legal,
2569 /// None is returned.
2570 auto ConversionTarget::isLegal(Operation *op) const
2571     -> Optional<LegalOpDetails> {
2572   Optional<LegalizationInfo> info = getOpInfo(op->getName());
2573   if (!info)
2574     return llvm::None;
2575 
2576   // Returns true if this operation instance is known to be legal.
2577   auto isOpLegal = [&] {
2578     // Handle dynamic legality either with the provided legality function, or
2579     // the default hook on the derived instance.
2580     if (info->action == LegalizationAction::Dynamic)
2581       return info->legalityFn ? (*info->legalityFn)(op)
2582                               : isDynamicallyLegal(op);
2583 
2584     // Otherwise, the operation is only legal if it was marked 'Legal'.
2585     return info->action == LegalizationAction::Legal;
2586   };
2587   if (!isOpLegal())
2588     return llvm::None;
2589 
2590   // This operation is legal, compute any additional legality information.
2591   LegalOpDetails legalityDetails;
2592   if (info->isRecursivelyLegal) {
2593     auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
2594     if (legalityFnIt != opRecursiveLegalityFns.end())
2595       legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
2596     else
2597       legalityDetails.isRecursivelyLegal = true;
2598   }
2599   return legalityDetails;
2600 }
2601 
2602 /// Set the dynamic legality callback for the given operation.
2603 void ConversionTarget::setLegalityCallback(
2604     OperationName name, const DynamicLegalityCallbackFn &callback) {
2605   assert(callback && "expected valid legality callback");
2606   auto infoIt = legalOperations.find(name);
2607   assert(infoIt != legalOperations.end() &&
2608          infoIt->second.action == LegalizationAction::Dynamic &&
2609          "expected operation to already be marked as dynamically legal");
2610   infoIt->second.legalityFn = callback;
2611 }
2612 
2613 /// Set the recursive legality callback for the given operation and mark the
2614 /// operation as recursively legal.
2615 void ConversionTarget::markOpRecursivelyLegal(
2616     OperationName name, const DynamicLegalityCallbackFn &callback) {
2617   auto infoIt = legalOperations.find(name);
2618   assert(infoIt != legalOperations.end() &&
2619          infoIt->second.action != LegalizationAction::Illegal &&
2620          "expected operation to already be marked as legal");
2621   infoIt->second.isRecursivelyLegal = true;
2622   if (callback)
2623     opRecursiveLegalityFns[name] = callback;
2624   else
2625     opRecursiveLegalityFns.erase(name);
2626 }
2627 
2628 /// Set the dynamic legality callback for the given dialects.
2629 void ConversionTarget::setLegalityCallback(
2630     ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
2631   assert(callback && "expected valid legality callback");
2632   for (StringRef dialect : dialects)
2633     dialectLegalityFns[dialect] = callback;
2634 }
2635 
2636 /// Get the legalization information for the given operation.
2637 auto ConversionTarget::getOpInfo(OperationName op) const
2638     -> Optional<LegalizationInfo> {
2639   // Check for info for this specific operation.
2640   auto it = legalOperations.find(op);
2641   if (it != legalOperations.end())
2642     return it->second;
2643   // Check for info for the parent dialect.
2644   auto dialectIt = legalDialects.find(op.getDialect());
2645   if (dialectIt != legalDialects.end()) {
2646     Optional<DynamicLegalityCallbackFn> callback;
2647     auto dialectFn = dialectLegalityFns.find(op.getDialect());
2648     if (dialectFn != dialectLegalityFns.end())
2649       callback = dialectFn->second;
2650     return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
2651                             callback};
2652   }
2653   // Otherwise, check if we mark unknown operations as dynamic.
2654   if (unknownOpsDynamicallyLegal)
2655     return LegalizationInfo{LegalizationAction::Dynamic,
2656                             /*isRecursivelyLegal=*/false, unknownLegalityFn};
2657   return llvm::None;
2658 }
2659 
2660 //===----------------------------------------------------------------------===//
2661 // Op Conversion Entry Points
2662 //===----------------------------------------------------------------------===//
2663 
2664 /// Apply a partial conversion on the given operations and all nested
2665 /// operations. This method converts as many operations to the target as
2666 /// possible, ignoring operations that failed to legalize. This method only
2667 /// returns failure if there ops explicitly marked as illegal.
2668 /// If an `unconvertedOps` set is provided, all operations that are found not
2669 /// to be legalizable to the given `target` are placed within that set. (Note
2670 /// that if there is an op explicitly marked as illegal, the conversion
2671 /// terminates and the `unconvertedOps` set will not necessarily be complete.)
2672 LogicalResult
2673 mlir::applyPartialConversion(ArrayRef<Operation *> ops,
2674                              ConversionTarget &target,
2675                              const FrozenRewritePatternList &patterns,
2676                              DenseSet<Operation *> *unconvertedOps) {
2677   OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
2678                                  unconvertedOps);
2679   return opConverter.convertOperations(ops);
2680 }
2681 LogicalResult
2682 mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
2683                              const FrozenRewritePatternList &patterns,
2684                              DenseSet<Operation *> *unconvertedOps) {
2685   return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
2686                                 unconvertedOps);
2687 }
2688 
2689 /// Apply a complete conversion on the given operations, and all nested
2690 /// operations. This method will return failure if the conversion of any
2691 /// operation fails.
2692 LogicalResult
2693 mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
2694                           const FrozenRewritePatternList &patterns) {
2695   OperationConverter opConverter(target, patterns, OpConversionMode::Full);
2696   return opConverter.convertOperations(ops);
2697 }
2698 LogicalResult
2699 mlir::applyFullConversion(Operation *op, ConversionTarget &target,
2700                           const FrozenRewritePatternList &patterns) {
2701   return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
2702 }
2703 
2704 /// Apply an analysis conversion on the given operations, and all nested
2705 /// operations. This method analyzes which operations would be successfully
2706 /// converted to the target if a conversion was applied. All operations that
2707 /// were found to be legalizable to the given 'target' are placed within the
2708 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
2709 /// operations on success and only pre-existing operations are added to the set.
2710 LogicalResult
2711 mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
2712                               ConversionTarget &target,
2713                               const FrozenRewritePatternList &patterns,
2714                               DenseSet<Operation *> &convertedOps) {
2715   OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
2716                                  &convertedOps);
2717   return opConverter.convertOperations(ops);
2718 }
2719 LogicalResult
2720 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
2721                               const FrozenRewritePatternList &patterns,
2722                               DenseSet<Operation *> &convertedOps) {
2723   return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
2724                                  convertedOps);
2725 }
2726