1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/BlockAndValueMapping.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Rewrite/PatternApplicator.h"
15 #include "mlir/Transforms/Utils.h"
16 #include "llvm/ADT/SetVector.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/Support/SaveAndRestore.h"
21 #include "llvm/Support/ScopedPrinter.h"
22 
23 using namespace mlir;
24 using namespace mlir::detail;
25 
26 #define DEBUG_TYPE "dialect-conversion"
27 
28 /// Recursively collect all of the operations to convert from within 'region'.
29 /// If 'target' is nonnull, operations that are recursively legal have their
30 /// regions pre-filtered to avoid considering them for legalization.
31 static LogicalResult
32 computeConversionSet(iterator_range<Region::iterator> region,
33                      Location regionLoc, std::vector<Operation *> &toConvert,
34                      ConversionTarget *target = nullptr) {
35   if (llvm::empty(region))
36     return success();
37 
38   // Traverse starting from the entry block.
39   SmallVector<Block *, 16> worklist(1, &*region.begin());
40   DenseSet<Block *> visitedBlocks;
41   visitedBlocks.insert(worklist.front());
42   while (!worklist.empty()) {
43     Block *block = worklist.pop_back_val();
44 
45     // Compute the conversion set of each of the nested operations.
46     for (Operation &op : *block) {
47       toConvert.emplace_back(&op);
48 
49       // Don't check this operation's children for conversion if the operation
50       // is recursively legal.
51       auto legalityInfo = target ? target->isLegal(&op)
52                                  : Optional<ConversionTarget::LegalOpDetails>();
53       if (legalityInfo && legalityInfo->isRecursivelyLegal)
54         continue;
55       for (auto &region : op.getRegions()) {
56         if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
57                                         toConvert, target)))
58           return failure();
59       }
60     }
61 
62     // Recurse to children that haven't been visited.
63     for (Block *succ : block->getSuccessors())
64       if (visitedBlocks.insert(succ).second)
65         worklist.push_back(succ);
66   }
67 
68   // Check that all blocks in the region were visited.
69   if (llvm::any_of(llvm::drop_begin(region, 1),
70                    [&](Block &block) { return !visitedBlocks.count(&block); }))
71     return emitError(regionLoc, "unreachable blocks were not converted");
72   return success();
73 }
74 
75 /// A utility function to log a successful result for the given reason.
76 template <typename... Args>
77 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
78   LLVM_DEBUG({
79     os.unindent();
80     os.startLine() << "} -> SUCCESS";
81     if (!fmt.empty())
82       os.getOStream() << " : "
83                       << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
84     os.getOStream() << "\n";
85   });
86 }
87 
88 /// A utility function to log a failure result for the given reason.
89 template <typename... Args>
90 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
91   LLVM_DEBUG({
92     os.unindent();
93     os.startLine() << "} -> FAILURE : "
94                    << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
95                    << "\n";
96   });
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // ConversionValueMapping
101 //===----------------------------------------------------------------------===//
102 
103 namespace {
104 /// This class wraps a BlockAndValueMapping to provide recursive lookup
105 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
106 struct ConversionValueMapping {
107   /// Lookup a mapped value within the map. If a mapping for the provided value
108   /// does not exist then return the provided value.
109   Value lookupOrDefault(Value from) const;
110 
111   /// Lookup the latest legal value within the map. If a mapping for the
112   /// provided value does not exist then return the provided value. If
113   /// `converter` is non-null, returns the most recently mapped value with the
114   /// legal type. If an operand of that type does not exist, defaults to normal
115   /// behavior.
116   Value lookupLatestLegal(Value from, TypeConverter *converter) const;
117 
118   /// Lookup a mapped value within the map, or return null if a mapping does not
119   /// exist. If a mapping exists, this follows the same behavior of
120   /// `lookupOrDefault`.
121   Value lookupOrNull(Value from) const;
122 
123   /// Map a value to the one provided.
124   void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); }
125 
126   /// Drop the last mapping for the given value.
127   void erase(Value value) { mapping.erase(value); }
128 
129 private:
130   /// Current value mappings.
131   BlockAndValueMapping mapping;
132 };
133 } // end anonymous namespace
134 
135 Value ConversionValueMapping::lookupOrDefault(Value from) const {
136   // If this value had a valid mapping, unmap that value as well in the case
137   // that it was also replaced.
138   while (auto mappedValue = mapping.lookupOrNull(from))
139     from = mappedValue;
140   return from;
141 }
142 
143 Value ConversionValueMapping::lookupLatestLegal(
144     Value from, TypeConverter *converter) const {
145   if (!converter)
146     return lookupOrDefault(from);
147 
148   // Otherwise, try to find the deepest value that has the legal type.
149   Value legalValue;
150   do {
151     if (converter->isLegal(from.getType()))
152       legalValue = from;
153 
154     Value mappedValue = mapping.lookupOrNull(from);
155     if (!mappedValue)
156       break;
157     from = mappedValue;
158   } while (true);
159 
160   // If the desired value was found use it, otherwise default to the leaf value.
161   return legalValue ? legalValue : from;
162 }
163 
164 Value ConversionValueMapping::lookupOrNull(Value from) const {
165   Value result = lookupOrDefault(from);
166   return result == from ? nullptr : result;
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // ArgConverter
171 //===----------------------------------------------------------------------===//
172 namespace {
173 /// This class provides a simple interface for converting the types of block
174 /// arguments. This is done by creating a new block that contains the new legal
175 /// types and extracting the block that contains the old illegal types to allow
176 /// for undoing pending rewrites in the case of failure.
177 struct ArgConverter {
178   ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {}
179 
180   /// This structure contains the information pertaining to an argument that has
181   /// been converted.
182   struct ConvertedArgInfo {
183     ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
184                      Value castValue = nullptr)
185         : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
186 
187     /// The start index of in the new argument list that contains arguments that
188     /// replace the original.
189     unsigned newArgIdx;
190 
191     /// The number of arguments that replaced the original argument.
192     unsigned newArgSize;
193 
194     /// The cast value that was created to cast from the new arguments to the
195     /// old. This only used if 'newArgSize' > 1.
196     Value castValue;
197   };
198 
199   /// This structure contains information pertaining to a block that has had its
200   /// signature converted.
201   struct ConvertedBlockInfo {
202     ConvertedBlockInfo(Block *origBlock, TypeConverter &converter)
203         : origBlock(origBlock), converter(&converter) {}
204 
205     /// The original block that was requested to have its signature converted.
206     Block *origBlock;
207 
208     /// The conversion information for each of the arguments. The information is
209     /// None if the argument was dropped during conversion.
210     SmallVector<Optional<ConvertedArgInfo>, 1> argInfo;
211 
212     /// The type converter used to convert the arguments.
213     TypeConverter *converter;
214   };
215 
216   /// Return if the signature of the given block has already been converted.
217   bool hasBeenConverted(Block *block) const {
218     return conversionInfo.count(block) || convertedBlocks.count(block);
219   }
220 
221   /// Set the type converter to use for the given region.
222   void setConverter(Region *region, TypeConverter *typeConverter) {
223     assert(typeConverter && "expected valid type converter");
224     regionToConverter[region] = typeConverter;
225   }
226 
227   /// Return the type converter to use for the given region, or null if there
228   /// isn't one.
229   TypeConverter *getConverter(Region *region) {
230     return regionToConverter.lookup(region);
231   }
232 
233   //===--------------------------------------------------------------------===//
234   // Rewrite Application
235   //===--------------------------------------------------------------------===//
236 
237   /// Erase any rewrites registered for the blocks within the given operation
238   /// which is about to be removed. This merely drops the rewrites without
239   /// undoing them.
240   void notifyOpRemoved(Operation *op);
241 
242   /// Cleanup and undo any generated conversions for the arguments of block.
243   /// This method replaces the new block with the original, reverting the IR to
244   /// its original state.
245   void discardRewrites(Block *block);
246 
247   /// Fully replace uses of the old arguments with the new.
248   void applyRewrites(ConversionValueMapping &mapping);
249 
250   /// Materialize any necessary conversions for converted arguments that have
251   /// live users, using the provided `findLiveUser` to search for a user that
252   /// survives the conversion process.
253   LogicalResult
254   materializeLiveConversions(ConversionValueMapping &mapping,
255                              OpBuilder &builder,
256                              function_ref<Operation *(Value)> findLiveUser);
257 
258   //===--------------------------------------------------------------------===//
259   // Conversion
260   //===--------------------------------------------------------------------===//
261 
262   /// Attempt to convert the signature of the given block, if successful a new
263   /// block is returned containing the new arguments. Returns `block` if it did
264   /// not require conversion.
265   FailureOr<Block *> convertSignature(Block *block, TypeConverter &converter,
266                                       ConversionValueMapping &mapping);
267 
268   /// Apply the given signature conversion on the given block. The new block
269   /// containing the updated signature is returned. If no conversions were
270   /// necessary, e.g. if the block has no arguments, `block` is returned.
271   /// `converter` is used to generate any necessary cast operations that
272   /// translate between the origin argument types and those specified in the
273   /// signature conversion.
274   Block *applySignatureConversion(
275       Block *block, TypeConverter &converter,
276       TypeConverter::SignatureConversion &signatureConversion,
277       ConversionValueMapping &mapping);
278 
279   /// Insert a new conversion into the cache.
280   void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
281 
282   /// A collection of blocks that have had their arguments converted. This is a
283   /// map from the new replacement block, back to the original block.
284   llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
285 
286   /// The set of original blocks that were converted.
287   DenseSet<Block *> convertedBlocks;
288 
289   /// A mapping from valid regions, to those containing the original blocks of a
290   /// conversion.
291   DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
292 
293   /// A mapping of regions to type converters that should be used when
294   /// converting the arguments of blocks within that region.
295   DenseMap<Region *, TypeConverter *> regionToConverter;
296 
297   /// The pattern rewriter to use when materializing conversions.
298   PatternRewriter &rewriter;
299 };
300 } // end anonymous namespace
301 
302 //===----------------------------------------------------------------------===//
303 // Rewrite Application
304 
305 void ArgConverter::notifyOpRemoved(Operation *op) {
306   if (conversionInfo.empty())
307     return;
308 
309   for (Region &region : op->getRegions()) {
310     for (Block &block : region) {
311       // Drop any rewrites from within.
312       for (Operation &nestedOp : block)
313         if (nestedOp.getNumRegions())
314           notifyOpRemoved(&nestedOp);
315 
316       // Check if this block was converted.
317       auto it = conversionInfo.find(&block);
318       if (it == conversionInfo.end())
319         continue;
320 
321       // Drop all uses of the original arguments and delete the original block.
322       Block *origBlock = it->second.origBlock;
323       for (BlockArgument arg : origBlock->getArguments())
324         arg.dropAllUses();
325       conversionInfo.erase(it);
326     }
327   }
328 }
329 
330 void ArgConverter::discardRewrites(Block *block) {
331   auto it = conversionInfo.find(block);
332   if (it == conversionInfo.end())
333     return;
334   Block *origBlock = it->second.origBlock;
335 
336   // Drop all uses of the new block arguments and replace uses of the new block.
337   for (int i = block->getNumArguments() - 1; i >= 0; --i)
338     block->getArgument(i).dropAllUses();
339   block->replaceAllUsesWith(origBlock);
340 
341   // Move the operations back the original block and the delete the new block.
342   origBlock->getOperations().splice(origBlock->end(), block->getOperations());
343   origBlock->moveBefore(block);
344   block->erase();
345 
346   convertedBlocks.erase(origBlock);
347   conversionInfo.erase(it);
348 }
349 
350 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
351   for (auto &info : conversionInfo) {
352     ConvertedBlockInfo &blockInfo = info.second;
353     Block *origBlock = blockInfo.origBlock;
354 
355     // Process the remapping for each of the original arguments.
356     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
357       Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
358       BlockArgument origArg = origBlock->getArgument(i);
359 
360       // Handle the case of a 1->0 value mapping.
361       if (!argInfo) {
362         if (Value newArg = mapping.lookupOrNull(origArg))
363           origArg.replaceAllUsesWith(newArg);
364         continue;
365       }
366 
367       // Otherwise this is a 1->1+ value mapping.
368       Value castValue = argInfo->castValue;
369       assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
370 
371       // If the argument is still used, replace it with the generated cast.
372       if (!origArg.use_empty())
373         origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue));
374     }
375   }
376 }
377 
378 LogicalResult ArgConverter::materializeLiveConversions(
379     ConversionValueMapping &mapping, OpBuilder &builder,
380     function_ref<Operation *(Value)> findLiveUser) {
381   for (auto &info : conversionInfo) {
382     Block *newBlock = info.first;
383     ConvertedBlockInfo &blockInfo = info.second;
384     Block *origBlock = blockInfo.origBlock;
385 
386     // Process the remapping for each of the original arguments.
387     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
388       // FIXME: We should run the below checks even if the type conversion was
389       // 1->N, but a lot of existing lowering rely on the block argument being
390       // blindly replaced. Those usages should be updated, and this if should be
391       // removed.
392       if (blockInfo.argInfo[i])
393         continue;
394 
395       // If the type of this argument changed and the argument is still live, we
396       // need to materialize a conversion.
397       BlockArgument origArg = origBlock->getArgument(i);
398       auto argReplacementValue = mapping.lookupOrDefault(origArg);
399       bool isDroppedArg = argReplacementValue == origArg;
400       if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg)
401         continue;
402       Operation *liveUser = findLiveUser(origArg);
403       if (!liveUser)
404         continue;
405 
406       if (OpResult result = argReplacementValue.dyn_cast<OpResult>())
407         rewriter.setInsertionPointAfter(result.getOwner());
408       else
409         rewriter.setInsertionPointToStart(newBlock);
410       Value newArg = blockInfo.converter->materializeSourceConversion(
411           rewriter, origArg.getLoc(), origArg.getType(),
412           isDroppedArg ? ValueRange() : ValueRange(argReplacementValue));
413       if (!newArg) {
414         InFlightDiagnostic diag =
415             emitError(origArg.getLoc())
416             << "failed to materialize conversion for block argument #" << i
417             << " that remained live after conversion, type was "
418             << origArg.getType();
419         if (!isDroppedArg)
420           diag << ", with target type " << argReplacementValue.getType();
421         diag.attachNote(liveUser->getLoc())
422             << "see existing live user here: " << *liveUser;
423         return failure();
424       }
425       mapping.map(origArg, newArg);
426     }
427   }
428   return success();
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // Conversion
433 
434 FailureOr<Block *>
435 ArgConverter::convertSignature(Block *block, TypeConverter &converter,
436                                ConversionValueMapping &mapping) {
437   // Check if the block was already converted. If the block is detached,
438   // conservatively assume it is going to be deleted.
439   if (hasBeenConverted(block) || !block->getParent())
440     return block;
441 
442   // Try to convert the signature for the block with the provided converter.
443   if (auto conversion = converter.convertBlockSignature(block))
444     return applySignatureConversion(block, converter, *conversion, mapping);
445   return failure();
446 }
447 
448 Block *ArgConverter::applySignatureConversion(
449     Block *block, TypeConverter &converter,
450     TypeConverter::SignatureConversion &signatureConversion,
451     ConversionValueMapping &mapping) {
452   // If no arguments are being changed or added, there is nothing to do.
453   unsigned origArgCount = block->getNumArguments();
454   auto convertedTypes = signatureConversion.getConvertedTypes();
455   if (origArgCount == 0 && convertedTypes.empty())
456     return block;
457 
458   // Split the block at the beginning to get a new block to use for the updated
459   // signature.
460   Block *newBlock = block->splitBlock(block->begin());
461   block->replaceAllUsesWith(newBlock);
462 
463   SmallVector<Value, 4> newArgRange(newBlock->addArguments(convertedTypes));
464   ArrayRef<Value> newArgs(newArgRange);
465 
466   // Remap each of the original arguments as determined by the signature
467   // conversion.
468   ConvertedBlockInfo info(block, converter);
469   info.argInfo.resize(origArgCount);
470 
471   OpBuilder::InsertionGuard guard(rewriter);
472   rewriter.setInsertionPointToStart(newBlock);
473   for (unsigned i = 0; i != origArgCount; ++i) {
474     auto inputMap = signatureConversion.getInputMapping(i);
475     if (!inputMap)
476       continue;
477     BlockArgument origArg = block->getArgument(i);
478 
479     // If inputMap->replacementValue is not nullptr, then the argument is
480     // dropped and a replacement value is provided to be the remappedValue.
481     if (inputMap->replacementValue) {
482       assert(inputMap->size == 0 &&
483              "invalid to provide a replacement value when the argument isn't "
484              "dropped");
485       mapping.map(origArg, inputMap->replacementValue);
486       continue;
487     }
488 
489     // Otherwise, this is a 1->1+ mapping. Call into the provided type converter
490     // to pack the new values. For 1->1 mappings, if there is no materialization
491     // provided, use the argument directly instead.
492     auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
493     Value newArg = converter.materializeArgumentConversion(
494         rewriter, origArg.getLoc(), origArg.getType(), replArgs);
495     if (!newArg) {
496       assert(replArgs.size() == 1 &&
497              "couldn't materialize the result of 1->N conversion");
498       newArg = replArgs.front();
499     }
500     mapping.map(origArg, newArg);
501     info.argInfo[i] =
502         ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
503   }
504 
505   // Remove the original block from the region and return the new one.
506   insertConversion(newBlock, std::move(info));
507   return newBlock;
508 }
509 
510 void ArgConverter::insertConversion(Block *newBlock,
511                                     ConvertedBlockInfo &&info) {
512   // Get a region to insert the old block.
513   Region *region = newBlock->getParent();
514   std::unique_ptr<Region> &mappedRegion = regionMapping[region];
515   if (!mappedRegion)
516     mappedRegion = std::make_unique<Region>(region->getParentOp());
517 
518   // Move the original block to the mapped region and emplace the conversion.
519   mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
520                                    info.origBlock->getIterator());
521   convertedBlocks.insert(info.origBlock);
522   conversionInfo.insert({newBlock, std::move(info)});
523 }
524 
525 //===----------------------------------------------------------------------===//
526 // Rewriter and Translation State
527 //===----------------------------------------------------------------------===//
528 namespace {
529 /// This class contains a snapshot of the current conversion rewriter state.
530 /// This is useful when saving and undoing a set of rewrites.
531 struct RewriterState {
532   RewriterState(unsigned numCreatedOps, unsigned numReplacements,
533                 unsigned numArgReplacements, unsigned numBlockActions,
534                 unsigned numIgnoredOperations, unsigned numRootUpdates)
535       : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
536         numArgReplacements(numArgReplacements),
537         numBlockActions(numBlockActions),
538         numIgnoredOperations(numIgnoredOperations),
539         numRootUpdates(numRootUpdates) {}
540 
541   /// The current number of created operations.
542   unsigned numCreatedOps;
543 
544   /// The current number of replacements queued.
545   unsigned numReplacements;
546 
547   /// The current number of argument replacements queued.
548   unsigned numArgReplacements;
549 
550   /// The current number of block actions performed.
551   unsigned numBlockActions;
552 
553   /// The current number of ignored operations.
554   unsigned numIgnoredOperations;
555 
556   /// The current number of operations that were updated in place.
557   unsigned numRootUpdates;
558 };
559 
560 /// The state of an operation that was updated by a pattern in-place. This
561 /// contains all of the necessary information to reconstruct an operation that
562 /// was updated in place.
563 class OperationTransactionState {
564 public:
565   OperationTransactionState() = default;
566   OperationTransactionState(Operation *op)
567       : op(op), loc(op->getLoc()), attrs(op->getMutableAttrDict()),
568         operands(op->operand_begin(), op->operand_end()),
569         successors(op->successor_begin(), op->successor_end()) {}
570 
571   /// Discard the transaction state and reset the state of the original
572   /// operation.
573   void resetOperation() const {
574     op->setLoc(loc);
575     op->setAttrs(attrs);
576     op->setOperands(operands);
577     for (auto it : llvm::enumerate(successors))
578       op->setSuccessor(it.value(), it.index());
579   }
580 
581   /// Return the original operation of this state.
582   Operation *getOperation() const { return op; }
583 
584 private:
585   Operation *op;
586   LocationAttr loc;
587   MutableDictionaryAttr attrs;
588   SmallVector<Value, 8> operands;
589   SmallVector<Block *, 2> successors;
590 };
591 
592 /// This class represents one requested operation replacement via 'replaceOp' or
593 /// 'eraseOp`.
594 struct OpReplacement {
595   OpReplacement() = default;
596   OpReplacement(TypeConverter *converter) : converter(converter) {}
597 
598   /// An optional type converter that can be used to materialize conversions
599   /// between the new and old values if necessary.
600   TypeConverter *converter = nullptr;
601 };
602 
603 /// The kind of the block action performed during the rewrite.  Actions can be
604 /// undone if the conversion fails.
605 enum class BlockActionKind {
606   Create,
607   Erase,
608   Merge,
609   Move,
610   Split,
611   TypeConversion
612 };
613 
614 /// Original position of the given block in its parent region. During undo
615 /// actions, the block needs to be placed after `insertAfterBlock`.
616 struct BlockPosition {
617   Region *region;
618   Block *insertAfterBlock;
619 };
620 
621 /// Information needed to undo the merge actions.
622 /// - the source block, and
623 /// - the Operation that was the last operation in the dest block before the
624 ///   merge (could be null if the dest block was empty).
625 struct MergeInfo {
626   Block *sourceBlock;
627   Operation *destBlockLastInst;
628 };
629 
630 /// The storage class for an undoable block action (one of BlockActionKind),
631 /// contains the information necessary to undo this action.
632 struct BlockAction {
633   static BlockAction getCreate(Block *block) {
634     return {BlockActionKind::Create, block, {}};
635   }
636   static BlockAction getErase(Block *block, BlockPosition originalPosition) {
637     return {BlockActionKind::Erase, block, {originalPosition}};
638   }
639   static BlockAction getMerge(Block *block, Block *sourceBlock) {
640     BlockAction action{BlockActionKind::Merge, block, {}};
641     action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
642     return action;
643   }
644   static BlockAction getMove(Block *block, BlockPosition originalPosition) {
645     return {BlockActionKind::Move, block, {originalPosition}};
646   }
647   static BlockAction getSplit(Block *block, Block *originalBlock) {
648     BlockAction action{BlockActionKind::Split, block, {}};
649     action.originalBlock = originalBlock;
650     return action;
651   }
652   static BlockAction getTypeConversion(Block *block) {
653     return BlockAction{BlockActionKind::TypeConversion, block, {}};
654   }
655 
656   // The action kind.
657   BlockActionKind kind;
658 
659   // A pointer to the block that was created by the action.
660   Block *block;
661 
662   union {
663     // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
664     // contains a pointer to the region that originally contained the block as
665     // well as the position of the block in that region.
666     BlockPosition originalPosition;
667     // In use if kind == BlockActionKind::Split and contains a pointer to the
668     // block that was split into two parts.
669     Block *originalBlock;
670     // In use if kind == BlockActionKind::Merge, and contains the information
671     // needed to undo the merge.
672     MergeInfo mergeInfo;
673   };
674 };
675 } // end anonymous namespace
676 
677 //===----------------------------------------------------------------------===//
678 // ConversionPatternRewriterImpl
679 //===----------------------------------------------------------------------===//
680 namespace mlir {
681 namespace detail {
682 struct ConversionPatternRewriterImpl {
683   ConversionPatternRewriterImpl(PatternRewriter &rewriter)
684       : argConverter(rewriter) {}
685 
686   /// Cleanup and destroy any generated rewrite operations. This method is
687   /// invoked when the conversion process fails.
688   void discardRewrites();
689 
690   /// Apply all requested operation rewrites. This method is invoked when the
691   /// conversion process succeeds.
692   void applyRewrites();
693 
694   //===--------------------------------------------------------------------===//
695   // State Management
696   //===--------------------------------------------------------------------===//
697 
698   /// Return the current state of the rewriter.
699   RewriterState getCurrentState();
700 
701   /// Reset the state of the rewriter to a previously saved point.
702   void resetState(RewriterState state);
703 
704   /// Erase any blocks that were unlinked from their regions and stored in block
705   /// actions.
706   void eraseDanglingBlocks();
707 
708   /// Undo the block actions (motions, splits) one by one in reverse order until
709   /// "numActionsToKeep" actions remains.
710   void undoBlockActions(unsigned numActionsToKeep = 0);
711 
712   /// Remap the given operands to those with potentially different types. The
713   /// provided type converter is used to ensure that the remapped types are
714   /// legal. Returns success if the operands could be remapped, failure
715   /// otherwise.
716   LogicalResult remapValues(Location loc, PatternRewriter &rewriter,
717                             TypeConverter *converter,
718                             Operation::operand_range operands,
719                             SmallVectorImpl<Value> &remapped);
720 
721   /// Returns true if the given operation is ignored, and does not need to be
722   /// converted.
723   bool isOpIgnored(Operation *op) const;
724 
725   /// Recursively marks the nested operations under 'op' as ignored. This
726   /// removes them from being considered for legalization.
727   void markNestedOpsIgnored(Operation *op);
728 
729   //===--------------------------------------------------------------------===//
730   // Type Conversion
731   //===--------------------------------------------------------------------===//
732 
733   /// Convert the signature of the given block.
734   FailureOr<Block *> convertBlockSignature(
735       Block *block, TypeConverter &converter,
736       TypeConverter::SignatureConversion *conversion = nullptr);
737 
738   /// Apply a signature conversion on the given region.
739   Block *
740   applySignatureConversion(Region *region,
741                            TypeConverter::SignatureConversion &conversion);
742 
743   /// Convert the types of block arguments within the given region.
744   FailureOr<Block *>
745   convertRegionTypes(Region *region, TypeConverter &converter,
746                      TypeConverter::SignatureConversion *entryConversion);
747 
748   //===--------------------------------------------------------------------===//
749   // Rewriter Notification Hooks
750   //===--------------------------------------------------------------------===//
751 
752   /// PatternRewriter hook for replacing the results of an operation.
753   void notifyOpReplaced(Operation *op, ValueRange newValues);
754 
755   /// Notifies that a block is about to be erased.
756   void notifyBlockIsBeingErased(Block *block);
757 
758   /// Notifies that a block was created.
759   void notifyCreatedBlock(Block *block);
760 
761   /// Notifies that a block was split.
762   void notifySplitBlock(Block *block, Block *continuation);
763 
764   /// Notifies that `block` is being merged with `srcBlock`.
765   void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
766 
767   /// Notifies that the blocks of a region are about to be moved.
768   void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
769                                         Region::iterator before);
770 
771   /// Notifies that the blocks of a region were cloned into another.
772   void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
773                                    Location origRegionLoc);
774 
775   /// Notifies that a pattern match failed for the given reason.
776   LogicalResult
777   notifyMatchFailure(Location loc,
778                      function_ref<void(Diagnostic &)> reasonCallback);
779 
780   //===--------------------------------------------------------------------===//
781   // State
782   //===--------------------------------------------------------------------===//
783 
784   // Mapping between replaced values that differ in type. This happens when
785   // replacing a value with one of a different type.
786   ConversionValueMapping mapping;
787 
788   /// Utility used to convert block arguments.
789   ArgConverter argConverter;
790 
791   /// Ordered vector of all of the newly created operations during conversion.
792   std::vector<Operation *> createdOps;
793 
794   /// Ordered map of requested operation replacements.
795   llvm::MapVector<Operation *, OpReplacement> replacements;
796 
797   /// Ordered vector of any requested block argument replacements.
798   SmallVector<BlockArgument, 4> argReplacements;
799 
800   /// Ordered list of block operations (creations, splits, motions).
801   SmallVector<BlockAction, 4> blockActions;
802 
803   /// A set of operations that should no longer be considered for legalization,
804   /// but were not directly replace/erased/etc. by a pattern. These are
805   /// generally child operations of other operations who were
806   /// replaced/erased/etc. This is not meant to be an exhaustive list of all
807   /// operations, but the minimal set that can be used to detect if a given
808   /// operation should be `ignored`. For example, we may add the operations that
809   /// define non-empty regions to the set, but not any of the others. This
810   /// simplifies the amount of memory needed as we can query if the parent
811   /// operation was ignored.
812   llvm::SetVector<Operation *> ignoredOps;
813 
814   /// A transaction state for each of operations that were updated in-place.
815   SmallVector<OperationTransactionState, 4> rootUpdates;
816 
817   /// A vector of indices into `replacements` of operations that were replaced
818   /// with values with different result types than the original operation, e.g.
819   /// 1->N conversion of some kind.
820   SmallVector<unsigned, 4> operationsWithChangedResults;
821 
822   /// A default type converter, used when block conversions do not have one
823   /// explicitly provided.
824   TypeConverter defaultTypeConverter;
825 
826   /// The current conversion pattern that is being rewritten, or nullptr if
827   /// called from outside of a conversion pattern rewrite.
828   const ConversionPattern *currentConversionPattern = nullptr;
829 
830 #ifndef NDEBUG
831   /// A set of operations that have pending updates. This tracking isn't
832   /// strictly necessary, and is thus only active during debug builds for extra
833   /// verification.
834   SmallPtrSet<Operation *, 1> pendingRootUpdates;
835 
836   /// A logger used to emit diagnostics during the conversion process.
837   llvm::ScopedPrinter logger{llvm::dbgs()};
838 #endif
839 };
840 } // end namespace detail
841 } // end namespace mlir
842 
843 /// Detach any operations nested in the given operation from their parent
844 /// blocks, and erase the given operation. This can be used when the nested
845 /// operations are scheduled for erasure themselves, so deleting the regions of
846 /// the given operation together with their content would result in double-free.
847 /// This happens, for example, when rolling back op creation in the reverse
848 /// order and if the nested ops were created before the parent op. This function
849 /// does not need to collect nested ops recursively because it is expected to
850 /// also be called for each nested op when it is about to be deleted.
851 static void detachNestedAndErase(Operation *op) {
852   for (Region &region : op->getRegions()) {
853     for (Block &block : region.getBlocks()) {
854       while (!block.getOperations().empty())
855         block.getOperations().remove(block.getOperations().begin());
856       block.dropAllDefinedValueUses();
857     }
858   }
859   op->erase();
860 }
861 
862 void ConversionPatternRewriterImpl::discardRewrites() {
863   // Reset any operations that were updated in place.
864   for (auto &state : rootUpdates)
865     state.resetOperation();
866 
867   undoBlockActions();
868 
869   // Remove any newly created ops.
870   for (auto *op : llvm::reverse(createdOps))
871     detachNestedAndErase(op);
872 }
873 
874 void ConversionPatternRewriterImpl::applyRewrites() {
875   // Apply all of the rewrites replacements requested during conversion.
876   for (auto &repl : replacements) {
877     for (OpResult result : repl.first->getResults())
878       if (Value newValue = mapping.lookupOrNull(result))
879         result.replaceAllUsesWith(newValue);
880 
881     // If this operation defines any regions, drop any pending argument
882     // rewrites.
883     if (repl.first->getNumRegions())
884       argConverter.notifyOpRemoved(repl.first);
885   }
886 
887   // Apply all of the requested argument replacements.
888   for (BlockArgument arg : argReplacements) {
889     Value repl = mapping.lookupOrDefault(arg);
890     if (repl.isa<BlockArgument>()) {
891       arg.replaceAllUsesWith(repl);
892       continue;
893     }
894 
895     // If the replacement value is an operation, we check to make sure that we
896     // don't replace uses that are within the parent operation of the
897     // replacement value.
898     Operation *replOp = repl.cast<OpResult>().getOwner();
899     Block *replBlock = replOp->getBlock();
900     arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
901       Operation *user = operand.getOwner();
902       return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
903     });
904   }
905 
906   // In a second pass, erase all of the replaced operations in reverse. This
907   // allows processing nested operations before their parent region is
908   // destroyed.
909   for (auto &repl : llvm::reverse(replacements))
910     repl.first->erase();
911 
912   argConverter.applyRewrites(mapping);
913 
914   // Now that the ops have been erased, also erase dangling blocks.
915   eraseDanglingBlocks();
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // State Management
920 
921 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
922   return RewriterState(createdOps.size(), replacements.size(),
923                        argReplacements.size(), blockActions.size(),
924                        ignoredOps.size(), rootUpdates.size());
925 }
926 
927 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
928   // Reset any operations that were updated in place.
929   for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
930     rootUpdates[i].resetOperation();
931   rootUpdates.resize(state.numRootUpdates);
932 
933   // Reset any replaced arguments.
934   for (BlockArgument replacedArg :
935        llvm::drop_begin(argReplacements, state.numArgReplacements))
936     mapping.erase(replacedArg);
937   argReplacements.resize(state.numArgReplacements);
938 
939   // Undo any block actions.
940   undoBlockActions(state.numBlockActions);
941 
942   // Reset any replaced operations and undo any saved mappings.
943   for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
944     for (auto result : repl.first->getResults())
945       mapping.erase(result);
946   while (replacements.size() != state.numReplacements)
947     replacements.pop_back();
948 
949   // Pop all of the newly created operations.
950   while (createdOps.size() != state.numCreatedOps) {
951     detachNestedAndErase(createdOps.back());
952     createdOps.pop_back();
953   }
954 
955   // Pop all of the recorded ignored operations that are no longer valid.
956   while (ignoredOps.size() != state.numIgnoredOperations)
957     ignoredOps.pop_back();
958 
959   // Reset operations with changed results.
960   while (!operationsWithChangedResults.empty() &&
961          operationsWithChangedResults.back() >= state.numReplacements)
962     operationsWithChangedResults.pop_back();
963 }
964 
965 void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
966   for (auto &action : blockActions)
967     if (action.kind == BlockActionKind::Erase)
968       delete action.block;
969 }
970 
971 void ConversionPatternRewriterImpl::undoBlockActions(
972     unsigned numActionsToKeep) {
973   for (auto &action :
974        llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
975     switch (action.kind) {
976     // Delete the created block.
977     case BlockActionKind::Create: {
978       // Unlink all of the operations within this block, they will be deleted
979       // separately.
980       auto &blockOps = action.block->getOperations();
981       while (!blockOps.empty())
982         blockOps.remove(blockOps.begin());
983       action.block->dropAllDefinedValueUses();
984       action.block->erase();
985       break;
986     }
987     // Put the block (owned by action) back into its original position.
988     case BlockActionKind::Erase: {
989       auto &blockList = action.originalPosition.region->getBlocks();
990       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
991       blockList.insert((insertAfterBlock
992                             ? std::next(Region::iterator(insertAfterBlock))
993                             : blockList.begin()),
994                        action.block);
995       break;
996     }
997     // Split the block at the position which was originally the end of the
998     // destination block (owned by action), and put the instructions back into
999     // the block used before the merge.
1000     case BlockActionKind::Merge: {
1001       Block *sourceBlock = action.mergeInfo.sourceBlock;
1002       Block::iterator splitPoint =
1003           (action.mergeInfo.destBlockLastInst
1004                ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
1005                : action.block->begin());
1006       sourceBlock->getOperations().splice(sourceBlock->begin(),
1007                                           action.block->getOperations(),
1008                                           splitPoint, action.block->end());
1009       break;
1010     }
1011     // Move the block back to its original position.
1012     case BlockActionKind::Move: {
1013       Region *originalRegion = action.originalPosition.region;
1014       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1015       originalRegion->getBlocks().splice(
1016           (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
1017                             : originalRegion->end()),
1018           action.block->getParent()->getBlocks(), action.block);
1019       break;
1020     }
1021     // Merge back the block that was split out.
1022     case BlockActionKind::Split: {
1023       action.originalBlock->getOperations().splice(
1024           action.originalBlock->end(), action.block->getOperations());
1025       action.block->dropAllDefinedValueUses();
1026       action.block->erase();
1027       break;
1028     }
1029     // Undo the type conversion.
1030     case BlockActionKind::TypeConversion: {
1031       argConverter.discardRewrites(action.block);
1032       break;
1033     }
1034     }
1035   }
1036   blockActions.resize(numActionsToKeep);
1037 }
1038 
1039 LogicalResult ConversionPatternRewriterImpl::remapValues(
1040     Location loc, PatternRewriter &rewriter, TypeConverter *converter,
1041     Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
1042   remapped.reserve(llvm::size(operands));
1043 
1044   SmallVector<Type, 1> legalTypes;
1045   for (auto it : llvm::enumerate(operands)) {
1046     Value operand = it.value();
1047     Type origType = operand.getType();
1048 
1049     Value newOperand = mapping.lookupLatestLegal(operand, converter);
1050 
1051     // Handle the case where the conversion was 1->1 and the new operand type
1052     // isn't legal.
1053     Type newOperandType = newOperand.getType();
1054     if (converter) {
1055       if (!converter->isLegal(newOperandType)) {
1056         legalTypes.clear();
1057 
1058         // If there is no legal conversion, fail to match this pattern.
1059         if (failed(converter->convertType(origType, legalTypes))) {
1060           return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1061             diag << "unable to convert type for operand #" << it.index()
1062                  << ", type was " << origType;
1063           });
1064         }
1065         // TODO: There currently isn't any mechanism to do 1->N type conversion
1066         // via the PatternRewriter replacement API, so for now we just ignore
1067         // it.
1068         if (legalTypes.size() != 1) {
1069           remapped.push_back(newOperand);
1070           continue;
1071         }
1072         Type desiredType = legalTypes.front();
1073         newOperand = converter->materializeTargetConversion(
1074             rewriter, loc, desiredType, newOperand);
1075         if (!newOperand) {
1076           return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1077             diag << "unable to materialize a conversion for "
1078                     "operand #"
1079                  << it.index() << ", from " << newOperandType << " to "
1080                  << desiredType;
1081           });
1082         }
1083       }
1084     } else {
1085       // TODO: What we should do here is just set `desiredType` to `origType`
1086       // and then handle the necessary type conversions after the conversion
1087       // process has finished. Unfortunately a lot of patterns currently rely on
1088       // receiving the new operands even if the types change, so we keep the
1089       // original behavior here for now until all of the patterns relying on
1090       // this get updated.
1091       // Attempt to materialize a conversion for this new value.
1092     }
1093     remapped.push_back(newOperand);
1094   }
1095   return success();
1096 }
1097 
1098 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1099   // Check to see if this operation was replaced or its parent ignored.
1100   return replacements.count(op) || ignoredOps.count(op->getParentOp());
1101 }
1102 
1103 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
1104   // Walk this operation and collect nested operations that define non-empty
1105   // regions. We mark such operations as 'ignored' so that we know we don't have
1106   // to convert them, or their nested ops.
1107   if (op->getNumRegions() == 0)
1108     return;
1109   op->walk([&](Operation *op) {
1110     if (llvm::any_of(op->getRegions(),
1111                      [](Region &region) { return !region.empty(); }))
1112       ignoredOps.insert(op);
1113   });
1114 }
1115 
1116 //===----------------------------------------------------------------------===//
1117 // Type Conversion
1118 
1119 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1120     Block *block, TypeConverter &converter,
1121     TypeConverter::SignatureConversion *conversion) {
1122   FailureOr<Block *> result =
1123       conversion ? argConverter.applySignatureConversion(block, converter,
1124                                                          *conversion, mapping)
1125                  : argConverter.convertSignature(block, converter, mapping);
1126   if (Block *newBlock = result.getValue()) {
1127     if (newBlock != block)
1128       blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1129   }
1130   return result;
1131 }
1132 
1133 Block *ConversionPatternRewriterImpl::applySignatureConversion(
1134     Region *region, TypeConverter::SignatureConversion &conversion) {
1135   if (!region->empty()) {
1136     return *convertBlockSignature(&region->front(), defaultTypeConverter,
1137                                   &conversion);
1138   }
1139   return nullptr;
1140 }
1141 
1142 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1143     Region *region, TypeConverter &converter,
1144     TypeConverter::SignatureConversion *entryConversion) {
1145   argConverter.setConverter(region, &converter);
1146   if (region->empty())
1147     return nullptr;
1148 
1149   // Convert the arguments of each block within the region.
1150   FailureOr<Block *> newEntry =
1151       convertBlockSignature(&region->front(), converter, entryConversion);
1152   for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
1153     if (failed(convertBlockSignature(&block, converter)))
1154       return failure();
1155   return newEntry;
1156 }
1157 
1158 //===----------------------------------------------------------------------===//
1159 // Rewriter Notification Hooks
1160 
1161 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1162                                                      ValueRange newValues) {
1163   assert(newValues.size() == op->getNumResults());
1164   assert(!replacements.count(op) && "operation was already replaced");
1165 
1166   // Track if any of the results changed, e.g. erased and replaced with null.
1167   bool resultChanged = false;
1168 
1169   // Create mappings for each of the new result values.
1170   Value newValue, result;
1171   for (auto it : llvm::zip(newValues, op->getResults())) {
1172     std::tie(newValue, result) = it;
1173     if (!newValue) {
1174       resultChanged = true;
1175       continue;
1176     }
1177     // Remap, and check for any result type changes.
1178     mapping.map(result, newValue);
1179     resultChanged |= (newValue.getType() != result.getType());
1180   }
1181   if (resultChanged)
1182     operationsWithChangedResults.push_back(replacements.size());
1183 
1184   // Record the requested operation replacement.
1185   TypeConverter *converter = nullptr;
1186   if (currentConversionPattern)
1187     converter = currentConversionPattern->getTypeConverter();
1188   replacements.insert(std::make_pair(op, OpReplacement(converter)));
1189 
1190   // Mark this operation as recursively ignored so that we don't need to
1191   // convert any nested operations.
1192   markNestedOpsIgnored(op);
1193 }
1194 
1195 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
1196   Region *region = block->getParent();
1197   Block *origPrevBlock = block->getPrevNode();
1198   blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1199 }
1200 
1201 void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
1202   blockActions.push_back(BlockAction::getCreate(block));
1203 }
1204 
1205 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
1206                                                      Block *continuation) {
1207   blockActions.push_back(BlockAction::getSplit(continuation, block));
1208 }
1209 
1210 void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block,
1211                                                             Block *srcBlock) {
1212   blockActions.push_back(BlockAction::getMerge(block, srcBlock));
1213 }
1214 
1215 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
1216     Region &region, Region &parent, Region::iterator before) {
1217   if (region.empty())
1218     return;
1219   Block *laterBlock = &region.back();
1220   for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1221     blockActions.push_back(
1222         BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
1223     laterBlock = &earlierBlock;
1224   }
1225   blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
1226 }
1227 
1228 void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
1229     iterator_range<Region::iterator> &blocks, Location origRegionLoc) {
1230   for (Block &block : blocks)
1231     blockActions.push_back(BlockAction::getCreate(&block));
1232 
1233   // Compute the conversion set for the inlined region.
1234   auto result = computeConversionSet(blocks, origRegionLoc, createdOps);
1235 
1236   // This original region has already had its conversion set computed, so there
1237   // shouldn't be any new failures.
1238   (void)result;
1239   assert(succeeded(result) && "expected region to have no unreachable blocks");
1240 }
1241 
1242 LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
1243     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1244   LLVM_DEBUG({
1245     Diagnostic diag(loc, DiagnosticSeverity::Remark);
1246     reasonCallback(diag);
1247     logger.startLine() << "** Failure : " << diag.str() << "\n";
1248   });
1249   return failure();
1250 }
1251 
1252 //===----------------------------------------------------------------------===//
1253 // ConversionPatternRewriter
1254 //===----------------------------------------------------------------------===//
1255 
1256 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
1257     : PatternRewriter(ctx),
1258       impl(new detail::ConversionPatternRewriterImpl(*this)) {}
1259 ConversionPatternRewriter::~ConversionPatternRewriter() {}
1260 
1261 /// PatternRewriter hook for replacing the results of an operation.
1262 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
1263   LLVM_DEBUG({
1264     impl->logger.startLine()
1265         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1266   });
1267   impl->notifyOpReplaced(op, newValues);
1268 }
1269 
1270 /// PatternRewriter hook for erasing a dead operation. The uses of this
1271 /// operation *must* be made dead by the end of the conversion process,
1272 /// otherwise an assert will be issued.
1273 void ConversionPatternRewriter::eraseOp(Operation *op) {
1274   LLVM_DEBUG({
1275     impl->logger.startLine()
1276         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
1277   });
1278   SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1279   impl->notifyOpReplaced(op, nullRepls);
1280 }
1281 
1282 void ConversionPatternRewriter::eraseBlock(Block *block) {
1283   impl->notifyBlockIsBeingErased(block);
1284 
1285   // Mark all ops for erasure.
1286   for (Operation &op : *block)
1287     eraseOp(&op);
1288 
1289   // Unlink the block from its parent region. The block is kept in the block
1290   // action and will be actually destroyed when rewrites are applied. This
1291   // allows us to keep the operations in the block live and undo the removal by
1292   // re-inserting the block.
1293   block->getParent()->getBlocks().remove(block);
1294 }
1295 
1296 Block *ConversionPatternRewriter::applySignatureConversion(
1297     Region *region, TypeConverter::SignatureConversion &conversion) {
1298   return impl->applySignatureConversion(region, conversion);
1299 }
1300 
1301 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1302     Region *region, TypeConverter &converter,
1303     TypeConverter::SignatureConversion *entryConversion) {
1304   return impl->convertRegionTypes(region, converter, entryConversion);
1305 }
1306 
1307 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1308                                                            Value to) {
1309   LLVM_DEBUG({
1310     Operation *parentOp = from.getOwner()->getParentOp();
1311     impl->logger.startLine() << "** Replace Argument : '" << from
1312                              << "'(in region of '" << parentOp->getName()
1313                              << "'(" << from.getOwner()->getParentOp() << ")\n";
1314   });
1315   impl->argReplacements.push_back(from);
1316   impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1317 }
1318 
1319 /// Return the converted value that replaces 'key'. Return 'key' if there is
1320 /// no such a converted value.
1321 Value ConversionPatternRewriter::getRemappedValue(Value key) {
1322   return impl->mapping.lookupOrDefault(key);
1323 }
1324 
1325 /// PatternRewriter hook for creating a new block with the given arguments.
1326 void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
1327   impl->notifyCreatedBlock(block);
1328 }
1329 
1330 /// PatternRewriter hook for splitting a block into two parts.
1331 Block *ConversionPatternRewriter::splitBlock(Block *block,
1332                                              Block::iterator before) {
1333   auto *continuation = PatternRewriter::splitBlock(block, before);
1334   impl->notifySplitBlock(block, continuation);
1335   return continuation;
1336 }
1337 
1338 /// PatternRewriter hook for merging a block into another.
1339 void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
1340                                             ValueRange argValues) {
1341   impl->notifyBlocksBeingMerged(dest, source);
1342   assert(llvm::all_of(source->getPredecessors(),
1343                       [dest](Block *succ) { return succ == dest; }) &&
1344          "expected 'source' to have no predecessors or only 'dest'");
1345   assert(argValues.size() == source->getNumArguments() &&
1346          "incorrect # of argument replacement values");
1347   for (auto it : llvm::zip(source->getArguments(), argValues))
1348     replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1349   dest->getOperations().splice(dest->end(), source->getOperations());
1350   eraseBlock(source);
1351 }
1352 
1353 /// PatternRewriter hook for moving blocks out of a region.
1354 void ConversionPatternRewriter::inlineRegionBefore(Region &region,
1355                                                    Region &parent,
1356                                                    Region::iterator before) {
1357   impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1358   PatternRewriter::inlineRegionBefore(region, parent, before);
1359 }
1360 
1361 /// PatternRewriter hook for cloning blocks of one region into another.
1362 void ConversionPatternRewriter::cloneRegionBefore(
1363     Region &region, Region &parent, Region::iterator before,
1364     BlockAndValueMapping &mapping) {
1365   if (region.empty())
1366     return;
1367   PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1368 
1369   // Collect the range of the cloned blocks.
1370   auto clonedBeginIt = mapping.lookup(&region.front())->getIterator();
1371   auto clonedBlocks = llvm::make_range(clonedBeginIt, before);
1372   impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
1373 }
1374 
1375 /// PatternRewriter hook for creating a new operation.
1376 void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
1377   LLVM_DEBUG({
1378     impl->logger.startLine()
1379         << "** Insert  : '" << op->getName() << "'(" << op << ")\n";
1380   });
1381   impl->createdOps.push_back(op);
1382 }
1383 
1384 /// PatternRewriter hook for updating the root operation in-place.
1385 void ConversionPatternRewriter::startRootUpdate(Operation *op) {
1386 #ifndef NDEBUG
1387   impl->pendingRootUpdates.insert(op);
1388 #endif
1389   impl->rootUpdates.emplace_back(op);
1390 }
1391 
1392 /// PatternRewriter hook for updating the root operation in-place.
1393 void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
1394   // There is nothing to do here, we only need to track the operation at the
1395   // start of the update.
1396 #ifndef NDEBUG
1397   assert(impl->pendingRootUpdates.erase(op) &&
1398          "operation did not have a pending in-place update");
1399 #endif
1400 }
1401 
1402 /// PatternRewriter hook for updating the root operation in-place.
1403 void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
1404 #ifndef NDEBUG
1405   assert(impl->pendingRootUpdates.erase(op) &&
1406          "operation did not have a pending in-place update");
1407 #endif
1408   // Erase the last update for this operation.
1409   auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
1410   auto &rootUpdates = impl->rootUpdates;
1411   auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1412   rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
1413 }
1414 
1415 /// PatternRewriter hook for notifying match failure reasons.
1416 LogicalResult ConversionPatternRewriter::notifyMatchFailure(
1417     Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
1418   return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
1419 }
1420 
1421 /// Return a reference to the internal implementation.
1422 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
1423   return *impl;
1424 }
1425 
1426 //===----------------------------------------------------------------------===//
1427 // ConversionPattern
1428 //===----------------------------------------------------------------------===//
1429 
1430 /// Attempt to match and rewrite the IR root at the specified operation.
1431 LogicalResult
1432 ConversionPattern::matchAndRewrite(Operation *op,
1433                                    PatternRewriter &rewriter) const {
1434   auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1435   auto &rewriterImpl = dialectRewriter.getImpl();
1436 
1437   // Track the current conversion pattern in the rewriter.
1438   assert(!rewriterImpl.currentConversionPattern &&
1439          "already inside of a pattern rewrite");
1440   llvm::SaveAndRestore<const ConversionPattern *> currentPatternGuard(
1441       rewriterImpl.currentConversionPattern, this);
1442 
1443   // Remap the operands of the operation.
1444   SmallVector<Value, 4> operands;
1445   if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter,
1446                                       getTypeConverter(), op->getOperands(),
1447                                       operands))) {
1448     return failure();
1449   }
1450   return matchAndRewrite(op, operands, dialectRewriter);
1451 }
1452 
1453 //===----------------------------------------------------------------------===//
1454 // OperationLegalizer
1455 //===----------------------------------------------------------------------===//
1456 
1457 namespace {
1458 /// A set of rewrite patterns that can be used to legalize a given operation.
1459 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1460 
1461 /// This class defines a recursive operation legalizer.
1462 class OperationLegalizer {
1463 public:
1464   using LegalizationAction = ConversionTarget::LegalizationAction;
1465 
1466   OperationLegalizer(ConversionTarget &targetInfo,
1467                      const FrozenRewritePatternList &patterns);
1468 
1469   /// Returns true if the given operation is known to be illegal on the target.
1470   bool isIllegal(Operation *op) const;
1471 
1472   /// Attempt to legalize the given operation. Returns success if the operation
1473   /// was legalized, failure otherwise.
1474   LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1475 
1476   /// Returns the conversion target in use by the legalizer.
1477   ConversionTarget &getTarget() { return target; }
1478 
1479 private:
1480   /// Attempt to legalize the given operation by folding it.
1481   LogicalResult legalizeWithFold(Operation *op,
1482                                  ConversionPatternRewriter &rewriter);
1483 
1484   /// Attempt to legalize the given operation by applying a pattern. Returns
1485   /// success if the operation was legalized, failure otherwise.
1486   LogicalResult legalizeWithPattern(Operation *op,
1487                                     ConversionPatternRewriter &rewriter);
1488 
1489   /// Return true if the given pattern may be applied to the given operation,
1490   /// false otherwise.
1491   bool canApplyPattern(Operation *op, const Pattern &pattern,
1492                        ConversionPatternRewriter &rewriter);
1493 
1494   /// Legalize the resultant IR after successfully applying the given pattern.
1495   LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1496                                       ConversionPatternRewriter &rewriter,
1497                                       RewriterState &curState);
1498 
1499   /// Legalizes the actions registered during the execution of a pattern.
1500   LogicalResult legalizePatternBlockActions(Operation *op,
1501                                             ConversionPatternRewriter &rewriter,
1502                                             ConversionPatternRewriterImpl &impl,
1503                                             RewriterState &state,
1504                                             RewriterState &newState);
1505   LogicalResult legalizePatternCreatedOperations(
1506       ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1507       RewriterState &state, RewriterState &newState);
1508   LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1509                                            ConversionPatternRewriterImpl &impl,
1510                                            RewriterState &state,
1511                                            RewriterState &newState);
1512 
1513   //===--------------------------------------------------------------------===//
1514   // Cost Model
1515   //===--------------------------------------------------------------------===//
1516 
1517   /// Build an optimistic legalization graph given the provided patterns. This
1518   /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1519   /// patterns for operations that are not directly legal, but may be
1520   /// transitively legal for the current target given the provided patterns.
1521   void buildLegalizationGraph(
1522       LegalizationPatterns &anyOpLegalizerPatterns,
1523       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1524 
1525   /// Compute the benefit of each node within the computed legalization graph.
1526   /// This orders the patterns within 'legalizerPatterns' based upon two
1527   /// criteria:
1528   ///  1) Prefer patterns that have the lowest legalization depth, i.e.
1529   ///     represent the more direct mapping to the target.
1530   ///  2) When comparing patterns with the same legalization depth, prefer the
1531   ///     pattern with the highest PatternBenefit. This allows for users to
1532   ///     prefer specific legalizations over others.
1533   void computeLegalizationGraphBenefit(
1534       LegalizationPatterns &anyOpLegalizerPatterns,
1535       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1536 
1537   /// Compute the legalization depth when legalizing an operation of the given
1538   /// type.
1539   unsigned computeOpLegalizationDepth(
1540       OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1541       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1542 
1543   /// Apply the conversion cost model to the given set of patterns, and return
1544   /// the smallest legalization depth of any of the patterns. See
1545   /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1546   unsigned applyCostModelToPatterns(
1547       LegalizationPatterns &patterns,
1548       DenseMap<OperationName, unsigned> &minOpPatternDepth,
1549       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1550 
1551   /// The current set of patterns that have been applied.
1552   SmallPtrSet<const Pattern *, 8> appliedPatterns;
1553 
1554   /// The legalization information provided by the target.
1555   ConversionTarget &target;
1556 
1557   /// The pattern applicator to use for conversions.
1558   PatternApplicator applicator;
1559 };
1560 } // namespace
1561 
1562 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
1563                                        const FrozenRewritePatternList &patterns)
1564     : target(targetInfo), applicator(patterns) {
1565   // The set of patterns that can be applied to illegal operations to transform
1566   // them into legal ones.
1567   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
1568   LegalizationPatterns anyOpLegalizerPatterns;
1569 
1570   buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1571   computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1572 }
1573 
1574 bool OperationLegalizer::isIllegal(Operation *op) const {
1575   // Check if the target explicitly marked this operation as illegal.
1576   return target.getOpAction(op->getName()) == LegalizationAction::Illegal;
1577 }
1578 
1579 LogicalResult
1580 OperationLegalizer::legalize(Operation *op,
1581                              ConversionPatternRewriter &rewriter) {
1582 #ifndef NDEBUG
1583   const char *logLineComment =
1584       "//===-------------------------------------------===//\n";
1585 
1586   auto &rewriterImpl = rewriter.getImpl();
1587 #endif
1588   LLVM_DEBUG({
1589     auto &os = rewriterImpl.logger;
1590     os.getOStream() << "\n";
1591     os.startLine() << logLineComment;
1592     os.startLine() << "Legalizing operation : '" << op->getName() << "'(" << op
1593                    << ") {\n";
1594     os.indent();
1595 
1596     // If the operation has no regions, just print it here.
1597     if (op->getNumRegions() == 0) {
1598       op->print(os.startLine(), OpPrintingFlags().printGenericOpForm());
1599       os.getOStream() << "\n\n";
1600     }
1601   });
1602 
1603   // Check if this operation is legal on the target.
1604   if (auto legalityInfo = target.isLegal(op)) {
1605     LLVM_DEBUG({
1606       logSuccess(
1607           rewriterImpl.logger, "operation marked legal by the target{0}",
1608           legalityInfo->isRecursivelyLegal
1609               ? "; NOTE: operation is recursively legal; skipping internals"
1610               : "");
1611       rewriterImpl.logger.startLine() << logLineComment;
1612     });
1613 
1614     // If this operation is recursively legal, mark its children as ignored so
1615     // that we don't consider them for legalization.
1616     if (legalityInfo->isRecursivelyLegal)
1617       rewriter.getImpl().markNestedOpsIgnored(op);
1618     return success();
1619   }
1620 
1621   // Check to see if the operation is ignored and doesn't need to be converted.
1622   if (rewriter.getImpl().isOpIgnored(op)) {
1623     LLVM_DEBUG({
1624       logSuccess(rewriterImpl.logger,
1625                  "operation marked 'ignored' during conversion");
1626       rewriterImpl.logger.startLine() << logLineComment;
1627     });
1628     return success();
1629   }
1630 
1631   // If the operation isn't legal, try to fold it in-place.
1632   // TODO: Should we always try to do this, even if the op is
1633   // already legal?
1634   if (succeeded(legalizeWithFold(op, rewriter))) {
1635     LLVM_DEBUG({
1636       logSuccess(rewriterImpl.logger, "operation was folded");
1637       rewriterImpl.logger.startLine() << logLineComment;
1638     });
1639     return success();
1640   }
1641 
1642   // Otherwise, we need to apply a legalization pattern to this operation.
1643   if (succeeded(legalizeWithPattern(op, rewriter))) {
1644     LLVM_DEBUG({
1645       logSuccess(rewriterImpl.logger, "");
1646       rewriterImpl.logger.startLine() << logLineComment;
1647     });
1648     return success();
1649   }
1650 
1651   LLVM_DEBUG({
1652     logFailure(rewriterImpl.logger, "no matched legalization pattern");
1653     rewriterImpl.logger.startLine() << logLineComment;
1654   });
1655   return failure();
1656 }
1657 
1658 LogicalResult
1659 OperationLegalizer::legalizeWithFold(Operation *op,
1660                                      ConversionPatternRewriter &rewriter) {
1661   auto &rewriterImpl = rewriter.getImpl();
1662   RewriterState curState = rewriterImpl.getCurrentState();
1663 
1664   LLVM_DEBUG({
1665     rewriterImpl.logger.startLine() << "* Fold {\n";
1666     rewriterImpl.logger.indent();
1667   });
1668 
1669   // Try to fold the operation.
1670   SmallVector<Value, 2> replacementValues;
1671   rewriter.setInsertionPoint(op);
1672   if (failed(rewriter.tryFold(op, replacementValues))) {
1673     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1674     return failure();
1675   }
1676 
1677   // Insert a replacement for 'op' with the folded replacement values.
1678   rewriter.replaceOp(op, replacementValues);
1679 
1680   // Recursively legalize any new constant operations.
1681   for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1682        i != e; ++i) {
1683     Operation *cstOp = rewriterImpl.createdOps[i];
1684     if (failed(legalize(cstOp, rewriter))) {
1685       LLVM_DEBUG(logFailure(rewriterImpl.logger,
1686                             "generated constant '{0}' was illegal",
1687                             cstOp->getName()));
1688       rewriterImpl.resetState(curState);
1689       return failure();
1690     }
1691   }
1692 
1693   LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1694   return success();
1695 }
1696 
1697 LogicalResult
1698 OperationLegalizer::legalizeWithPattern(Operation *op,
1699                                         ConversionPatternRewriter &rewriter) {
1700   auto &rewriterImpl = rewriter.getImpl();
1701 
1702   // Functor that returns if the given pattern may be applied.
1703   auto canApply = [&](const Pattern &pattern) {
1704     return canApplyPattern(op, pattern, rewriter);
1705   };
1706 
1707   // Functor that cleans up the rewriter state after a pattern failed to match.
1708   RewriterState curState = rewriterImpl.getCurrentState();
1709   auto onFailure = [&](const Pattern &pattern) {
1710     LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
1711     rewriterImpl.resetState(curState);
1712     appliedPatterns.erase(&pattern);
1713   };
1714 
1715   // Functor that performs additional legalization when a pattern is
1716   // successfully applied.
1717   auto onSuccess = [&](const Pattern &pattern) {
1718     auto result = legalizePatternResult(op, pattern, rewriter, curState);
1719     appliedPatterns.erase(&pattern);
1720     if (failed(result))
1721       rewriterImpl.resetState(curState);
1722     return result;
1723   };
1724 
1725   // Try to match and rewrite a pattern on this operation.
1726   return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1727                                     onSuccess);
1728 }
1729 
1730 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1731                                          ConversionPatternRewriter &rewriter) {
1732   LLVM_DEBUG({
1733     auto &os = rewriter.getImpl().logger;
1734     os.getOStream() << "\n";
1735     os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1736     llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs());
1737     os.getOStream() << ")' {\n";
1738     os.indent();
1739   });
1740 
1741   // Ensure that we don't cycle by not allowing the same pattern to be
1742   // applied twice in the same recursion stack if it is not known to be safe.
1743   if (!pattern.hasBoundedRewriteRecursion() &&
1744       !appliedPatterns.insert(&pattern).second) {
1745     LLVM_DEBUG(
1746         logFailure(rewriter.getImpl().logger, "pattern was already applied"));
1747     return false;
1748   }
1749   return true;
1750 }
1751 
1752 LogicalResult
1753 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
1754                                           ConversionPatternRewriter &rewriter,
1755                                           RewriterState &curState) {
1756   auto &impl = rewriter.getImpl();
1757 
1758 #ifndef NDEBUG
1759   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
1760 #endif
1761 
1762   // Check that the root was either replaced or updated in place.
1763   auto replacedRoot = [&] {
1764     return llvm::any_of(
1765         llvm::drop_begin(impl.replacements, curState.numReplacements),
1766         [op](auto &it) { return it.first == op; });
1767   };
1768   auto updatedRootInPlace = [&] {
1769     return llvm::any_of(
1770         llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
1771         [op](auto &state) { return state.getOperation() == op; });
1772   };
1773   (void)replacedRoot;
1774   (void)updatedRootInPlace;
1775   assert((replacedRoot() || updatedRootInPlace()) &&
1776          "expected pattern to replace the root operation");
1777 
1778   // Legalize each of the actions registered during application.
1779   RewriterState newState = impl.getCurrentState();
1780   if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
1781                                          newState)) ||
1782       failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
1783       failed(legalizePatternCreatedOperations(rewriter, impl, curState,
1784                                               newState))) {
1785     return failure();
1786   }
1787 
1788   LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
1789   return success();
1790 }
1791 
1792 LogicalResult OperationLegalizer::legalizePatternBlockActions(
1793     Operation *op, ConversionPatternRewriter &rewriter,
1794     ConversionPatternRewriterImpl &impl, RewriterState &state,
1795     RewriterState &newState) {
1796   SmallPtrSet<Operation *, 16> operationsToIgnore;
1797 
1798   // If the pattern moved or created any blocks, make sure the types of block
1799   // arguments get legalized.
1800   for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
1801        ++i) {
1802     auto &action = impl.blockActions[i];
1803     if (action.kind == BlockActionKind::TypeConversion ||
1804         action.kind == BlockActionKind::Erase)
1805       continue;
1806     // Only check blocks outside of the current operation.
1807     Operation *parentOp = action.block->getParentOp();
1808     if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
1809       continue;
1810 
1811     // If the region of the block has a type converter, try to convert the block
1812     // directly.
1813     if (auto *converter =
1814             impl.argConverter.getConverter(action.block->getParent())) {
1815       if (failed(impl.convertBlockSignature(action.block, *converter))) {
1816         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
1817                                            "block"));
1818         return failure();
1819       }
1820       continue;
1821     }
1822 
1823     // Otherwise, check that this operation isn't one generated by this pattern.
1824     // This is because we will attempt to legalize the parent operation, and
1825     // blocks in regions created by this pattern will already be legalized later
1826     // on. If we haven't built the set yet, build it now.
1827     if (operationsToIgnore.empty()) {
1828       auto createdOps = ArrayRef<Operation *>(impl.createdOps)
1829                             .drop_front(state.numCreatedOps);
1830       operationsToIgnore.insert(createdOps.begin(), createdOps.end());
1831     }
1832 
1833     // If this operation should be considered for re-legalization, try it.
1834     if (operationsToIgnore.insert(parentOp).second &&
1835         failed(legalize(parentOp, rewriter))) {
1836       LLVM_DEBUG(logFailure(
1837           impl.logger, "operation '{0}'({1}) became illegal after block action",
1838           parentOp->getName(), parentOp));
1839       return failure();
1840     }
1841   }
1842   return success();
1843 }
1844 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
1845     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1846     RewriterState &state, RewriterState &newState) {
1847   for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
1848     Operation *op = impl.createdOps[i];
1849     if (failed(legalize(op, rewriter))) {
1850       LLVM_DEBUG(logFailure(impl.logger,
1851                             "generated operation '{0}'({1}) was illegal",
1852                             op->getName(), op));
1853       return failure();
1854     }
1855   }
1856   return success();
1857 }
1858 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
1859     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1860     RewriterState &state, RewriterState &newState) {
1861   for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
1862     Operation *op = impl.rootUpdates[i].getOperation();
1863     if (failed(legalize(op, rewriter))) {
1864       LLVM_DEBUG(logFailure(impl.logger,
1865                             "operation updated in-place '{0}' was illegal",
1866                             op->getName()));
1867       return failure();
1868     }
1869   }
1870   return success();
1871 }
1872 
1873 //===----------------------------------------------------------------------===//
1874 // Cost Model
1875 
1876 void OperationLegalizer::buildLegalizationGraph(
1877     LegalizationPatterns &anyOpLegalizerPatterns,
1878     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1879   // A mapping between an operation and a set of operations that can be used to
1880   // generate it.
1881   DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
1882   // A mapping between an operation and any currently invalid patterns it has.
1883   DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
1884   // A worklist of patterns to consider for legality.
1885   llvm::SetVector<const Pattern *> patternWorklist;
1886 
1887   // Build the mapping from operations to the parent ops that may generate them.
1888   applicator.walkAllPatterns([&](const Pattern &pattern) {
1889     Optional<OperationName> root = pattern.getRootKind();
1890 
1891     // If the pattern has no specific root, we can't analyze the relationship
1892     // between the root op and generated operations. Given that, add all such
1893     // patterns to the legalization set.
1894     if (!root) {
1895       anyOpLegalizerPatterns.push_back(&pattern);
1896       return;
1897     }
1898 
1899     // Skip operations that are always known to be legal.
1900     if (target.getOpAction(*root) == LegalizationAction::Legal)
1901       return;
1902 
1903     // Add this pattern to the invalid set for the root op and record this root
1904     // as a parent for any generated operations.
1905     invalidPatterns[*root].insert(&pattern);
1906     for (auto op : pattern.getGeneratedOps())
1907       parentOps[op].insert(*root);
1908 
1909     // Add this pattern to the worklist.
1910     patternWorklist.insert(&pattern);
1911   });
1912 
1913   // If there are any patterns that don't have a specific root kind, we can't
1914   // make direct assumptions about what operations will never be legalized.
1915   // Note: Technically we could, but it would require an analysis that may
1916   // recurse into itself. It would be better to perform this kind of filtering
1917   // at a higher level than here anyways.
1918   if (!anyOpLegalizerPatterns.empty()) {
1919     for (const Pattern *pattern : patternWorklist)
1920       legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1921     return;
1922   }
1923 
1924   while (!patternWorklist.empty()) {
1925     auto *pattern = patternWorklist.pop_back_val();
1926 
1927     // Check to see if any of the generated operations are invalid.
1928     if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
1929           Optional<LegalizationAction> action = target.getOpAction(op);
1930           return !legalizerPatterns.count(op) &&
1931                  (!action || action == LegalizationAction::Illegal);
1932         }))
1933       continue;
1934 
1935     // Otherwise, if all of the generated operation are valid, this op is now
1936     // legal so add all of the child patterns to the worklist.
1937     legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1938     invalidPatterns[*pattern->getRootKind()].erase(pattern);
1939 
1940     // Add any invalid patterns of the parent operations to see if they have now
1941     // become legal.
1942     for (auto op : parentOps[*pattern->getRootKind()])
1943       patternWorklist.set_union(invalidPatterns[op]);
1944   }
1945 }
1946 
1947 void OperationLegalizer::computeLegalizationGraphBenefit(
1948     LegalizationPatterns &anyOpLegalizerPatterns,
1949     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1950   // The smallest pattern depth, when legalizing an operation.
1951   DenseMap<OperationName, unsigned> minOpPatternDepth;
1952 
1953   // For each operation that is transitively legal, compute a cost for it.
1954   for (auto &opIt : legalizerPatterns)
1955     if (!minOpPatternDepth.count(opIt.first))
1956       computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
1957                                  legalizerPatterns);
1958 
1959   // Apply the cost model to the patterns that can match any operation. Those
1960   // with a specific operation type are already resolved when computing the op
1961   // legalization depth.
1962   if (!anyOpLegalizerPatterns.empty())
1963     applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
1964                              legalizerPatterns);
1965 
1966   // Apply a cost model to the pattern applicator. We order patterns first by
1967   // depth then benefit. `legalizerPatterns` contains per-op patterns by
1968   // decreasing benefit.
1969   applicator.applyCostModel([&](const Pattern &pattern) {
1970     ArrayRef<const Pattern *> orderedPatternList;
1971     if (Optional<OperationName> rootName = pattern.getRootKind())
1972       orderedPatternList = legalizerPatterns[*rootName];
1973     else
1974       orderedPatternList = anyOpLegalizerPatterns;
1975 
1976     // If the pattern is not found, then it was removed and cannot be matched.
1977     auto it = llvm::find(orderedPatternList, &pattern);
1978     if (it == orderedPatternList.end())
1979       return PatternBenefit::impossibleToMatch();
1980 
1981     // Patterns found earlier in the list have higher benefit.
1982     return PatternBenefit(std::distance(it, orderedPatternList.end()));
1983   });
1984 }
1985 
1986 unsigned OperationLegalizer::computeOpLegalizationDepth(
1987     OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1988     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1989   // Check for existing depth.
1990   auto depthIt = minOpPatternDepth.find(op);
1991   if (depthIt != minOpPatternDepth.end())
1992     return depthIt->second;
1993 
1994   // If a mapping for this operation does not exist, then this operation
1995   // is always legal. Return 0 as the depth for a directly legal operation.
1996   auto opPatternsIt = legalizerPatterns.find(op);
1997   if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
1998     return 0u;
1999 
2000   // Record this initial depth in case we encounter this op again when
2001   // recursively computing the depth.
2002   minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2003 
2004   // Apply the cost model to the operation patterns, and update the minimum
2005   // depth.
2006   unsigned minDepth = applyCostModelToPatterns(
2007       opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2008   minOpPatternDepth[op] = minDepth;
2009   return minDepth;
2010 }
2011 
2012 unsigned OperationLegalizer::applyCostModelToPatterns(
2013     LegalizationPatterns &patterns,
2014     DenseMap<OperationName, unsigned> &minOpPatternDepth,
2015     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2016   unsigned minDepth = std::numeric_limits<unsigned>::max();
2017 
2018   // Compute the depth for each pattern within the set.
2019   SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2020   patternsByDepth.reserve(patterns.size());
2021   for (const Pattern *pattern : patterns) {
2022     unsigned depth = 0;
2023     for (auto generatedOp : pattern->getGeneratedOps()) {
2024       unsigned generatedOpDepth = computeOpLegalizationDepth(
2025           generatedOp, minOpPatternDepth, legalizerPatterns);
2026       depth = std::max(depth, generatedOpDepth + 1);
2027     }
2028     patternsByDepth.emplace_back(pattern, depth);
2029 
2030     // Update the minimum depth of the pattern list.
2031     minDepth = std::min(minDepth, depth);
2032   }
2033 
2034   // If the operation only has one legalization pattern, there is no need to
2035   // sort them.
2036   if (patternsByDepth.size() == 1)
2037     return minDepth;
2038 
2039   // Sort the patterns by those likely to be the most beneficial.
2040   llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
2041                        [](const std::pair<const Pattern *, unsigned> *lhs,
2042                           const std::pair<const Pattern *, unsigned> *rhs) {
2043                          // First sort by the smaller pattern legalization
2044                          // depth.
2045                          if (lhs->second != rhs->second)
2046                            return llvm::array_pod_sort_comparator<unsigned>(
2047                                &lhs->second, &rhs->second);
2048 
2049                          // Then sort by the larger pattern benefit.
2050                          auto lhsBenefit = lhs->first->getBenefit();
2051                          auto rhsBenefit = rhs->first->getBenefit();
2052                          return llvm::array_pod_sort_comparator<PatternBenefit>(
2053                              &rhsBenefit, &lhsBenefit);
2054                        });
2055 
2056   // Update the legalization pattern to use the new sorted list.
2057   patterns.clear();
2058   for (auto &patternIt : patternsByDepth)
2059     patterns.push_back(patternIt.first);
2060   return minDepth;
2061 }
2062 
2063 //===----------------------------------------------------------------------===//
2064 // OperationConverter
2065 //===----------------------------------------------------------------------===//
2066 namespace {
2067 enum OpConversionMode {
2068   // In this mode, the conversion will ignore failed conversions to allow
2069   // illegal operations to co-exist in the IR.
2070   Partial,
2071 
2072   // In this mode, all operations must be legal for the given target for the
2073   // conversion to succeed.
2074   Full,
2075 
2076   // In this mode, operations are analyzed for legality. No actual rewrites are
2077   // applied to the operations on success.
2078   Analysis,
2079 };
2080 
2081 // This class converts operations to a given conversion target via a set of
2082 // rewrite patterns. The conversion behaves differently depending on the
2083 // conversion mode.
2084 struct OperationConverter {
2085   explicit OperationConverter(ConversionTarget &target,
2086                               const FrozenRewritePatternList &patterns,
2087                               OpConversionMode mode,
2088                               DenseSet<Operation *> *trackedOps = nullptr)
2089       : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2090 
2091   /// Converts the given operations to the conversion target.
2092   LogicalResult convertOperations(ArrayRef<Operation *> ops);
2093 
2094 private:
2095   /// Converts an operation with the given rewriter.
2096   LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2097 
2098   /// This method is called after the conversion process to legalize any
2099   /// remaining artifacts and complete the conversion.
2100   LogicalResult finalize(ConversionPatternRewriter &rewriter);
2101 
2102   /// Legalize the types of converted block arguments.
2103   LogicalResult
2104   legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2105                                  ConversionPatternRewriterImpl &rewriterImpl);
2106 
2107   /// Legalize an operation result that was marked as "erased".
2108   LogicalResult
2109   legalizeErasedResult(Operation *op, OpResult result,
2110                        ConversionPatternRewriterImpl &rewriterImpl);
2111 
2112   /// Legalize an operation result that was replaced with a value of a different
2113   /// type.
2114   LogicalResult
2115   legalizeChangedResultType(Operation *op, OpResult result, Value newValue,
2116                             TypeConverter *replConverter,
2117                             ConversionPatternRewriter &rewriter,
2118                             ConversionPatternRewriterImpl &rewriterImpl);
2119 
2120   /// The legalizer to use when converting operations.
2121   OperationLegalizer opLegalizer;
2122 
2123   /// The conversion mode to use when legalizing operations.
2124   OpConversionMode mode;
2125 
2126   /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2127   /// this is populated with ops found to be legalizable to the target.
2128   /// When mode == OpConversionMode::Partial, this is populated with ops found
2129   /// *not* to be legalizable to the target.
2130   DenseSet<Operation *> *trackedOps;
2131 };
2132 } // end anonymous namespace
2133 
2134 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2135                                           Operation *op) {
2136   // Legalize the given operation.
2137   if (failed(opLegalizer.legalize(op, rewriter))) {
2138     // Handle the case of a failed conversion for each of the different modes.
2139     // Full conversions expect all operations to be converted.
2140     if (mode == OpConversionMode::Full)
2141       return op->emitError()
2142              << "failed to legalize operation '" << op->getName() << "'";
2143     // Partial conversions allow conversions to fail iff the operation was not
2144     // explicitly marked as illegal. If the user provided a nonlegalizableOps
2145     // set, non-legalizable ops are included.
2146     if (mode == OpConversionMode::Partial) {
2147       if (opLegalizer.isIllegal(op))
2148         return op->emitError()
2149                << "failed to legalize operation '" << op->getName()
2150                << "' that was explicitly marked illegal";
2151       if (trackedOps)
2152         trackedOps->insert(op);
2153     }
2154   } else if (mode == OpConversionMode::Analysis) {
2155     // Analysis conversions don't fail if any operations fail to legalize,
2156     // they are only interested in the operations that were successfully
2157     // legalized.
2158     trackedOps->insert(op);
2159   }
2160   return success();
2161 }
2162 
2163 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2164   if (ops.empty())
2165     return success();
2166   ConversionTarget &target = opLegalizer.getTarget();
2167 
2168   // Compute the set of operations and blocks to convert.
2169   std::vector<Operation *> toConvert;
2170   for (auto *op : ops) {
2171     toConvert.emplace_back(op);
2172     for (auto &region : op->getRegions())
2173       if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
2174                                       toConvert, &target)))
2175         return failure();
2176   }
2177 
2178   // Convert each operation and discard rewrites on failure.
2179   ConversionPatternRewriter rewriter(ops.front()->getContext());
2180   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2181   for (auto *op : toConvert)
2182     if (failed(convert(rewriter, op)))
2183       return rewriterImpl.discardRewrites(), failure();
2184 
2185   // Now that all of the operations have been converted, finalize the conversion
2186   // process to ensure any lingering conversion artifacts are cleaned up and
2187   // legalized.
2188   if (failed(finalize(rewriter)))
2189     return rewriterImpl.discardRewrites(), failure();
2190 
2191   // After a successful conversion, apply rewrites if this is not an analysis
2192   // conversion.
2193   if (mode == OpConversionMode::Analysis)
2194     rewriterImpl.discardRewrites();
2195   else
2196     rewriterImpl.applyRewrites();
2197   return success();
2198 }
2199 
2200 LogicalResult
2201 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2202   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2203 
2204   // Legalize converted block arguments.
2205   if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2206     return failure();
2207 
2208   // Process requested operation replacements.
2209   for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
2210        i != e; ++i) {
2211     unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
2212     auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
2213     for (OpResult result : repl.first->getResults()) {
2214       Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2215 
2216       // If the operation result was replaced with null, all of the uses of this
2217       // value should be replaced.
2218       if (!newValue) {
2219         if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2220           return failure();
2221         continue;
2222       }
2223 
2224       // Otherwise, check to see if the type of the result changed.
2225       if (result.getType() == newValue.getType())
2226         continue;
2227 
2228       // Legalize this result.
2229       rewriter.setInsertionPoint(repl.first);
2230       if (failed(legalizeChangedResultType(repl.first, result, newValue,
2231                                            repl.second.converter, rewriter,
2232                                            rewriterImpl)))
2233         return failure();
2234 
2235       // Update the end iterator for this loop in the case it was updated
2236       // when legalizing generated conversion operations.
2237       e = rewriterImpl.operationsWithChangedResults.size();
2238     }
2239   }
2240   return success();
2241 }
2242 
2243 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2244     ConversionPatternRewriter &rewriter,
2245     ConversionPatternRewriterImpl &rewriterImpl) {
2246   // Functor used to check if all users of a value will be dead after
2247   // conversion.
2248   auto findLiveUser = [&](Value val) {
2249     auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2250       return rewriterImpl.isOpIgnored(user);
2251     });
2252     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2253   };
2254 
2255   // Materialize any necessary conversions for converted block arguments that
2256   // are still live.
2257   size_t numCreatedOps = rewriterImpl.createdOps.size();
2258   if (failed(rewriterImpl.argConverter.materializeLiveConversions(
2259           rewriterImpl.mapping, rewriter, findLiveUser)))
2260     return failure();
2261 
2262   // Legalize any newly created operations during argument materialization.
2263   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2264     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2265       return rewriterImpl.createdOps[i]->emitError()
2266              << "failed to legalize conversion operation generated for block "
2267                 "argument that remained live after conversion";
2268     }
2269   }
2270   return success();
2271 }
2272 
2273 LogicalResult OperationConverter::legalizeErasedResult(
2274     Operation *op, OpResult result,
2275     ConversionPatternRewriterImpl &rewriterImpl) {
2276   // If the operation result was replaced with null, all of the uses of this
2277   // value should be replaced.
2278   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2279     return rewriterImpl.isOpIgnored(user);
2280   });
2281   if (liveUserIt != result.user_end()) {
2282     InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2283                               << op->getName() << "' marked as erased";
2284     diag.attachNote(liveUserIt->getLoc())
2285         << "found live user of result #" << result.getResultNumber() << ": "
2286         << *liveUserIt;
2287     return failure();
2288   }
2289   return success();
2290 }
2291 
2292 LogicalResult OperationConverter::legalizeChangedResultType(
2293     Operation *op, OpResult result, Value newValue,
2294     TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2295     ConversionPatternRewriterImpl &rewriterImpl) {
2296   // Walk the users of this value to see if there are any live users that
2297   // weren't replaced during conversion.
2298   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2299     return rewriterImpl.isOpIgnored(user);
2300   });
2301   if (liveUserIt == result.user_end())
2302     return success();
2303 
2304   // If the replacement has a type converter, attempt to materialize a
2305   // conversion back to the original type.
2306   if (!replConverter) {
2307     // TODO: We should emit an error here, similarly to the case where the
2308     // result is replaced with null. Unfortunately a lot of existing
2309     // patterns rely on this behavior, so until those patterns are updated
2310     // we keep the legacy behavior here of just forwarding the new value.
2311     return success();
2312   }
2313 
2314   // Track the number of created operations so that new ones can be legalized.
2315   size_t numCreatedOps = rewriterImpl.createdOps.size();
2316 
2317   // Materialize a conversion for this live result value.
2318   Type resultType = result.getType();
2319   Value convertedValue = replConverter->materializeSourceConversion(
2320       rewriter, op->getLoc(), resultType, newValue);
2321   if (!convertedValue) {
2322     InFlightDiagnostic diag = op->emitError()
2323                               << "failed to materialize conversion for result #"
2324                               << result.getResultNumber() << " of operation '"
2325                               << op->getName()
2326                               << "' that remained live after conversion";
2327     diag.attachNote(liveUserIt->getLoc())
2328         << "see existing live user here: " << *liveUserIt;
2329     return failure();
2330   }
2331 
2332   // Legalize all of the newly created conversion operations.
2333   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2334     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2335       return op->emitError("failed to legalize conversion operation generated ")
2336              << "for result #" << result.getResultNumber() << " of operation '"
2337              << op->getName() << "' that remained live after conversion";
2338     }
2339   }
2340 
2341   rewriterImpl.mapping.map(result, convertedValue);
2342   return success();
2343 }
2344 
2345 //===----------------------------------------------------------------------===//
2346 // Type Conversion
2347 //===----------------------------------------------------------------------===//
2348 
2349 /// Remap an input of the original signature with a new set of types. The
2350 /// new types are appended to the new signature conversion.
2351 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
2352                                                    ArrayRef<Type> types) {
2353   assert(!types.empty() && "expected valid types");
2354   remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2355   addInputs(types);
2356 }
2357 
2358 /// Append new input types to the signature conversion, this should only be
2359 /// used if the new types are not intended to remap an existing input.
2360 void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
2361   assert(!types.empty() &&
2362          "1->0 type remappings don't need to be added explicitly");
2363   argTypes.append(types.begin(), types.end());
2364 }
2365 
2366 /// Remap an input of the original signature with a range of types in the
2367 /// new signature.
2368 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2369                                                     unsigned newInputNo,
2370                                                     unsigned newInputCount) {
2371   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2372   assert(newInputCount != 0 && "expected valid input count");
2373   remappedInputs[origInputNo] =
2374       InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2375 }
2376 
2377 /// Remap an input of the original signature to another `replacementValue`
2378 /// value. This would make the signature converter drop this argument.
2379 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2380                                                     Value replacementValue) {
2381   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2382   remappedInputs[origInputNo] =
2383       InputMapping{origInputNo, /*size=*/0, replacementValue};
2384 }
2385 
2386 /// This hooks allows for converting a type.
2387 LogicalResult TypeConverter::convertType(Type t,
2388                                          SmallVectorImpl<Type> &results) {
2389   auto existingIt = cachedDirectConversions.find(t);
2390   if (existingIt != cachedDirectConversions.end()) {
2391     if (existingIt->second)
2392       results.push_back(existingIt->second);
2393     return success(existingIt->second != nullptr);
2394   }
2395   auto multiIt = cachedMultiConversions.find(t);
2396   if (multiIt != cachedMultiConversions.end()) {
2397     results.append(multiIt->second.begin(), multiIt->second.end());
2398     return success();
2399   }
2400 
2401   // Walk the added converters in reverse order to apply the most recently
2402   // registered first.
2403   size_t currentCount = results.size();
2404   for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2405     if (Optional<LogicalResult> result = converter(t, results)) {
2406       if (!succeeded(*result)) {
2407         cachedDirectConversions.try_emplace(t, nullptr);
2408         return failure();
2409       }
2410       auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2411       if (newTypes.size() == 1)
2412         cachedDirectConversions.try_emplace(t, newTypes.front());
2413       else
2414         cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2415       return success();
2416     }
2417   }
2418   return failure();
2419 }
2420 
2421 /// This hook simplifies defining 1-1 type conversions. This function returns
2422 /// the type to convert to on success, and a null type on failure.
2423 Type TypeConverter::convertType(Type t) {
2424   // Use the multi-type result version to convert the type.
2425   SmallVector<Type, 1> results;
2426   if (failed(convertType(t, results)))
2427     return nullptr;
2428 
2429   // Check to ensure that only one type was produced.
2430   return results.size() == 1 ? results.front() : nullptr;
2431 }
2432 
2433 /// Convert the given set of types, filling 'results' as necessary. This
2434 /// returns failure if the conversion of any of the types fails, success
2435 /// otherwise.
2436 LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
2437                                           SmallVectorImpl<Type> &results) {
2438   for (auto type : types)
2439     if (failed(convertType(type, results)))
2440       return failure();
2441   return success();
2442 }
2443 
2444 /// Return true if the given type is legal for this type converter, i.e. the
2445 /// type converts to itself.
2446 bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
2447 /// Return true if the given operation has legal operand and result types.
2448 bool TypeConverter::isLegal(Operation *op) {
2449   return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2450 }
2451 
2452 /// Return true if the types of block arguments within the region are legal.
2453 bool TypeConverter::isLegal(Region *region) {
2454   return llvm::all_of(*region, [this](Block &block) {
2455     return isLegal(block.getArgumentTypes());
2456   });
2457 }
2458 
2459 /// Return true if the inputs and outputs of the given function type are
2460 /// legal.
2461 bool TypeConverter::isSignatureLegal(FunctionType ty) {
2462   return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2463 }
2464 
2465 /// This hook allows for converting a specific argument of a signature.
2466 LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
2467                                                  SignatureConversion &result) {
2468   // Try to convert the given input type.
2469   SmallVector<Type, 1> convertedTypes;
2470   if (failed(convertType(type, convertedTypes)))
2471     return failure();
2472 
2473   // If this argument is being dropped, there is nothing left to do.
2474   if (convertedTypes.empty())
2475     return success();
2476 
2477   // Otherwise, add the new inputs.
2478   result.addInputs(inputNo, convertedTypes);
2479   return success();
2480 }
2481 LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
2482                                                   SignatureConversion &result,
2483                                                   unsigned origInputOffset) {
2484   for (unsigned i = 0, e = types.size(); i != e; ++i)
2485     if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2486       return failure();
2487   return success();
2488 }
2489 
2490 Value TypeConverter::materializeConversion(
2491     MutableArrayRef<MaterializationCallbackFn> materializations,
2492     OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
2493   for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
2494     if (Optional<Value> result = fn(builder, resultType, inputs, loc))
2495       return result.getValue();
2496   return nullptr;
2497 }
2498 
2499 /// This function converts the type signature of the given block, by invoking
2500 /// 'convertSignatureArg' for each argument. This function should return a valid
2501 /// conversion for the signature on success, None otherwise.
2502 auto TypeConverter::convertBlockSignature(Block *block)
2503     -> Optional<SignatureConversion> {
2504   SignatureConversion conversion(block->getNumArguments());
2505   if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2506     return llvm::None;
2507   return conversion;
2508 }
2509 
2510 /// Create a default conversion pattern that rewrites the type signature of a
2511 /// FuncOp.
2512 namespace {
2513 struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
2514   FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
2515       : OpConversionPattern(converter, ctx) {}
2516 
2517   /// Hook for derived classes to implement combined matching and rewriting.
2518   LogicalResult
2519   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
2520                   ConversionPatternRewriter &rewriter) const override {
2521     FunctionType type = funcOp.getType();
2522 
2523     // Convert the original function types.
2524     TypeConverter::SignatureConversion result(type.getNumInputs());
2525     SmallVector<Type, 1> newResults;
2526     if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
2527         failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
2528         failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter,
2529                                            &result)))
2530       return failure();
2531 
2532     // Update the function signature in-place.
2533     rewriter.updateRootInPlace(funcOp, [&] {
2534       funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults,
2535                                        funcOp.getContext()));
2536     });
2537     return success();
2538   }
2539 };
2540 } // end anonymous namespace
2541 
2542 void mlir::populateFuncOpTypeConversionPattern(
2543     OwningRewritePatternList &patterns, MLIRContext *ctx,
2544     TypeConverter &converter) {
2545   patterns.insert<FuncOpSignatureConversion>(ctx, converter);
2546 }
2547 
2548 //===----------------------------------------------------------------------===//
2549 // ConversionTarget
2550 //===----------------------------------------------------------------------===//
2551 
2552 /// Register a legality action for the given operation.
2553 void ConversionTarget::setOpAction(OperationName op,
2554                                    LegalizationAction action) {
2555   legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None};
2556 }
2557 
2558 /// Register a legality action for the given dialects.
2559 void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
2560                                         LegalizationAction action) {
2561   for (StringRef dialect : dialectNames)
2562     legalDialects[dialect] = action;
2563 }
2564 
2565 /// Get the legality action for the given operation.
2566 auto ConversionTarget::getOpAction(OperationName op) const
2567     -> Optional<LegalizationAction> {
2568   Optional<LegalizationInfo> info = getOpInfo(op);
2569   return info ? info->action : Optional<LegalizationAction>();
2570 }
2571 
2572 /// If the given operation instance is legal on this target, a structure
2573 /// containing legality information is returned. If the operation is not legal,
2574 /// None is returned.
2575 auto ConversionTarget::isLegal(Operation *op) const
2576     -> Optional<LegalOpDetails> {
2577   Optional<LegalizationInfo> info = getOpInfo(op->getName());
2578   if (!info)
2579     return llvm::None;
2580 
2581   // Returns true if this operation instance is known to be legal.
2582   auto isOpLegal = [&] {
2583     // Handle dynamic legality either with the provided legality function, or
2584     // the default hook on the derived instance.
2585     if (info->action == LegalizationAction::Dynamic)
2586       return info->legalityFn ? (*info->legalityFn)(op)
2587                               : isDynamicallyLegal(op);
2588 
2589     // Otherwise, the operation is only legal if it was marked 'Legal'.
2590     return info->action == LegalizationAction::Legal;
2591   };
2592   if (!isOpLegal())
2593     return llvm::None;
2594 
2595   // This operation is legal, compute any additional legality information.
2596   LegalOpDetails legalityDetails;
2597   if (info->isRecursivelyLegal) {
2598     auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
2599     if (legalityFnIt != opRecursiveLegalityFns.end())
2600       legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
2601     else
2602       legalityDetails.isRecursivelyLegal = true;
2603   }
2604   return legalityDetails;
2605 }
2606 
2607 /// Set the dynamic legality callback for the given operation.
2608 void ConversionTarget::setLegalityCallback(
2609     OperationName name, const DynamicLegalityCallbackFn &callback) {
2610   assert(callback && "expected valid legality callback");
2611   auto infoIt = legalOperations.find(name);
2612   assert(infoIt != legalOperations.end() &&
2613          infoIt->second.action == LegalizationAction::Dynamic &&
2614          "expected operation to already be marked as dynamically legal");
2615   infoIt->second.legalityFn = callback;
2616 }
2617 
2618 /// Set the recursive legality callback for the given operation and mark the
2619 /// operation as recursively legal.
2620 void ConversionTarget::markOpRecursivelyLegal(
2621     OperationName name, const DynamicLegalityCallbackFn &callback) {
2622   auto infoIt = legalOperations.find(name);
2623   assert(infoIt != legalOperations.end() &&
2624          infoIt->second.action != LegalizationAction::Illegal &&
2625          "expected operation to already be marked as legal");
2626   infoIt->second.isRecursivelyLegal = true;
2627   if (callback)
2628     opRecursiveLegalityFns[name] = callback;
2629   else
2630     opRecursiveLegalityFns.erase(name);
2631 }
2632 
2633 /// Set the dynamic legality callback for the given dialects.
2634 void ConversionTarget::setLegalityCallback(
2635     ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
2636   assert(callback && "expected valid legality callback");
2637   for (StringRef dialect : dialects)
2638     dialectLegalityFns[dialect] = callback;
2639 }
2640 
2641 /// Get the legalization information for the given operation.
2642 auto ConversionTarget::getOpInfo(OperationName op) const
2643     -> Optional<LegalizationInfo> {
2644   // Check for info for this specific operation.
2645   auto it = legalOperations.find(op);
2646   if (it != legalOperations.end())
2647     return it->second;
2648   // Check for info for the parent dialect.
2649   auto dialectIt = legalDialects.find(op.getDialect());
2650   if (dialectIt != legalDialects.end()) {
2651     Optional<DynamicLegalityCallbackFn> callback;
2652     auto dialectFn = dialectLegalityFns.find(op.getDialect());
2653     if (dialectFn != dialectLegalityFns.end())
2654       callback = dialectFn->second;
2655     return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
2656                             callback};
2657   }
2658   // Otherwise, check if we mark unknown operations as dynamic.
2659   if (unknownOpsDynamicallyLegal)
2660     return LegalizationInfo{LegalizationAction::Dynamic,
2661                             /*isRecursivelyLegal=*/false, unknownLegalityFn};
2662   return llvm::None;
2663 }
2664 
2665 //===----------------------------------------------------------------------===//
2666 // Op Conversion Entry Points
2667 //===----------------------------------------------------------------------===//
2668 
2669 /// Apply a partial conversion on the given operations and all nested
2670 /// operations. This method converts as many operations to the target as
2671 /// possible, ignoring operations that failed to legalize. This method only
2672 /// returns failure if there ops explicitly marked as illegal.
2673 /// If an `unconvertedOps` set is provided, all operations that are found not
2674 /// to be legalizable to the given `target` are placed within that set. (Note
2675 /// that if there is an op explicitly marked as illegal, the conversion
2676 /// terminates and the `unconvertedOps` set will not necessarily be complete.)
2677 LogicalResult
2678 mlir::applyPartialConversion(ArrayRef<Operation *> ops,
2679                              ConversionTarget &target,
2680                              const FrozenRewritePatternList &patterns,
2681                              DenseSet<Operation *> *unconvertedOps) {
2682   OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
2683                                  unconvertedOps);
2684   return opConverter.convertOperations(ops);
2685 }
2686 LogicalResult
2687 mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
2688                              const FrozenRewritePatternList &patterns,
2689                              DenseSet<Operation *> *unconvertedOps) {
2690   return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
2691                                 unconvertedOps);
2692 }
2693 
2694 /// Apply a complete conversion on the given operations, and all nested
2695 /// operations. This method will return failure if the conversion of any
2696 /// operation fails.
2697 LogicalResult
2698 mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
2699                           const FrozenRewritePatternList &patterns) {
2700   OperationConverter opConverter(target, patterns, OpConversionMode::Full);
2701   return opConverter.convertOperations(ops);
2702 }
2703 LogicalResult
2704 mlir::applyFullConversion(Operation *op, ConversionTarget &target,
2705                           const FrozenRewritePatternList &patterns) {
2706   return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
2707 }
2708 
2709 /// Apply an analysis conversion on the given operations, and all nested
2710 /// operations. This method analyzes which operations would be successfully
2711 /// converted to the target if a conversion was applied. All operations that
2712 /// were found to be legalizable to the given 'target' are placed within the
2713 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
2714 /// operations on success and only pre-existing operations are added to the set.
2715 LogicalResult
2716 mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
2717                               ConversionTarget &target,
2718                               const FrozenRewritePatternList &patterns,
2719                               DenseSet<Operation *> &convertedOps) {
2720   OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
2721                                  &convertedOps);
2722   return opConverter.convertOperations(ops);
2723 }
2724 LogicalResult
2725 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
2726                               const FrozenRewritePatternList &patterns,
2727                               DenseSet<Operation *> &convertedOps) {
2728   return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
2729                                  convertedOps);
2730 }
2731