1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/BlockAndValueMapping.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Function.h"
14 #include "mlir/IR/Module.h"
15 #include "mlir/Rewrite/PatternApplicator.h"
16 #include "mlir/Transforms/Utils.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/Support/SaveAndRestore.h"
22 #include "llvm/Support/ScopedPrinter.h"
23 
24 using namespace mlir;
25 using namespace mlir::detail;
26 
27 #define DEBUG_TYPE "dialect-conversion"
28 
29 /// Recursively collect all of the operations to convert from within 'region'.
30 /// If 'target' is nonnull, operations that are recursively legal have their
31 /// regions pre-filtered to avoid considering them for legalization.
32 static LogicalResult
33 computeConversionSet(iterator_range<Region::iterator> region,
34                      Location regionLoc, std::vector<Operation *> &toConvert,
35                      ConversionTarget *target = nullptr) {
36   if (llvm::empty(region))
37     return success();
38 
39   // Traverse starting from the entry block.
40   SmallVector<Block *, 16> worklist(1, &*region.begin());
41   DenseSet<Block *> visitedBlocks;
42   visitedBlocks.insert(worklist.front());
43   while (!worklist.empty()) {
44     Block *block = worklist.pop_back_val();
45 
46     // Compute the conversion set of each of the nested operations.
47     for (Operation &op : *block) {
48       toConvert.emplace_back(&op);
49 
50       // Don't check this operation's children for conversion if the operation
51       // is recursively legal.
52       auto legalityInfo = target ? target->isLegal(&op)
53                                  : Optional<ConversionTarget::LegalOpDetails>();
54       if (legalityInfo && legalityInfo->isRecursivelyLegal)
55         continue;
56       for (auto &region : op.getRegions()) {
57         if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
58                                         toConvert, target)))
59           return failure();
60       }
61     }
62 
63     // Recurse to children that haven't been visited.
64     for (Block *succ : block->getSuccessors())
65       if (visitedBlocks.insert(succ).second)
66         worklist.push_back(succ);
67   }
68 
69   // Check that all blocks in the region were visited.
70   if (llvm::any_of(llvm::drop_begin(region, 1),
71                    [&](Block &block) { return !visitedBlocks.count(&block); }))
72     return emitError(regionLoc, "unreachable blocks were not converted");
73   return success();
74 }
75 
76 /// A utility function to log a successful result for the given reason.
77 template <typename... Args>
78 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
79   LLVM_DEBUG({
80     os.unindent();
81     os.startLine() << "} -> SUCCESS";
82     if (!fmt.empty())
83       os.getOStream() << " : "
84                       << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
85     os.getOStream() << "\n";
86   });
87 }
88 
89 /// A utility function to log a failure result for the given reason.
90 template <typename... Args>
91 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
92   LLVM_DEBUG({
93     os.unindent();
94     os.startLine() << "} -> FAILURE : "
95                    << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
96                    << "\n";
97   });
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // ConversionValueMapping
102 //===----------------------------------------------------------------------===//
103 
104 namespace {
105 /// This class wraps a BlockAndValueMapping to provide recursive lookup
106 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
107 struct ConversionValueMapping {
108   /// Lookup a mapped value within the map. If a mapping for the provided value
109   /// does not exist then return the provided value. If `desiredType` is
110   /// non-null, returns the most recently mapped value with that type. If an
111   /// operand of that type does not exist, defaults to normal behavior.
112   Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
113 
114   /// Lookup a mapped value within the map, or return null if a mapping does not
115   /// exist. If a mapping exists, this follows the same behavior of
116   /// `lookupOrDefault`.
117   Value lookupOrNull(Value from) const;
118 
119   /// Map a value to the one provided.
120   void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); }
121 
122   /// Drop the last mapping for the given value.
123   void erase(Value value) { mapping.erase(value); }
124 
125 private:
126   /// Current value mappings.
127   BlockAndValueMapping mapping;
128 };
129 } // end anonymous namespace
130 
131 Value ConversionValueMapping::lookupOrDefault(Value from,
132                                               Type desiredType) const {
133   // If there was no desired type, simply find the leaf value.
134   if (!desiredType) {
135     // If this value had a valid mapping, unmap that value as well in the case
136     // that it was also replaced.
137     while (auto mappedValue = mapping.lookupOrNull(from))
138       from = mappedValue;
139     return from;
140   }
141 
142   // Otherwise, try to find the deepest value that has the desired type.
143   Value desiredValue;
144   do {
145     if (from.getType() == desiredType)
146       desiredValue = from;
147 
148     Value mappedValue = mapping.lookupOrNull(from);
149     if (!mappedValue)
150       break;
151     from = mappedValue;
152   } while (true);
153 
154   // If the desired value was found use it, otherwise default to the leaf value.
155   return desiredValue ? desiredValue : from;
156 }
157 
158 Value ConversionValueMapping::lookupOrNull(Value from) const {
159   Value result = lookupOrDefault(from);
160   return result == from ? nullptr : result;
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // ArgConverter
165 //===----------------------------------------------------------------------===//
166 namespace {
167 /// This class provides a simple interface for converting the types of block
168 /// arguments. This is done by creating a new block that contains the new legal
169 /// types and extracting the block that contains the old illegal types to allow
170 /// for undoing pending rewrites in the case of failure.
171 struct ArgConverter {
172   ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {}
173 
174   /// This structure contains the information pertaining to an argument that has
175   /// been converted.
176   struct ConvertedArgInfo {
177     ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
178                      Value castValue = nullptr)
179         : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
180 
181     /// The start index of in the new argument list that contains arguments that
182     /// replace the original.
183     unsigned newArgIdx;
184 
185     /// The number of arguments that replaced the original argument.
186     unsigned newArgSize;
187 
188     /// The cast value that was created to cast from the new arguments to the
189     /// old. This only used if 'newArgSize' > 1.
190     Value castValue;
191   };
192 
193   /// This structure contains information pertaining to a block that has had its
194   /// signature converted.
195   struct ConvertedBlockInfo {
196     ConvertedBlockInfo(Block *origBlock, TypeConverter &converter)
197         : origBlock(origBlock), converter(&converter) {}
198 
199     /// The original block that was requested to have its signature converted.
200     Block *origBlock;
201 
202     /// The conversion information for each of the arguments. The information is
203     /// None if the argument was dropped during conversion.
204     SmallVector<Optional<ConvertedArgInfo>, 1> argInfo;
205 
206     /// The type converter used to convert the arguments.
207     TypeConverter *converter;
208   };
209 
210   /// Return if the signature of the given block has already been converted.
211   bool hasBeenConverted(Block *block) const {
212     return conversionInfo.count(block) || convertedBlocks.count(block);
213   }
214 
215   /// Set the type converter to use for the given region.
216   void setConverter(Region *region, TypeConverter *typeConverter) {
217     assert(typeConverter && "expected valid type converter");
218     regionToConverter[region] = typeConverter;
219   }
220 
221   /// Return the type converter to use for the given region, or null if there
222   /// isn't one.
223   TypeConverter *getConverter(Region *region) {
224     return regionToConverter.lookup(region);
225   }
226 
227   //===--------------------------------------------------------------------===//
228   // Rewrite Application
229   //===--------------------------------------------------------------------===//
230 
231   /// Erase any rewrites registered for the blocks within the given operation
232   /// which is about to be removed. This merely drops the rewrites without
233   /// undoing them.
234   void notifyOpRemoved(Operation *op);
235 
236   /// Cleanup and undo any generated conversions for the arguments of block.
237   /// This method replaces the new block with the original, reverting the IR to
238   /// its original state.
239   void discardRewrites(Block *block);
240 
241   /// Fully replace uses of the old arguments with the new.
242   void applyRewrites(ConversionValueMapping &mapping);
243 
244   /// Materialize any necessary conversions for converted arguments that have
245   /// live users, using the provided `findLiveUser` to search for a user that
246   /// survives the conversion process.
247   LogicalResult
248   materializeLiveConversions(ConversionValueMapping &mapping,
249                              OpBuilder &builder,
250                              function_ref<Operation *(Value)> findLiveUser);
251 
252   //===--------------------------------------------------------------------===//
253   // Conversion
254   //===--------------------------------------------------------------------===//
255 
256   /// Attempt to convert the signature of the given block, if successful a new
257   /// block is returned containing the new arguments. Returns `block` if it did
258   /// not require conversion.
259   FailureOr<Block *> convertSignature(Block *block, TypeConverter &converter,
260                                       ConversionValueMapping &mapping);
261 
262   /// Apply the given signature conversion on the given block. The new block
263   /// containing the updated signature is returned. If no conversions were
264   /// necessary, e.g. if the block has no arguments, `block` is returned.
265   /// `converter` is used to generate any necessary cast operations that
266   /// translate between the origin argument types and those specified in the
267   /// signature conversion.
268   Block *applySignatureConversion(
269       Block *block, TypeConverter &converter,
270       TypeConverter::SignatureConversion &signatureConversion,
271       ConversionValueMapping &mapping);
272 
273   /// Insert a new conversion into the cache.
274   void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
275 
276   /// A collection of blocks that have had their arguments converted. This is a
277   /// map from the new replacement block, back to the original block.
278   llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
279 
280   /// The set of original blocks that were converted.
281   DenseSet<Block *> convertedBlocks;
282 
283   /// A mapping from valid regions, to those containing the original blocks of a
284   /// conversion.
285   DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
286 
287   /// A mapping of regions to type converters that should be used when
288   /// converting the arguments of blocks within that region.
289   DenseMap<Region *, TypeConverter *> regionToConverter;
290 
291   /// The pattern rewriter to use when materializing conversions.
292   PatternRewriter &rewriter;
293 };
294 } // end anonymous namespace
295 
296 //===----------------------------------------------------------------------===//
297 // Rewrite Application
298 
299 void ArgConverter::notifyOpRemoved(Operation *op) {
300   if (conversionInfo.empty())
301     return;
302 
303   for (Region &region : op->getRegions()) {
304     for (Block &block : region) {
305       // Drop any rewrites from within.
306       for (Operation &nestedOp : block)
307         if (nestedOp.getNumRegions())
308           notifyOpRemoved(&nestedOp);
309 
310       // Check if this block was converted.
311       auto it = conversionInfo.find(&block);
312       if (it == conversionInfo.end())
313         continue;
314 
315       // Drop all uses of the original arguments and delete the original block.
316       Block *origBlock = it->second.origBlock;
317       for (BlockArgument arg : origBlock->getArguments())
318         arg.dropAllUses();
319       conversionInfo.erase(it);
320     }
321   }
322 }
323 
324 void ArgConverter::discardRewrites(Block *block) {
325   auto it = conversionInfo.find(block);
326   if (it == conversionInfo.end())
327     return;
328   Block *origBlock = it->second.origBlock;
329 
330   // Drop all uses of the new block arguments and replace uses of the new block.
331   for (int i = block->getNumArguments() - 1; i >= 0; --i)
332     block->getArgument(i).dropAllUses();
333   block->replaceAllUsesWith(origBlock);
334 
335   // Move the operations back the original block and the delete the new block.
336   origBlock->getOperations().splice(origBlock->end(), block->getOperations());
337   origBlock->moveBefore(block);
338   block->erase();
339 
340   convertedBlocks.erase(origBlock);
341   conversionInfo.erase(it);
342 }
343 
344 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
345   for (auto &info : conversionInfo) {
346     ConvertedBlockInfo &blockInfo = info.second;
347     Block *origBlock = blockInfo.origBlock;
348 
349     // Process the remapping for each of the original arguments.
350     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
351       Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
352       BlockArgument origArg = origBlock->getArgument(i);
353 
354       // Handle the case of a 1->0 value mapping.
355       if (!argInfo) {
356         if (Value newArg = mapping.lookupOrNull(origArg))
357           origArg.replaceAllUsesWith(newArg);
358         continue;
359       }
360 
361       // Otherwise this is a 1->1+ value mapping.
362       Value castValue = argInfo->castValue;
363       assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
364 
365       // If the argument is still used, replace it with the generated cast.
366       if (!origArg.use_empty())
367         origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue));
368 
369       // If all users of the cast were removed, we can drop it. Otherwise, keep
370       // the operation alive and let the user handle any remaining usages.
371       if (castValue.use_empty() && castValue.getDefiningOp())
372         castValue.getDefiningOp()->erase();
373     }
374   }
375 }
376 
377 LogicalResult ArgConverter::materializeLiveConversions(
378     ConversionValueMapping &mapping, OpBuilder &builder,
379     function_ref<Operation *(Value)> findLiveUser) {
380   for (auto &info : conversionInfo) {
381     Block *newBlock = info.first;
382     ConvertedBlockInfo &blockInfo = info.second;
383     Block *origBlock = blockInfo.origBlock;
384 
385     // Process the remapping for each of the original arguments.
386     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
387       // FIXME: We should run the below checks even if the type conversion was
388       // 1->N, but a lot of existing lowering rely on the block argument being
389       // blindly replaced. Those usages should be updated, and this if should be
390       // removed.
391       if (blockInfo.argInfo[i])
392         continue;
393 
394       // If the type of this argument changed and the argument is still live, we
395       // need to materialize a conversion.
396       BlockArgument origArg = origBlock->getArgument(i);
397       auto argReplacementValue = mapping.lookupOrDefault(origArg);
398       bool isDroppedArg = argReplacementValue == origArg;
399       if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg)
400         continue;
401       Operation *liveUser = findLiveUser(origArg);
402       if (!liveUser)
403         continue;
404 
405       if (OpResult result = argReplacementValue.dyn_cast<OpResult>())
406         rewriter.setInsertionPointAfter(result.getOwner());
407       else
408         rewriter.setInsertionPointToStart(newBlock);
409       Value newArg = blockInfo.converter->materializeSourceConversion(
410           rewriter, origArg.getLoc(), origArg.getType(),
411           isDroppedArg ? ValueRange() : ValueRange(argReplacementValue));
412       if (!newArg) {
413         InFlightDiagnostic diag =
414             emitError(origArg.getLoc())
415             << "failed to materialize conversion for block argument #" << i
416             << " that remained live after conversion, type was "
417             << origArg.getType();
418         if (!isDroppedArg)
419           diag << ", with target type " << argReplacementValue.getType();
420         diag.attachNote(liveUser->getLoc())
421             << "see existing live user here: " << *liveUser;
422         return failure();
423       }
424       mapping.map(origArg, newArg);
425     }
426   }
427   return success();
428 }
429 
430 //===----------------------------------------------------------------------===//
431 // Conversion
432 
433 FailureOr<Block *>
434 ArgConverter::convertSignature(Block *block, TypeConverter &converter,
435                                ConversionValueMapping &mapping) {
436   // Check if the block was already converted. If the block is detached,
437   // conservatively assume it is going to be deleted.
438   if (hasBeenConverted(block) || !block->getParent())
439     return block;
440 
441   // Try to convert the signature for the block with the provided converter.
442   if (auto conversion = converter.convertBlockSignature(block))
443     return applySignatureConversion(block, converter, *conversion, mapping);
444   return failure();
445 }
446 
447 Block *ArgConverter::applySignatureConversion(
448     Block *block, TypeConverter &converter,
449     TypeConverter::SignatureConversion &signatureConversion,
450     ConversionValueMapping &mapping) {
451   // If no arguments are being changed or added, there is nothing to do.
452   unsigned origArgCount = block->getNumArguments();
453   auto convertedTypes = signatureConversion.getConvertedTypes();
454   if (origArgCount == 0 && convertedTypes.empty())
455     return block;
456 
457   // Split the block at the beginning to get a new block to use for the updated
458   // signature.
459   Block *newBlock = block->splitBlock(block->begin());
460   block->replaceAllUsesWith(newBlock);
461 
462   SmallVector<Value, 4> newArgRange(newBlock->addArguments(convertedTypes));
463   ArrayRef<Value> newArgs(newArgRange);
464 
465   // Remap each of the original arguments as determined by the signature
466   // conversion.
467   ConvertedBlockInfo info(block, converter);
468   info.argInfo.resize(origArgCount);
469 
470   OpBuilder::InsertionGuard guard(rewriter);
471   rewriter.setInsertionPointToStart(newBlock);
472   for (unsigned i = 0; i != origArgCount; ++i) {
473     auto inputMap = signatureConversion.getInputMapping(i);
474     if (!inputMap)
475       continue;
476     BlockArgument origArg = block->getArgument(i);
477 
478     // If inputMap->replacementValue is not nullptr, then the argument is
479     // dropped and a replacement value is provided to be the remappedValue.
480     if (inputMap->replacementValue) {
481       assert(inputMap->size == 0 &&
482              "invalid to provide a replacement value when the argument isn't "
483              "dropped");
484       mapping.map(origArg, inputMap->replacementValue);
485       continue;
486     }
487 
488     // Otherwise, this is a 1->1+ mapping. Call into the provided type converter
489     // to pack the new values. For 1->1 mappings, if there is no materialization
490     // provided, use the argument directly instead.
491     auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
492     Value newArg = converter.materializeArgumentConversion(
493         rewriter, origArg.getLoc(), origArg.getType(), replArgs);
494     if (!newArg) {
495       assert(replArgs.size() == 1 &&
496              "couldn't materialize the result of 1->N conversion");
497       newArg = replArgs.front();
498     }
499     mapping.map(origArg, newArg);
500     info.argInfo[i] =
501         ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
502   }
503 
504   // Remove the original block from the region and return the new one.
505   insertConversion(newBlock, std::move(info));
506   return newBlock;
507 }
508 
509 void ArgConverter::insertConversion(Block *newBlock,
510                                     ConvertedBlockInfo &&info) {
511   // Get a region to insert the old block.
512   Region *region = newBlock->getParent();
513   std::unique_ptr<Region> &mappedRegion = regionMapping[region];
514   if (!mappedRegion)
515     mappedRegion = std::make_unique<Region>(region->getParentOp());
516 
517   // Move the original block to the mapped region and emplace the conversion.
518   mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
519                                    info.origBlock->getIterator());
520   convertedBlocks.insert(info.origBlock);
521   conversionInfo.insert({newBlock, std::move(info)});
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // Rewriter and Translation State
526 //===----------------------------------------------------------------------===//
527 namespace {
528 /// This class contains a snapshot of the current conversion rewriter state.
529 /// This is useful when saving and undoing a set of rewrites.
530 struct RewriterState {
531   RewriterState(unsigned numCreatedOps, unsigned numReplacements,
532                 unsigned numArgReplacements, unsigned numBlockActions,
533                 unsigned numIgnoredOperations, unsigned numRootUpdates)
534       : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
535         numArgReplacements(numArgReplacements),
536         numBlockActions(numBlockActions),
537         numIgnoredOperations(numIgnoredOperations),
538         numRootUpdates(numRootUpdates) {}
539 
540   /// The current number of created operations.
541   unsigned numCreatedOps;
542 
543   /// The current number of replacements queued.
544   unsigned numReplacements;
545 
546   /// The current number of argument replacements queued.
547   unsigned numArgReplacements;
548 
549   /// The current number of block actions performed.
550   unsigned numBlockActions;
551 
552   /// The current number of ignored operations.
553   unsigned numIgnoredOperations;
554 
555   /// The current number of operations that were updated in place.
556   unsigned numRootUpdates;
557 };
558 
559 /// The state of an operation that was updated by a pattern in-place. This
560 /// contains all of the necessary information to reconstruct an operation that
561 /// was updated in place.
562 class OperationTransactionState {
563 public:
564   OperationTransactionState() = default;
565   OperationTransactionState(Operation *op)
566       : op(op), loc(op->getLoc()), attrs(op->getMutableAttrDict()),
567         operands(op->operand_begin(), op->operand_end()),
568         successors(op->successor_begin(), op->successor_end()) {}
569 
570   /// Discard the transaction state and reset the state of the original
571   /// operation.
572   void resetOperation() const {
573     op->setLoc(loc);
574     op->setAttrs(attrs);
575     op->setOperands(operands);
576     for (auto it : llvm::enumerate(successors))
577       op->setSuccessor(it.value(), it.index());
578   }
579 
580   /// Return the original operation of this state.
581   Operation *getOperation() const { return op; }
582 
583 private:
584   Operation *op;
585   LocationAttr loc;
586   MutableDictionaryAttr attrs;
587   SmallVector<Value, 8> operands;
588   SmallVector<Block *, 2> successors;
589 };
590 
591 /// This class represents one requested operation replacement via 'replaceOp' or
592 /// 'eraseOp`.
593 struct OpReplacement {
594   OpReplacement() = default;
595   OpReplacement(TypeConverter *converter) : converter(converter) {}
596 
597   /// An optional type converter that can be used to materialize conversions
598   /// between the new and old values if necessary.
599   TypeConverter *converter = nullptr;
600 };
601 
602 /// The kind of the block action performed during the rewrite.  Actions can be
603 /// undone if the conversion fails.
604 enum class BlockActionKind {
605   Create,
606   Erase,
607   Merge,
608   Move,
609   Split,
610   TypeConversion
611 };
612 
613 /// Original position of the given block in its parent region. During undo
614 /// actions, the block needs to be placed after `insertAfterBlock`.
615 struct BlockPosition {
616   Region *region;
617   Block *insertAfterBlock;
618 };
619 
620 /// Information needed to undo the merge actions.
621 /// - the source block, and
622 /// - the Operation that was the last operation in the dest block before the
623 ///   merge (could be null if the dest block was empty).
624 struct MergeInfo {
625   Block *sourceBlock;
626   Operation *destBlockLastInst;
627 };
628 
629 /// The storage class for an undoable block action (one of BlockActionKind),
630 /// contains the information necessary to undo this action.
631 struct BlockAction {
632   static BlockAction getCreate(Block *block) {
633     return {BlockActionKind::Create, block, {}};
634   }
635   static BlockAction getErase(Block *block, BlockPosition originalPosition) {
636     return {BlockActionKind::Erase, block, {originalPosition}};
637   }
638   static BlockAction getMerge(Block *block, Block *sourceBlock) {
639     BlockAction action{BlockActionKind::Merge, block, {}};
640     action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
641     return action;
642   }
643   static BlockAction getMove(Block *block, BlockPosition originalPosition) {
644     return {BlockActionKind::Move, block, {originalPosition}};
645   }
646   static BlockAction getSplit(Block *block, Block *originalBlock) {
647     BlockAction action{BlockActionKind::Split, block, {}};
648     action.originalBlock = originalBlock;
649     return action;
650   }
651   static BlockAction getTypeConversion(Block *block) {
652     return BlockAction{BlockActionKind::TypeConversion, block, {}};
653   }
654 
655   // The action kind.
656   BlockActionKind kind;
657 
658   // A pointer to the block that was created by the action.
659   Block *block;
660 
661   union {
662     // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
663     // contains a pointer to the region that originally contained the block as
664     // well as the position of the block in that region.
665     BlockPosition originalPosition;
666     // In use if kind == BlockActionKind::Split and contains a pointer to the
667     // block that was split into two parts.
668     Block *originalBlock;
669     // In use if kind == BlockActionKind::Merge, and contains the information
670     // needed to undo the merge.
671     MergeInfo mergeInfo;
672   };
673 };
674 } // end anonymous namespace
675 
676 //===----------------------------------------------------------------------===//
677 // ConversionPatternRewriterImpl
678 //===----------------------------------------------------------------------===//
679 namespace mlir {
680 namespace detail {
681 struct ConversionPatternRewriterImpl {
682   ConversionPatternRewriterImpl(PatternRewriter &rewriter)
683       : argConverter(rewriter) {}
684 
685   /// Cleanup and destroy any generated rewrite operations. This method is
686   /// invoked when the conversion process fails.
687   void discardRewrites();
688 
689   /// Apply all requested operation rewrites. This method is invoked when the
690   /// conversion process succeeds.
691   void applyRewrites();
692 
693   //===--------------------------------------------------------------------===//
694   // State Management
695   //===--------------------------------------------------------------------===//
696 
697   /// Return the current state of the rewriter.
698   RewriterState getCurrentState();
699 
700   /// Reset the state of the rewriter to a previously saved point.
701   void resetState(RewriterState state);
702 
703   /// Erase any blocks that were unlinked from their regions and stored in block
704   /// actions.
705   void eraseDanglingBlocks();
706 
707   /// Undo the block actions (motions, splits) one by one in reverse order until
708   /// "numActionsToKeep" actions remains.
709   void undoBlockActions(unsigned numActionsToKeep = 0);
710 
711   /// Remap the given operands to those with potentially different types. The
712   /// provided type converter is used to ensure that the remapped types are
713   /// legal. Returns success if the operands could be remapped, failure
714   /// otherwise.
715   LogicalResult remapValues(Location loc, PatternRewriter &rewriter,
716                             TypeConverter *converter,
717                             Operation::operand_range operands,
718                             SmallVectorImpl<Value> &remapped);
719 
720   /// Returns true if the given operation is ignored, and does not need to be
721   /// converted.
722   bool isOpIgnored(Operation *op) const;
723 
724   /// Recursively marks the nested operations under 'op' as ignored. This
725   /// removes them from being considered for legalization.
726   void markNestedOpsIgnored(Operation *op);
727 
728   //===--------------------------------------------------------------------===//
729   // Type Conversion
730   //===--------------------------------------------------------------------===//
731 
732   /// Convert the signature of the given block.
733   FailureOr<Block *> convertBlockSignature(
734       Block *block, TypeConverter &converter,
735       TypeConverter::SignatureConversion *conversion = nullptr);
736 
737   /// Apply a signature conversion on the given region.
738   Block *
739   applySignatureConversion(Region *region,
740                            TypeConverter::SignatureConversion &conversion);
741 
742   /// Convert the types of block arguments within the given region.
743   FailureOr<Block *>
744   convertRegionTypes(Region *region, TypeConverter &converter,
745                      TypeConverter::SignatureConversion *entryConversion);
746 
747   //===--------------------------------------------------------------------===//
748   // Rewriter Notification Hooks
749   //===--------------------------------------------------------------------===//
750 
751   /// PatternRewriter hook for replacing the results of an operation.
752   void notifyOpReplaced(Operation *op, ValueRange newValues);
753 
754   /// Notifies that a block is about to be erased.
755   void notifyBlockIsBeingErased(Block *block);
756 
757   /// Notifies that a block was created.
758   void notifyCreatedBlock(Block *block);
759 
760   /// Notifies that a block was split.
761   void notifySplitBlock(Block *block, Block *continuation);
762 
763   /// Notifies that `block` is being merged with `srcBlock`.
764   void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
765 
766   /// Notifies that the blocks of a region are about to be moved.
767   void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
768                                         Region::iterator before);
769 
770   /// Notifies that the blocks of a region were cloned into another.
771   void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
772                                    Location origRegionLoc);
773 
774   /// Notifies that a pattern match failed for the given reason.
775   LogicalResult
776   notifyMatchFailure(Location loc,
777                      function_ref<void(Diagnostic &)> reasonCallback);
778 
779   //===--------------------------------------------------------------------===//
780   // State
781   //===--------------------------------------------------------------------===//
782 
783   // Mapping between replaced values that differ in type. This happens when
784   // replacing a value with one of a different type.
785   ConversionValueMapping mapping;
786 
787   /// Utility used to convert block arguments.
788   ArgConverter argConverter;
789 
790   /// Ordered vector of all of the newly created operations during conversion.
791   std::vector<Operation *> createdOps;
792 
793   /// Ordered map of requested operation replacements.
794   llvm::MapVector<Operation *, OpReplacement> replacements;
795 
796   /// Ordered vector of any requested block argument replacements.
797   SmallVector<BlockArgument, 4> argReplacements;
798 
799   /// Ordered list of block operations (creations, splits, motions).
800   SmallVector<BlockAction, 4> blockActions;
801 
802   /// A set of operations that should no longer be considered for legalization,
803   /// but were not directly replace/erased/etc. by a pattern. These are
804   /// generally child operations of other operations who were
805   /// replaced/erased/etc. This is not meant to be an exhaustive list of all
806   /// operations, but the minimal set that can be used to detect if a given
807   /// operation should be `ignored`. For example, we may add the operations that
808   /// define non-empty regions to the set, but not any of the others. This
809   /// simplifies the amount of memory needed as we can query if the parent
810   /// operation was ignored.
811   llvm::SetVector<Operation *> ignoredOps;
812 
813   /// A transaction state for each of operations that were updated in-place.
814   SmallVector<OperationTransactionState, 4> rootUpdates;
815 
816   /// A vector of indices into `replacements` of operations that were replaced
817   /// with values with different result types than the original operation, e.g.
818   /// 1->N conversion of some kind.
819   SmallVector<unsigned, 4> operationsWithChangedResults;
820 
821   /// A default type converter, used when block conversions do not have one
822   /// explicitly provided.
823   TypeConverter defaultTypeConverter;
824 
825   /// The current conversion pattern that is being rewritten, or nullptr if
826   /// called from outside of a conversion pattern rewrite.
827   const ConversionPattern *currentConversionPattern = nullptr;
828 
829 #ifndef NDEBUG
830   /// A set of operations that have pending updates. This tracking isn't
831   /// strictly necessary, and is thus only active during debug builds for extra
832   /// verification.
833   SmallPtrSet<Operation *, 1> pendingRootUpdates;
834 
835   /// A logger used to emit diagnostics during the conversion process.
836   llvm::ScopedPrinter logger{llvm::dbgs()};
837 #endif
838 };
839 } // end namespace detail
840 } // end namespace mlir
841 
842 /// Detach any operations nested in the given operation from their parent
843 /// blocks, and erase the given operation. This can be used when the nested
844 /// operations are scheduled for erasure themselves, so deleting the regions of
845 /// the given operation together with their content would result in double-free.
846 /// This happens, for example, when rolling back op creation in the reverse
847 /// order and if the nested ops were created before the parent op. This function
848 /// does not need to collect nested ops recursively because it is expected to
849 /// also be called for each nested op when it is about to be deleted.
850 static void detachNestedAndErase(Operation *op) {
851   for (Region &region : op->getRegions()) {
852     for (Block &block : region.getBlocks()) {
853       while (!block.getOperations().empty())
854         block.getOperations().remove(block.getOperations().begin());
855       block.dropAllDefinedValueUses();
856     }
857   }
858   op->erase();
859 }
860 
861 void ConversionPatternRewriterImpl::discardRewrites() {
862   // Reset any operations that were updated in place.
863   for (auto &state : rootUpdates)
864     state.resetOperation();
865 
866   undoBlockActions();
867 
868   // Remove any newly created ops.
869   for (auto *op : llvm::reverse(createdOps))
870     detachNestedAndErase(op);
871 }
872 
873 void ConversionPatternRewriterImpl::applyRewrites() {
874   // Apply all of the rewrites replacements requested during conversion.
875   for (auto &repl : replacements) {
876     for (OpResult result : repl.first->getResults())
877       if (Value newValue = mapping.lookupOrNull(result))
878         result.replaceAllUsesWith(newValue);
879 
880     // If this operation defines any regions, drop any pending argument
881     // rewrites.
882     if (repl.first->getNumRegions())
883       argConverter.notifyOpRemoved(repl.first);
884   }
885 
886   // Apply all of the requested argument replacements.
887   for (BlockArgument arg : argReplacements) {
888     Value repl = mapping.lookupOrDefault(arg);
889     if (repl.isa<BlockArgument>()) {
890       arg.replaceAllUsesWith(repl);
891       continue;
892     }
893 
894     // If the replacement value is an operation, we check to make sure that we
895     // don't replace uses that are within the parent operation of the
896     // replacement value.
897     Operation *replOp = repl.cast<OpResult>().getOwner();
898     Block *replBlock = replOp->getBlock();
899     arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
900       Operation *user = operand.getOwner();
901       return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
902     });
903   }
904 
905   // In a second pass, erase all of the replaced operations in reverse. This
906   // allows processing nested operations before their parent region is
907   // destroyed.
908   for (auto &repl : llvm::reverse(replacements))
909     repl.first->erase();
910 
911   argConverter.applyRewrites(mapping);
912 
913   // Now that the ops have been erased, also erase dangling blocks.
914   eraseDanglingBlocks();
915 }
916 
917 //===----------------------------------------------------------------------===//
918 // State Management
919 
920 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
921   return RewriterState(createdOps.size(), replacements.size(),
922                        argReplacements.size(), blockActions.size(),
923                        ignoredOps.size(), rootUpdates.size());
924 }
925 
926 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
927   // Reset any operations that were updated in place.
928   for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
929     rootUpdates[i].resetOperation();
930   rootUpdates.resize(state.numRootUpdates);
931 
932   // Reset any replaced arguments.
933   for (BlockArgument replacedArg :
934        llvm::drop_begin(argReplacements, state.numArgReplacements))
935     mapping.erase(replacedArg);
936   argReplacements.resize(state.numArgReplacements);
937 
938   // Undo any block actions.
939   undoBlockActions(state.numBlockActions);
940 
941   // Reset any replaced operations and undo any saved mappings.
942   for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
943     for (auto result : repl.first->getResults())
944       mapping.erase(result);
945   while (replacements.size() != state.numReplacements)
946     replacements.pop_back();
947 
948   // Pop all of the newly created operations.
949   while (createdOps.size() != state.numCreatedOps) {
950     detachNestedAndErase(createdOps.back());
951     createdOps.pop_back();
952   }
953 
954   // Pop all of the recorded ignored operations that are no longer valid.
955   while (ignoredOps.size() != state.numIgnoredOperations)
956     ignoredOps.pop_back();
957 
958   // Reset operations with changed results.
959   while (!operationsWithChangedResults.empty() &&
960          operationsWithChangedResults.back() >= state.numReplacements)
961     operationsWithChangedResults.pop_back();
962 }
963 
964 void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
965   for (auto &action : blockActions)
966     if (action.kind == BlockActionKind::Erase)
967       delete action.block;
968 }
969 
970 void ConversionPatternRewriterImpl::undoBlockActions(
971     unsigned numActionsToKeep) {
972   for (auto &action :
973        llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
974     switch (action.kind) {
975     // Delete the created block.
976     case BlockActionKind::Create: {
977       // Unlink all of the operations within this block, they will be deleted
978       // separately.
979       auto &blockOps = action.block->getOperations();
980       while (!blockOps.empty())
981         blockOps.remove(blockOps.begin());
982       action.block->dropAllDefinedValueUses();
983       action.block->erase();
984       break;
985     }
986     // Put the block (owned by action) back into its original position.
987     case BlockActionKind::Erase: {
988       auto &blockList = action.originalPosition.region->getBlocks();
989       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
990       blockList.insert((insertAfterBlock
991                             ? std::next(Region::iterator(insertAfterBlock))
992                             : blockList.end()),
993                        action.block);
994       break;
995     }
996     // Split the block at the position which was originally the end of the
997     // destination block (owned by action), and put the instructions back into
998     // the block used before the merge.
999     case BlockActionKind::Merge: {
1000       Block *sourceBlock = action.mergeInfo.sourceBlock;
1001       Block::iterator splitPoint =
1002           (action.mergeInfo.destBlockLastInst
1003                ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
1004                : action.block->begin());
1005       sourceBlock->getOperations().splice(sourceBlock->begin(),
1006                                           action.block->getOperations(),
1007                                           splitPoint, action.block->end());
1008       break;
1009     }
1010     // Move the block back to its original position.
1011     case BlockActionKind::Move: {
1012       Region *originalRegion = action.originalPosition.region;
1013       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1014       originalRegion->getBlocks().splice(
1015           (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
1016                             : originalRegion->end()),
1017           action.block->getParent()->getBlocks(), action.block);
1018       break;
1019     }
1020     // Merge back the block that was split out.
1021     case BlockActionKind::Split: {
1022       action.originalBlock->getOperations().splice(
1023           action.originalBlock->end(), action.block->getOperations());
1024       action.block->dropAllDefinedValueUses();
1025       action.block->erase();
1026       break;
1027     }
1028     // Undo the type conversion.
1029     case BlockActionKind::TypeConversion: {
1030       argConverter.discardRewrites(action.block);
1031       break;
1032     }
1033     }
1034   }
1035   blockActions.resize(numActionsToKeep);
1036 }
1037 
1038 LogicalResult ConversionPatternRewriterImpl::remapValues(
1039     Location loc, PatternRewriter &rewriter, TypeConverter *converter,
1040     Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
1041   remapped.reserve(llvm::size(operands));
1042 
1043   SmallVector<Type, 1> legalTypes;
1044   for (auto it : llvm::enumerate(operands)) {
1045     Value operand = it.value();
1046     Type origType = operand.getType();
1047 
1048     // If a converter was provided, get the desired legal types for this
1049     // operand.
1050     Type desiredType;
1051     if (converter) {
1052       // If there is no legal conversion, fail to match this pattern.
1053       legalTypes.clear();
1054       if (failed(converter->convertType(origType, legalTypes))) {
1055         return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1056           diag << "unable to convert type for operand #" << it.index()
1057                << ", type was " << origType;
1058         });
1059       }
1060       // TODO: There currently isn't any mechanism to do 1->N type conversion
1061       // via the PatternRewriter replacement API, so for now we just ignore it.
1062       if (legalTypes.size() == 1)
1063         desiredType = legalTypes.front();
1064     } else {
1065       // TODO: What we should do here is just set `desiredType` to `origType`
1066       // and then handle the necessary type conversions after the conversion
1067       // process has finished. Unfortunately a lot of patterns currently rely on
1068       // receiving the new operands even if the types change, so we keep the
1069       // original behavior here for now until all of the patterns relying on
1070       // this get updated.
1071     }
1072     Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1073 
1074     // Handle the case where the conversion was 1->1 and the new operand type
1075     // isn't legal.
1076     Type newOperandType = newOperand.getType();
1077     if (converter && desiredType && newOperandType != desiredType) {
1078       // Attempt to materialize a conversion for this new value.
1079       newOperand = converter->materializeTargetConversion(
1080           rewriter, loc, desiredType, newOperand);
1081       if (!newOperand) {
1082         return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1083           diag << "unable to materialize a conversion for "
1084                   "operand #"
1085                << it.index() << ", from " << newOperandType << " to "
1086                << desiredType;
1087         });
1088       }
1089     }
1090     remapped.push_back(newOperand);
1091   }
1092   return success();
1093 }
1094 
1095 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1096   // Check to see if this operation was replaced or its parent ignored.
1097   return replacements.count(op) || ignoredOps.count(op->getParentOp());
1098 }
1099 
1100 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
1101   // Walk this operation and collect nested operations that define non-empty
1102   // regions. We mark such operations as 'ignored' so that we know we don't have
1103   // to convert them, or their nested ops.
1104   if (op->getNumRegions() == 0)
1105     return;
1106   op->walk([&](Operation *op) {
1107     if (llvm::any_of(op->getRegions(),
1108                      [](Region &region) { return !region.empty(); }))
1109       ignoredOps.insert(op);
1110   });
1111 }
1112 
1113 //===----------------------------------------------------------------------===//
1114 // Type Conversion
1115 
1116 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1117     Block *block, TypeConverter &converter,
1118     TypeConverter::SignatureConversion *conversion) {
1119   FailureOr<Block *> result =
1120       conversion ? argConverter.applySignatureConversion(block, converter,
1121                                                          *conversion, mapping)
1122                  : argConverter.convertSignature(block, converter, mapping);
1123   if (Block *newBlock = result.getValue()) {
1124     if (newBlock != block)
1125       blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1126   }
1127   return result;
1128 }
1129 
1130 Block *ConversionPatternRewriterImpl::applySignatureConversion(
1131     Region *region, TypeConverter::SignatureConversion &conversion) {
1132   if (!region->empty()) {
1133     return *convertBlockSignature(&region->front(), defaultTypeConverter,
1134                                   &conversion);
1135   }
1136   return nullptr;
1137 }
1138 
1139 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1140     Region *region, TypeConverter &converter,
1141     TypeConverter::SignatureConversion *entryConversion) {
1142   argConverter.setConverter(region, &converter);
1143   if (region->empty())
1144     return nullptr;
1145 
1146   // Convert the arguments of each block within the region.
1147   FailureOr<Block *> newEntry =
1148       convertBlockSignature(&region->front(), converter, entryConversion);
1149   for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
1150     if (failed(convertBlockSignature(&block, converter)))
1151       return failure();
1152   return newEntry;
1153 }
1154 
1155 //===----------------------------------------------------------------------===//
1156 // Rewriter Notification Hooks
1157 
1158 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1159                                                      ValueRange newValues) {
1160   assert(newValues.size() == op->getNumResults());
1161   assert(!replacements.count(op) && "operation was already replaced");
1162 
1163   // Track if any of the results changed, e.g. erased and replaced with null.
1164   bool resultChanged = false;
1165 
1166   // Create mappings for each of the new result values.
1167   Value newValue, result;
1168   for (auto it : llvm::zip(newValues, op->getResults())) {
1169     std::tie(newValue, result) = it;
1170     if (!newValue) {
1171       resultChanged = true;
1172       continue;
1173     }
1174     // Remap, and check for any result type changes.
1175     mapping.map(result, newValue);
1176     resultChanged |= (newValue.getType() != result.getType());
1177   }
1178   if (resultChanged)
1179     operationsWithChangedResults.push_back(replacements.size());
1180 
1181   // Record the requested operation replacement.
1182   TypeConverter *converter = nullptr;
1183   if (currentConversionPattern)
1184     converter = currentConversionPattern->getTypeConverter();
1185   replacements.insert(std::make_pair(op, OpReplacement(converter)));
1186 
1187   // Mark this operation as recursively ignored so that we don't need to
1188   // convert any nested operations.
1189   markNestedOpsIgnored(op);
1190 }
1191 
1192 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
1193   Region *region = block->getParent();
1194   Block *origPrevBlock = block->getPrevNode();
1195   blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1196 }
1197 
1198 void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
1199   blockActions.push_back(BlockAction::getCreate(block));
1200 }
1201 
1202 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
1203                                                      Block *continuation) {
1204   blockActions.push_back(BlockAction::getSplit(continuation, block));
1205 }
1206 
1207 void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block,
1208                                                             Block *srcBlock) {
1209   blockActions.push_back(BlockAction::getMerge(block, srcBlock));
1210 }
1211 
1212 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
1213     Region &region, Region &parent, Region::iterator before) {
1214   if (region.empty())
1215     return;
1216   Block *laterBlock = &region.back();
1217   for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1218     blockActions.push_back(
1219         BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
1220     laterBlock = &earlierBlock;
1221   }
1222   blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
1223 }
1224 
1225 void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
1226     iterator_range<Region::iterator> &blocks, Location origRegionLoc) {
1227   for (Block &block : blocks)
1228     blockActions.push_back(BlockAction::getCreate(&block));
1229 
1230   // Compute the conversion set for the inlined region.
1231   auto result = computeConversionSet(blocks, origRegionLoc, createdOps);
1232 
1233   // This original region has already had its conversion set computed, so there
1234   // shouldn't be any new failures.
1235   (void)result;
1236   assert(succeeded(result) && "expected region to have no unreachable blocks");
1237 }
1238 
1239 LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
1240     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1241   LLVM_DEBUG({
1242     Diagnostic diag(loc, DiagnosticSeverity::Remark);
1243     reasonCallback(diag);
1244     logger.startLine() << "** Failure : " << diag.str() << "\n";
1245   });
1246   return failure();
1247 }
1248 
1249 //===----------------------------------------------------------------------===//
1250 // ConversionPatternRewriter
1251 //===----------------------------------------------------------------------===//
1252 
1253 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
1254     : PatternRewriter(ctx),
1255       impl(new detail::ConversionPatternRewriterImpl(*this)) {}
1256 ConversionPatternRewriter::~ConversionPatternRewriter() {}
1257 
1258 /// PatternRewriter hook for replacing the results of an operation.
1259 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
1260   LLVM_DEBUG({
1261     impl->logger.startLine()
1262         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1263   });
1264   impl->notifyOpReplaced(op, newValues);
1265 }
1266 
1267 /// PatternRewriter hook for erasing a dead operation. The uses of this
1268 /// operation *must* be made dead by the end of the conversion process,
1269 /// otherwise an assert will be issued.
1270 void ConversionPatternRewriter::eraseOp(Operation *op) {
1271   LLVM_DEBUG({
1272     impl->logger.startLine()
1273         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
1274   });
1275   SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1276   impl->notifyOpReplaced(op, nullRepls);
1277 }
1278 
1279 void ConversionPatternRewriter::eraseBlock(Block *block) {
1280   impl->notifyBlockIsBeingErased(block);
1281 
1282   // Mark all ops for erasure.
1283   for (Operation &op : *block)
1284     eraseOp(&op);
1285 
1286   // Unlink the block from its parent region. The block is kept in the block
1287   // action and will be actually destroyed when rewrites are applied. This
1288   // allows us to keep the operations in the block live and undo the removal by
1289   // re-inserting the block.
1290   block->getParent()->getBlocks().remove(block);
1291 }
1292 
1293 Block *ConversionPatternRewriter::applySignatureConversion(
1294     Region *region, TypeConverter::SignatureConversion &conversion) {
1295   return impl->applySignatureConversion(region, conversion);
1296 }
1297 
1298 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1299     Region *region, TypeConverter &converter,
1300     TypeConverter::SignatureConversion *entryConversion) {
1301   return impl->convertRegionTypes(region, converter, entryConversion);
1302 }
1303 
1304 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1305                                                            Value to) {
1306   LLVM_DEBUG({
1307     Operation *parentOp = from.getOwner()->getParentOp();
1308     impl->logger.startLine() << "** Replace Argument : '" << from
1309                              << "'(in region of '" << parentOp->getName()
1310                              << "'(" << from.getOwner()->getParentOp() << ")\n";
1311   });
1312   impl->argReplacements.push_back(from);
1313   impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1314 }
1315 
1316 /// Return the converted value that replaces 'key'. Return 'key' if there is
1317 /// no such a converted value.
1318 Value ConversionPatternRewriter::getRemappedValue(Value key) {
1319   return impl->mapping.lookupOrDefault(key);
1320 }
1321 
1322 /// PatternRewriter hook for creating a new block with the given arguments.
1323 void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
1324   impl->notifyCreatedBlock(block);
1325 }
1326 
1327 /// PatternRewriter hook for splitting a block into two parts.
1328 Block *ConversionPatternRewriter::splitBlock(Block *block,
1329                                              Block::iterator before) {
1330   auto *continuation = PatternRewriter::splitBlock(block, before);
1331   impl->notifySplitBlock(block, continuation);
1332   return continuation;
1333 }
1334 
1335 /// PatternRewriter hook for merging a block into another.
1336 void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
1337                                             ValueRange argValues) {
1338   impl->notifyBlocksBeingMerged(dest, source);
1339   assert(llvm::all_of(source->getPredecessors(),
1340                       [dest](Block *succ) { return succ == dest; }) &&
1341          "expected 'source' to have no predecessors or only 'dest'");
1342   assert(argValues.size() == source->getNumArguments() &&
1343          "incorrect # of argument replacement values");
1344   for (auto it : llvm::zip(source->getArguments(), argValues))
1345     replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1346   dest->getOperations().splice(dest->end(), source->getOperations());
1347   eraseBlock(source);
1348 }
1349 
1350 /// PatternRewriter hook for moving blocks out of a region.
1351 void ConversionPatternRewriter::inlineRegionBefore(Region &region,
1352                                                    Region &parent,
1353                                                    Region::iterator before) {
1354   impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1355   PatternRewriter::inlineRegionBefore(region, parent, before);
1356 }
1357 
1358 /// PatternRewriter hook for cloning blocks of one region into another.
1359 void ConversionPatternRewriter::cloneRegionBefore(
1360     Region &region, Region &parent, Region::iterator before,
1361     BlockAndValueMapping &mapping) {
1362   if (region.empty())
1363     return;
1364   PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1365 
1366   // Collect the range of the cloned blocks.
1367   auto clonedBeginIt = mapping.lookup(&region.front())->getIterator();
1368   auto clonedBlocks = llvm::make_range(clonedBeginIt, before);
1369   impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
1370 }
1371 
1372 /// PatternRewriter hook for creating a new operation.
1373 void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
1374   LLVM_DEBUG({
1375     impl->logger.startLine()
1376         << "** Insert  : '" << op->getName() << "'(" << op << ")\n";
1377   });
1378   impl->createdOps.push_back(op);
1379 }
1380 
1381 /// PatternRewriter hook for updating the root operation in-place.
1382 void ConversionPatternRewriter::startRootUpdate(Operation *op) {
1383 #ifndef NDEBUG
1384   impl->pendingRootUpdates.insert(op);
1385 #endif
1386   impl->rootUpdates.emplace_back(op);
1387 }
1388 
1389 /// PatternRewriter hook for updating the root operation in-place.
1390 void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
1391   // There is nothing to do here, we only need to track the operation at the
1392   // start of the update.
1393 #ifndef NDEBUG
1394   assert(impl->pendingRootUpdates.erase(op) &&
1395          "operation did not have a pending in-place update");
1396 #endif
1397 }
1398 
1399 /// PatternRewriter hook for updating the root operation in-place.
1400 void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
1401 #ifndef NDEBUG
1402   assert(impl->pendingRootUpdates.erase(op) &&
1403          "operation did not have a pending in-place update");
1404 #endif
1405   // Erase the last update for this operation.
1406   auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
1407   auto &rootUpdates = impl->rootUpdates;
1408   auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1409   rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
1410 }
1411 
1412 /// PatternRewriter hook for notifying match failure reasons.
1413 LogicalResult ConversionPatternRewriter::notifyMatchFailure(
1414     Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
1415   return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
1416 }
1417 
1418 /// Return a reference to the internal implementation.
1419 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
1420   return *impl;
1421 }
1422 
1423 //===----------------------------------------------------------------------===//
1424 // ConversionPattern
1425 //===----------------------------------------------------------------------===//
1426 
1427 /// Attempt to match and rewrite the IR root at the specified operation.
1428 LogicalResult
1429 ConversionPattern::matchAndRewrite(Operation *op,
1430                                    PatternRewriter &rewriter) const {
1431   auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1432   auto &rewriterImpl = dialectRewriter.getImpl();
1433 
1434   // Track the current conversion pattern in the rewriter.
1435   assert(!rewriterImpl.currentConversionPattern &&
1436          "already inside of a pattern rewrite");
1437   llvm::SaveAndRestore<const ConversionPattern *> currentPatternGuard(
1438       rewriterImpl.currentConversionPattern, this);
1439 
1440   // Remap the operands of the operation.
1441   SmallVector<Value, 4> operands;
1442   if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter,
1443                                       getTypeConverter(), op->getOperands(),
1444                                       operands))) {
1445     return failure();
1446   }
1447   return matchAndRewrite(op, operands, dialectRewriter);
1448 }
1449 
1450 //===----------------------------------------------------------------------===//
1451 // OperationLegalizer
1452 //===----------------------------------------------------------------------===//
1453 
1454 namespace {
1455 /// A set of rewrite patterns that can be used to legalize a given operation.
1456 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1457 
1458 /// This class defines a recursive operation legalizer.
1459 class OperationLegalizer {
1460 public:
1461   using LegalizationAction = ConversionTarget::LegalizationAction;
1462 
1463   OperationLegalizer(ConversionTarget &targetInfo,
1464                      const FrozenRewritePatternList &patterns);
1465 
1466   /// Returns true if the given operation is known to be illegal on the target.
1467   bool isIllegal(Operation *op) const;
1468 
1469   /// Attempt to legalize the given operation. Returns success if the operation
1470   /// was legalized, failure otherwise.
1471   LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1472 
1473   /// Returns the conversion target in use by the legalizer.
1474   ConversionTarget &getTarget() { return target; }
1475 
1476 private:
1477   /// Attempt to legalize the given operation by folding it.
1478   LogicalResult legalizeWithFold(Operation *op,
1479                                  ConversionPatternRewriter &rewriter);
1480 
1481   /// Attempt to legalize the given operation by applying a pattern. Returns
1482   /// success if the operation was legalized, failure otherwise.
1483   LogicalResult legalizeWithPattern(Operation *op,
1484                                     ConversionPatternRewriter &rewriter);
1485 
1486   /// Return true if the given pattern may be applied to the given operation,
1487   /// false otherwise.
1488   bool canApplyPattern(Operation *op, const Pattern &pattern,
1489                        ConversionPatternRewriter &rewriter);
1490 
1491   /// Legalize the resultant IR after successfully applying the given pattern.
1492   LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1493                                       ConversionPatternRewriter &rewriter,
1494                                       RewriterState &curState);
1495 
1496   /// Legalizes the actions registered during the execution of a pattern.
1497   LogicalResult legalizePatternBlockActions(Operation *op,
1498                                             ConversionPatternRewriter &rewriter,
1499                                             ConversionPatternRewriterImpl &impl,
1500                                             RewriterState &state,
1501                                             RewriterState &newState);
1502   LogicalResult legalizePatternCreatedOperations(
1503       ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1504       RewriterState &state, RewriterState &newState);
1505   LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1506                                            ConversionPatternRewriterImpl &impl,
1507                                            RewriterState &state,
1508                                            RewriterState &newState);
1509 
1510   //===--------------------------------------------------------------------===//
1511   // Cost Model
1512   //===--------------------------------------------------------------------===//
1513 
1514   /// Build an optimistic legalization graph given the provided patterns. This
1515   /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1516   /// patterns for operations that are not directly legal, but may be
1517   /// transitively legal for the current target given the provided patterns.
1518   void buildLegalizationGraph(
1519       LegalizationPatterns &anyOpLegalizerPatterns,
1520       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1521 
1522   /// Compute the benefit of each node within the computed legalization graph.
1523   /// This orders the patterns within 'legalizerPatterns' based upon two
1524   /// criteria:
1525   ///  1) Prefer patterns that have the lowest legalization depth, i.e.
1526   ///     represent the more direct mapping to the target.
1527   ///  2) When comparing patterns with the same legalization depth, prefer the
1528   ///     pattern with the highest PatternBenefit. This allows for users to
1529   ///     prefer specific legalizations over others.
1530   void computeLegalizationGraphBenefit(
1531       LegalizationPatterns &anyOpLegalizerPatterns,
1532       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1533 
1534   /// Compute the legalization depth when legalizing an operation of the given
1535   /// type.
1536   unsigned computeOpLegalizationDepth(
1537       OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1538       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1539 
1540   /// Apply the conversion cost model to the given set of patterns, and return
1541   /// the smallest legalization depth of any of the patterns. See
1542   /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1543   unsigned applyCostModelToPatterns(
1544       LegalizationPatterns &patterns,
1545       DenseMap<OperationName, unsigned> &minOpPatternDepth,
1546       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1547 
1548   /// The current set of patterns that have been applied.
1549   SmallPtrSet<const Pattern *, 8> appliedPatterns;
1550 
1551   /// The legalization information provided by the target.
1552   ConversionTarget &target;
1553 
1554   /// The pattern applicator to use for conversions.
1555   PatternApplicator applicator;
1556 };
1557 } // namespace
1558 
1559 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
1560                                        const FrozenRewritePatternList &patterns)
1561     : target(targetInfo), applicator(patterns) {
1562   // The set of patterns that can be applied to illegal operations to transform
1563   // them into legal ones.
1564   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
1565   LegalizationPatterns anyOpLegalizerPatterns;
1566 
1567   buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1568   computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1569 }
1570 
1571 bool OperationLegalizer::isIllegal(Operation *op) const {
1572   // Check if the target explicitly marked this operation as illegal.
1573   return target.getOpAction(op->getName()) == LegalizationAction::Illegal;
1574 }
1575 
1576 LogicalResult
1577 OperationLegalizer::legalize(Operation *op,
1578                              ConversionPatternRewriter &rewriter) {
1579 #ifndef NDEBUG
1580   const char *logLineComment =
1581       "//===-------------------------------------------===//\n";
1582 
1583   auto &rewriterImpl = rewriter.getImpl();
1584 #endif
1585   LLVM_DEBUG({
1586     auto &os = rewriterImpl.logger;
1587     os.getOStream() << "\n";
1588     os.startLine() << logLineComment;
1589     os.startLine() << "Legalizing operation : '" << op->getName() << "'(" << op
1590                    << ") {\n";
1591     os.indent();
1592 
1593     // If the operation has no regions, just print it here.
1594     if (op->getNumRegions() == 0) {
1595       op->print(os.startLine(), OpPrintingFlags().printGenericOpForm());
1596       os.getOStream() << "\n\n";
1597     }
1598   });
1599 
1600   // Check if this operation is legal on the target.
1601   if (auto legalityInfo = target.isLegal(op)) {
1602     LLVM_DEBUG({
1603       logSuccess(
1604           rewriterImpl.logger, "operation marked legal by the target{0}",
1605           legalityInfo->isRecursivelyLegal
1606               ? "; NOTE: operation is recursively legal; skipping internals"
1607               : "");
1608       rewriterImpl.logger.startLine() << logLineComment;
1609     });
1610 
1611     // If this operation is recursively legal, mark its children as ignored so
1612     // that we don't consider them for legalization.
1613     if (legalityInfo->isRecursivelyLegal)
1614       rewriter.getImpl().markNestedOpsIgnored(op);
1615     return success();
1616   }
1617 
1618   // Check to see if the operation is ignored and doesn't need to be converted.
1619   if (rewriter.getImpl().isOpIgnored(op)) {
1620     LLVM_DEBUG({
1621       logSuccess(rewriterImpl.logger,
1622                  "operation marked 'ignored' during conversion");
1623       rewriterImpl.logger.startLine() << logLineComment;
1624     });
1625     return success();
1626   }
1627 
1628   // If the operation isn't legal, try to fold it in-place.
1629   // TODO: Should we always try to do this, even if the op is
1630   // already legal?
1631   if (succeeded(legalizeWithFold(op, rewriter))) {
1632     LLVM_DEBUG({
1633       logSuccess(rewriterImpl.logger, "operation was folded");
1634       rewriterImpl.logger.startLine() << logLineComment;
1635     });
1636     return success();
1637   }
1638 
1639   // Otherwise, we need to apply a legalization pattern to this operation.
1640   if (succeeded(legalizeWithPattern(op, rewriter))) {
1641     LLVM_DEBUG({
1642       logSuccess(rewriterImpl.logger, "");
1643       rewriterImpl.logger.startLine() << logLineComment;
1644     });
1645     return success();
1646   }
1647 
1648   LLVM_DEBUG({
1649     logFailure(rewriterImpl.logger, "no matched legalization pattern");
1650     rewriterImpl.logger.startLine() << logLineComment;
1651   });
1652   return failure();
1653 }
1654 
1655 LogicalResult
1656 OperationLegalizer::legalizeWithFold(Operation *op,
1657                                      ConversionPatternRewriter &rewriter) {
1658   auto &rewriterImpl = rewriter.getImpl();
1659   RewriterState curState = rewriterImpl.getCurrentState();
1660 
1661   LLVM_DEBUG({
1662     rewriterImpl.logger.startLine() << "* Fold {\n";
1663     rewriterImpl.logger.indent();
1664   });
1665 
1666   // Try to fold the operation.
1667   SmallVector<Value, 2> replacementValues;
1668   rewriter.setInsertionPoint(op);
1669   if (failed(rewriter.tryFold(op, replacementValues))) {
1670     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1671     return failure();
1672   }
1673 
1674   // Insert a replacement for 'op' with the folded replacement values.
1675   rewriter.replaceOp(op, replacementValues);
1676 
1677   // Recursively legalize any new constant operations.
1678   for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1679        i != e; ++i) {
1680     Operation *cstOp = rewriterImpl.createdOps[i];
1681     if (failed(legalize(cstOp, rewriter))) {
1682       LLVM_DEBUG(logFailure(rewriterImpl.logger,
1683                             "generated constant '{0}' was illegal",
1684                             cstOp->getName()));
1685       rewriterImpl.resetState(curState);
1686       return failure();
1687     }
1688   }
1689 
1690   LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1691   return success();
1692 }
1693 
1694 LogicalResult
1695 OperationLegalizer::legalizeWithPattern(Operation *op,
1696                                         ConversionPatternRewriter &rewriter) {
1697   auto &rewriterImpl = rewriter.getImpl();
1698 
1699   // Functor that returns if the given pattern may be applied.
1700   auto canApply = [&](const Pattern &pattern) {
1701     return canApplyPattern(op, pattern, rewriter);
1702   };
1703 
1704   // Functor that cleans up the rewriter state after a pattern failed to match.
1705   RewriterState curState = rewriterImpl.getCurrentState();
1706   auto onFailure = [&](const Pattern &pattern) {
1707     LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
1708     rewriterImpl.resetState(curState);
1709     appliedPatterns.erase(&pattern);
1710   };
1711 
1712   // Functor that performs additional legalization when a pattern is
1713   // successfully applied.
1714   auto onSuccess = [&](const Pattern &pattern) {
1715     auto result = legalizePatternResult(op, pattern, rewriter, curState);
1716     appliedPatterns.erase(&pattern);
1717     if (failed(result))
1718       rewriterImpl.resetState(curState);
1719     return result;
1720   };
1721 
1722   // Try to match and rewrite a pattern on this operation.
1723   return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1724                                     onSuccess);
1725 }
1726 
1727 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1728                                          ConversionPatternRewriter &rewriter) {
1729   LLVM_DEBUG({
1730     auto &os = rewriter.getImpl().logger;
1731     os.getOStream() << "\n";
1732     os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1733     llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs());
1734     os.getOStream() << ")' {\n";
1735     os.indent();
1736   });
1737 
1738   // Ensure that we don't cycle by not allowing the same pattern to be
1739   // applied twice in the same recursion stack if it is not known to be safe.
1740   if (!pattern.hasBoundedRewriteRecursion() &&
1741       !appliedPatterns.insert(&pattern).second) {
1742     LLVM_DEBUG(
1743         logFailure(rewriter.getImpl().logger, "pattern was already applied"));
1744     return false;
1745   }
1746   return true;
1747 }
1748 
1749 LogicalResult
1750 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
1751                                           ConversionPatternRewriter &rewriter,
1752                                           RewriterState &curState) {
1753   auto &impl = rewriter.getImpl();
1754 
1755 #ifndef NDEBUG
1756   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
1757 #endif
1758 
1759   // Check that the root was either replaced or updated in place.
1760   auto replacedRoot = [&] {
1761     return llvm::any_of(
1762         llvm::drop_begin(impl.replacements, curState.numReplacements),
1763         [op](auto &it) { return it.first == op; });
1764   };
1765   auto updatedRootInPlace = [&] {
1766     return llvm::any_of(
1767         llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
1768         [op](auto &state) { return state.getOperation() == op; });
1769   };
1770   (void)replacedRoot;
1771   (void)updatedRootInPlace;
1772   assert((replacedRoot() || updatedRootInPlace()) &&
1773          "expected pattern to replace the root operation");
1774 
1775   // Legalize each of the actions registered during application.
1776   RewriterState newState = impl.getCurrentState();
1777   if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
1778                                          newState)) ||
1779       failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
1780       failed(legalizePatternCreatedOperations(rewriter, impl, curState,
1781                                               newState))) {
1782     return failure();
1783   }
1784 
1785   LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
1786   return success();
1787 }
1788 
1789 LogicalResult OperationLegalizer::legalizePatternBlockActions(
1790     Operation *op, ConversionPatternRewriter &rewriter,
1791     ConversionPatternRewriterImpl &impl, RewriterState &state,
1792     RewriterState &newState) {
1793   SmallPtrSet<Operation *, 16> operationsToIgnore;
1794 
1795   // If the pattern moved or created any blocks, make sure the types of block
1796   // arguments get legalized.
1797   for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
1798        ++i) {
1799     auto &action = impl.blockActions[i];
1800     if (action.kind == BlockActionKind::TypeConversion ||
1801         action.kind == BlockActionKind::Erase)
1802       continue;
1803     // Only check blocks outside of the current operation.
1804     Operation *parentOp = action.block->getParentOp();
1805     if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
1806       continue;
1807 
1808     // If the region of the block has a type converter, try to convert the block
1809     // directly.
1810     if (auto *converter =
1811             impl.argConverter.getConverter(action.block->getParent())) {
1812       if (failed(impl.convertBlockSignature(action.block, *converter))) {
1813         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
1814                                            "block"));
1815         return failure();
1816       }
1817       continue;
1818     }
1819 
1820     // Otherwise, check that this operation isn't one generated by this pattern.
1821     // This is because we will attempt to legalize the parent operation, and
1822     // blocks in regions created by this pattern will already be legalized later
1823     // on. If we haven't built the set yet, build it now.
1824     if (operationsToIgnore.empty()) {
1825       auto createdOps = ArrayRef<Operation *>(impl.createdOps)
1826                             .drop_front(state.numCreatedOps);
1827       operationsToIgnore.insert(createdOps.begin(), createdOps.end());
1828     }
1829 
1830     // If this operation should be considered for re-legalization, try it.
1831     if (operationsToIgnore.insert(parentOp).second &&
1832         failed(legalize(parentOp, rewriter))) {
1833       LLVM_DEBUG(logFailure(
1834           impl.logger, "operation '{0}'({1}) became illegal after block action",
1835           parentOp->getName(), parentOp));
1836       return failure();
1837     }
1838   }
1839   return success();
1840 }
1841 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
1842     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1843     RewriterState &state, RewriterState &newState) {
1844   for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
1845     Operation *op = impl.createdOps[i];
1846     if (failed(legalize(op, rewriter))) {
1847       LLVM_DEBUG(logFailure(impl.logger,
1848                             "generated operation '{0}'({1}) was illegal",
1849                             op->getName(), op));
1850       return failure();
1851     }
1852   }
1853   return success();
1854 }
1855 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
1856     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1857     RewriterState &state, RewriterState &newState) {
1858   for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
1859     Operation *op = impl.rootUpdates[i].getOperation();
1860     if (failed(legalize(op, rewriter))) {
1861       LLVM_DEBUG(logFailure(impl.logger,
1862                             "operation updated in-place '{0}' was illegal",
1863                             op->getName()));
1864       return failure();
1865     }
1866   }
1867   return success();
1868 }
1869 
1870 //===----------------------------------------------------------------------===//
1871 // Cost Model
1872 
1873 void OperationLegalizer::buildLegalizationGraph(
1874     LegalizationPatterns &anyOpLegalizerPatterns,
1875     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1876   // A mapping between an operation and a set of operations that can be used to
1877   // generate it.
1878   DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
1879   // A mapping between an operation and any currently invalid patterns it has.
1880   DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
1881   // A worklist of patterns to consider for legality.
1882   llvm::SetVector<const Pattern *> patternWorklist;
1883 
1884   // Build the mapping from operations to the parent ops that may generate them.
1885   applicator.walkAllPatterns([&](const Pattern &pattern) {
1886     Optional<OperationName> root = pattern.getRootKind();
1887 
1888     // If the pattern has no specific root, we can't analyze the relationship
1889     // between the root op and generated operations. Given that, add all such
1890     // patterns to the legalization set.
1891     if (!root) {
1892       anyOpLegalizerPatterns.push_back(&pattern);
1893       return;
1894     }
1895 
1896     // Skip operations that are always known to be legal.
1897     if (target.getOpAction(*root) == LegalizationAction::Legal)
1898       return;
1899 
1900     // Add this pattern to the invalid set for the root op and record this root
1901     // as a parent for any generated operations.
1902     invalidPatterns[*root].insert(&pattern);
1903     for (auto op : pattern.getGeneratedOps())
1904       parentOps[op].insert(*root);
1905 
1906     // Add this pattern to the worklist.
1907     patternWorklist.insert(&pattern);
1908   });
1909 
1910   // If there are any patterns that don't have a specific root kind, we can't
1911   // make direct assumptions about what operations will never be legalized.
1912   // Note: Technically we could, but it would require an analysis that may
1913   // recurse into itself. It would be better to perform this kind of filtering
1914   // at a higher level than here anyways.
1915   if (!anyOpLegalizerPatterns.empty()) {
1916     for (const Pattern *pattern : patternWorklist)
1917       legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1918     return;
1919   }
1920 
1921   while (!patternWorklist.empty()) {
1922     auto *pattern = patternWorklist.pop_back_val();
1923 
1924     // Check to see if any of the generated operations are invalid.
1925     if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
1926           Optional<LegalizationAction> action = target.getOpAction(op);
1927           return !legalizerPatterns.count(op) &&
1928                  (!action || action == LegalizationAction::Illegal);
1929         }))
1930       continue;
1931 
1932     // Otherwise, if all of the generated operation are valid, this op is now
1933     // legal so add all of the child patterns to the worklist.
1934     legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1935     invalidPatterns[*pattern->getRootKind()].erase(pattern);
1936 
1937     // Add any invalid patterns of the parent operations to see if they have now
1938     // become legal.
1939     for (auto op : parentOps[*pattern->getRootKind()])
1940       patternWorklist.set_union(invalidPatterns[op]);
1941   }
1942 }
1943 
1944 void OperationLegalizer::computeLegalizationGraphBenefit(
1945     LegalizationPatterns &anyOpLegalizerPatterns,
1946     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1947   // The smallest pattern depth, when legalizing an operation.
1948   DenseMap<OperationName, unsigned> minOpPatternDepth;
1949 
1950   // For each operation that is transitively legal, compute a cost for it.
1951   for (auto &opIt : legalizerPatterns)
1952     if (!minOpPatternDepth.count(opIt.first))
1953       computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
1954                                  legalizerPatterns);
1955 
1956   // Apply the cost model to the patterns that can match any operation. Those
1957   // with a specific operation type are already resolved when computing the op
1958   // legalization depth.
1959   if (!anyOpLegalizerPatterns.empty())
1960     applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
1961                              legalizerPatterns);
1962 
1963   // Apply a cost model to the pattern applicator. We order patterns first by
1964   // depth then benefit. `legalizerPatterns` contains per-op patterns by
1965   // decreasing benefit.
1966   applicator.applyCostModel([&](const Pattern &pattern) {
1967     ArrayRef<const Pattern *> orderedPatternList;
1968     if (Optional<OperationName> rootName = pattern.getRootKind())
1969       orderedPatternList = legalizerPatterns[*rootName];
1970     else
1971       orderedPatternList = anyOpLegalizerPatterns;
1972 
1973     // If the pattern is not found, then it was removed and cannot be matched.
1974     auto it = llvm::find(orderedPatternList, &pattern);
1975     if (it == orderedPatternList.end())
1976       return PatternBenefit::impossibleToMatch();
1977 
1978     // Patterns found earlier in the list have higher benefit.
1979     return PatternBenefit(std::distance(it, orderedPatternList.end()));
1980   });
1981 }
1982 
1983 unsigned OperationLegalizer::computeOpLegalizationDepth(
1984     OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1985     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1986   // Check for existing depth.
1987   auto depthIt = minOpPatternDepth.find(op);
1988   if (depthIt != minOpPatternDepth.end())
1989     return depthIt->second;
1990 
1991   // If a mapping for this operation does not exist, then this operation
1992   // is always legal. Return 0 as the depth for a directly legal operation.
1993   auto opPatternsIt = legalizerPatterns.find(op);
1994   if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
1995     return 0u;
1996 
1997   // Record this initial depth in case we encounter this op again when
1998   // recursively computing the depth.
1999   minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2000 
2001   // Apply the cost model to the operation patterns, and update the minimum
2002   // depth.
2003   unsigned minDepth = applyCostModelToPatterns(
2004       opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2005   minOpPatternDepth[op] = minDepth;
2006   return minDepth;
2007 }
2008 
2009 unsigned OperationLegalizer::applyCostModelToPatterns(
2010     LegalizationPatterns &patterns,
2011     DenseMap<OperationName, unsigned> &minOpPatternDepth,
2012     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2013   unsigned minDepth = std::numeric_limits<unsigned>::max();
2014 
2015   // Compute the depth for each pattern within the set.
2016   SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2017   patternsByDepth.reserve(patterns.size());
2018   for (const Pattern *pattern : patterns) {
2019     unsigned depth = 0;
2020     for (auto generatedOp : pattern->getGeneratedOps()) {
2021       unsigned generatedOpDepth = computeOpLegalizationDepth(
2022           generatedOp, minOpPatternDepth, legalizerPatterns);
2023       depth = std::max(depth, generatedOpDepth + 1);
2024     }
2025     patternsByDepth.emplace_back(pattern, depth);
2026 
2027     // Update the minimum depth of the pattern list.
2028     minDepth = std::min(minDepth, depth);
2029   }
2030 
2031   // If the operation only has one legalization pattern, there is no need to
2032   // sort them.
2033   if (patternsByDepth.size() == 1)
2034     return minDepth;
2035 
2036   // Sort the patterns by those likely to be the most beneficial.
2037   llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
2038                        [](const std::pair<const Pattern *, unsigned> *lhs,
2039                           const std::pair<const Pattern *, unsigned> *rhs) {
2040                          // First sort by the smaller pattern legalization
2041                          // depth.
2042                          if (lhs->second != rhs->second)
2043                            return llvm::array_pod_sort_comparator<unsigned>(
2044                                &lhs->second, &rhs->second);
2045 
2046                          // Then sort by the larger pattern benefit.
2047                          auto lhsBenefit = lhs->first->getBenefit();
2048                          auto rhsBenefit = rhs->first->getBenefit();
2049                          return llvm::array_pod_sort_comparator<PatternBenefit>(
2050                              &rhsBenefit, &lhsBenefit);
2051                        });
2052 
2053   // Update the legalization pattern to use the new sorted list.
2054   patterns.clear();
2055   for (auto &patternIt : patternsByDepth)
2056     patterns.push_back(patternIt.first);
2057   return minDepth;
2058 }
2059 
2060 //===----------------------------------------------------------------------===//
2061 // OperationConverter
2062 //===----------------------------------------------------------------------===//
2063 namespace {
2064 enum OpConversionMode {
2065   // In this mode, the conversion will ignore failed conversions to allow
2066   // illegal operations to co-exist in the IR.
2067   Partial,
2068 
2069   // In this mode, all operations must be legal for the given target for the
2070   // conversion to succeed.
2071   Full,
2072 
2073   // In this mode, operations are analyzed for legality. No actual rewrites are
2074   // applied to the operations on success.
2075   Analysis,
2076 };
2077 
2078 // This class converts operations to a given conversion target via a set of
2079 // rewrite patterns. The conversion behaves differently depending on the
2080 // conversion mode.
2081 struct OperationConverter {
2082   explicit OperationConverter(ConversionTarget &target,
2083                               const FrozenRewritePatternList &patterns,
2084                               OpConversionMode mode,
2085                               DenseSet<Operation *> *trackedOps = nullptr)
2086       : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2087 
2088   /// Converts the given operations to the conversion target.
2089   LogicalResult convertOperations(ArrayRef<Operation *> ops);
2090 
2091 private:
2092   /// Converts an operation with the given rewriter.
2093   LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2094 
2095   /// This method is called after the conversion process to legalize any
2096   /// remaining artifacts and complete the conversion.
2097   LogicalResult finalize(ConversionPatternRewriter &rewriter);
2098 
2099   /// Legalize the types of converted block arguments.
2100   LogicalResult
2101   legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2102                                  ConversionPatternRewriterImpl &rewriterImpl);
2103 
2104   /// Legalize an operation result that was marked as "erased".
2105   LogicalResult
2106   legalizeErasedResult(Operation *op, OpResult result,
2107                        ConversionPatternRewriterImpl &rewriterImpl);
2108 
2109   /// Legalize an operation result that was replaced with a value of a different
2110   /// type.
2111   LogicalResult
2112   legalizeChangedResultType(Operation *op, OpResult result, Value newValue,
2113                             TypeConverter *replConverter,
2114                             ConversionPatternRewriter &rewriter,
2115                             ConversionPatternRewriterImpl &rewriterImpl);
2116 
2117   /// The legalizer to use when converting operations.
2118   OperationLegalizer opLegalizer;
2119 
2120   /// The conversion mode to use when legalizing operations.
2121   OpConversionMode mode;
2122 
2123   /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2124   /// this is populated with ops found to be legalizable to the target.
2125   /// When mode == OpConversionMode::Partial, this is populated with ops found
2126   /// *not* to be legalizable to the target.
2127   DenseSet<Operation *> *trackedOps;
2128 };
2129 } // end anonymous namespace
2130 
2131 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2132                                           Operation *op) {
2133   // Legalize the given operation.
2134   if (failed(opLegalizer.legalize(op, rewriter))) {
2135     // Handle the case of a failed conversion for each of the different modes.
2136     // Full conversions expect all operations to be converted.
2137     if (mode == OpConversionMode::Full)
2138       return op->emitError()
2139              << "failed to legalize operation '" << op->getName() << "'";
2140     // Partial conversions allow conversions to fail iff the operation was not
2141     // explicitly marked as illegal. If the user provided a nonlegalizableOps
2142     // set, non-legalizable ops are included.
2143     if (mode == OpConversionMode::Partial) {
2144       if (opLegalizer.isIllegal(op))
2145         return op->emitError()
2146                << "failed to legalize operation '" << op->getName()
2147                << "' that was explicitly marked illegal";
2148       if (trackedOps)
2149         trackedOps->insert(op);
2150     }
2151   } else if (mode == OpConversionMode::Analysis) {
2152     // Analysis conversions don't fail if any operations fail to legalize,
2153     // they are only interested in the operations that were successfully
2154     // legalized.
2155     trackedOps->insert(op);
2156   }
2157   return success();
2158 }
2159 
2160 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2161   if (ops.empty())
2162     return success();
2163   ConversionTarget &target = opLegalizer.getTarget();
2164 
2165   // Compute the set of operations and blocks to convert.
2166   std::vector<Operation *> toConvert;
2167   for (auto *op : ops) {
2168     toConvert.emplace_back(op);
2169     for (auto &region : op->getRegions())
2170       if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
2171                                       toConvert, &target)))
2172         return failure();
2173   }
2174 
2175   // Convert each operation and discard rewrites on failure.
2176   ConversionPatternRewriter rewriter(ops.front()->getContext());
2177   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2178   for (auto *op : toConvert)
2179     if (failed(convert(rewriter, op)))
2180       return rewriterImpl.discardRewrites(), failure();
2181 
2182   // Now that all of the operations have been converted, finalize the conversion
2183   // process to ensure any lingering conversion artifacts are cleaned up and
2184   // legalized.
2185   if (failed(finalize(rewriter)))
2186     return rewriterImpl.discardRewrites(), failure();
2187 
2188   // After a successful conversion, apply rewrites if this is not an analysis
2189   // conversion.
2190   if (mode == OpConversionMode::Analysis)
2191     rewriterImpl.discardRewrites();
2192   else
2193     rewriterImpl.applyRewrites();
2194   return success();
2195 }
2196 
2197 LogicalResult
2198 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2199   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2200 
2201   // Legalize converted block arguments.
2202   if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2203     return failure();
2204 
2205   // Process requested operation replacements.
2206   for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
2207        i != e; ++i) {
2208     unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
2209     auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
2210     for (OpResult result : repl.first->getResults()) {
2211       Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2212 
2213       // If the operation result was replaced with null, all of the uses of this
2214       // value should be replaced.
2215       if (!newValue) {
2216         if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2217           return failure();
2218         continue;
2219       }
2220 
2221       // Otherwise, check to see if the type of the result changed.
2222       if (result.getType() == newValue.getType())
2223         continue;
2224 
2225       // Legalize this result.
2226       rewriter.setInsertionPoint(repl.first);
2227       if (failed(legalizeChangedResultType(repl.first, result, newValue,
2228                                            repl.second.converter, rewriter,
2229                                            rewriterImpl)))
2230         return failure();
2231 
2232       // Update the end iterator for this loop in the case it was updated
2233       // when legalizing generated conversion operations.
2234       e = rewriterImpl.operationsWithChangedResults.size();
2235     }
2236   }
2237   return success();
2238 }
2239 
2240 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2241     ConversionPatternRewriter &rewriter,
2242     ConversionPatternRewriterImpl &rewriterImpl) {
2243   // Functor used to check if all users of a value will be dead after
2244   // conversion.
2245   auto findLiveUser = [&](Value val) {
2246     auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2247       return rewriterImpl.isOpIgnored(user);
2248     });
2249     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2250   };
2251 
2252   // Materialize any necessary conversions for converted block arguments that
2253   // are still live.
2254   size_t numCreatedOps = rewriterImpl.createdOps.size();
2255   if (failed(rewriterImpl.argConverter.materializeLiveConversions(
2256           rewriterImpl.mapping, rewriter, findLiveUser)))
2257     return failure();
2258 
2259   // Legalize any newly created operations during argument materialization.
2260   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2261     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2262       return rewriterImpl.createdOps[i]->emitError()
2263              << "failed to legalize conversion operation generated for block "
2264                 "argument that remained live after conversion";
2265     }
2266   }
2267   return success();
2268 }
2269 
2270 LogicalResult OperationConverter::legalizeErasedResult(
2271     Operation *op, OpResult result,
2272     ConversionPatternRewriterImpl &rewriterImpl) {
2273   // If the operation result was replaced with null, all of the uses of this
2274   // value should be replaced.
2275   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2276     return rewriterImpl.isOpIgnored(user);
2277   });
2278   if (liveUserIt != result.user_end()) {
2279     InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2280                               << op->getName() << "' marked as erased";
2281     diag.attachNote(liveUserIt->getLoc())
2282         << "found live user of result #" << result.getResultNumber() << ": "
2283         << *liveUserIt;
2284     return failure();
2285   }
2286   return success();
2287 }
2288 
2289 LogicalResult OperationConverter::legalizeChangedResultType(
2290     Operation *op, OpResult result, Value newValue,
2291     TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2292     ConversionPatternRewriterImpl &rewriterImpl) {
2293   // Walk the users of this value to see if there are any live users that
2294   // weren't replaced during conversion.
2295   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2296     return rewriterImpl.isOpIgnored(user);
2297   });
2298   if (liveUserIt == result.user_end())
2299     return success();
2300 
2301   // If the replacement has a type converter, attempt to materialize a
2302   // conversion back to the original type.
2303   if (!replConverter) {
2304     // TODO: We should emit an error here, similarly to the case where the
2305     // result is replaced with null. Unfortunately a lot of existing
2306     // patterns rely on this behavior, so until those patterns are updated
2307     // we keep the legacy behavior here of just forwarding the new value.
2308     return success();
2309   }
2310 
2311   // Track the number of created operations so that new ones can be legalized.
2312   size_t numCreatedOps = rewriterImpl.createdOps.size();
2313 
2314   // Materialize a conversion for this live result value.
2315   Type resultType = result.getType();
2316   Value convertedValue = replConverter->materializeSourceConversion(
2317       rewriter, op->getLoc(), resultType, newValue);
2318   if (!convertedValue) {
2319     InFlightDiagnostic diag = op->emitError()
2320                               << "failed to materialize conversion for result #"
2321                               << result.getResultNumber() << " of operation '"
2322                               << op->getName()
2323                               << "' that remained live after conversion";
2324     diag.attachNote(liveUserIt->getLoc())
2325         << "see existing live user here: " << *liveUserIt;
2326     return failure();
2327   }
2328 
2329   // Legalize all of the newly created conversion operations.
2330   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2331     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2332       return op->emitError("failed to legalize conversion operation generated ")
2333              << "for result #" << result.getResultNumber() << " of operation '"
2334              << op->getName() << "' that remained live after conversion";
2335     }
2336   }
2337 
2338   rewriterImpl.mapping.map(result, convertedValue);
2339   return success();
2340 }
2341 
2342 //===----------------------------------------------------------------------===//
2343 // Type Conversion
2344 //===----------------------------------------------------------------------===//
2345 
2346 /// Remap an input of the original signature with a new set of types. The
2347 /// new types are appended to the new signature conversion.
2348 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
2349                                                    ArrayRef<Type> types) {
2350   assert(!types.empty() && "expected valid types");
2351   remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2352   addInputs(types);
2353 }
2354 
2355 /// Append new input types to the signature conversion, this should only be
2356 /// used if the new types are not intended to remap an existing input.
2357 void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
2358   assert(!types.empty() &&
2359          "1->0 type remappings don't need to be added explicitly");
2360   argTypes.append(types.begin(), types.end());
2361 }
2362 
2363 /// Remap an input of the original signature with a range of types in the
2364 /// new signature.
2365 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2366                                                     unsigned newInputNo,
2367                                                     unsigned newInputCount) {
2368   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2369   assert(newInputCount != 0 && "expected valid input count");
2370   remappedInputs[origInputNo] =
2371       InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2372 }
2373 
2374 /// Remap an input of the original signature to another `replacementValue`
2375 /// value. This would make the signature converter drop this argument.
2376 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2377                                                     Value replacementValue) {
2378   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2379   remappedInputs[origInputNo] =
2380       InputMapping{origInputNo, /*size=*/0, replacementValue};
2381 }
2382 
2383 /// This hooks allows for converting a type.
2384 LogicalResult TypeConverter::convertType(Type t,
2385                                          SmallVectorImpl<Type> &results) {
2386   auto existingIt = cachedDirectConversions.find(t);
2387   if (existingIt != cachedDirectConversions.end()) {
2388     if (existingIt->second)
2389       results.push_back(existingIt->second);
2390     return success(existingIt->second != nullptr);
2391   }
2392   auto multiIt = cachedMultiConversions.find(t);
2393   if (multiIt != cachedMultiConversions.end()) {
2394     results.append(multiIt->second.begin(), multiIt->second.end());
2395     return success();
2396   }
2397 
2398   // Walk the added converters in reverse order to apply the most recently
2399   // registered first.
2400   size_t currentCount = results.size();
2401   for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2402     if (Optional<LogicalResult> result = converter(t, results)) {
2403       if (!succeeded(*result)) {
2404         cachedDirectConversions.try_emplace(t, nullptr);
2405         return failure();
2406       }
2407       auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2408       if (newTypes.size() == 1)
2409         cachedDirectConversions.try_emplace(t, newTypes.front());
2410       else
2411         cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2412       return success();
2413     }
2414   }
2415   return failure();
2416 }
2417 
2418 /// This hook simplifies defining 1-1 type conversions. This function returns
2419 /// the type to convert to on success, and a null type on failure.
2420 Type TypeConverter::convertType(Type t) {
2421   // Use the multi-type result version to convert the type.
2422   SmallVector<Type, 1> results;
2423   if (failed(convertType(t, results)))
2424     return nullptr;
2425 
2426   // Check to ensure that only one type was produced.
2427   return results.size() == 1 ? results.front() : nullptr;
2428 }
2429 
2430 /// Convert the given set of types, filling 'results' as necessary. This
2431 /// returns failure if the conversion of any of the types fails, success
2432 /// otherwise.
2433 LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
2434                                           SmallVectorImpl<Type> &results) {
2435   for (auto type : types)
2436     if (failed(convertType(type, results)))
2437       return failure();
2438   return success();
2439 }
2440 
2441 /// Return true if the given type is legal for this type converter, i.e. the
2442 /// type converts to itself.
2443 bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
2444 /// Return true if the given operation has legal operand and result types.
2445 bool TypeConverter::isLegal(Operation *op) {
2446   return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2447 }
2448 
2449 /// Return true if the types of block arguments within the region are legal.
2450 bool TypeConverter::isLegal(Region *region) {
2451   return llvm::all_of(*region, [this](Block &block) {
2452     return isLegal(block.getArgumentTypes());
2453   });
2454 }
2455 
2456 /// Return true if the inputs and outputs of the given function type are
2457 /// legal.
2458 bool TypeConverter::isSignatureLegal(FunctionType ty) {
2459   return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2460 }
2461 
2462 /// This hook allows for converting a specific argument of a signature.
2463 LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
2464                                                  SignatureConversion &result) {
2465   // Try to convert the given input type.
2466   SmallVector<Type, 1> convertedTypes;
2467   if (failed(convertType(type, convertedTypes)))
2468     return failure();
2469 
2470   // If this argument is being dropped, there is nothing left to do.
2471   if (convertedTypes.empty())
2472     return success();
2473 
2474   // Otherwise, add the new inputs.
2475   result.addInputs(inputNo, convertedTypes);
2476   return success();
2477 }
2478 LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
2479                                                   SignatureConversion &result,
2480                                                   unsigned origInputOffset) {
2481   for (unsigned i = 0, e = types.size(); i != e; ++i)
2482     if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2483       return failure();
2484   return success();
2485 }
2486 
2487 Value TypeConverter::materializeConversion(
2488     MutableArrayRef<MaterializationCallbackFn> materializations,
2489     OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
2490   for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
2491     if (Optional<Value> result = fn(builder, resultType, inputs, loc))
2492       return result.getValue();
2493   return nullptr;
2494 }
2495 
2496 /// This function converts the type signature of the given block, by invoking
2497 /// 'convertSignatureArg' for each argument. This function should return a valid
2498 /// conversion for the signature on success, None otherwise.
2499 auto TypeConverter::convertBlockSignature(Block *block)
2500     -> Optional<SignatureConversion> {
2501   SignatureConversion conversion(block->getNumArguments());
2502   if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2503     return llvm::None;
2504   return conversion;
2505 }
2506 
2507 /// Create a default conversion pattern that rewrites the type signature of a
2508 /// FuncOp.
2509 namespace {
2510 struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
2511   FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
2512       : OpConversionPattern(converter, ctx) {}
2513 
2514   /// Hook for derived classes to implement combined matching and rewriting.
2515   LogicalResult
2516   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
2517                   ConversionPatternRewriter &rewriter) const override {
2518     FunctionType type = funcOp.getType();
2519 
2520     // Convert the original function types.
2521     TypeConverter::SignatureConversion result(type.getNumInputs());
2522     SmallVector<Type, 1> newResults;
2523     if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
2524         failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
2525         failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter,
2526                                            &result)))
2527       return failure();
2528 
2529     // Update the function signature in-place.
2530     rewriter.updateRootInPlace(funcOp, [&] {
2531       funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults,
2532                                        funcOp.getContext()));
2533     });
2534     return success();
2535   }
2536 };
2537 } // end anonymous namespace
2538 
2539 void mlir::populateFuncOpTypeConversionPattern(
2540     OwningRewritePatternList &patterns, MLIRContext *ctx,
2541     TypeConverter &converter) {
2542   patterns.insert<FuncOpSignatureConversion>(ctx, converter);
2543 }
2544 
2545 //===----------------------------------------------------------------------===//
2546 // ConversionTarget
2547 //===----------------------------------------------------------------------===//
2548 
2549 /// Register a legality action for the given operation.
2550 void ConversionTarget::setOpAction(OperationName op,
2551                                    LegalizationAction action) {
2552   legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None};
2553 }
2554 
2555 /// Register a legality action for the given dialects.
2556 void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
2557                                         LegalizationAction action) {
2558   for (StringRef dialect : dialectNames)
2559     legalDialects[dialect] = action;
2560 }
2561 
2562 /// Get the legality action for the given operation.
2563 auto ConversionTarget::getOpAction(OperationName op) const
2564     -> Optional<LegalizationAction> {
2565   Optional<LegalizationInfo> info = getOpInfo(op);
2566   return info ? info->action : Optional<LegalizationAction>();
2567 }
2568 
2569 /// If the given operation instance is legal on this target, a structure
2570 /// containing legality information is returned. If the operation is not legal,
2571 /// None is returned.
2572 auto ConversionTarget::isLegal(Operation *op) const
2573     -> Optional<LegalOpDetails> {
2574   Optional<LegalizationInfo> info = getOpInfo(op->getName());
2575   if (!info)
2576     return llvm::None;
2577 
2578   // Returns true if this operation instance is known to be legal.
2579   auto isOpLegal = [&] {
2580     // Handle dynamic legality either with the provided legality function, or
2581     // the default hook on the derived instance.
2582     if (info->action == LegalizationAction::Dynamic)
2583       return info->legalityFn ? (*info->legalityFn)(op)
2584                               : isDynamicallyLegal(op);
2585 
2586     // Otherwise, the operation is only legal if it was marked 'Legal'.
2587     return info->action == LegalizationAction::Legal;
2588   };
2589   if (!isOpLegal())
2590     return llvm::None;
2591 
2592   // This operation is legal, compute any additional legality information.
2593   LegalOpDetails legalityDetails;
2594   if (info->isRecursivelyLegal) {
2595     auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
2596     if (legalityFnIt != opRecursiveLegalityFns.end())
2597       legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
2598     else
2599       legalityDetails.isRecursivelyLegal = true;
2600   }
2601   return legalityDetails;
2602 }
2603 
2604 /// Set the dynamic legality callback for the given operation.
2605 void ConversionTarget::setLegalityCallback(
2606     OperationName name, const DynamicLegalityCallbackFn &callback) {
2607   assert(callback && "expected valid legality callback");
2608   auto infoIt = legalOperations.find(name);
2609   assert(infoIt != legalOperations.end() &&
2610          infoIt->second.action == LegalizationAction::Dynamic &&
2611          "expected operation to already be marked as dynamically legal");
2612   infoIt->second.legalityFn = callback;
2613 }
2614 
2615 /// Set the recursive legality callback for the given operation and mark the
2616 /// operation as recursively legal.
2617 void ConversionTarget::markOpRecursivelyLegal(
2618     OperationName name, const DynamicLegalityCallbackFn &callback) {
2619   auto infoIt = legalOperations.find(name);
2620   assert(infoIt != legalOperations.end() &&
2621          infoIt->second.action != LegalizationAction::Illegal &&
2622          "expected operation to already be marked as legal");
2623   infoIt->second.isRecursivelyLegal = true;
2624   if (callback)
2625     opRecursiveLegalityFns[name] = callback;
2626   else
2627     opRecursiveLegalityFns.erase(name);
2628 }
2629 
2630 /// Set the dynamic legality callback for the given dialects.
2631 void ConversionTarget::setLegalityCallback(
2632     ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
2633   assert(callback && "expected valid legality callback");
2634   for (StringRef dialect : dialects)
2635     dialectLegalityFns[dialect] = callback;
2636 }
2637 
2638 /// Get the legalization information for the given operation.
2639 auto ConversionTarget::getOpInfo(OperationName op) const
2640     -> Optional<LegalizationInfo> {
2641   // Check for info for this specific operation.
2642   auto it = legalOperations.find(op);
2643   if (it != legalOperations.end())
2644     return it->second;
2645   // Check for info for the parent dialect.
2646   auto dialectIt = legalDialects.find(op.getDialect());
2647   if (dialectIt != legalDialects.end()) {
2648     Optional<DynamicLegalityCallbackFn> callback;
2649     auto dialectFn = dialectLegalityFns.find(op.getDialect());
2650     if (dialectFn != dialectLegalityFns.end())
2651       callback = dialectFn->second;
2652     return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
2653                             callback};
2654   }
2655   // Otherwise, check if we mark unknown operations as dynamic.
2656   if (unknownOpsDynamicallyLegal)
2657     return LegalizationInfo{LegalizationAction::Dynamic,
2658                             /*isRecursivelyLegal=*/false, unknownLegalityFn};
2659   return llvm::None;
2660 }
2661 
2662 //===----------------------------------------------------------------------===//
2663 // Op Conversion Entry Points
2664 //===----------------------------------------------------------------------===//
2665 
2666 /// Apply a partial conversion on the given operations and all nested
2667 /// operations. This method converts as many operations to the target as
2668 /// possible, ignoring operations that failed to legalize. This method only
2669 /// returns failure if there ops explicitly marked as illegal.
2670 /// If an `unconvertedOps` set is provided, all operations that are found not
2671 /// to be legalizable to the given `target` are placed within that set. (Note
2672 /// that if there is an op explicitly marked as illegal, the conversion
2673 /// terminates and the `unconvertedOps` set will not necessarily be complete.)
2674 LogicalResult
2675 mlir::applyPartialConversion(ArrayRef<Operation *> ops,
2676                              ConversionTarget &target,
2677                              const FrozenRewritePatternList &patterns,
2678                              DenseSet<Operation *> *unconvertedOps) {
2679   OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
2680                                  unconvertedOps);
2681   return opConverter.convertOperations(ops);
2682 }
2683 LogicalResult
2684 mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
2685                              const FrozenRewritePatternList &patterns,
2686                              DenseSet<Operation *> *unconvertedOps) {
2687   return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
2688                                 unconvertedOps);
2689 }
2690 
2691 /// Apply a complete conversion on the given operations, and all nested
2692 /// operations. This method will return failure if the conversion of any
2693 /// operation fails.
2694 LogicalResult
2695 mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
2696                           const FrozenRewritePatternList &patterns) {
2697   OperationConverter opConverter(target, patterns, OpConversionMode::Full);
2698   return opConverter.convertOperations(ops);
2699 }
2700 LogicalResult
2701 mlir::applyFullConversion(Operation *op, ConversionTarget &target,
2702                           const FrozenRewritePatternList &patterns) {
2703   return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
2704 }
2705 
2706 /// Apply an analysis conversion on the given operations, and all nested
2707 /// operations. This method analyzes which operations would be successfully
2708 /// converted to the target if a conversion was applied. All operations that
2709 /// were found to be legalizable to the given 'target' are placed within the
2710 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
2711 /// operations on success and only pre-existing operations are added to the set.
2712 LogicalResult
2713 mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
2714                               ConversionTarget &target,
2715                               const FrozenRewritePatternList &patterns,
2716                               DenseSet<Operation *> &convertedOps) {
2717   OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
2718                                  &convertedOps);
2719   return opConverter.convertOperations(ops);
2720 }
2721 LogicalResult
2722 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
2723                               const FrozenRewritePatternList &patterns,
2724                               DenseSet<Operation *> &convertedOps) {
2725   return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
2726                                  convertedOps);
2727 }
2728