1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/BlockAndValueMapping.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinDialect.h"
14 #include "mlir/Rewrite/PatternApplicator.h"
15 #include "mlir/Transforms/Utils.h"
16 #include "llvm/ADT/SetVector.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/Support/SaveAndRestore.h"
21 #include "llvm/Support/ScopedPrinter.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 #define DEBUG_TYPE "dialect-conversion"
27 
28 /// Recursively collect all of the operations to convert from within 'region'.
29 /// If 'target' is nonnull, operations that are recursively legal have their
30 /// regions pre-filtered to avoid considering them for legalization.
31 static LogicalResult
32 computeConversionSet(iterator_range<Region::iterator> region,
33                      Location regionLoc, std::vector<Operation *> &toConvert,
34                      ConversionTarget *target = nullptr) {
35   if (llvm::empty(region))
36     return success();
37 
38   // Traverse starting from the entry block.
39   SmallVector<Block *, 16> worklist(1, &*region.begin());
40   DenseSet<Block *> visitedBlocks;
41   visitedBlocks.insert(worklist.front());
42   while (!worklist.empty()) {
43     Block *block = worklist.pop_back_val();
44 
45     // Compute the conversion set of each of the nested operations.
46     for (Operation &op : *block) {
47       toConvert.emplace_back(&op);
48 
49       // Don't check this operation's children for conversion if the operation
50       // is recursively legal.
51       auto legalityInfo = target ? target->isLegal(&op)
52                                  : Optional<ConversionTarget::LegalOpDetails>();
53       if (legalityInfo && legalityInfo->isRecursivelyLegal)
54         continue;
55       for (auto &region : op.getRegions()) {
56         if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
57                                         toConvert, target)))
58           return failure();
59       }
60     }
61 
62     // Recurse to children that haven't been visited.
63     for (Block *succ : block->getSuccessors())
64       if (visitedBlocks.insert(succ).second)
65         worklist.push_back(succ);
66   }
67 
68   // Check that all blocks in the region were visited.
69   if (llvm::any_of(llvm::drop_begin(region, 1),
70                    [&](Block &block) { return !visitedBlocks.count(&block); }))
71     return emitError(regionLoc, "unreachable blocks were not converted");
72   return success();
73 }
74 
75 /// A utility function to log a successful result for the given reason.
76 template <typename... Args>
77 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
78   LLVM_DEBUG({
79     os.unindent();
80     os.startLine() << "} -> SUCCESS";
81     if (!fmt.empty())
82       os.getOStream() << " : "
83                       << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
84     os.getOStream() << "\n";
85   });
86 }
87 
88 /// A utility function to log a failure result for the given reason.
89 template <typename... Args>
90 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
91   LLVM_DEBUG({
92     os.unindent();
93     os.startLine() << "} -> FAILURE : "
94                    << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
95                    << "\n";
96   });
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // ConversionValueMapping
101 //===----------------------------------------------------------------------===//
102 
103 namespace {
104 /// This class wraps a BlockAndValueMapping to provide recursive lookup
105 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
106 struct ConversionValueMapping {
107   /// Lookup a mapped value within the map. If a mapping for the provided value
108   /// does not exist then return the provided value. If `desiredType` is
109   /// non-null, returns the most recently mapped value with that type. If an
110   /// operand of that type does not exist, defaults to normal behavior.
111   Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
112 
113   /// Lookup a mapped value within the map, or return null if a mapping does not
114   /// exist. If a mapping exists, this follows the same behavior of
115   /// `lookupOrDefault`.
116   Value lookupOrNull(Value from) const;
117 
118   /// Map a value to the one provided.
119   void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); }
120 
121   /// Drop the last mapping for the given value.
122   void erase(Value value) { mapping.erase(value); }
123 
124 private:
125   /// Current value mappings.
126   BlockAndValueMapping mapping;
127 };
128 } // end anonymous namespace
129 
130 Value ConversionValueMapping::lookupOrDefault(Value from,
131                                               Type desiredType) const {
132   // If there was no desired type, simply find the leaf value.
133   if (!desiredType) {
134     // If this value had a valid mapping, unmap that value as well in the case
135     // that it was also replaced.
136     while (auto mappedValue = mapping.lookupOrNull(from))
137       from = mappedValue;
138     return from;
139   }
140 
141   // Otherwise, try to find the deepest value that has the desired type.
142   Value desiredValue;
143   do {
144     if (from.getType() == desiredType)
145       desiredValue = from;
146 
147     Value mappedValue = mapping.lookupOrNull(from);
148     if (!mappedValue)
149       break;
150     from = mappedValue;
151   } while (true);
152 
153   // If the desired value was found use it, otherwise default to the leaf value.
154   return desiredValue ? desiredValue : from;
155 }
156 
157 Value ConversionValueMapping::lookupOrNull(Value from) const {
158   Value result = lookupOrDefault(from);
159   return result == from ? nullptr : result;
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // ArgConverter
164 //===----------------------------------------------------------------------===//
165 namespace {
166 /// This class provides a simple interface for converting the types of block
167 /// arguments. This is done by creating a new block that contains the new legal
168 /// types and extracting the block that contains the old illegal types to allow
169 /// for undoing pending rewrites in the case of failure.
170 struct ArgConverter {
171   ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {}
172 
173   /// This structure contains the information pertaining to an argument that has
174   /// been converted.
175   struct ConvertedArgInfo {
176     ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
177                      Value castValue = nullptr)
178         : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
179 
180     /// The start index of in the new argument list that contains arguments that
181     /// replace the original.
182     unsigned newArgIdx;
183 
184     /// The number of arguments that replaced the original argument.
185     unsigned newArgSize;
186 
187     /// The cast value that was created to cast from the new arguments to the
188     /// old. This only used if 'newArgSize' > 1.
189     Value castValue;
190   };
191 
192   /// This structure contains information pertaining to a block that has had its
193   /// signature converted.
194   struct ConvertedBlockInfo {
195     ConvertedBlockInfo(Block *origBlock, TypeConverter &converter)
196         : origBlock(origBlock), converter(&converter) {}
197 
198     /// The original block that was requested to have its signature converted.
199     Block *origBlock;
200 
201     /// The conversion information for each of the arguments. The information is
202     /// None if the argument was dropped during conversion.
203     SmallVector<Optional<ConvertedArgInfo>, 1> argInfo;
204 
205     /// The type converter used to convert the arguments.
206     TypeConverter *converter;
207   };
208 
209   /// Return if the signature of the given block has already been converted.
210   bool hasBeenConverted(Block *block) const {
211     return conversionInfo.count(block) || convertedBlocks.count(block);
212   }
213 
214   /// Set the type converter to use for the given region.
215   void setConverter(Region *region, TypeConverter *typeConverter) {
216     assert(typeConverter && "expected valid type converter");
217     regionToConverter[region] = typeConverter;
218   }
219 
220   /// Return the type converter to use for the given region, or null if there
221   /// isn't one.
222   TypeConverter *getConverter(Region *region) {
223     return regionToConverter.lookup(region);
224   }
225 
226   //===--------------------------------------------------------------------===//
227   // Rewrite Application
228   //===--------------------------------------------------------------------===//
229 
230   /// Erase any rewrites registered for the blocks within the given operation
231   /// which is about to be removed. This merely drops the rewrites without
232   /// undoing them.
233   void notifyOpRemoved(Operation *op);
234 
235   /// Cleanup and undo any generated conversions for the arguments of block.
236   /// This method replaces the new block with the original, reverting the IR to
237   /// its original state.
238   void discardRewrites(Block *block);
239 
240   /// Fully replace uses of the old arguments with the new.
241   void applyRewrites(ConversionValueMapping &mapping);
242 
243   /// Materialize any necessary conversions for converted arguments that have
244   /// live users, using the provided `findLiveUser` to search for a user that
245   /// survives the conversion process.
246   LogicalResult
247   materializeLiveConversions(ConversionValueMapping &mapping,
248                              OpBuilder &builder,
249                              function_ref<Operation *(Value)> findLiveUser);
250 
251   //===--------------------------------------------------------------------===//
252   // Conversion
253   //===--------------------------------------------------------------------===//
254 
255   /// Attempt to convert the signature of the given block, if successful a new
256   /// block is returned containing the new arguments. Returns `block` if it did
257   /// not require conversion.
258   FailureOr<Block *> convertSignature(Block *block, TypeConverter &converter,
259                                       ConversionValueMapping &mapping);
260 
261   /// Apply the given signature conversion on the given block. The new block
262   /// containing the updated signature is returned. If no conversions were
263   /// necessary, e.g. if the block has no arguments, `block` is returned.
264   /// `converter` is used to generate any necessary cast operations that
265   /// translate between the origin argument types and those specified in the
266   /// signature conversion.
267   Block *applySignatureConversion(
268       Block *block, TypeConverter &converter,
269       TypeConverter::SignatureConversion &signatureConversion,
270       ConversionValueMapping &mapping);
271 
272   /// Insert a new conversion into the cache.
273   void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
274 
275   /// A collection of blocks that have had their arguments converted. This is a
276   /// map from the new replacement block, back to the original block.
277   llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
278 
279   /// The set of original blocks that were converted.
280   DenseSet<Block *> convertedBlocks;
281 
282   /// A mapping from valid regions, to those containing the original blocks of a
283   /// conversion.
284   DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
285 
286   /// A mapping of regions to type converters that should be used when
287   /// converting the arguments of blocks within that region.
288   DenseMap<Region *, TypeConverter *> regionToConverter;
289 
290   /// The pattern rewriter to use when materializing conversions.
291   PatternRewriter &rewriter;
292 };
293 } // end anonymous namespace
294 
295 //===----------------------------------------------------------------------===//
296 // Rewrite Application
297 
298 void ArgConverter::notifyOpRemoved(Operation *op) {
299   if (conversionInfo.empty())
300     return;
301 
302   for (Region &region : op->getRegions()) {
303     for (Block &block : region) {
304       // Drop any rewrites from within.
305       for (Operation &nestedOp : block)
306         if (nestedOp.getNumRegions())
307           notifyOpRemoved(&nestedOp);
308 
309       // Check if this block was converted.
310       auto it = conversionInfo.find(&block);
311       if (it == conversionInfo.end())
312         continue;
313 
314       // Drop all uses of the original arguments and delete the original block.
315       Block *origBlock = it->second.origBlock;
316       for (BlockArgument arg : origBlock->getArguments())
317         arg.dropAllUses();
318       conversionInfo.erase(it);
319     }
320   }
321 }
322 
323 void ArgConverter::discardRewrites(Block *block) {
324   auto it = conversionInfo.find(block);
325   if (it == conversionInfo.end())
326     return;
327   Block *origBlock = it->second.origBlock;
328 
329   // Drop all uses of the new block arguments and replace uses of the new block.
330   for (int i = block->getNumArguments() - 1; i >= 0; --i)
331     block->getArgument(i).dropAllUses();
332   block->replaceAllUsesWith(origBlock);
333 
334   // Move the operations back the original block and the delete the new block.
335   origBlock->getOperations().splice(origBlock->end(), block->getOperations());
336   origBlock->moveBefore(block);
337   block->erase();
338 
339   convertedBlocks.erase(origBlock);
340   conversionInfo.erase(it);
341 }
342 
343 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
344   for (auto &info : conversionInfo) {
345     ConvertedBlockInfo &blockInfo = info.second;
346     Block *origBlock = blockInfo.origBlock;
347 
348     // Process the remapping for each of the original arguments.
349     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
350       Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
351       BlockArgument origArg = origBlock->getArgument(i);
352 
353       // Handle the case of a 1->0 value mapping.
354       if (!argInfo) {
355         if (Value newArg = mapping.lookupOrNull(origArg))
356           origArg.replaceAllUsesWith(newArg);
357         continue;
358       }
359 
360       // Otherwise this is a 1->1+ value mapping.
361       Value castValue = argInfo->castValue;
362       assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
363 
364       // If the argument is still used, replace it with the generated cast.
365       if (!origArg.use_empty())
366         origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue));
367 
368       // If all users of the cast were removed, we can drop it. Otherwise, keep
369       // the operation alive and let the user handle any remaining usages.
370       if (castValue.use_empty() && castValue.getDefiningOp())
371         castValue.getDefiningOp()->erase();
372     }
373   }
374 }
375 
376 LogicalResult ArgConverter::materializeLiveConversions(
377     ConversionValueMapping &mapping, OpBuilder &builder,
378     function_ref<Operation *(Value)> findLiveUser) {
379   for (auto &info : conversionInfo) {
380     Block *newBlock = info.first;
381     ConvertedBlockInfo &blockInfo = info.second;
382     Block *origBlock = blockInfo.origBlock;
383 
384     // Process the remapping for each of the original arguments.
385     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
386       // FIXME: We should run the below checks even if the type conversion was
387       // 1->N, but a lot of existing lowering rely on the block argument being
388       // blindly replaced. Those usages should be updated, and this if should be
389       // removed.
390       if (blockInfo.argInfo[i])
391         continue;
392 
393       // If the type of this argument changed and the argument is still live, we
394       // need to materialize a conversion.
395       BlockArgument origArg = origBlock->getArgument(i);
396       auto argReplacementValue = mapping.lookupOrDefault(origArg);
397       bool isDroppedArg = argReplacementValue == origArg;
398       if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg)
399         continue;
400       Operation *liveUser = findLiveUser(origArg);
401       if (!liveUser)
402         continue;
403 
404       if (OpResult result = argReplacementValue.dyn_cast<OpResult>())
405         rewriter.setInsertionPointAfter(result.getOwner());
406       else
407         rewriter.setInsertionPointToStart(newBlock);
408       Value newArg = blockInfo.converter->materializeSourceConversion(
409           rewriter, origArg.getLoc(), origArg.getType(),
410           isDroppedArg ? ValueRange() : ValueRange(argReplacementValue));
411       if (!newArg) {
412         InFlightDiagnostic diag =
413             emitError(origArg.getLoc())
414             << "failed to materialize conversion for block argument #" << i
415             << " that remained live after conversion, type was "
416             << origArg.getType();
417         if (!isDroppedArg)
418           diag << ", with target type " << argReplacementValue.getType();
419         diag.attachNote(liveUser->getLoc())
420             << "see existing live user here: " << *liveUser;
421         return failure();
422       }
423       mapping.map(origArg, newArg);
424     }
425   }
426   return success();
427 }
428 
429 //===----------------------------------------------------------------------===//
430 // Conversion
431 
432 FailureOr<Block *>
433 ArgConverter::convertSignature(Block *block, TypeConverter &converter,
434                                ConversionValueMapping &mapping) {
435   // Check if the block was already converted. If the block is detached,
436   // conservatively assume it is going to be deleted.
437   if (hasBeenConverted(block) || !block->getParent())
438     return block;
439 
440   // Try to convert the signature for the block with the provided converter.
441   if (auto conversion = converter.convertBlockSignature(block))
442     return applySignatureConversion(block, converter, *conversion, mapping);
443   return failure();
444 }
445 
446 Block *ArgConverter::applySignatureConversion(
447     Block *block, TypeConverter &converter,
448     TypeConverter::SignatureConversion &signatureConversion,
449     ConversionValueMapping &mapping) {
450   // If no arguments are being changed or added, there is nothing to do.
451   unsigned origArgCount = block->getNumArguments();
452   auto convertedTypes = signatureConversion.getConvertedTypes();
453   if (origArgCount == 0 && convertedTypes.empty())
454     return block;
455 
456   // Split the block at the beginning to get a new block to use for the updated
457   // signature.
458   Block *newBlock = block->splitBlock(block->begin());
459   block->replaceAllUsesWith(newBlock);
460 
461   SmallVector<Value, 4> newArgRange(newBlock->addArguments(convertedTypes));
462   ArrayRef<Value> newArgs(newArgRange);
463 
464   // Remap each of the original arguments as determined by the signature
465   // conversion.
466   ConvertedBlockInfo info(block, converter);
467   info.argInfo.resize(origArgCount);
468 
469   OpBuilder::InsertionGuard guard(rewriter);
470   rewriter.setInsertionPointToStart(newBlock);
471   for (unsigned i = 0; i != origArgCount; ++i) {
472     auto inputMap = signatureConversion.getInputMapping(i);
473     if (!inputMap)
474       continue;
475     BlockArgument origArg = block->getArgument(i);
476 
477     // If inputMap->replacementValue is not nullptr, then the argument is
478     // dropped and a replacement value is provided to be the remappedValue.
479     if (inputMap->replacementValue) {
480       assert(inputMap->size == 0 &&
481              "invalid to provide a replacement value when the argument isn't "
482              "dropped");
483       mapping.map(origArg, inputMap->replacementValue);
484       continue;
485     }
486 
487     // Otherwise, this is a 1->1+ mapping. Call into the provided type converter
488     // to pack the new values. For 1->1 mappings, if there is no materialization
489     // provided, use the argument directly instead.
490     auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
491     Value newArg = converter.materializeArgumentConversion(
492         rewriter, origArg.getLoc(), origArg.getType(), replArgs);
493     if (!newArg) {
494       assert(replArgs.size() == 1 &&
495              "couldn't materialize the result of 1->N conversion");
496       newArg = replArgs.front();
497     }
498     mapping.map(origArg, newArg);
499     info.argInfo[i] =
500         ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
501   }
502 
503   // Remove the original block from the region and return the new one.
504   insertConversion(newBlock, std::move(info));
505   return newBlock;
506 }
507 
508 void ArgConverter::insertConversion(Block *newBlock,
509                                     ConvertedBlockInfo &&info) {
510   // Get a region to insert the old block.
511   Region *region = newBlock->getParent();
512   std::unique_ptr<Region> &mappedRegion = regionMapping[region];
513   if (!mappedRegion)
514     mappedRegion = std::make_unique<Region>(region->getParentOp());
515 
516   // Move the original block to the mapped region and emplace the conversion.
517   mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
518                                    info.origBlock->getIterator());
519   convertedBlocks.insert(info.origBlock);
520   conversionInfo.insert({newBlock, std::move(info)});
521 }
522 
523 //===----------------------------------------------------------------------===//
524 // Rewriter and Translation State
525 //===----------------------------------------------------------------------===//
526 namespace {
527 /// This class contains a snapshot of the current conversion rewriter state.
528 /// This is useful when saving and undoing a set of rewrites.
529 struct RewriterState {
530   RewriterState(unsigned numCreatedOps, unsigned numReplacements,
531                 unsigned numArgReplacements, unsigned numBlockActions,
532                 unsigned numIgnoredOperations, unsigned numRootUpdates)
533       : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
534         numArgReplacements(numArgReplacements),
535         numBlockActions(numBlockActions),
536         numIgnoredOperations(numIgnoredOperations),
537         numRootUpdates(numRootUpdates) {}
538 
539   /// The current number of created operations.
540   unsigned numCreatedOps;
541 
542   /// The current number of replacements queued.
543   unsigned numReplacements;
544 
545   /// The current number of argument replacements queued.
546   unsigned numArgReplacements;
547 
548   /// The current number of block actions performed.
549   unsigned numBlockActions;
550 
551   /// The current number of ignored operations.
552   unsigned numIgnoredOperations;
553 
554   /// The current number of operations that were updated in place.
555   unsigned numRootUpdates;
556 };
557 
558 /// The state of an operation that was updated by a pattern in-place. This
559 /// contains all of the necessary information to reconstruct an operation that
560 /// was updated in place.
561 class OperationTransactionState {
562 public:
563   OperationTransactionState() = default;
564   OperationTransactionState(Operation *op)
565       : op(op), loc(op->getLoc()), attrs(op->getMutableAttrDict()),
566         operands(op->operand_begin(), op->operand_end()),
567         successors(op->successor_begin(), op->successor_end()) {}
568 
569   /// Discard the transaction state and reset the state of the original
570   /// operation.
571   void resetOperation() const {
572     op->setLoc(loc);
573     op->setAttrs(attrs);
574     op->setOperands(operands);
575     for (auto it : llvm::enumerate(successors))
576       op->setSuccessor(it.value(), it.index());
577   }
578 
579   /// Return the original operation of this state.
580   Operation *getOperation() const { return op; }
581 
582 private:
583   Operation *op;
584   LocationAttr loc;
585   MutableDictionaryAttr attrs;
586   SmallVector<Value, 8> operands;
587   SmallVector<Block *, 2> successors;
588 };
589 
590 /// This class represents one requested operation replacement via 'replaceOp' or
591 /// 'eraseOp`.
592 struct OpReplacement {
593   OpReplacement() = default;
594   OpReplacement(TypeConverter *converter) : converter(converter) {}
595 
596   /// An optional type converter that can be used to materialize conversions
597   /// between the new and old values if necessary.
598   TypeConverter *converter = nullptr;
599 };
600 
601 /// The kind of the block action performed during the rewrite.  Actions can be
602 /// undone if the conversion fails.
603 enum class BlockActionKind {
604   Create,
605   Erase,
606   Merge,
607   Move,
608   Split,
609   TypeConversion
610 };
611 
612 /// Original position of the given block in its parent region. During undo
613 /// actions, the block needs to be placed after `insertAfterBlock`.
614 struct BlockPosition {
615   Region *region;
616   Block *insertAfterBlock;
617 };
618 
619 /// Information needed to undo the merge actions.
620 /// - the source block, and
621 /// - the Operation that was the last operation in the dest block before the
622 ///   merge (could be null if the dest block was empty).
623 struct MergeInfo {
624   Block *sourceBlock;
625   Operation *destBlockLastInst;
626 };
627 
628 /// The storage class for an undoable block action (one of BlockActionKind),
629 /// contains the information necessary to undo this action.
630 struct BlockAction {
631   static BlockAction getCreate(Block *block) {
632     return {BlockActionKind::Create, block, {}};
633   }
634   static BlockAction getErase(Block *block, BlockPosition originalPosition) {
635     return {BlockActionKind::Erase, block, {originalPosition}};
636   }
637   static BlockAction getMerge(Block *block, Block *sourceBlock) {
638     BlockAction action{BlockActionKind::Merge, block, {}};
639     action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
640     return action;
641   }
642   static BlockAction getMove(Block *block, BlockPosition originalPosition) {
643     return {BlockActionKind::Move, block, {originalPosition}};
644   }
645   static BlockAction getSplit(Block *block, Block *originalBlock) {
646     BlockAction action{BlockActionKind::Split, block, {}};
647     action.originalBlock = originalBlock;
648     return action;
649   }
650   static BlockAction getTypeConversion(Block *block) {
651     return BlockAction{BlockActionKind::TypeConversion, block, {}};
652   }
653 
654   // The action kind.
655   BlockActionKind kind;
656 
657   // A pointer to the block that was created by the action.
658   Block *block;
659 
660   union {
661     // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
662     // contains a pointer to the region that originally contained the block as
663     // well as the position of the block in that region.
664     BlockPosition originalPosition;
665     // In use if kind == BlockActionKind::Split and contains a pointer to the
666     // block that was split into two parts.
667     Block *originalBlock;
668     // In use if kind == BlockActionKind::Merge, and contains the information
669     // needed to undo the merge.
670     MergeInfo mergeInfo;
671   };
672 };
673 } // end anonymous namespace
674 
675 //===----------------------------------------------------------------------===//
676 // ConversionPatternRewriterImpl
677 //===----------------------------------------------------------------------===//
678 namespace mlir {
679 namespace detail {
680 struct ConversionPatternRewriterImpl {
681   ConversionPatternRewriterImpl(PatternRewriter &rewriter)
682       : argConverter(rewriter) {}
683 
684   /// Cleanup and destroy any generated rewrite operations. This method is
685   /// invoked when the conversion process fails.
686   void discardRewrites();
687 
688   /// Apply all requested operation rewrites. This method is invoked when the
689   /// conversion process succeeds.
690   void applyRewrites();
691 
692   //===--------------------------------------------------------------------===//
693   // State Management
694   //===--------------------------------------------------------------------===//
695 
696   /// Return the current state of the rewriter.
697   RewriterState getCurrentState();
698 
699   /// Reset the state of the rewriter to a previously saved point.
700   void resetState(RewriterState state);
701 
702   /// Erase any blocks that were unlinked from their regions and stored in block
703   /// actions.
704   void eraseDanglingBlocks();
705 
706   /// Undo the block actions (motions, splits) one by one in reverse order until
707   /// "numActionsToKeep" actions remains.
708   void undoBlockActions(unsigned numActionsToKeep = 0);
709 
710   /// Remap the given operands to those with potentially different types. The
711   /// provided type converter is used to ensure that the remapped types are
712   /// legal. Returns success if the operands could be remapped, failure
713   /// otherwise.
714   LogicalResult remapValues(Location loc, PatternRewriter &rewriter,
715                             TypeConverter *converter,
716                             Operation::operand_range operands,
717                             SmallVectorImpl<Value> &remapped);
718 
719   /// Returns true if the given operation is ignored, and does not need to be
720   /// converted.
721   bool isOpIgnored(Operation *op) const;
722 
723   /// Recursively marks the nested operations under 'op' as ignored. This
724   /// removes them from being considered for legalization.
725   void markNestedOpsIgnored(Operation *op);
726 
727   //===--------------------------------------------------------------------===//
728   // Type Conversion
729   //===--------------------------------------------------------------------===//
730 
731   /// Convert the signature of the given block.
732   FailureOr<Block *> convertBlockSignature(
733       Block *block, TypeConverter &converter,
734       TypeConverter::SignatureConversion *conversion = nullptr);
735 
736   /// Apply a signature conversion on the given region.
737   Block *
738   applySignatureConversion(Region *region,
739                            TypeConverter::SignatureConversion &conversion);
740 
741   /// Convert the types of block arguments within the given region.
742   FailureOr<Block *>
743   convertRegionTypes(Region *region, TypeConverter &converter,
744                      TypeConverter::SignatureConversion *entryConversion);
745 
746   //===--------------------------------------------------------------------===//
747   // Rewriter Notification Hooks
748   //===--------------------------------------------------------------------===//
749 
750   /// PatternRewriter hook for replacing the results of an operation.
751   void notifyOpReplaced(Operation *op, ValueRange newValues);
752 
753   /// Notifies that a block is about to be erased.
754   void notifyBlockIsBeingErased(Block *block);
755 
756   /// Notifies that a block was created.
757   void notifyCreatedBlock(Block *block);
758 
759   /// Notifies that a block was split.
760   void notifySplitBlock(Block *block, Block *continuation);
761 
762   /// Notifies that `block` is being merged with `srcBlock`.
763   void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
764 
765   /// Notifies that the blocks of a region are about to be moved.
766   void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
767                                         Region::iterator before);
768 
769   /// Notifies that the blocks of a region were cloned into another.
770   void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
771                                    Location origRegionLoc);
772 
773   /// Notifies that a pattern match failed for the given reason.
774   LogicalResult
775   notifyMatchFailure(Location loc,
776                      function_ref<void(Diagnostic &)> reasonCallback);
777 
778   //===--------------------------------------------------------------------===//
779   // State
780   //===--------------------------------------------------------------------===//
781 
782   // Mapping between replaced values that differ in type. This happens when
783   // replacing a value with one of a different type.
784   ConversionValueMapping mapping;
785 
786   /// Utility used to convert block arguments.
787   ArgConverter argConverter;
788 
789   /// Ordered vector of all of the newly created operations during conversion.
790   std::vector<Operation *> createdOps;
791 
792   /// Ordered map of requested operation replacements.
793   llvm::MapVector<Operation *, OpReplacement> replacements;
794 
795   /// Ordered vector of any requested block argument replacements.
796   SmallVector<BlockArgument, 4> argReplacements;
797 
798   /// Ordered list of block operations (creations, splits, motions).
799   SmallVector<BlockAction, 4> blockActions;
800 
801   /// A set of operations that should no longer be considered for legalization,
802   /// but were not directly replace/erased/etc. by a pattern. These are
803   /// generally child operations of other operations who were
804   /// replaced/erased/etc. This is not meant to be an exhaustive list of all
805   /// operations, but the minimal set that can be used to detect if a given
806   /// operation should be `ignored`. For example, we may add the operations that
807   /// define non-empty regions to the set, but not any of the others. This
808   /// simplifies the amount of memory needed as we can query if the parent
809   /// operation was ignored.
810   llvm::SetVector<Operation *> ignoredOps;
811 
812   /// A transaction state for each of operations that were updated in-place.
813   SmallVector<OperationTransactionState, 4> rootUpdates;
814 
815   /// A vector of indices into `replacements` of operations that were replaced
816   /// with values with different result types than the original operation, e.g.
817   /// 1->N conversion of some kind.
818   SmallVector<unsigned, 4> operationsWithChangedResults;
819 
820   /// A default type converter, used when block conversions do not have one
821   /// explicitly provided.
822   TypeConverter defaultTypeConverter;
823 
824   /// The current conversion pattern that is being rewritten, or nullptr if
825   /// called from outside of a conversion pattern rewrite.
826   const ConversionPattern *currentConversionPattern = nullptr;
827 
828 #ifndef NDEBUG
829   /// A set of operations that have pending updates. This tracking isn't
830   /// strictly necessary, and is thus only active during debug builds for extra
831   /// verification.
832   SmallPtrSet<Operation *, 1> pendingRootUpdates;
833 
834   /// A logger used to emit diagnostics during the conversion process.
835   llvm::ScopedPrinter logger{llvm::dbgs()};
836 #endif
837 };
838 } // end namespace detail
839 } // end namespace mlir
840 
841 /// Detach any operations nested in the given operation from their parent
842 /// blocks, and erase the given operation. This can be used when the nested
843 /// operations are scheduled for erasure themselves, so deleting the regions of
844 /// the given operation together with their content would result in double-free.
845 /// This happens, for example, when rolling back op creation in the reverse
846 /// order and if the nested ops were created before the parent op. This function
847 /// does not need to collect nested ops recursively because it is expected to
848 /// also be called for each nested op when it is about to be deleted.
849 static void detachNestedAndErase(Operation *op) {
850   for (Region &region : op->getRegions()) {
851     for (Block &block : region.getBlocks()) {
852       while (!block.getOperations().empty())
853         block.getOperations().remove(block.getOperations().begin());
854       block.dropAllDefinedValueUses();
855     }
856   }
857   op->erase();
858 }
859 
860 void ConversionPatternRewriterImpl::discardRewrites() {
861   // Reset any operations that were updated in place.
862   for (auto &state : rootUpdates)
863     state.resetOperation();
864 
865   undoBlockActions();
866 
867   // Remove any newly created ops.
868   for (auto *op : llvm::reverse(createdOps))
869     detachNestedAndErase(op);
870 }
871 
872 void ConversionPatternRewriterImpl::applyRewrites() {
873   // Apply all of the rewrites replacements requested during conversion.
874   for (auto &repl : replacements) {
875     for (OpResult result : repl.first->getResults())
876       if (Value newValue = mapping.lookupOrNull(result))
877         result.replaceAllUsesWith(newValue);
878 
879     // If this operation defines any regions, drop any pending argument
880     // rewrites.
881     if (repl.first->getNumRegions())
882       argConverter.notifyOpRemoved(repl.first);
883   }
884 
885   // Apply all of the requested argument replacements.
886   for (BlockArgument arg : argReplacements) {
887     Value repl = mapping.lookupOrDefault(arg);
888     if (repl.isa<BlockArgument>()) {
889       arg.replaceAllUsesWith(repl);
890       continue;
891     }
892 
893     // If the replacement value is an operation, we check to make sure that we
894     // don't replace uses that are within the parent operation of the
895     // replacement value.
896     Operation *replOp = repl.cast<OpResult>().getOwner();
897     Block *replBlock = replOp->getBlock();
898     arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
899       Operation *user = operand.getOwner();
900       return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
901     });
902   }
903 
904   // In a second pass, erase all of the replaced operations in reverse. This
905   // allows processing nested operations before their parent region is
906   // destroyed.
907   for (auto &repl : llvm::reverse(replacements))
908     repl.first->erase();
909 
910   argConverter.applyRewrites(mapping);
911 
912   // Now that the ops have been erased, also erase dangling blocks.
913   eraseDanglingBlocks();
914 }
915 
916 //===----------------------------------------------------------------------===//
917 // State Management
918 
919 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
920   return RewriterState(createdOps.size(), replacements.size(),
921                        argReplacements.size(), blockActions.size(),
922                        ignoredOps.size(), rootUpdates.size());
923 }
924 
925 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
926   // Reset any operations that were updated in place.
927   for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
928     rootUpdates[i].resetOperation();
929   rootUpdates.resize(state.numRootUpdates);
930 
931   // Reset any replaced arguments.
932   for (BlockArgument replacedArg :
933        llvm::drop_begin(argReplacements, state.numArgReplacements))
934     mapping.erase(replacedArg);
935   argReplacements.resize(state.numArgReplacements);
936 
937   // Undo any block actions.
938   undoBlockActions(state.numBlockActions);
939 
940   // Reset any replaced operations and undo any saved mappings.
941   for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
942     for (auto result : repl.first->getResults())
943       mapping.erase(result);
944   while (replacements.size() != state.numReplacements)
945     replacements.pop_back();
946 
947   // Pop all of the newly created operations.
948   while (createdOps.size() != state.numCreatedOps) {
949     detachNestedAndErase(createdOps.back());
950     createdOps.pop_back();
951   }
952 
953   // Pop all of the recorded ignored operations that are no longer valid.
954   while (ignoredOps.size() != state.numIgnoredOperations)
955     ignoredOps.pop_back();
956 
957   // Reset operations with changed results.
958   while (!operationsWithChangedResults.empty() &&
959          operationsWithChangedResults.back() >= state.numReplacements)
960     operationsWithChangedResults.pop_back();
961 }
962 
963 void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
964   for (auto &action : blockActions)
965     if (action.kind == BlockActionKind::Erase)
966       delete action.block;
967 }
968 
969 void ConversionPatternRewriterImpl::undoBlockActions(
970     unsigned numActionsToKeep) {
971   for (auto &action :
972        llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
973     switch (action.kind) {
974     // Delete the created block.
975     case BlockActionKind::Create: {
976       // Unlink all of the operations within this block, they will be deleted
977       // separately.
978       auto &blockOps = action.block->getOperations();
979       while (!blockOps.empty())
980         blockOps.remove(blockOps.begin());
981       action.block->dropAllDefinedValueUses();
982       action.block->erase();
983       break;
984     }
985     // Put the block (owned by action) back into its original position.
986     case BlockActionKind::Erase: {
987       auto &blockList = action.originalPosition.region->getBlocks();
988       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
989       blockList.insert((insertAfterBlock
990                             ? std::next(Region::iterator(insertAfterBlock))
991                             : blockList.end()),
992                        action.block);
993       break;
994     }
995     // Split the block at the position which was originally the end of the
996     // destination block (owned by action), and put the instructions back into
997     // the block used before the merge.
998     case BlockActionKind::Merge: {
999       Block *sourceBlock = action.mergeInfo.sourceBlock;
1000       Block::iterator splitPoint =
1001           (action.mergeInfo.destBlockLastInst
1002                ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
1003                : action.block->begin());
1004       sourceBlock->getOperations().splice(sourceBlock->begin(),
1005                                           action.block->getOperations(),
1006                                           splitPoint, action.block->end());
1007       break;
1008     }
1009     // Move the block back to its original position.
1010     case BlockActionKind::Move: {
1011       Region *originalRegion = action.originalPosition.region;
1012       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1013       originalRegion->getBlocks().splice(
1014           (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
1015                             : originalRegion->end()),
1016           action.block->getParent()->getBlocks(), action.block);
1017       break;
1018     }
1019     // Merge back the block that was split out.
1020     case BlockActionKind::Split: {
1021       action.originalBlock->getOperations().splice(
1022           action.originalBlock->end(), action.block->getOperations());
1023       action.block->dropAllDefinedValueUses();
1024       action.block->erase();
1025       break;
1026     }
1027     // Undo the type conversion.
1028     case BlockActionKind::TypeConversion: {
1029       argConverter.discardRewrites(action.block);
1030       break;
1031     }
1032     }
1033   }
1034   blockActions.resize(numActionsToKeep);
1035 }
1036 
1037 LogicalResult ConversionPatternRewriterImpl::remapValues(
1038     Location loc, PatternRewriter &rewriter, TypeConverter *converter,
1039     Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
1040   remapped.reserve(llvm::size(operands));
1041 
1042   SmallVector<Type, 1> legalTypes;
1043   for (auto it : llvm::enumerate(operands)) {
1044     Value operand = it.value();
1045     Type origType = operand.getType();
1046 
1047     // If a converter was provided, get the desired legal types for this
1048     // operand.
1049     Type desiredType;
1050     if (converter) {
1051       // If there is no legal conversion, fail to match this pattern.
1052       legalTypes.clear();
1053       if (failed(converter->convertType(origType, legalTypes))) {
1054         return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1055           diag << "unable to convert type for operand #" << it.index()
1056                << ", type was " << origType;
1057         });
1058       }
1059       // TODO: There currently isn't any mechanism to do 1->N type conversion
1060       // via the PatternRewriter replacement API, so for now we just ignore it.
1061       if (legalTypes.size() == 1)
1062         desiredType = legalTypes.front();
1063     } else {
1064       // TODO: What we should do here is just set `desiredType` to `origType`
1065       // and then handle the necessary type conversions after the conversion
1066       // process has finished. Unfortunately a lot of patterns currently rely on
1067       // receiving the new operands even if the types change, so we keep the
1068       // original behavior here for now until all of the patterns relying on
1069       // this get updated.
1070     }
1071     Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1072 
1073     // Handle the case where the conversion was 1->1 and the new operand type
1074     // isn't legal.
1075     Type newOperandType = newOperand.getType();
1076     if (converter && desiredType && newOperandType != desiredType) {
1077       // Attempt to materialize a conversion for this new value.
1078       newOperand = converter->materializeTargetConversion(
1079           rewriter, loc, desiredType, newOperand);
1080       if (!newOperand) {
1081         return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1082           diag << "unable to materialize a conversion for "
1083                   "operand #"
1084                << it.index() << ", from " << newOperandType << " to "
1085                << desiredType;
1086         });
1087       }
1088     }
1089     remapped.push_back(newOperand);
1090   }
1091   return success();
1092 }
1093 
1094 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1095   // Check to see if this operation was replaced or its parent ignored.
1096   return replacements.count(op) || ignoredOps.count(op->getParentOp());
1097 }
1098 
1099 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
1100   // Walk this operation and collect nested operations that define non-empty
1101   // regions. We mark such operations as 'ignored' so that we know we don't have
1102   // to convert them, or their nested ops.
1103   if (op->getNumRegions() == 0)
1104     return;
1105   op->walk([&](Operation *op) {
1106     if (llvm::any_of(op->getRegions(),
1107                      [](Region &region) { return !region.empty(); }))
1108       ignoredOps.insert(op);
1109   });
1110 }
1111 
1112 //===----------------------------------------------------------------------===//
1113 // Type Conversion
1114 
1115 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1116     Block *block, TypeConverter &converter,
1117     TypeConverter::SignatureConversion *conversion) {
1118   FailureOr<Block *> result =
1119       conversion ? argConverter.applySignatureConversion(block, converter,
1120                                                          *conversion, mapping)
1121                  : argConverter.convertSignature(block, converter, mapping);
1122   if (Block *newBlock = result.getValue()) {
1123     if (newBlock != block)
1124       blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1125   }
1126   return result;
1127 }
1128 
1129 Block *ConversionPatternRewriterImpl::applySignatureConversion(
1130     Region *region, TypeConverter::SignatureConversion &conversion) {
1131   if (!region->empty()) {
1132     return *convertBlockSignature(&region->front(), defaultTypeConverter,
1133                                   &conversion);
1134   }
1135   return nullptr;
1136 }
1137 
1138 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1139     Region *region, TypeConverter &converter,
1140     TypeConverter::SignatureConversion *entryConversion) {
1141   argConverter.setConverter(region, &converter);
1142   if (region->empty())
1143     return nullptr;
1144 
1145   // Convert the arguments of each block within the region.
1146   FailureOr<Block *> newEntry =
1147       convertBlockSignature(&region->front(), converter, entryConversion);
1148   for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
1149     if (failed(convertBlockSignature(&block, converter)))
1150       return failure();
1151   return newEntry;
1152 }
1153 
1154 //===----------------------------------------------------------------------===//
1155 // Rewriter Notification Hooks
1156 
1157 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1158                                                      ValueRange newValues) {
1159   assert(newValues.size() == op->getNumResults());
1160   assert(!replacements.count(op) && "operation was already replaced");
1161 
1162   // Track if any of the results changed, e.g. erased and replaced with null.
1163   bool resultChanged = false;
1164 
1165   // Create mappings for each of the new result values.
1166   Value newValue, result;
1167   for (auto it : llvm::zip(newValues, op->getResults())) {
1168     std::tie(newValue, result) = it;
1169     if (!newValue) {
1170       resultChanged = true;
1171       continue;
1172     }
1173     // Remap, and check for any result type changes.
1174     mapping.map(result, newValue);
1175     resultChanged |= (newValue.getType() != result.getType());
1176   }
1177   if (resultChanged)
1178     operationsWithChangedResults.push_back(replacements.size());
1179 
1180   // Record the requested operation replacement.
1181   TypeConverter *converter = nullptr;
1182   if (currentConversionPattern)
1183     converter = currentConversionPattern->getTypeConverter();
1184   replacements.insert(std::make_pair(op, OpReplacement(converter)));
1185 
1186   // Mark this operation as recursively ignored so that we don't need to
1187   // convert any nested operations.
1188   markNestedOpsIgnored(op);
1189 }
1190 
1191 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
1192   Region *region = block->getParent();
1193   Block *origPrevBlock = block->getPrevNode();
1194   blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1195 }
1196 
1197 void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
1198   blockActions.push_back(BlockAction::getCreate(block));
1199 }
1200 
1201 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
1202                                                      Block *continuation) {
1203   blockActions.push_back(BlockAction::getSplit(continuation, block));
1204 }
1205 
1206 void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block,
1207                                                             Block *srcBlock) {
1208   blockActions.push_back(BlockAction::getMerge(block, srcBlock));
1209 }
1210 
1211 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
1212     Region &region, Region &parent, Region::iterator before) {
1213   if (region.empty())
1214     return;
1215   Block *laterBlock = &region.back();
1216   for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1217     blockActions.push_back(
1218         BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
1219     laterBlock = &earlierBlock;
1220   }
1221   blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
1222 }
1223 
1224 void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
1225     iterator_range<Region::iterator> &blocks, Location origRegionLoc) {
1226   for (Block &block : blocks)
1227     blockActions.push_back(BlockAction::getCreate(&block));
1228 
1229   // Compute the conversion set for the inlined region.
1230   auto result = computeConversionSet(blocks, origRegionLoc, createdOps);
1231 
1232   // This original region has already had its conversion set computed, so there
1233   // shouldn't be any new failures.
1234   (void)result;
1235   assert(succeeded(result) && "expected region to have no unreachable blocks");
1236 }
1237 
1238 LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
1239     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1240   LLVM_DEBUG({
1241     Diagnostic diag(loc, DiagnosticSeverity::Remark);
1242     reasonCallback(diag);
1243     logger.startLine() << "** Failure : " << diag.str() << "\n";
1244   });
1245   return failure();
1246 }
1247 
1248 //===----------------------------------------------------------------------===//
1249 // ConversionPatternRewriter
1250 //===----------------------------------------------------------------------===//
1251 
1252 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
1253     : PatternRewriter(ctx),
1254       impl(new detail::ConversionPatternRewriterImpl(*this)) {}
1255 ConversionPatternRewriter::~ConversionPatternRewriter() {}
1256 
1257 /// PatternRewriter hook for replacing the results of an operation.
1258 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
1259   LLVM_DEBUG({
1260     impl->logger.startLine()
1261         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1262   });
1263   impl->notifyOpReplaced(op, newValues);
1264 }
1265 
1266 /// PatternRewriter hook for erasing a dead operation. The uses of this
1267 /// operation *must* be made dead by the end of the conversion process,
1268 /// otherwise an assert will be issued.
1269 void ConversionPatternRewriter::eraseOp(Operation *op) {
1270   LLVM_DEBUG({
1271     impl->logger.startLine()
1272         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
1273   });
1274   SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1275   impl->notifyOpReplaced(op, nullRepls);
1276 }
1277 
1278 void ConversionPatternRewriter::eraseBlock(Block *block) {
1279   impl->notifyBlockIsBeingErased(block);
1280 
1281   // Mark all ops for erasure.
1282   for (Operation &op : *block)
1283     eraseOp(&op);
1284 
1285   // Unlink the block from its parent region. The block is kept in the block
1286   // action and will be actually destroyed when rewrites are applied. This
1287   // allows us to keep the operations in the block live and undo the removal by
1288   // re-inserting the block.
1289   block->getParent()->getBlocks().remove(block);
1290 }
1291 
1292 Block *ConversionPatternRewriter::applySignatureConversion(
1293     Region *region, TypeConverter::SignatureConversion &conversion) {
1294   return impl->applySignatureConversion(region, conversion);
1295 }
1296 
1297 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1298     Region *region, TypeConverter &converter,
1299     TypeConverter::SignatureConversion *entryConversion) {
1300   return impl->convertRegionTypes(region, converter, entryConversion);
1301 }
1302 
1303 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1304                                                            Value to) {
1305   LLVM_DEBUG({
1306     Operation *parentOp = from.getOwner()->getParentOp();
1307     impl->logger.startLine() << "** Replace Argument : '" << from
1308                              << "'(in region of '" << parentOp->getName()
1309                              << "'(" << from.getOwner()->getParentOp() << ")\n";
1310   });
1311   impl->argReplacements.push_back(from);
1312   impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1313 }
1314 
1315 /// Return the converted value that replaces 'key'. Return 'key' if there is
1316 /// no such a converted value.
1317 Value ConversionPatternRewriter::getRemappedValue(Value key) {
1318   return impl->mapping.lookupOrDefault(key);
1319 }
1320 
1321 /// PatternRewriter hook for creating a new block with the given arguments.
1322 void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
1323   impl->notifyCreatedBlock(block);
1324 }
1325 
1326 /// PatternRewriter hook for splitting a block into two parts.
1327 Block *ConversionPatternRewriter::splitBlock(Block *block,
1328                                              Block::iterator before) {
1329   auto *continuation = PatternRewriter::splitBlock(block, before);
1330   impl->notifySplitBlock(block, continuation);
1331   return continuation;
1332 }
1333 
1334 /// PatternRewriter hook for merging a block into another.
1335 void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
1336                                             ValueRange argValues) {
1337   impl->notifyBlocksBeingMerged(dest, source);
1338   assert(llvm::all_of(source->getPredecessors(),
1339                       [dest](Block *succ) { return succ == dest; }) &&
1340          "expected 'source' to have no predecessors or only 'dest'");
1341   assert(argValues.size() == source->getNumArguments() &&
1342          "incorrect # of argument replacement values");
1343   for (auto it : llvm::zip(source->getArguments(), argValues))
1344     replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1345   dest->getOperations().splice(dest->end(), source->getOperations());
1346   eraseBlock(source);
1347 }
1348 
1349 /// PatternRewriter hook for moving blocks out of a region.
1350 void ConversionPatternRewriter::inlineRegionBefore(Region &region,
1351                                                    Region &parent,
1352                                                    Region::iterator before) {
1353   impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1354   PatternRewriter::inlineRegionBefore(region, parent, before);
1355 }
1356 
1357 /// PatternRewriter hook for cloning blocks of one region into another.
1358 void ConversionPatternRewriter::cloneRegionBefore(
1359     Region &region, Region &parent, Region::iterator before,
1360     BlockAndValueMapping &mapping) {
1361   if (region.empty())
1362     return;
1363   PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1364 
1365   // Collect the range of the cloned blocks.
1366   auto clonedBeginIt = mapping.lookup(&region.front())->getIterator();
1367   auto clonedBlocks = llvm::make_range(clonedBeginIt, before);
1368   impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
1369 }
1370 
1371 /// PatternRewriter hook for creating a new operation.
1372 void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
1373   LLVM_DEBUG({
1374     impl->logger.startLine()
1375         << "** Insert  : '" << op->getName() << "'(" << op << ")\n";
1376   });
1377   impl->createdOps.push_back(op);
1378 }
1379 
1380 /// PatternRewriter hook for updating the root operation in-place.
1381 void ConversionPatternRewriter::startRootUpdate(Operation *op) {
1382 #ifndef NDEBUG
1383   impl->pendingRootUpdates.insert(op);
1384 #endif
1385   impl->rootUpdates.emplace_back(op);
1386 }
1387 
1388 /// PatternRewriter hook for updating the root operation in-place.
1389 void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
1390   // There is nothing to do here, we only need to track the operation at the
1391   // start of the update.
1392 #ifndef NDEBUG
1393   assert(impl->pendingRootUpdates.erase(op) &&
1394          "operation did not have a pending in-place update");
1395 #endif
1396 }
1397 
1398 /// PatternRewriter hook for updating the root operation in-place.
1399 void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
1400 #ifndef NDEBUG
1401   assert(impl->pendingRootUpdates.erase(op) &&
1402          "operation did not have a pending in-place update");
1403 #endif
1404   // Erase the last update for this operation.
1405   auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
1406   auto &rootUpdates = impl->rootUpdates;
1407   auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1408   rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
1409 }
1410 
1411 /// PatternRewriter hook for notifying match failure reasons.
1412 LogicalResult ConversionPatternRewriter::notifyMatchFailure(
1413     Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
1414   return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
1415 }
1416 
1417 /// Return a reference to the internal implementation.
1418 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
1419   return *impl;
1420 }
1421 
1422 //===----------------------------------------------------------------------===//
1423 // ConversionPattern
1424 //===----------------------------------------------------------------------===//
1425 
1426 /// Attempt to match and rewrite the IR root at the specified operation.
1427 LogicalResult
1428 ConversionPattern::matchAndRewrite(Operation *op,
1429                                    PatternRewriter &rewriter) const {
1430   auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1431   auto &rewriterImpl = dialectRewriter.getImpl();
1432 
1433   // Track the current conversion pattern in the rewriter.
1434   assert(!rewriterImpl.currentConversionPattern &&
1435          "already inside of a pattern rewrite");
1436   llvm::SaveAndRestore<const ConversionPattern *> currentPatternGuard(
1437       rewriterImpl.currentConversionPattern, this);
1438 
1439   // Remap the operands of the operation.
1440   SmallVector<Value, 4> operands;
1441   if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter,
1442                                       getTypeConverter(), op->getOperands(),
1443                                       operands))) {
1444     return failure();
1445   }
1446   return matchAndRewrite(op, operands, dialectRewriter);
1447 }
1448 
1449 //===----------------------------------------------------------------------===//
1450 // OperationLegalizer
1451 //===----------------------------------------------------------------------===//
1452 
1453 namespace {
1454 /// A set of rewrite patterns that can be used to legalize a given operation.
1455 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1456 
1457 /// This class defines a recursive operation legalizer.
1458 class OperationLegalizer {
1459 public:
1460   using LegalizationAction = ConversionTarget::LegalizationAction;
1461 
1462   OperationLegalizer(ConversionTarget &targetInfo,
1463                      const FrozenRewritePatternList &patterns);
1464 
1465   /// Returns true if the given operation is known to be illegal on the target.
1466   bool isIllegal(Operation *op) const;
1467 
1468   /// Attempt to legalize the given operation. Returns success if the operation
1469   /// was legalized, failure otherwise.
1470   LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1471 
1472   /// Returns the conversion target in use by the legalizer.
1473   ConversionTarget &getTarget() { return target; }
1474 
1475 private:
1476   /// Attempt to legalize the given operation by folding it.
1477   LogicalResult legalizeWithFold(Operation *op,
1478                                  ConversionPatternRewriter &rewriter);
1479 
1480   /// Attempt to legalize the given operation by applying a pattern. Returns
1481   /// success if the operation was legalized, failure otherwise.
1482   LogicalResult legalizeWithPattern(Operation *op,
1483                                     ConversionPatternRewriter &rewriter);
1484 
1485   /// Return true if the given pattern may be applied to the given operation,
1486   /// false otherwise.
1487   bool canApplyPattern(Operation *op, const Pattern &pattern,
1488                        ConversionPatternRewriter &rewriter);
1489 
1490   /// Legalize the resultant IR after successfully applying the given pattern.
1491   LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1492                                       ConversionPatternRewriter &rewriter,
1493                                       RewriterState &curState);
1494 
1495   /// Legalizes the actions registered during the execution of a pattern.
1496   LogicalResult legalizePatternBlockActions(Operation *op,
1497                                             ConversionPatternRewriter &rewriter,
1498                                             ConversionPatternRewriterImpl &impl,
1499                                             RewriterState &state,
1500                                             RewriterState &newState);
1501   LogicalResult legalizePatternCreatedOperations(
1502       ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1503       RewriterState &state, RewriterState &newState);
1504   LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1505                                            ConversionPatternRewriterImpl &impl,
1506                                            RewriterState &state,
1507                                            RewriterState &newState);
1508 
1509   //===--------------------------------------------------------------------===//
1510   // Cost Model
1511   //===--------------------------------------------------------------------===//
1512 
1513   /// Build an optimistic legalization graph given the provided patterns. This
1514   /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1515   /// patterns for operations that are not directly legal, but may be
1516   /// transitively legal for the current target given the provided patterns.
1517   void buildLegalizationGraph(
1518       LegalizationPatterns &anyOpLegalizerPatterns,
1519       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1520 
1521   /// Compute the benefit of each node within the computed legalization graph.
1522   /// This orders the patterns within 'legalizerPatterns' based upon two
1523   /// criteria:
1524   ///  1) Prefer patterns that have the lowest legalization depth, i.e.
1525   ///     represent the more direct mapping to the target.
1526   ///  2) When comparing patterns with the same legalization depth, prefer the
1527   ///     pattern with the highest PatternBenefit. This allows for users to
1528   ///     prefer specific legalizations over others.
1529   void computeLegalizationGraphBenefit(
1530       LegalizationPatterns &anyOpLegalizerPatterns,
1531       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1532 
1533   /// Compute the legalization depth when legalizing an operation of the given
1534   /// type.
1535   unsigned computeOpLegalizationDepth(
1536       OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1537       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1538 
1539   /// Apply the conversion cost model to the given set of patterns, and return
1540   /// the smallest legalization depth of any of the patterns. See
1541   /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1542   unsigned applyCostModelToPatterns(
1543       LegalizationPatterns &patterns,
1544       DenseMap<OperationName, unsigned> &minOpPatternDepth,
1545       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1546 
1547   /// The current set of patterns that have been applied.
1548   SmallPtrSet<const Pattern *, 8> appliedPatterns;
1549 
1550   /// The legalization information provided by the target.
1551   ConversionTarget &target;
1552 
1553   /// The pattern applicator to use for conversions.
1554   PatternApplicator applicator;
1555 };
1556 } // namespace
1557 
1558 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
1559                                        const FrozenRewritePatternList &patterns)
1560     : target(targetInfo), applicator(patterns) {
1561   // The set of patterns that can be applied to illegal operations to transform
1562   // them into legal ones.
1563   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
1564   LegalizationPatterns anyOpLegalizerPatterns;
1565 
1566   buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1567   computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1568 }
1569 
1570 bool OperationLegalizer::isIllegal(Operation *op) const {
1571   // Check if the target explicitly marked this operation as illegal.
1572   return target.getOpAction(op->getName()) == LegalizationAction::Illegal;
1573 }
1574 
1575 LogicalResult
1576 OperationLegalizer::legalize(Operation *op,
1577                              ConversionPatternRewriter &rewriter) {
1578 #ifndef NDEBUG
1579   const char *logLineComment =
1580       "//===-------------------------------------------===//\n";
1581 
1582   auto &rewriterImpl = rewriter.getImpl();
1583 #endif
1584   LLVM_DEBUG({
1585     auto &os = rewriterImpl.logger;
1586     os.getOStream() << "\n";
1587     os.startLine() << logLineComment;
1588     os.startLine() << "Legalizing operation : '" << op->getName() << "'(" << op
1589                    << ") {\n";
1590     os.indent();
1591 
1592     // If the operation has no regions, just print it here.
1593     if (op->getNumRegions() == 0) {
1594       op->print(os.startLine(), OpPrintingFlags().printGenericOpForm());
1595       os.getOStream() << "\n\n";
1596     }
1597   });
1598 
1599   // Check if this operation is legal on the target.
1600   if (auto legalityInfo = target.isLegal(op)) {
1601     LLVM_DEBUG({
1602       logSuccess(
1603           rewriterImpl.logger, "operation marked legal by the target{0}",
1604           legalityInfo->isRecursivelyLegal
1605               ? "; NOTE: operation is recursively legal; skipping internals"
1606               : "");
1607       rewriterImpl.logger.startLine() << logLineComment;
1608     });
1609 
1610     // If this operation is recursively legal, mark its children as ignored so
1611     // that we don't consider them for legalization.
1612     if (legalityInfo->isRecursivelyLegal)
1613       rewriter.getImpl().markNestedOpsIgnored(op);
1614     return success();
1615   }
1616 
1617   // Check to see if the operation is ignored and doesn't need to be converted.
1618   if (rewriter.getImpl().isOpIgnored(op)) {
1619     LLVM_DEBUG({
1620       logSuccess(rewriterImpl.logger,
1621                  "operation marked 'ignored' during conversion");
1622       rewriterImpl.logger.startLine() << logLineComment;
1623     });
1624     return success();
1625   }
1626 
1627   // If the operation isn't legal, try to fold it in-place.
1628   // TODO: Should we always try to do this, even if the op is
1629   // already legal?
1630   if (succeeded(legalizeWithFold(op, rewriter))) {
1631     LLVM_DEBUG({
1632       logSuccess(rewriterImpl.logger, "operation was folded");
1633       rewriterImpl.logger.startLine() << logLineComment;
1634     });
1635     return success();
1636   }
1637 
1638   // Otherwise, we need to apply a legalization pattern to this operation.
1639   if (succeeded(legalizeWithPattern(op, rewriter))) {
1640     LLVM_DEBUG({
1641       logSuccess(rewriterImpl.logger, "");
1642       rewriterImpl.logger.startLine() << logLineComment;
1643     });
1644     return success();
1645   }
1646 
1647   LLVM_DEBUG({
1648     logFailure(rewriterImpl.logger, "no matched legalization pattern");
1649     rewriterImpl.logger.startLine() << logLineComment;
1650   });
1651   return failure();
1652 }
1653 
1654 LogicalResult
1655 OperationLegalizer::legalizeWithFold(Operation *op,
1656                                      ConversionPatternRewriter &rewriter) {
1657   auto &rewriterImpl = rewriter.getImpl();
1658   RewriterState curState = rewriterImpl.getCurrentState();
1659 
1660   LLVM_DEBUG({
1661     rewriterImpl.logger.startLine() << "* Fold {\n";
1662     rewriterImpl.logger.indent();
1663   });
1664 
1665   // Try to fold the operation.
1666   SmallVector<Value, 2> replacementValues;
1667   rewriter.setInsertionPoint(op);
1668   if (failed(rewriter.tryFold(op, replacementValues))) {
1669     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1670     return failure();
1671   }
1672 
1673   // Insert a replacement for 'op' with the folded replacement values.
1674   rewriter.replaceOp(op, replacementValues);
1675 
1676   // Recursively legalize any new constant operations.
1677   for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1678        i != e; ++i) {
1679     Operation *cstOp = rewriterImpl.createdOps[i];
1680     if (failed(legalize(cstOp, rewriter))) {
1681       LLVM_DEBUG(logFailure(rewriterImpl.logger,
1682                             "generated constant '{0}' was illegal",
1683                             cstOp->getName()));
1684       rewriterImpl.resetState(curState);
1685       return failure();
1686     }
1687   }
1688 
1689   LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1690   return success();
1691 }
1692 
1693 LogicalResult
1694 OperationLegalizer::legalizeWithPattern(Operation *op,
1695                                         ConversionPatternRewriter &rewriter) {
1696   auto &rewriterImpl = rewriter.getImpl();
1697 
1698   // Functor that returns if the given pattern may be applied.
1699   auto canApply = [&](const Pattern &pattern) {
1700     return canApplyPattern(op, pattern, rewriter);
1701   };
1702 
1703   // Functor that cleans up the rewriter state after a pattern failed to match.
1704   RewriterState curState = rewriterImpl.getCurrentState();
1705   auto onFailure = [&](const Pattern &pattern) {
1706     LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
1707     rewriterImpl.resetState(curState);
1708     appliedPatterns.erase(&pattern);
1709   };
1710 
1711   // Functor that performs additional legalization when a pattern is
1712   // successfully applied.
1713   auto onSuccess = [&](const Pattern &pattern) {
1714     auto result = legalizePatternResult(op, pattern, rewriter, curState);
1715     appliedPatterns.erase(&pattern);
1716     if (failed(result))
1717       rewriterImpl.resetState(curState);
1718     return result;
1719   };
1720 
1721   // Try to match and rewrite a pattern on this operation.
1722   return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1723                                     onSuccess);
1724 }
1725 
1726 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1727                                          ConversionPatternRewriter &rewriter) {
1728   LLVM_DEBUG({
1729     auto &os = rewriter.getImpl().logger;
1730     os.getOStream() << "\n";
1731     os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1732     llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs());
1733     os.getOStream() << ")' {\n";
1734     os.indent();
1735   });
1736 
1737   // Ensure that we don't cycle by not allowing the same pattern to be
1738   // applied twice in the same recursion stack if it is not known to be safe.
1739   if (!pattern.hasBoundedRewriteRecursion() &&
1740       !appliedPatterns.insert(&pattern).second) {
1741     LLVM_DEBUG(
1742         logFailure(rewriter.getImpl().logger, "pattern was already applied"));
1743     return false;
1744   }
1745   return true;
1746 }
1747 
1748 LogicalResult
1749 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
1750                                           ConversionPatternRewriter &rewriter,
1751                                           RewriterState &curState) {
1752   auto &impl = rewriter.getImpl();
1753 
1754 #ifndef NDEBUG
1755   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
1756 #endif
1757 
1758   // Check that the root was either replaced or updated in place.
1759   auto replacedRoot = [&] {
1760     return llvm::any_of(
1761         llvm::drop_begin(impl.replacements, curState.numReplacements),
1762         [op](auto &it) { return it.first == op; });
1763   };
1764   auto updatedRootInPlace = [&] {
1765     return llvm::any_of(
1766         llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
1767         [op](auto &state) { return state.getOperation() == op; });
1768   };
1769   (void)replacedRoot;
1770   (void)updatedRootInPlace;
1771   assert((replacedRoot() || updatedRootInPlace()) &&
1772          "expected pattern to replace the root operation");
1773 
1774   // Legalize each of the actions registered during application.
1775   RewriterState newState = impl.getCurrentState();
1776   if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
1777                                          newState)) ||
1778       failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
1779       failed(legalizePatternCreatedOperations(rewriter, impl, curState,
1780                                               newState))) {
1781     return failure();
1782   }
1783 
1784   LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
1785   return success();
1786 }
1787 
1788 LogicalResult OperationLegalizer::legalizePatternBlockActions(
1789     Operation *op, ConversionPatternRewriter &rewriter,
1790     ConversionPatternRewriterImpl &impl, RewriterState &state,
1791     RewriterState &newState) {
1792   SmallPtrSet<Operation *, 16> operationsToIgnore;
1793 
1794   // If the pattern moved or created any blocks, make sure the types of block
1795   // arguments get legalized.
1796   for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
1797        ++i) {
1798     auto &action = impl.blockActions[i];
1799     if (action.kind == BlockActionKind::TypeConversion ||
1800         action.kind == BlockActionKind::Erase)
1801       continue;
1802     // Only check blocks outside of the current operation.
1803     Operation *parentOp = action.block->getParentOp();
1804     if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
1805       continue;
1806 
1807     // If the region of the block has a type converter, try to convert the block
1808     // directly.
1809     if (auto *converter =
1810             impl.argConverter.getConverter(action.block->getParent())) {
1811       if (failed(impl.convertBlockSignature(action.block, *converter))) {
1812         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
1813                                            "block"));
1814         return failure();
1815       }
1816       continue;
1817     }
1818 
1819     // Otherwise, check that this operation isn't one generated by this pattern.
1820     // This is because we will attempt to legalize the parent operation, and
1821     // blocks in regions created by this pattern will already be legalized later
1822     // on. If we haven't built the set yet, build it now.
1823     if (operationsToIgnore.empty()) {
1824       auto createdOps = ArrayRef<Operation *>(impl.createdOps)
1825                             .drop_front(state.numCreatedOps);
1826       operationsToIgnore.insert(createdOps.begin(), createdOps.end());
1827     }
1828 
1829     // If this operation should be considered for re-legalization, try it.
1830     if (operationsToIgnore.insert(parentOp).second &&
1831         failed(legalize(parentOp, rewriter))) {
1832       LLVM_DEBUG(logFailure(
1833           impl.logger, "operation '{0}'({1}) became illegal after block action",
1834           parentOp->getName(), parentOp));
1835       return failure();
1836     }
1837   }
1838   return success();
1839 }
1840 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
1841     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1842     RewriterState &state, RewriterState &newState) {
1843   for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
1844     Operation *op = impl.createdOps[i];
1845     if (failed(legalize(op, rewriter))) {
1846       LLVM_DEBUG(logFailure(impl.logger,
1847                             "generated operation '{0}'({1}) was illegal",
1848                             op->getName(), op));
1849       return failure();
1850     }
1851   }
1852   return success();
1853 }
1854 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
1855     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1856     RewriterState &state, RewriterState &newState) {
1857   for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
1858     Operation *op = impl.rootUpdates[i].getOperation();
1859     if (failed(legalize(op, rewriter))) {
1860       LLVM_DEBUG(logFailure(impl.logger,
1861                             "operation updated in-place '{0}' was illegal",
1862                             op->getName()));
1863       return failure();
1864     }
1865   }
1866   return success();
1867 }
1868 
1869 //===----------------------------------------------------------------------===//
1870 // Cost Model
1871 
1872 void OperationLegalizer::buildLegalizationGraph(
1873     LegalizationPatterns &anyOpLegalizerPatterns,
1874     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1875   // A mapping between an operation and a set of operations that can be used to
1876   // generate it.
1877   DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
1878   // A mapping between an operation and any currently invalid patterns it has.
1879   DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
1880   // A worklist of patterns to consider for legality.
1881   llvm::SetVector<const Pattern *> patternWorklist;
1882 
1883   // Build the mapping from operations to the parent ops that may generate them.
1884   applicator.walkAllPatterns([&](const Pattern &pattern) {
1885     Optional<OperationName> root = pattern.getRootKind();
1886 
1887     // If the pattern has no specific root, we can't analyze the relationship
1888     // between the root op and generated operations. Given that, add all such
1889     // patterns to the legalization set.
1890     if (!root) {
1891       anyOpLegalizerPatterns.push_back(&pattern);
1892       return;
1893     }
1894 
1895     // Skip operations that are always known to be legal.
1896     if (target.getOpAction(*root) == LegalizationAction::Legal)
1897       return;
1898 
1899     // Add this pattern to the invalid set for the root op and record this root
1900     // as a parent for any generated operations.
1901     invalidPatterns[*root].insert(&pattern);
1902     for (auto op : pattern.getGeneratedOps())
1903       parentOps[op].insert(*root);
1904 
1905     // Add this pattern to the worklist.
1906     patternWorklist.insert(&pattern);
1907   });
1908 
1909   // If there are any patterns that don't have a specific root kind, we can't
1910   // make direct assumptions about what operations will never be legalized.
1911   // Note: Technically we could, but it would require an analysis that may
1912   // recurse into itself. It would be better to perform this kind of filtering
1913   // at a higher level than here anyways.
1914   if (!anyOpLegalizerPatterns.empty()) {
1915     for (const Pattern *pattern : patternWorklist)
1916       legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1917     return;
1918   }
1919 
1920   while (!patternWorklist.empty()) {
1921     auto *pattern = patternWorklist.pop_back_val();
1922 
1923     // Check to see if any of the generated operations are invalid.
1924     if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
1925           Optional<LegalizationAction> action = target.getOpAction(op);
1926           return !legalizerPatterns.count(op) &&
1927                  (!action || action == LegalizationAction::Illegal);
1928         }))
1929       continue;
1930 
1931     // Otherwise, if all of the generated operation are valid, this op is now
1932     // legal so add all of the child patterns to the worklist.
1933     legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1934     invalidPatterns[*pattern->getRootKind()].erase(pattern);
1935 
1936     // Add any invalid patterns of the parent operations to see if they have now
1937     // become legal.
1938     for (auto op : parentOps[*pattern->getRootKind()])
1939       patternWorklist.set_union(invalidPatterns[op]);
1940   }
1941 }
1942 
1943 void OperationLegalizer::computeLegalizationGraphBenefit(
1944     LegalizationPatterns &anyOpLegalizerPatterns,
1945     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1946   // The smallest pattern depth, when legalizing an operation.
1947   DenseMap<OperationName, unsigned> minOpPatternDepth;
1948 
1949   // For each operation that is transitively legal, compute a cost for it.
1950   for (auto &opIt : legalizerPatterns)
1951     if (!minOpPatternDepth.count(opIt.first))
1952       computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
1953                                  legalizerPatterns);
1954 
1955   // Apply the cost model to the patterns that can match any operation. Those
1956   // with a specific operation type are already resolved when computing the op
1957   // legalization depth.
1958   if (!anyOpLegalizerPatterns.empty())
1959     applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
1960                              legalizerPatterns);
1961 
1962   // Apply a cost model to the pattern applicator. We order patterns first by
1963   // depth then benefit. `legalizerPatterns` contains per-op patterns by
1964   // decreasing benefit.
1965   applicator.applyCostModel([&](const Pattern &pattern) {
1966     ArrayRef<const Pattern *> orderedPatternList;
1967     if (Optional<OperationName> rootName = pattern.getRootKind())
1968       orderedPatternList = legalizerPatterns[*rootName];
1969     else
1970       orderedPatternList = anyOpLegalizerPatterns;
1971 
1972     // If the pattern is not found, then it was removed and cannot be matched.
1973     auto it = llvm::find(orderedPatternList, &pattern);
1974     if (it == orderedPatternList.end())
1975       return PatternBenefit::impossibleToMatch();
1976 
1977     // Patterns found earlier in the list have higher benefit.
1978     return PatternBenefit(std::distance(it, orderedPatternList.end()));
1979   });
1980 }
1981 
1982 unsigned OperationLegalizer::computeOpLegalizationDepth(
1983     OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1984     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1985   // Check for existing depth.
1986   auto depthIt = minOpPatternDepth.find(op);
1987   if (depthIt != minOpPatternDepth.end())
1988     return depthIt->second;
1989 
1990   // If a mapping for this operation does not exist, then this operation
1991   // is always legal. Return 0 as the depth for a directly legal operation.
1992   auto opPatternsIt = legalizerPatterns.find(op);
1993   if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
1994     return 0u;
1995 
1996   // Record this initial depth in case we encounter this op again when
1997   // recursively computing the depth.
1998   minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
1999 
2000   // Apply the cost model to the operation patterns, and update the minimum
2001   // depth.
2002   unsigned minDepth = applyCostModelToPatterns(
2003       opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2004   minOpPatternDepth[op] = minDepth;
2005   return minDepth;
2006 }
2007 
2008 unsigned OperationLegalizer::applyCostModelToPatterns(
2009     LegalizationPatterns &patterns,
2010     DenseMap<OperationName, unsigned> &minOpPatternDepth,
2011     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2012   unsigned minDepth = std::numeric_limits<unsigned>::max();
2013 
2014   // Compute the depth for each pattern within the set.
2015   SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2016   patternsByDepth.reserve(patterns.size());
2017   for (const Pattern *pattern : patterns) {
2018     unsigned depth = 0;
2019     for (auto generatedOp : pattern->getGeneratedOps()) {
2020       unsigned generatedOpDepth = computeOpLegalizationDepth(
2021           generatedOp, minOpPatternDepth, legalizerPatterns);
2022       depth = std::max(depth, generatedOpDepth + 1);
2023     }
2024     patternsByDepth.emplace_back(pattern, depth);
2025 
2026     // Update the minimum depth of the pattern list.
2027     minDepth = std::min(minDepth, depth);
2028   }
2029 
2030   // If the operation only has one legalization pattern, there is no need to
2031   // sort them.
2032   if (patternsByDepth.size() == 1)
2033     return minDepth;
2034 
2035   // Sort the patterns by those likely to be the most beneficial.
2036   llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
2037                        [](const std::pair<const Pattern *, unsigned> *lhs,
2038                           const std::pair<const Pattern *, unsigned> *rhs) {
2039                          // First sort by the smaller pattern legalization
2040                          // depth.
2041                          if (lhs->second != rhs->second)
2042                            return llvm::array_pod_sort_comparator<unsigned>(
2043                                &lhs->second, &rhs->second);
2044 
2045                          // Then sort by the larger pattern benefit.
2046                          auto lhsBenefit = lhs->first->getBenefit();
2047                          auto rhsBenefit = rhs->first->getBenefit();
2048                          return llvm::array_pod_sort_comparator<PatternBenefit>(
2049                              &rhsBenefit, &lhsBenefit);
2050                        });
2051 
2052   // Update the legalization pattern to use the new sorted list.
2053   patterns.clear();
2054   for (auto &patternIt : patternsByDepth)
2055     patterns.push_back(patternIt.first);
2056   return minDepth;
2057 }
2058 
2059 //===----------------------------------------------------------------------===//
2060 // OperationConverter
2061 //===----------------------------------------------------------------------===//
2062 namespace {
2063 enum OpConversionMode {
2064   // In this mode, the conversion will ignore failed conversions to allow
2065   // illegal operations to co-exist in the IR.
2066   Partial,
2067 
2068   // In this mode, all operations must be legal for the given target for the
2069   // conversion to succeed.
2070   Full,
2071 
2072   // In this mode, operations are analyzed for legality. No actual rewrites are
2073   // applied to the operations on success.
2074   Analysis,
2075 };
2076 
2077 // This class converts operations to a given conversion target via a set of
2078 // rewrite patterns. The conversion behaves differently depending on the
2079 // conversion mode.
2080 struct OperationConverter {
2081   explicit OperationConverter(ConversionTarget &target,
2082                               const FrozenRewritePatternList &patterns,
2083                               OpConversionMode mode,
2084                               DenseSet<Operation *> *trackedOps = nullptr)
2085       : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2086 
2087   /// Converts the given operations to the conversion target.
2088   LogicalResult convertOperations(ArrayRef<Operation *> ops);
2089 
2090 private:
2091   /// Converts an operation with the given rewriter.
2092   LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2093 
2094   /// This method is called after the conversion process to legalize any
2095   /// remaining artifacts and complete the conversion.
2096   LogicalResult finalize(ConversionPatternRewriter &rewriter);
2097 
2098   /// Legalize the types of converted block arguments.
2099   LogicalResult
2100   legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2101                                  ConversionPatternRewriterImpl &rewriterImpl);
2102 
2103   /// Legalize an operation result that was marked as "erased".
2104   LogicalResult
2105   legalizeErasedResult(Operation *op, OpResult result,
2106                        ConversionPatternRewriterImpl &rewriterImpl);
2107 
2108   /// Legalize an operation result that was replaced with a value of a different
2109   /// type.
2110   LogicalResult
2111   legalizeChangedResultType(Operation *op, OpResult result, Value newValue,
2112                             TypeConverter *replConverter,
2113                             ConversionPatternRewriter &rewriter,
2114                             ConversionPatternRewriterImpl &rewriterImpl);
2115 
2116   /// The legalizer to use when converting operations.
2117   OperationLegalizer opLegalizer;
2118 
2119   /// The conversion mode to use when legalizing operations.
2120   OpConversionMode mode;
2121 
2122   /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2123   /// this is populated with ops found to be legalizable to the target.
2124   /// When mode == OpConversionMode::Partial, this is populated with ops found
2125   /// *not* to be legalizable to the target.
2126   DenseSet<Operation *> *trackedOps;
2127 };
2128 } // end anonymous namespace
2129 
2130 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2131                                           Operation *op) {
2132   // Legalize the given operation.
2133   if (failed(opLegalizer.legalize(op, rewriter))) {
2134     // Handle the case of a failed conversion for each of the different modes.
2135     // Full conversions expect all operations to be converted.
2136     if (mode == OpConversionMode::Full)
2137       return op->emitError()
2138              << "failed to legalize operation '" << op->getName() << "'";
2139     // Partial conversions allow conversions to fail iff the operation was not
2140     // explicitly marked as illegal. If the user provided a nonlegalizableOps
2141     // set, non-legalizable ops are included.
2142     if (mode == OpConversionMode::Partial) {
2143       if (opLegalizer.isIllegal(op))
2144         return op->emitError()
2145                << "failed to legalize operation '" << op->getName()
2146                << "' that was explicitly marked illegal";
2147       if (trackedOps)
2148         trackedOps->insert(op);
2149     }
2150   } else if (mode == OpConversionMode::Analysis) {
2151     // Analysis conversions don't fail if any operations fail to legalize,
2152     // they are only interested in the operations that were successfully
2153     // legalized.
2154     trackedOps->insert(op);
2155   }
2156   return success();
2157 }
2158 
2159 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2160   if (ops.empty())
2161     return success();
2162   ConversionTarget &target = opLegalizer.getTarget();
2163 
2164   // Compute the set of operations and blocks to convert.
2165   std::vector<Operation *> toConvert;
2166   for (auto *op : ops) {
2167     toConvert.emplace_back(op);
2168     for (auto &region : op->getRegions())
2169       if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
2170                                       toConvert, &target)))
2171         return failure();
2172   }
2173 
2174   // Convert each operation and discard rewrites on failure.
2175   ConversionPatternRewriter rewriter(ops.front()->getContext());
2176   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2177   for (auto *op : toConvert)
2178     if (failed(convert(rewriter, op)))
2179       return rewriterImpl.discardRewrites(), failure();
2180 
2181   // Now that all of the operations have been converted, finalize the conversion
2182   // process to ensure any lingering conversion artifacts are cleaned up and
2183   // legalized.
2184   if (failed(finalize(rewriter)))
2185     return rewriterImpl.discardRewrites(), failure();
2186 
2187   // After a successful conversion, apply rewrites if this is not an analysis
2188   // conversion.
2189   if (mode == OpConversionMode::Analysis)
2190     rewriterImpl.discardRewrites();
2191   else
2192     rewriterImpl.applyRewrites();
2193   return success();
2194 }
2195 
2196 LogicalResult
2197 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2198   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2199 
2200   // Legalize converted block arguments.
2201   if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2202     return failure();
2203 
2204   // Process requested operation replacements.
2205   for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
2206        i != e; ++i) {
2207     unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
2208     auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
2209     for (OpResult result : repl.first->getResults()) {
2210       Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2211 
2212       // If the operation result was replaced with null, all of the uses of this
2213       // value should be replaced.
2214       if (!newValue) {
2215         if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2216           return failure();
2217         continue;
2218       }
2219 
2220       // Otherwise, check to see if the type of the result changed.
2221       if (result.getType() == newValue.getType())
2222         continue;
2223 
2224       // Legalize this result.
2225       rewriter.setInsertionPoint(repl.first);
2226       if (failed(legalizeChangedResultType(repl.first, result, newValue,
2227                                            repl.second.converter, rewriter,
2228                                            rewriterImpl)))
2229         return failure();
2230 
2231       // Update the end iterator for this loop in the case it was updated
2232       // when legalizing generated conversion operations.
2233       e = rewriterImpl.operationsWithChangedResults.size();
2234     }
2235   }
2236   return success();
2237 }
2238 
2239 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2240     ConversionPatternRewriter &rewriter,
2241     ConversionPatternRewriterImpl &rewriterImpl) {
2242   // Functor used to check if all users of a value will be dead after
2243   // conversion.
2244   auto findLiveUser = [&](Value val) {
2245     auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2246       return rewriterImpl.isOpIgnored(user);
2247     });
2248     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2249   };
2250 
2251   // Materialize any necessary conversions for converted block arguments that
2252   // are still live.
2253   size_t numCreatedOps = rewriterImpl.createdOps.size();
2254   if (failed(rewriterImpl.argConverter.materializeLiveConversions(
2255           rewriterImpl.mapping, rewriter, findLiveUser)))
2256     return failure();
2257 
2258   // Legalize any newly created operations during argument materialization.
2259   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2260     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2261       return rewriterImpl.createdOps[i]->emitError()
2262              << "failed to legalize conversion operation generated for block "
2263                 "argument that remained live after conversion";
2264     }
2265   }
2266   return success();
2267 }
2268 
2269 LogicalResult OperationConverter::legalizeErasedResult(
2270     Operation *op, OpResult result,
2271     ConversionPatternRewriterImpl &rewriterImpl) {
2272   // If the operation result was replaced with null, all of the uses of this
2273   // value should be replaced.
2274   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2275     return rewriterImpl.isOpIgnored(user);
2276   });
2277   if (liveUserIt != result.user_end()) {
2278     InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2279                               << op->getName() << "' marked as erased";
2280     diag.attachNote(liveUserIt->getLoc())
2281         << "found live user of result #" << result.getResultNumber() << ": "
2282         << *liveUserIt;
2283     return failure();
2284   }
2285   return success();
2286 }
2287 
2288 LogicalResult OperationConverter::legalizeChangedResultType(
2289     Operation *op, OpResult result, Value newValue,
2290     TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2291     ConversionPatternRewriterImpl &rewriterImpl) {
2292   // Walk the users of this value to see if there are any live users that
2293   // weren't replaced during conversion.
2294   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2295     return rewriterImpl.isOpIgnored(user);
2296   });
2297   if (liveUserIt == result.user_end())
2298     return success();
2299 
2300   // If the replacement has a type converter, attempt to materialize a
2301   // conversion back to the original type.
2302   if (!replConverter) {
2303     // TODO: We should emit an error here, similarly to the case where the
2304     // result is replaced with null. Unfortunately a lot of existing
2305     // patterns rely on this behavior, so until those patterns are updated
2306     // we keep the legacy behavior here of just forwarding the new value.
2307     return success();
2308   }
2309 
2310   // Track the number of created operations so that new ones can be legalized.
2311   size_t numCreatedOps = rewriterImpl.createdOps.size();
2312 
2313   // Materialize a conversion for this live result value.
2314   Type resultType = result.getType();
2315   Value convertedValue = replConverter->materializeSourceConversion(
2316       rewriter, op->getLoc(), resultType, newValue);
2317   if (!convertedValue) {
2318     InFlightDiagnostic diag = op->emitError()
2319                               << "failed to materialize conversion for result #"
2320                               << result.getResultNumber() << " of operation '"
2321                               << op->getName()
2322                               << "' that remained live after conversion";
2323     diag.attachNote(liveUserIt->getLoc())
2324         << "see existing live user here: " << *liveUserIt;
2325     return failure();
2326   }
2327 
2328   // Legalize all of the newly created conversion operations.
2329   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2330     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2331       return op->emitError("failed to legalize conversion operation generated ")
2332              << "for result #" << result.getResultNumber() << " of operation '"
2333              << op->getName() << "' that remained live after conversion";
2334     }
2335   }
2336 
2337   rewriterImpl.mapping.map(result, convertedValue);
2338   return success();
2339 }
2340 
2341 //===----------------------------------------------------------------------===//
2342 // Type Conversion
2343 //===----------------------------------------------------------------------===//
2344 
2345 /// Remap an input of the original signature with a new set of types. The
2346 /// new types are appended to the new signature conversion.
2347 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
2348                                                    ArrayRef<Type> types) {
2349   assert(!types.empty() && "expected valid types");
2350   remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2351   addInputs(types);
2352 }
2353 
2354 /// Append new input types to the signature conversion, this should only be
2355 /// used if the new types are not intended to remap an existing input.
2356 void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
2357   assert(!types.empty() &&
2358          "1->0 type remappings don't need to be added explicitly");
2359   argTypes.append(types.begin(), types.end());
2360 }
2361 
2362 /// Remap an input of the original signature with a range of types in the
2363 /// new signature.
2364 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2365                                                     unsigned newInputNo,
2366                                                     unsigned newInputCount) {
2367   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2368   assert(newInputCount != 0 && "expected valid input count");
2369   remappedInputs[origInputNo] =
2370       InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2371 }
2372 
2373 /// Remap an input of the original signature to another `replacementValue`
2374 /// value. This would make the signature converter drop this argument.
2375 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2376                                                     Value replacementValue) {
2377   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2378   remappedInputs[origInputNo] =
2379       InputMapping{origInputNo, /*size=*/0, replacementValue};
2380 }
2381 
2382 /// This hooks allows for converting a type.
2383 LogicalResult TypeConverter::convertType(Type t,
2384                                          SmallVectorImpl<Type> &results) {
2385   auto existingIt = cachedDirectConversions.find(t);
2386   if (existingIt != cachedDirectConversions.end()) {
2387     if (existingIt->second)
2388       results.push_back(existingIt->second);
2389     return success(existingIt->second != nullptr);
2390   }
2391   auto multiIt = cachedMultiConversions.find(t);
2392   if (multiIt != cachedMultiConversions.end()) {
2393     results.append(multiIt->second.begin(), multiIt->second.end());
2394     return success();
2395   }
2396 
2397   // Walk the added converters in reverse order to apply the most recently
2398   // registered first.
2399   size_t currentCount = results.size();
2400   for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2401     if (Optional<LogicalResult> result = converter(t, results)) {
2402       if (!succeeded(*result)) {
2403         cachedDirectConversions.try_emplace(t, nullptr);
2404         return failure();
2405       }
2406       auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2407       if (newTypes.size() == 1)
2408         cachedDirectConversions.try_emplace(t, newTypes.front());
2409       else
2410         cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2411       return success();
2412     }
2413   }
2414   return failure();
2415 }
2416 
2417 /// This hook simplifies defining 1-1 type conversions. This function returns
2418 /// the type to convert to on success, and a null type on failure.
2419 Type TypeConverter::convertType(Type t) {
2420   // Use the multi-type result version to convert the type.
2421   SmallVector<Type, 1> results;
2422   if (failed(convertType(t, results)))
2423     return nullptr;
2424 
2425   // Check to ensure that only one type was produced.
2426   return results.size() == 1 ? results.front() : nullptr;
2427 }
2428 
2429 /// Convert the given set of types, filling 'results' as necessary. This
2430 /// returns failure if the conversion of any of the types fails, success
2431 /// otherwise.
2432 LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
2433                                           SmallVectorImpl<Type> &results) {
2434   for (auto type : types)
2435     if (failed(convertType(type, results)))
2436       return failure();
2437   return success();
2438 }
2439 
2440 /// Return true if the given type is legal for this type converter, i.e. the
2441 /// type converts to itself.
2442 bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
2443 /// Return true if the given operation has legal operand and result types.
2444 bool TypeConverter::isLegal(Operation *op) {
2445   return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2446 }
2447 
2448 /// Return true if the types of block arguments within the region are legal.
2449 bool TypeConverter::isLegal(Region *region) {
2450   return llvm::all_of(*region, [this](Block &block) {
2451     return isLegal(block.getArgumentTypes());
2452   });
2453 }
2454 
2455 /// Return true if the inputs and outputs of the given function type are
2456 /// legal.
2457 bool TypeConverter::isSignatureLegal(FunctionType ty) {
2458   return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2459 }
2460 
2461 /// This hook allows for converting a specific argument of a signature.
2462 LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
2463                                                  SignatureConversion &result) {
2464   // Try to convert the given input type.
2465   SmallVector<Type, 1> convertedTypes;
2466   if (failed(convertType(type, convertedTypes)))
2467     return failure();
2468 
2469   // If this argument is being dropped, there is nothing left to do.
2470   if (convertedTypes.empty())
2471     return success();
2472 
2473   // Otherwise, add the new inputs.
2474   result.addInputs(inputNo, convertedTypes);
2475   return success();
2476 }
2477 LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
2478                                                   SignatureConversion &result,
2479                                                   unsigned origInputOffset) {
2480   for (unsigned i = 0, e = types.size(); i != e; ++i)
2481     if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2482       return failure();
2483   return success();
2484 }
2485 
2486 Value TypeConverter::materializeConversion(
2487     MutableArrayRef<MaterializationCallbackFn> materializations,
2488     OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
2489   for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
2490     if (Optional<Value> result = fn(builder, resultType, inputs, loc))
2491       return result.getValue();
2492   return nullptr;
2493 }
2494 
2495 /// This function converts the type signature of the given block, by invoking
2496 /// 'convertSignatureArg' for each argument. This function should return a valid
2497 /// conversion for the signature on success, None otherwise.
2498 auto TypeConverter::convertBlockSignature(Block *block)
2499     -> Optional<SignatureConversion> {
2500   SignatureConversion conversion(block->getNumArguments());
2501   if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2502     return llvm::None;
2503   return conversion;
2504 }
2505 
2506 /// Create a default conversion pattern that rewrites the type signature of a
2507 /// FuncOp.
2508 namespace {
2509 struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
2510   FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
2511       : OpConversionPattern(converter, ctx) {}
2512 
2513   /// Hook for derived classes to implement combined matching and rewriting.
2514   LogicalResult
2515   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
2516                   ConversionPatternRewriter &rewriter) const override {
2517     FunctionType type = funcOp.getType();
2518 
2519     // Convert the original function types.
2520     TypeConverter::SignatureConversion result(type.getNumInputs());
2521     SmallVector<Type, 1> newResults;
2522     if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
2523         failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
2524         failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter,
2525                                            &result)))
2526       return failure();
2527 
2528     // Update the function signature in-place.
2529     rewriter.updateRootInPlace(funcOp, [&] {
2530       funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults,
2531                                        funcOp.getContext()));
2532     });
2533     return success();
2534   }
2535 };
2536 } // end anonymous namespace
2537 
2538 void mlir::populateFuncOpTypeConversionPattern(
2539     OwningRewritePatternList &patterns, MLIRContext *ctx,
2540     TypeConverter &converter) {
2541   patterns.insert<FuncOpSignatureConversion>(ctx, converter);
2542 }
2543 
2544 //===----------------------------------------------------------------------===//
2545 // ConversionTarget
2546 //===----------------------------------------------------------------------===//
2547 
2548 /// Register a legality action for the given operation.
2549 void ConversionTarget::setOpAction(OperationName op,
2550                                    LegalizationAction action) {
2551   legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None};
2552 }
2553 
2554 /// Register a legality action for the given dialects.
2555 void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
2556                                         LegalizationAction action) {
2557   for (StringRef dialect : dialectNames)
2558     legalDialects[dialect] = action;
2559 }
2560 
2561 /// Get the legality action for the given operation.
2562 auto ConversionTarget::getOpAction(OperationName op) const
2563     -> Optional<LegalizationAction> {
2564   Optional<LegalizationInfo> info = getOpInfo(op);
2565   return info ? info->action : Optional<LegalizationAction>();
2566 }
2567 
2568 /// If the given operation instance is legal on this target, a structure
2569 /// containing legality information is returned. If the operation is not legal,
2570 /// None is returned.
2571 auto ConversionTarget::isLegal(Operation *op) const
2572     -> Optional<LegalOpDetails> {
2573   Optional<LegalizationInfo> info = getOpInfo(op->getName());
2574   if (!info)
2575     return llvm::None;
2576 
2577   // Returns true if this operation instance is known to be legal.
2578   auto isOpLegal = [&] {
2579     // Handle dynamic legality either with the provided legality function, or
2580     // the default hook on the derived instance.
2581     if (info->action == LegalizationAction::Dynamic)
2582       return info->legalityFn ? (*info->legalityFn)(op)
2583                               : isDynamicallyLegal(op);
2584 
2585     // Otherwise, the operation is only legal if it was marked 'Legal'.
2586     return info->action == LegalizationAction::Legal;
2587   };
2588   if (!isOpLegal())
2589     return llvm::None;
2590 
2591   // This operation is legal, compute any additional legality information.
2592   LegalOpDetails legalityDetails;
2593   if (info->isRecursivelyLegal) {
2594     auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
2595     if (legalityFnIt != opRecursiveLegalityFns.end())
2596       legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
2597     else
2598       legalityDetails.isRecursivelyLegal = true;
2599   }
2600   return legalityDetails;
2601 }
2602 
2603 /// Set the dynamic legality callback for the given operation.
2604 void ConversionTarget::setLegalityCallback(
2605     OperationName name, const DynamicLegalityCallbackFn &callback) {
2606   assert(callback && "expected valid legality callback");
2607   auto infoIt = legalOperations.find(name);
2608   assert(infoIt != legalOperations.end() &&
2609          infoIt->second.action == LegalizationAction::Dynamic &&
2610          "expected operation to already be marked as dynamically legal");
2611   infoIt->second.legalityFn = callback;
2612 }
2613 
2614 /// Set the recursive legality callback for the given operation and mark the
2615 /// operation as recursively legal.
2616 void ConversionTarget::markOpRecursivelyLegal(
2617     OperationName name, const DynamicLegalityCallbackFn &callback) {
2618   auto infoIt = legalOperations.find(name);
2619   assert(infoIt != legalOperations.end() &&
2620          infoIt->second.action != LegalizationAction::Illegal &&
2621          "expected operation to already be marked as legal");
2622   infoIt->second.isRecursivelyLegal = true;
2623   if (callback)
2624     opRecursiveLegalityFns[name] = callback;
2625   else
2626     opRecursiveLegalityFns.erase(name);
2627 }
2628 
2629 /// Set the dynamic legality callback for the given dialects.
2630 void ConversionTarget::setLegalityCallback(
2631     ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
2632   assert(callback && "expected valid legality callback");
2633   for (StringRef dialect : dialects)
2634     dialectLegalityFns[dialect] = callback;
2635 }
2636 
2637 /// Get the legalization information for the given operation.
2638 auto ConversionTarget::getOpInfo(OperationName op) const
2639     -> Optional<LegalizationInfo> {
2640   // Check for info for this specific operation.
2641   auto it = legalOperations.find(op);
2642   if (it != legalOperations.end())
2643     return it->second;
2644   // Check for info for the parent dialect.
2645   auto dialectIt = legalDialects.find(op.getDialect());
2646   if (dialectIt != legalDialects.end()) {
2647     Optional<DynamicLegalityCallbackFn> callback;
2648     auto dialectFn = dialectLegalityFns.find(op.getDialect());
2649     if (dialectFn != dialectLegalityFns.end())
2650       callback = dialectFn->second;
2651     return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
2652                             callback};
2653   }
2654   // Otherwise, check if we mark unknown operations as dynamic.
2655   if (unknownOpsDynamicallyLegal)
2656     return LegalizationInfo{LegalizationAction::Dynamic,
2657                             /*isRecursivelyLegal=*/false, unknownLegalityFn};
2658   return llvm::None;
2659 }
2660 
2661 //===----------------------------------------------------------------------===//
2662 // Op Conversion Entry Points
2663 //===----------------------------------------------------------------------===//
2664 
2665 /// Apply a partial conversion on the given operations and all nested
2666 /// operations. This method converts as many operations to the target as
2667 /// possible, ignoring operations that failed to legalize. This method only
2668 /// returns failure if there ops explicitly marked as illegal.
2669 /// If an `unconvertedOps` set is provided, all operations that are found not
2670 /// to be legalizable to the given `target` are placed within that set. (Note
2671 /// that if there is an op explicitly marked as illegal, the conversion
2672 /// terminates and the `unconvertedOps` set will not necessarily be complete.)
2673 LogicalResult
2674 mlir::applyPartialConversion(ArrayRef<Operation *> ops,
2675                              ConversionTarget &target,
2676                              const FrozenRewritePatternList &patterns,
2677                              DenseSet<Operation *> *unconvertedOps) {
2678   OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
2679                                  unconvertedOps);
2680   return opConverter.convertOperations(ops);
2681 }
2682 LogicalResult
2683 mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
2684                              const FrozenRewritePatternList &patterns,
2685                              DenseSet<Operation *> *unconvertedOps) {
2686   return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
2687                                 unconvertedOps);
2688 }
2689 
2690 /// Apply a complete conversion on the given operations, and all nested
2691 /// operations. This method will return failure if the conversion of any
2692 /// operation fails.
2693 LogicalResult
2694 mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
2695                           const FrozenRewritePatternList &patterns) {
2696   OperationConverter opConverter(target, patterns, OpConversionMode::Full);
2697   return opConverter.convertOperations(ops);
2698 }
2699 LogicalResult
2700 mlir::applyFullConversion(Operation *op, ConversionTarget &target,
2701                           const FrozenRewritePatternList &patterns) {
2702   return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
2703 }
2704 
2705 /// Apply an analysis conversion on the given operations, and all nested
2706 /// operations. This method analyzes which operations would be successfully
2707 /// converted to the target if a conversion was applied. All operations that
2708 /// were found to be legalizable to the given 'target' are placed within the
2709 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
2710 /// operations on success and only pre-existing operations are added to the set.
2711 LogicalResult
2712 mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
2713                               ConversionTarget &target,
2714                               const FrozenRewritePatternList &patterns,
2715                               DenseSet<Operation *> &convertedOps) {
2716   OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
2717                                  &convertedOps);
2718   return opConverter.convertOperations(ops);
2719 }
2720 LogicalResult
2721 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
2722                               const FrozenRewritePatternList &patterns,
2723                               DenseSet<Operation *> &convertedOps) {
2724   return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
2725                                  convertedOps);
2726 }
2727