1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/BlockAndValueMapping.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/FunctionSupport.h"
15 #include "mlir/Rewrite/PatternApplicator.h"
16 #include "mlir/Transforms/Utils.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/Support/SaveAndRestore.h"
22 #include "llvm/Support/ScopedPrinter.h"
23 
24 using namespace mlir;
25 using namespace mlir::detail;
26 
27 #define DEBUG_TYPE "dialect-conversion"
28 
29 /// Recursively collect all of the operations to convert from within 'region'.
30 /// If 'target' is nonnull, operations that are recursively legal have their
31 /// regions pre-filtered to avoid considering them for legalization.
32 static LogicalResult
33 computeConversionSet(iterator_range<Region::iterator> region,
34                      Location regionLoc, std::vector<Operation *> &toConvert,
35                      ConversionTarget *target = nullptr) {
36   if (llvm::empty(region))
37     return success();
38 
39   // Traverse starting from the entry block.
40   SmallVector<Block *, 16> worklist(1, &*region.begin());
41   DenseSet<Block *> visitedBlocks;
42   visitedBlocks.insert(worklist.front());
43   while (!worklist.empty()) {
44     Block *block = worklist.pop_back_val();
45 
46     // Compute the conversion set of each of the nested operations.
47     for (Operation &op : *block) {
48       toConvert.emplace_back(&op);
49 
50       // Don't check this operation's children for conversion if the operation
51       // is recursively legal.
52       auto legalityInfo = target ? target->isLegal(&op)
53                                  : Optional<ConversionTarget::LegalOpDetails>();
54       if (legalityInfo && legalityInfo->isRecursivelyLegal)
55         continue;
56       for (auto &region : op.getRegions()) {
57         if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
58                                         toConvert, target)))
59           return failure();
60       }
61     }
62 
63     // Recurse to children that haven't been visited.
64     for (Block *succ : block->getSuccessors())
65       if (visitedBlocks.insert(succ).second)
66         worklist.push_back(succ);
67   }
68 
69   // Check that all blocks in the region were visited.
70   if (llvm::any_of(llvm::drop_begin(region, 1),
71                    [&](Block &block) { return !visitedBlocks.count(&block); }))
72     return emitError(regionLoc, "unreachable blocks were not converted");
73   return success();
74 }
75 
76 /// A utility function to log a successful result for the given reason.
77 template <typename... Args>
78 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
79   LLVM_DEBUG({
80     os.unindent();
81     os.startLine() << "} -> SUCCESS";
82     if (!fmt.empty())
83       os.getOStream() << " : "
84                       << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
85     os.getOStream() << "\n";
86   });
87 }
88 
89 /// A utility function to log a failure result for the given reason.
90 template <typename... Args>
91 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
92   LLVM_DEBUG({
93     os.unindent();
94     os.startLine() << "} -> FAILURE : "
95                    << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
96                    << "\n";
97   });
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // ConversionValueMapping
102 //===----------------------------------------------------------------------===//
103 
104 namespace {
105 /// This class wraps a BlockAndValueMapping to provide recursive lookup
106 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
107 struct ConversionValueMapping {
108   /// Lookup a mapped value within the map. If a mapping for the provided value
109   /// does not exist then return the provided value. If `desiredType` is
110   /// non-null, returns the most recently mapped value with that type. If an
111   /// operand of that type does not exist, defaults to normal behavior.
112   Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
113 
114   /// Lookup a mapped value within the map, or return null if a mapping does not
115   /// exist. If a mapping exists, this follows the same behavior of
116   /// `lookupOrDefault`.
117   Value lookupOrNull(Value from) const;
118 
119   /// Map a value to the one provided.
120   void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); }
121 
122   /// Drop the last mapping for the given value.
123   void erase(Value value) { mapping.erase(value); }
124 
125 private:
126   /// Current value mappings.
127   BlockAndValueMapping mapping;
128 };
129 } // end anonymous namespace
130 
131 Value ConversionValueMapping::lookupOrDefault(Value from,
132                                               Type desiredType) const {
133   // If there was no desired type, simply find the leaf value.
134   if (!desiredType) {
135     // If this value had a valid mapping, unmap that value as well in the case
136     // that it was also replaced.
137     while (auto mappedValue = mapping.lookupOrNull(from))
138       from = mappedValue;
139     return from;
140   }
141 
142   // Otherwise, try to find the deepest value that has the desired type.
143   Value desiredValue;
144   do {
145     if (from.getType() == desiredType)
146       desiredValue = from;
147 
148     Value mappedValue = mapping.lookupOrNull(from);
149     if (!mappedValue)
150       break;
151     from = mappedValue;
152   } while (true);
153 
154   // If the desired value was found use it, otherwise default to the leaf value.
155   return desiredValue ? desiredValue : from;
156 }
157 
158 Value ConversionValueMapping::lookupOrNull(Value from) const {
159   Value result = lookupOrDefault(from);
160   return result == from ? nullptr : result;
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // ArgConverter
165 //===----------------------------------------------------------------------===//
166 namespace {
167 /// This class provides a simple interface for converting the types of block
168 /// arguments. This is done by creating a new block that contains the new legal
169 /// types and extracting the block that contains the old illegal types to allow
170 /// for undoing pending rewrites in the case of failure.
171 struct ArgConverter {
172   ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {}
173 
174   /// This structure contains the information pertaining to an argument that has
175   /// been converted.
176   struct ConvertedArgInfo {
177     ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
178                      Value castValue = nullptr)
179         : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
180 
181     /// The start index of in the new argument list that contains arguments that
182     /// replace the original.
183     unsigned newArgIdx;
184 
185     /// The number of arguments that replaced the original argument.
186     unsigned newArgSize;
187 
188     /// The cast value that was created to cast from the new arguments to the
189     /// old. This only used if 'newArgSize' > 1.
190     Value castValue;
191   };
192 
193   /// This structure contains information pertaining to a block that has had its
194   /// signature converted.
195   struct ConvertedBlockInfo {
196     ConvertedBlockInfo(Block *origBlock, TypeConverter &converter)
197         : origBlock(origBlock), converter(&converter) {}
198 
199     /// The original block that was requested to have its signature converted.
200     Block *origBlock;
201 
202     /// The conversion information for each of the arguments. The information is
203     /// None if the argument was dropped during conversion.
204     SmallVector<Optional<ConvertedArgInfo>, 1> argInfo;
205 
206     /// The type converter used to convert the arguments.
207     TypeConverter *converter;
208   };
209 
210   /// Return if the signature of the given block has already been converted.
211   bool hasBeenConverted(Block *block) const {
212     return conversionInfo.count(block) || convertedBlocks.count(block);
213   }
214 
215   /// Set the type converter to use for the given region.
216   void setConverter(Region *region, TypeConverter *typeConverter) {
217     assert(typeConverter && "expected valid type converter");
218     regionToConverter[region] = typeConverter;
219   }
220 
221   /// Return the type converter to use for the given region, or null if there
222   /// isn't one.
223   TypeConverter *getConverter(Region *region) {
224     return regionToConverter.lookup(region);
225   }
226 
227   //===--------------------------------------------------------------------===//
228   // Rewrite Application
229   //===--------------------------------------------------------------------===//
230 
231   /// Erase any rewrites registered for the blocks within the given operation
232   /// which is about to be removed. This merely drops the rewrites without
233   /// undoing them.
234   void notifyOpRemoved(Operation *op);
235 
236   /// Cleanup and undo any generated conversions for the arguments of block.
237   /// This method replaces the new block with the original, reverting the IR to
238   /// its original state.
239   void discardRewrites(Block *block);
240 
241   /// Fully replace uses of the old arguments with the new.
242   void applyRewrites(ConversionValueMapping &mapping);
243 
244   /// Materialize any necessary conversions for converted arguments that have
245   /// live users, using the provided `findLiveUser` to search for a user that
246   /// survives the conversion process.
247   LogicalResult
248   materializeLiveConversions(ConversionValueMapping &mapping,
249                              OpBuilder &builder,
250                              function_ref<Operation *(Value)> findLiveUser);
251 
252   //===--------------------------------------------------------------------===//
253   // Conversion
254   //===--------------------------------------------------------------------===//
255 
256   /// Attempt to convert the signature of the given block, if successful a new
257   /// block is returned containing the new arguments. Returns `block` if it did
258   /// not require conversion.
259   FailureOr<Block *> convertSignature(Block *block, TypeConverter &converter,
260                                       ConversionValueMapping &mapping);
261 
262   /// Apply the given signature conversion on the given block. The new block
263   /// containing the updated signature is returned. If no conversions were
264   /// necessary, e.g. if the block has no arguments, `block` is returned.
265   /// `converter` is used to generate any necessary cast operations that
266   /// translate between the origin argument types and those specified in the
267   /// signature conversion.
268   Block *applySignatureConversion(
269       Block *block, TypeConverter &converter,
270       TypeConverter::SignatureConversion &signatureConversion,
271       ConversionValueMapping &mapping);
272 
273   /// Insert a new conversion into the cache.
274   void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
275 
276   /// A collection of blocks that have had their arguments converted. This is a
277   /// map from the new replacement block, back to the original block.
278   llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
279 
280   /// The set of original blocks that were converted.
281   DenseSet<Block *> convertedBlocks;
282 
283   /// A mapping from valid regions, to those containing the original blocks of a
284   /// conversion.
285   DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
286 
287   /// A mapping of regions to type converters that should be used when
288   /// converting the arguments of blocks within that region.
289   DenseMap<Region *, TypeConverter *> regionToConverter;
290 
291   /// The pattern rewriter to use when materializing conversions.
292   PatternRewriter &rewriter;
293 };
294 } // end anonymous namespace
295 
296 //===----------------------------------------------------------------------===//
297 // Rewrite Application
298 
299 void ArgConverter::notifyOpRemoved(Operation *op) {
300   if (conversionInfo.empty())
301     return;
302 
303   for (Region &region : op->getRegions()) {
304     for (Block &block : region) {
305       // Drop any rewrites from within.
306       for (Operation &nestedOp : block)
307         if (nestedOp.getNumRegions())
308           notifyOpRemoved(&nestedOp);
309 
310       // Check if this block was converted.
311       auto it = conversionInfo.find(&block);
312       if (it == conversionInfo.end())
313         continue;
314 
315       // Drop all uses of the original arguments and delete the original block.
316       Block *origBlock = it->second.origBlock;
317       for (BlockArgument arg : origBlock->getArguments())
318         arg.dropAllUses();
319       conversionInfo.erase(it);
320     }
321   }
322 }
323 
324 void ArgConverter::discardRewrites(Block *block) {
325   auto it = conversionInfo.find(block);
326   if (it == conversionInfo.end())
327     return;
328   Block *origBlock = it->second.origBlock;
329 
330   // Drop all uses of the new block arguments and replace uses of the new block.
331   for (int i = block->getNumArguments() - 1; i >= 0; --i)
332     block->getArgument(i).dropAllUses();
333   block->replaceAllUsesWith(origBlock);
334 
335   // Move the operations back the original block and the delete the new block.
336   origBlock->getOperations().splice(origBlock->end(), block->getOperations());
337   origBlock->moveBefore(block);
338   block->erase();
339 
340   convertedBlocks.erase(origBlock);
341   conversionInfo.erase(it);
342 }
343 
344 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
345   for (auto &info : conversionInfo) {
346     ConvertedBlockInfo &blockInfo = info.second;
347     Block *origBlock = blockInfo.origBlock;
348 
349     // Process the remapping for each of the original arguments.
350     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
351       Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
352       BlockArgument origArg = origBlock->getArgument(i);
353 
354       // Handle the case of a 1->0 value mapping.
355       if (!argInfo) {
356         if (Value newArg = mapping.lookupOrNull(origArg))
357           origArg.replaceAllUsesWith(newArg);
358         continue;
359       }
360 
361       // Otherwise this is a 1->1+ value mapping.
362       Value castValue = argInfo->castValue;
363       assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
364 
365       // If the argument is still used, replace it with the generated cast.
366       if (!origArg.use_empty())
367         origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue));
368     }
369   }
370 }
371 
372 LogicalResult ArgConverter::materializeLiveConversions(
373     ConversionValueMapping &mapping, OpBuilder &builder,
374     function_ref<Operation *(Value)> findLiveUser) {
375   for (auto &info : conversionInfo) {
376     Block *newBlock = info.first;
377     ConvertedBlockInfo &blockInfo = info.second;
378     Block *origBlock = blockInfo.origBlock;
379 
380     // Process the remapping for each of the original arguments.
381     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
382       // FIXME: We should run the below checks even if the type conversion was
383       // 1->N, but a lot of existing lowering rely on the block argument being
384       // blindly replaced. Those usages should be updated, and this if should be
385       // removed.
386       if (blockInfo.argInfo[i])
387         continue;
388 
389       // If the type of this argument changed and the argument is still live, we
390       // need to materialize a conversion.
391       BlockArgument origArg = origBlock->getArgument(i);
392       auto argReplacementValue = mapping.lookupOrDefault(origArg);
393       bool isDroppedArg = argReplacementValue == origArg;
394       if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg)
395         continue;
396       Operation *liveUser = findLiveUser(origArg);
397       if (!liveUser)
398         continue;
399 
400       if (OpResult result = argReplacementValue.dyn_cast<OpResult>())
401         rewriter.setInsertionPointAfter(result.getOwner());
402       else
403         rewriter.setInsertionPointToStart(newBlock);
404       Value newArg = blockInfo.converter->materializeSourceConversion(
405           rewriter, origArg.getLoc(), origArg.getType(),
406           isDroppedArg ? ValueRange() : ValueRange(argReplacementValue));
407       if (!newArg) {
408         InFlightDiagnostic diag =
409             emitError(origArg.getLoc())
410             << "failed to materialize conversion for block argument #" << i
411             << " that remained live after conversion, type was "
412             << origArg.getType();
413         if (!isDroppedArg)
414           diag << ", with target type " << argReplacementValue.getType();
415         diag.attachNote(liveUser->getLoc())
416             << "see existing live user here: " << *liveUser;
417         return failure();
418       }
419       mapping.map(origArg, newArg);
420     }
421   }
422   return success();
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // Conversion
427 
428 FailureOr<Block *>
429 ArgConverter::convertSignature(Block *block, TypeConverter &converter,
430                                ConversionValueMapping &mapping) {
431   // Check if the block was already converted. If the block is detached,
432   // conservatively assume it is going to be deleted.
433   if (hasBeenConverted(block) || !block->getParent())
434     return block;
435 
436   // Try to convert the signature for the block with the provided converter.
437   if (auto conversion = converter.convertBlockSignature(block))
438     return applySignatureConversion(block, converter, *conversion, mapping);
439   return failure();
440 }
441 
442 Block *ArgConverter::applySignatureConversion(
443     Block *block, TypeConverter &converter,
444     TypeConverter::SignatureConversion &signatureConversion,
445     ConversionValueMapping &mapping) {
446   // If no arguments are being changed or added, there is nothing to do.
447   unsigned origArgCount = block->getNumArguments();
448   auto convertedTypes = signatureConversion.getConvertedTypes();
449   if (origArgCount == 0 && convertedTypes.empty())
450     return block;
451 
452   // Split the block at the beginning to get a new block to use for the updated
453   // signature.
454   Block *newBlock = block->splitBlock(block->begin());
455   block->replaceAllUsesWith(newBlock);
456 
457   SmallVector<Value, 4> newArgRange(newBlock->addArguments(convertedTypes));
458   ArrayRef<Value> newArgs(newArgRange);
459 
460   // Remap each of the original arguments as determined by the signature
461   // conversion.
462   ConvertedBlockInfo info(block, converter);
463   info.argInfo.resize(origArgCount);
464 
465   OpBuilder::InsertionGuard guard(rewriter);
466   rewriter.setInsertionPointToStart(newBlock);
467   for (unsigned i = 0; i != origArgCount; ++i) {
468     auto inputMap = signatureConversion.getInputMapping(i);
469     if (!inputMap)
470       continue;
471     BlockArgument origArg = block->getArgument(i);
472 
473     // If inputMap->replacementValue is not nullptr, then the argument is
474     // dropped and a replacement value is provided to be the remappedValue.
475     if (inputMap->replacementValue) {
476       assert(inputMap->size == 0 &&
477              "invalid to provide a replacement value when the argument isn't "
478              "dropped");
479       mapping.map(origArg, inputMap->replacementValue);
480       continue;
481     }
482 
483     // Otherwise, this is a 1->1+ mapping. Call into the provided type converter
484     // to pack the new values. For 1->1 mappings, if there is no materialization
485     // provided, use the argument directly instead.
486     auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
487     Value newArg = converter.materializeArgumentConversion(
488         rewriter, origArg.getLoc(), origArg.getType(), replArgs);
489     if (!newArg) {
490       assert(replArgs.size() == 1 &&
491              "couldn't materialize the result of 1->N conversion");
492       newArg = replArgs.front();
493     }
494     mapping.map(origArg, newArg);
495     info.argInfo[i] =
496         ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
497   }
498 
499   // Remove the original block from the region and return the new one.
500   insertConversion(newBlock, std::move(info));
501   return newBlock;
502 }
503 
504 void ArgConverter::insertConversion(Block *newBlock,
505                                     ConvertedBlockInfo &&info) {
506   // Get a region to insert the old block.
507   Region *region = newBlock->getParent();
508   std::unique_ptr<Region> &mappedRegion = regionMapping[region];
509   if (!mappedRegion)
510     mappedRegion = std::make_unique<Region>(region->getParentOp());
511 
512   // Move the original block to the mapped region and emplace the conversion.
513   mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
514                                    info.origBlock->getIterator());
515   convertedBlocks.insert(info.origBlock);
516   conversionInfo.insert({newBlock, std::move(info)});
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // Rewriter and Translation State
521 //===----------------------------------------------------------------------===//
522 namespace {
523 /// This class contains a snapshot of the current conversion rewriter state.
524 /// This is useful when saving and undoing a set of rewrites.
525 struct RewriterState {
526   RewriterState(unsigned numCreatedOps, unsigned numReplacements,
527                 unsigned numArgReplacements, unsigned numBlockActions,
528                 unsigned numIgnoredOperations, unsigned numRootUpdates)
529       : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
530         numArgReplacements(numArgReplacements),
531         numBlockActions(numBlockActions),
532         numIgnoredOperations(numIgnoredOperations),
533         numRootUpdates(numRootUpdates) {}
534 
535   /// The current number of created operations.
536   unsigned numCreatedOps;
537 
538   /// The current number of replacements queued.
539   unsigned numReplacements;
540 
541   /// The current number of argument replacements queued.
542   unsigned numArgReplacements;
543 
544   /// The current number of block actions performed.
545   unsigned numBlockActions;
546 
547   /// The current number of ignored operations.
548   unsigned numIgnoredOperations;
549 
550   /// The current number of operations that were updated in place.
551   unsigned numRootUpdates;
552 };
553 
554 /// The state of an operation that was updated by a pattern in-place. This
555 /// contains all of the necessary information to reconstruct an operation that
556 /// was updated in place.
557 class OperationTransactionState {
558 public:
559   OperationTransactionState() = default;
560   OperationTransactionState(Operation *op)
561       : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
562         operands(op->operand_begin(), op->operand_end()),
563         successors(op->successor_begin(), op->successor_end()) {}
564 
565   /// Discard the transaction state and reset the state of the original
566   /// operation.
567   void resetOperation() const {
568     op->setLoc(loc);
569     op->setAttrs(attrs);
570     op->setOperands(operands);
571     for (auto it : llvm::enumerate(successors))
572       op->setSuccessor(it.value(), it.index());
573   }
574 
575   /// Return the original operation of this state.
576   Operation *getOperation() const { return op; }
577 
578 private:
579   Operation *op;
580   LocationAttr loc;
581   DictionaryAttr attrs;
582   SmallVector<Value, 8> operands;
583   SmallVector<Block *, 2> successors;
584 };
585 
586 /// This class represents one requested operation replacement via 'replaceOp' or
587 /// 'eraseOp`.
588 struct OpReplacement {
589   OpReplacement() = default;
590   OpReplacement(TypeConverter *converter) : converter(converter) {}
591 
592   /// An optional type converter that can be used to materialize conversions
593   /// between the new and old values if necessary.
594   TypeConverter *converter = nullptr;
595 };
596 
597 /// The kind of the block action performed during the rewrite.  Actions can be
598 /// undone if the conversion fails.
599 enum class BlockActionKind {
600   Create,
601   Erase,
602   Merge,
603   Move,
604   Split,
605   TypeConversion
606 };
607 
608 /// Original position of the given block in its parent region. During undo
609 /// actions, the block needs to be placed after `insertAfterBlock`.
610 struct BlockPosition {
611   Region *region;
612   Block *insertAfterBlock;
613 };
614 
615 /// Information needed to undo the merge actions.
616 /// - the source block, and
617 /// - the Operation that was the last operation in the dest block before the
618 ///   merge (could be null if the dest block was empty).
619 struct MergeInfo {
620   Block *sourceBlock;
621   Operation *destBlockLastInst;
622 };
623 
624 /// The storage class for an undoable block action (one of BlockActionKind),
625 /// contains the information necessary to undo this action.
626 struct BlockAction {
627   static BlockAction getCreate(Block *block) {
628     return {BlockActionKind::Create, block, {}};
629   }
630   static BlockAction getErase(Block *block, BlockPosition originalPosition) {
631     return {BlockActionKind::Erase, block, {originalPosition}};
632   }
633   static BlockAction getMerge(Block *block, Block *sourceBlock) {
634     BlockAction action{BlockActionKind::Merge, block, {}};
635     action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
636     return action;
637   }
638   static BlockAction getMove(Block *block, BlockPosition originalPosition) {
639     return {BlockActionKind::Move, block, {originalPosition}};
640   }
641   static BlockAction getSplit(Block *block, Block *originalBlock) {
642     BlockAction action{BlockActionKind::Split, block, {}};
643     action.originalBlock = originalBlock;
644     return action;
645   }
646   static BlockAction getTypeConversion(Block *block) {
647     return BlockAction{BlockActionKind::TypeConversion, block, {}};
648   }
649 
650   // The action kind.
651   BlockActionKind kind;
652 
653   // A pointer to the block that was created by the action.
654   Block *block;
655 
656   union {
657     // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
658     // contains a pointer to the region that originally contained the block as
659     // well as the position of the block in that region.
660     BlockPosition originalPosition;
661     // In use if kind == BlockActionKind::Split and contains a pointer to the
662     // block that was split into two parts.
663     Block *originalBlock;
664     // In use if kind == BlockActionKind::Merge, and contains the information
665     // needed to undo the merge.
666     MergeInfo mergeInfo;
667   };
668 };
669 } // end anonymous namespace
670 
671 //===----------------------------------------------------------------------===//
672 // ConversionPatternRewriterImpl
673 //===----------------------------------------------------------------------===//
674 namespace mlir {
675 namespace detail {
676 struct ConversionPatternRewriterImpl {
677   ConversionPatternRewriterImpl(PatternRewriter &rewriter)
678       : argConverter(rewriter) {}
679 
680   /// Cleanup and destroy any generated rewrite operations. This method is
681   /// invoked when the conversion process fails.
682   void discardRewrites();
683 
684   /// Apply all requested operation rewrites. This method is invoked when the
685   /// conversion process succeeds.
686   void applyRewrites();
687 
688   //===--------------------------------------------------------------------===//
689   // State Management
690   //===--------------------------------------------------------------------===//
691 
692   /// Return the current state of the rewriter.
693   RewriterState getCurrentState();
694 
695   /// Reset the state of the rewriter to a previously saved point.
696   void resetState(RewriterState state);
697 
698   /// Erase any blocks that were unlinked from their regions and stored in block
699   /// actions.
700   void eraseDanglingBlocks();
701 
702   /// Undo the block actions (motions, splits) one by one in reverse order until
703   /// "numActionsToKeep" actions remains.
704   void undoBlockActions(unsigned numActionsToKeep = 0);
705 
706   /// Remap the given operands to those with potentially different types. The
707   /// provided type converter is used to ensure that the remapped types are
708   /// legal. Returns success if the operands could be remapped, failure
709   /// otherwise.
710   LogicalResult remapValues(Location loc, PatternRewriter &rewriter,
711                             TypeConverter *converter,
712                             Operation::operand_range operands,
713                             SmallVectorImpl<Value> &remapped);
714 
715   /// Returns true if the given operation is ignored, and does not need to be
716   /// converted.
717   bool isOpIgnored(Operation *op) const;
718 
719   /// Recursively marks the nested operations under 'op' as ignored. This
720   /// removes them from being considered for legalization.
721   void markNestedOpsIgnored(Operation *op);
722 
723   //===--------------------------------------------------------------------===//
724   // Type Conversion
725   //===--------------------------------------------------------------------===//
726 
727   /// Convert the signature of the given block.
728   FailureOr<Block *> convertBlockSignature(
729       Block *block, TypeConverter &converter,
730       TypeConverter::SignatureConversion *conversion = nullptr);
731 
732   /// Apply a signature conversion on the given region.
733   Block *
734   applySignatureConversion(Region *region,
735                            TypeConverter::SignatureConversion &conversion);
736 
737   /// Convert the types of block arguments within the given region.
738   FailureOr<Block *>
739   convertRegionTypes(Region *region, TypeConverter &converter,
740                      TypeConverter::SignatureConversion *entryConversion);
741 
742   //===--------------------------------------------------------------------===//
743   // Rewriter Notification Hooks
744   //===--------------------------------------------------------------------===//
745 
746   /// PatternRewriter hook for replacing the results of an operation.
747   void notifyOpReplaced(Operation *op, ValueRange newValues);
748 
749   /// Notifies that a block is about to be erased.
750   void notifyBlockIsBeingErased(Block *block);
751 
752   /// Notifies that a block was created.
753   void notifyCreatedBlock(Block *block);
754 
755   /// Notifies that a block was split.
756   void notifySplitBlock(Block *block, Block *continuation);
757 
758   /// Notifies that `block` is being merged with `srcBlock`.
759   void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
760 
761   /// Notifies that the blocks of a region are about to be moved.
762   void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
763                                         Region::iterator before);
764 
765   /// Notifies that the blocks of a region were cloned into another.
766   void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
767                                    Location origRegionLoc);
768 
769   /// Notifies that a pattern match failed for the given reason.
770   LogicalResult
771   notifyMatchFailure(Location loc,
772                      function_ref<void(Diagnostic &)> reasonCallback);
773 
774   //===--------------------------------------------------------------------===//
775   // State
776   //===--------------------------------------------------------------------===//
777 
778   // Mapping between replaced values that differ in type. This happens when
779   // replacing a value with one of a different type.
780   ConversionValueMapping mapping;
781 
782   /// Utility used to convert block arguments.
783   ArgConverter argConverter;
784 
785   /// Ordered vector of all of the newly created operations during conversion.
786   std::vector<Operation *> createdOps;
787 
788   /// Ordered map of requested operation replacements.
789   llvm::MapVector<Operation *, OpReplacement> replacements;
790 
791   /// Ordered vector of any requested block argument replacements.
792   SmallVector<BlockArgument, 4> argReplacements;
793 
794   /// Ordered list of block operations (creations, splits, motions).
795   SmallVector<BlockAction, 4> blockActions;
796 
797   /// A set of operations that should no longer be considered for legalization,
798   /// but were not directly replace/erased/etc. by a pattern. These are
799   /// generally child operations of other operations who were
800   /// replaced/erased/etc. This is not meant to be an exhaustive list of all
801   /// operations, but the minimal set that can be used to detect if a given
802   /// operation should be `ignored`. For example, we may add the operations that
803   /// define non-empty regions to the set, but not any of the others. This
804   /// simplifies the amount of memory needed as we can query if the parent
805   /// operation was ignored.
806   llvm::SetVector<Operation *> ignoredOps;
807 
808   /// A transaction state for each of operations that were updated in-place.
809   SmallVector<OperationTransactionState, 4> rootUpdates;
810 
811   /// A vector of indices into `replacements` of operations that were replaced
812   /// with values with different result types than the original operation, e.g.
813   /// 1->N conversion of some kind.
814   SmallVector<unsigned, 4> operationsWithChangedResults;
815 
816   /// A default type converter, used when block conversions do not have one
817   /// explicitly provided.
818   TypeConverter defaultTypeConverter;
819 
820   /// The current conversion pattern that is being rewritten, or nullptr if
821   /// called from outside of a conversion pattern rewrite.
822   const ConversionPattern *currentConversionPattern = nullptr;
823 
824 #ifndef NDEBUG
825   /// A set of operations that have pending updates. This tracking isn't
826   /// strictly necessary, and is thus only active during debug builds for extra
827   /// verification.
828   SmallPtrSet<Operation *, 1> pendingRootUpdates;
829 
830   /// A logger used to emit diagnostics during the conversion process.
831   llvm::ScopedPrinter logger{llvm::dbgs()};
832 #endif
833 };
834 } // end namespace detail
835 } // end namespace mlir
836 
837 /// Detach any operations nested in the given operation from their parent
838 /// blocks, and erase the given operation. This can be used when the nested
839 /// operations are scheduled for erasure themselves, so deleting the regions of
840 /// the given operation together with their content would result in double-free.
841 /// This happens, for example, when rolling back op creation in the reverse
842 /// order and if the nested ops were created before the parent op. This function
843 /// does not need to collect nested ops recursively because it is expected to
844 /// also be called for each nested op when it is about to be deleted.
845 static void detachNestedAndErase(Operation *op) {
846   for (Region &region : op->getRegions()) {
847     for (Block &block : region.getBlocks()) {
848       while (!block.getOperations().empty())
849         block.getOperations().remove(block.getOperations().begin());
850       block.dropAllDefinedValueUses();
851     }
852   }
853   op->dropAllUses();
854   op->erase();
855 }
856 
857 void ConversionPatternRewriterImpl::discardRewrites() {
858   // Reset any operations that were updated in place.
859   for (auto &state : rootUpdates)
860     state.resetOperation();
861 
862   undoBlockActions();
863 
864   // Remove any newly created ops.
865   for (auto *op : llvm::reverse(createdOps))
866     detachNestedAndErase(op);
867 }
868 
869 void ConversionPatternRewriterImpl::applyRewrites() {
870   // Apply all of the rewrites replacements requested during conversion.
871   for (auto &repl : replacements) {
872     for (OpResult result : repl.first->getResults())
873       if (Value newValue = mapping.lookupOrNull(result))
874         result.replaceAllUsesWith(newValue);
875 
876     // If this operation defines any regions, drop any pending argument
877     // rewrites.
878     if (repl.first->getNumRegions())
879       argConverter.notifyOpRemoved(repl.first);
880   }
881 
882   // Apply all of the requested argument replacements.
883   for (BlockArgument arg : argReplacements) {
884     Value repl = mapping.lookupOrDefault(arg);
885     if (repl.isa<BlockArgument>()) {
886       arg.replaceAllUsesWith(repl);
887       continue;
888     }
889 
890     // If the replacement value is an operation, we check to make sure that we
891     // don't replace uses that are within the parent operation of the
892     // replacement value.
893     Operation *replOp = repl.cast<OpResult>().getOwner();
894     Block *replBlock = replOp->getBlock();
895     arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
896       Operation *user = operand.getOwner();
897       return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
898     });
899   }
900 
901   // In a second pass, erase all of the replaced operations in reverse. This
902   // allows processing nested operations before their parent region is
903   // destroyed.
904   for (auto &repl : llvm::reverse(replacements))
905     repl.first->erase();
906 
907   argConverter.applyRewrites(mapping);
908 
909   // Now that the ops have been erased, also erase dangling blocks.
910   eraseDanglingBlocks();
911 }
912 
913 //===----------------------------------------------------------------------===//
914 // State Management
915 
916 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
917   return RewriterState(createdOps.size(), replacements.size(),
918                        argReplacements.size(), blockActions.size(),
919                        ignoredOps.size(), rootUpdates.size());
920 }
921 
922 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
923   // Reset any operations that were updated in place.
924   for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
925     rootUpdates[i].resetOperation();
926   rootUpdates.resize(state.numRootUpdates);
927 
928   // Reset any replaced arguments.
929   for (BlockArgument replacedArg :
930        llvm::drop_begin(argReplacements, state.numArgReplacements))
931     mapping.erase(replacedArg);
932   argReplacements.resize(state.numArgReplacements);
933 
934   // Undo any block actions.
935   undoBlockActions(state.numBlockActions);
936 
937   // Reset any replaced operations and undo any saved mappings.
938   for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
939     for (auto result : repl.first->getResults())
940       mapping.erase(result);
941   while (replacements.size() != state.numReplacements)
942     replacements.pop_back();
943 
944   // Pop all of the newly created operations.
945   while (createdOps.size() != state.numCreatedOps) {
946     detachNestedAndErase(createdOps.back());
947     createdOps.pop_back();
948   }
949 
950   // Pop all of the recorded ignored operations that are no longer valid.
951   while (ignoredOps.size() != state.numIgnoredOperations)
952     ignoredOps.pop_back();
953 
954   // Reset operations with changed results.
955   while (!operationsWithChangedResults.empty() &&
956          operationsWithChangedResults.back() >= state.numReplacements)
957     operationsWithChangedResults.pop_back();
958 }
959 
960 void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
961   for (auto &action : blockActions)
962     if (action.kind == BlockActionKind::Erase)
963       delete action.block;
964 }
965 
966 void ConversionPatternRewriterImpl::undoBlockActions(
967     unsigned numActionsToKeep) {
968   for (auto &action :
969        llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
970     switch (action.kind) {
971     // Delete the created block.
972     case BlockActionKind::Create: {
973       // Unlink all of the operations within this block, they will be deleted
974       // separately.
975       auto &blockOps = action.block->getOperations();
976       while (!blockOps.empty())
977         blockOps.remove(blockOps.begin());
978       action.block->dropAllDefinedValueUses();
979       action.block->erase();
980       break;
981     }
982     // Put the block (owned by action) back into its original position.
983     case BlockActionKind::Erase: {
984       auto &blockList = action.originalPosition.region->getBlocks();
985       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
986       blockList.insert((insertAfterBlock
987                             ? std::next(Region::iterator(insertAfterBlock))
988                             : blockList.begin()),
989                        action.block);
990       break;
991     }
992     // Split the block at the position which was originally the end of the
993     // destination block (owned by action), and put the instructions back into
994     // the block used before the merge.
995     case BlockActionKind::Merge: {
996       Block *sourceBlock = action.mergeInfo.sourceBlock;
997       Block::iterator splitPoint =
998           (action.mergeInfo.destBlockLastInst
999                ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
1000                : action.block->begin());
1001       sourceBlock->getOperations().splice(sourceBlock->begin(),
1002                                           action.block->getOperations(),
1003                                           splitPoint, action.block->end());
1004       break;
1005     }
1006     // Move the block back to its original position.
1007     case BlockActionKind::Move: {
1008       Region *originalRegion = action.originalPosition.region;
1009       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1010       originalRegion->getBlocks().splice(
1011           (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
1012                             : originalRegion->end()),
1013           action.block->getParent()->getBlocks(), action.block);
1014       break;
1015     }
1016     // Merge back the block that was split out.
1017     case BlockActionKind::Split: {
1018       action.originalBlock->getOperations().splice(
1019           action.originalBlock->end(), action.block->getOperations());
1020       action.block->dropAllDefinedValueUses();
1021       action.block->erase();
1022       break;
1023     }
1024     // Undo the type conversion.
1025     case BlockActionKind::TypeConversion: {
1026       argConverter.discardRewrites(action.block);
1027       break;
1028     }
1029     }
1030   }
1031   blockActions.resize(numActionsToKeep);
1032 }
1033 
1034 LogicalResult ConversionPatternRewriterImpl::remapValues(
1035     Location loc, PatternRewriter &rewriter, TypeConverter *converter,
1036     Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
1037   remapped.reserve(llvm::size(operands));
1038 
1039   SmallVector<Type, 1> legalTypes;
1040   for (auto it : llvm::enumerate(operands)) {
1041     Value operand = it.value();
1042     Type origType = operand.getType();
1043 
1044     // If a converter was provided, get the desired legal types for this
1045     // operand.
1046     Type desiredType;
1047     if (converter) {
1048       // If there is no legal conversion, fail to match this pattern.
1049       legalTypes.clear();
1050       if (failed(converter->convertType(origType, legalTypes))) {
1051         return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1052           diag << "unable to convert type for operand #" << it.index()
1053                << ", type was " << origType;
1054         });
1055       }
1056       // TODO: There currently isn't any mechanism to do 1->N type conversion
1057       // via the PatternRewriter replacement API, so for now we just ignore it.
1058       if (legalTypes.size() == 1)
1059         desiredType = legalTypes.front();
1060     } else {
1061       // TODO: What we should do here is just set `desiredType` to `origType`
1062       // and then handle the necessary type conversions after the conversion
1063       // process has finished. Unfortunately a lot of patterns currently rely on
1064       // receiving the new operands even if the types change, so we keep the
1065       // original behavior here for now until all of the patterns relying on
1066       // this get updated.
1067     }
1068     Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1069 
1070     // Handle the case where the conversion was 1->1 and the new operand type
1071     // isn't legal.
1072     Type newOperandType = newOperand.getType();
1073     if (converter && desiredType && newOperandType != desiredType) {
1074       // Attempt to materialize a conversion for this new value.
1075       newOperand = converter->materializeTargetConversion(
1076           rewriter, loc, desiredType, newOperand);
1077       if (!newOperand) {
1078         return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1079           diag << "unable to materialize a conversion for "
1080                   "operand #"
1081                << it.index() << ", from " << newOperandType << " to "
1082                << desiredType;
1083         });
1084       }
1085     }
1086     remapped.push_back(newOperand);
1087   }
1088   return success();
1089 }
1090 
1091 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1092   // Check to see if this operation was replaced or its parent ignored.
1093   return replacements.count(op) || ignoredOps.count(op->getParentOp());
1094 }
1095 
1096 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
1097   // Walk this operation and collect nested operations that define non-empty
1098   // regions. We mark such operations as 'ignored' so that we know we don't have
1099   // to convert them, or their nested ops.
1100   if (op->getNumRegions() == 0)
1101     return;
1102   op->walk([&](Operation *op) {
1103     if (llvm::any_of(op->getRegions(),
1104                      [](Region &region) { return !region.empty(); }))
1105       ignoredOps.insert(op);
1106   });
1107 }
1108 
1109 //===----------------------------------------------------------------------===//
1110 // Type Conversion
1111 
1112 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1113     Block *block, TypeConverter &converter,
1114     TypeConverter::SignatureConversion *conversion) {
1115   FailureOr<Block *> result =
1116       conversion ? argConverter.applySignatureConversion(block, converter,
1117                                                          *conversion, mapping)
1118                  : argConverter.convertSignature(block, converter, mapping);
1119   if (Block *newBlock = result.getValue()) {
1120     if (newBlock != block)
1121       blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1122   }
1123   return result;
1124 }
1125 
1126 Block *ConversionPatternRewriterImpl::applySignatureConversion(
1127     Region *region, TypeConverter::SignatureConversion &conversion) {
1128   if (!region->empty()) {
1129     return *convertBlockSignature(&region->front(), defaultTypeConverter,
1130                                   &conversion);
1131   }
1132   return nullptr;
1133 }
1134 
1135 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1136     Region *region, TypeConverter &converter,
1137     TypeConverter::SignatureConversion *entryConversion) {
1138   argConverter.setConverter(region, &converter);
1139   if (region->empty())
1140     return nullptr;
1141 
1142   // Convert the arguments of each block within the region.
1143   FailureOr<Block *> newEntry =
1144       convertBlockSignature(&region->front(), converter, entryConversion);
1145   for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
1146     if (failed(convertBlockSignature(&block, converter)))
1147       return failure();
1148   return newEntry;
1149 }
1150 
1151 //===----------------------------------------------------------------------===//
1152 // Rewriter Notification Hooks
1153 
1154 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1155                                                      ValueRange newValues) {
1156   assert(newValues.size() == op->getNumResults());
1157   assert(!replacements.count(op) && "operation was already replaced");
1158 
1159   // Track if any of the results changed, e.g. erased and replaced with null.
1160   bool resultChanged = false;
1161 
1162   // Create mappings for each of the new result values.
1163   Value newValue, result;
1164   for (auto it : llvm::zip(newValues, op->getResults())) {
1165     std::tie(newValue, result) = it;
1166     if (!newValue) {
1167       resultChanged = true;
1168       continue;
1169     }
1170     // Remap, and check for any result type changes.
1171     mapping.map(result, newValue);
1172     resultChanged |= (newValue.getType() != result.getType());
1173   }
1174   if (resultChanged)
1175     operationsWithChangedResults.push_back(replacements.size());
1176 
1177   // Record the requested operation replacement.
1178   TypeConverter *converter = nullptr;
1179   if (currentConversionPattern)
1180     converter = currentConversionPattern->getTypeConverter();
1181   replacements.insert(std::make_pair(op, OpReplacement(converter)));
1182 
1183   // Mark this operation as recursively ignored so that we don't need to
1184   // convert any nested operations.
1185   markNestedOpsIgnored(op);
1186 }
1187 
1188 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
1189   Region *region = block->getParent();
1190   Block *origPrevBlock = block->getPrevNode();
1191   blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1192 }
1193 
1194 void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
1195   blockActions.push_back(BlockAction::getCreate(block));
1196 }
1197 
1198 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
1199                                                      Block *continuation) {
1200   blockActions.push_back(BlockAction::getSplit(continuation, block));
1201 }
1202 
1203 void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block,
1204                                                             Block *srcBlock) {
1205   blockActions.push_back(BlockAction::getMerge(block, srcBlock));
1206 }
1207 
1208 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
1209     Region &region, Region &parent, Region::iterator before) {
1210   if (region.empty())
1211     return;
1212   Block *laterBlock = &region.back();
1213   for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1214     blockActions.push_back(
1215         BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
1216     laterBlock = &earlierBlock;
1217   }
1218   blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
1219 }
1220 
1221 void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
1222     iterator_range<Region::iterator> &blocks, Location origRegionLoc) {
1223   for (Block &block : blocks)
1224     blockActions.push_back(BlockAction::getCreate(&block));
1225 
1226   // Compute the conversion set for the inlined region.
1227   auto result = computeConversionSet(blocks, origRegionLoc, createdOps);
1228 
1229   // This original region has already had its conversion set computed, so there
1230   // shouldn't be any new failures.
1231   (void)result;
1232   assert(succeeded(result) && "expected region to have no unreachable blocks");
1233 }
1234 
1235 LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
1236     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1237   LLVM_DEBUG({
1238     Diagnostic diag(loc, DiagnosticSeverity::Remark);
1239     reasonCallback(diag);
1240     logger.startLine() << "** Failure : " << diag.str() << "\n";
1241   });
1242   return failure();
1243 }
1244 
1245 //===----------------------------------------------------------------------===//
1246 // ConversionPatternRewriter
1247 //===----------------------------------------------------------------------===//
1248 
1249 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
1250     : PatternRewriter(ctx),
1251       impl(new detail::ConversionPatternRewriterImpl(*this)) {}
1252 ConversionPatternRewriter::~ConversionPatternRewriter() {}
1253 
1254 /// PatternRewriter hook for replacing the results of an operation when the
1255 /// given functor returns true.
1256 void ConversionPatternRewriter::replaceOpWithIf(
1257     Operation *op, ValueRange newValues, bool *allUsesReplaced,
1258     llvm::unique_function<bool(OpOperand &) const> functor) {
1259   // TODO: To support this we will need to rework a bit of how replacements are
1260   // tracked, given that this isn't guranteed to replace all of the uses of an
1261   // operation. The main change is that now an operation can be replaced
1262   // multiple times, in parts. The current "set" based tracking is mainly useful
1263   // for tracking if a replaced operation should be ignored, i.e. if all of the
1264   // uses will be replaced.
1265   llvm_unreachable(
1266       "replaceOpWithIf is currently not supported by DialectConversion");
1267 }
1268 
1269 /// PatternRewriter hook for replacing the results of an operation.
1270 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
1271   LLVM_DEBUG({
1272     impl->logger.startLine()
1273         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1274   });
1275   impl->notifyOpReplaced(op, newValues);
1276 }
1277 
1278 /// PatternRewriter hook for erasing a dead operation. The uses of this
1279 /// operation *must* be made dead by the end of the conversion process,
1280 /// otherwise an assert will be issued.
1281 void ConversionPatternRewriter::eraseOp(Operation *op) {
1282   LLVM_DEBUG({
1283     impl->logger.startLine()
1284         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
1285   });
1286   SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1287   impl->notifyOpReplaced(op, nullRepls);
1288 }
1289 
1290 void ConversionPatternRewriter::eraseBlock(Block *block) {
1291   impl->notifyBlockIsBeingErased(block);
1292 
1293   // Mark all ops for erasure.
1294   for (Operation &op : *block)
1295     eraseOp(&op);
1296 
1297   // Unlink the block from its parent region. The block is kept in the block
1298   // action and will be actually destroyed when rewrites are applied. This
1299   // allows us to keep the operations in the block live and undo the removal by
1300   // re-inserting the block.
1301   block->getParent()->getBlocks().remove(block);
1302 }
1303 
1304 Block *ConversionPatternRewriter::applySignatureConversion(
1305     Region *region, TypeConverter::SignatureConversion &conversion) {
1306   return impl->applySignatureConversion(region, conversion);
1307 }
1308 
1309 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1310     Region *region, TypeConverter &converter,
1311     TypeConverter::SignatureConversion *entryConversion) {
1312   return impl->convertRegionTypes(region, converter, entryConversion);
1313 }
1314 
1315 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1316                                                            Value to) {
1317   LLVM_DEBUG({
1318     Operation *parentOp = from.getOwner()->getParentOp();
1319     impl->logger.startLine() << "** Replace Argument : '" << from
1320                              << "'(in region of '" << parentOp->getName()
1321                              << "'(" << from.getOwner()->getParentOp() << ")\n";
1322   });
1323   impl->argReplacements.push_back(from);
1324   impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1325 }
1326 
1327 /// Return the converted value that replaces 'key'. Return 'key' if there is
1328 /// no such a converted value.
1329 Value ConversionPatternRewriter::getRemappedValue(Value key) {
1330   return impl->mapping.lookupOrDefault(key);
1331 }
1332 
1333 /// PatternRewriter hook for creating a new block with the given arguments.
1334 void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
1335   impl->notifyCreatedBlock(block);
1336 }
1337 
1338 /// PatternRewriter hook for splitting a block into two parts.
1339 Block *ConversionPatternRewriter::splitBlock(Block *block,
1340                                              Block::iterator before) {
1341   auto *continuation = PatternRewriter::splitBlock(block, before);
1342   impl->notifySplitBlock(block, continuation);
1343   return continuation;
1344 }
1345 
1346 /// PatternRewriter hook for merging a block into another.
1347 void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
1348                                             ValueRange argValues) {
1349   impl->notifyBlocksBeingMerged(dest, source);
1350   assert(llvm::all_of(source->getPredecessors(),
1351                       [dest](Block *succ) { return succ == dest; }) &&
1352          "expected 'source' to have no predecessors or only 'dest'");
1353   assert(argValues.size() == source->getNumArguments() &&
1354          "incorrect # of argument replacement values");
1355   for (auto it : llvm::zip(source->getArguments(), argValues))
1356     replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1357   dest->getOperations().splice(dest->end(), source->getOperations());
1358   eraseBlock(source);
1359 }
1360 
1361 /// PatternRewriter hook for moving blocks out of a region.
1362 void ConversionPatternRewriter::inlineRegionBefore(Region &region,
1363                                                    Region &parent,
1364                                                    Region::iterator before) {
1365   impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1366   PatternRewriter::inlineRegionBefore(region, parent, before);
1367 }
1368 
1369 /// PatternRewriter hook for cloning blocks of one region into another.
1370 void ConversionPatternRewriter::cloneRegionBefore(
1371     Region &region, Region &parent, Region::iterator before,
1372     BlockAndValueMapping &mapping) {
1373   if (region.empty())
1374     return;
1375   PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1376 
1377   // Collect the range of the cloned blocks.
1378   auto clonedBeginIt = mapping.lookup(&region.front())->getIterator();
1379   auto clonedBlocks = llvm::make_range(clonedBeginIt, before);
1380   impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
1381 }
1382 
1383 /// PatternRewriter hook for creating a new operation.
1384 void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
1385   LLVM_DEBUG({
1386     impl->logger.startLine()
1387         << "** Insert  : '" << op->getName() << "'(" << op << ")\n";
1388   });
1389   impl->createdOps.push_back(op);
1390 }
1391 
1392 /// PatternRewriter hook for updating the root operation in-place.
1393 void ConversionPatternRewriter::startRootUpdate(Operation *op) {
1394 #ifndef NDEBUG
1395   impl->pendingRootUpdates.insert(op);
1396 #endif
1397   impl->rootUpdates.emplace_back(op);
1398 }
1399 
1400 /// PatternRewriter hook for updating the root operation in-place.
1401 void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
1402   // There is nothing to do here, we only need to track the operation at the
1403   // start of the update.
1404 #ifndef NDEBUG
1405   assert(impl->pendingRootUpdates.erase(op) &&
1406          "operation did not have a pending in-place update");
1407 #endif
1408 }
1409 
1410 /// PatternRewriter hook for updating the root operation in-place.
1411 void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
1412 #ifndef NDEBUG
1413   assert(impl->pendingRootUpdates.erase(op) &&
1414          "operation did not have a pending in-place update");
1415 #endif
1416   // Erase the last update for this operation.
1417   auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
1418   auto &rootUpdates = impl->rootUpdates;
1419   auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1420   rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
1421 }
1422 
1423 /// PatternRewriter hook for notifying match failure reasons.
1424 LogicalResult ConversionPatternRewriter::notifyMatchFailure(
1425     Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
1426   return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
1427 }
1428 
1429 /// Return a reference to the internal implementation.
1430 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
1431   return *impl;
1432 }
1433 
1434 //===----------------------------------------------------------------------===//
1435 // ConversionPattern
1436 //===----------------------------------------------------------------------===//
1437 
1438 /// Attempt to match and rewrite the IR root at the specified operation.
1439 LogicalResult
1440 ConversionPattern::matchAndRewrite(Operation *op,
1441                                    PatternRewriter &rewriter) const {
1442   auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1443   auto &rewriterImpl = dialectRewriter.getImpl();
1444 
1445   // Track the current conversion pattern in the rewriter.
1446   assert(!rewriterImpl.currentConversionPattern &&
1447          "already inside of a pattern rewrite");
1448   llvm::SaveAndRestore<const ConversionPattern *> currentPatternGuard(
1449       rewriterImpl.currentConversionPattern, this);
1450 
1451   // Remap the operands of the operation.
1452   SmallVector<Value, 4> operands;
1453   if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter,
1454                                       getTypeConverter(), op->getOperands(),
1455                                       operands))) {
1456     return failure();
1457   }
1458   return matchAndRewrite(op, operands, dialectRewriter);
1459 }
1460 
1461 //===----------------------------------------------------------------------===//
1462 // OperationLegalizer
1463 //===----------------------------------------------------------------------===//
1464 
1465 namespace {
1466 /// A set of rewrite patterns that can be used to legalize a given operation.
1467 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1468 
1469 /// This class defines a recursive operation legalizer.
1470 class OperationLegalizer {
1471 public:
1472   using LegalizationAction = ConversionTarget::LegalizationAction;
1473 
1474   OperationLegalizer(ConversionTarget &targetInfo,
1475                      const FrozenRewritePatternList &patterns);
1476 
1477   /// Returns true if the given operation is known to be illegal on the target.
1478   bool isIllegal(Operation *op) const;
1479 
1480   /// Attempt to legalize the given operation. Returns success if the operation
1481   /// was legalized, failure otherwise.
1482   LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1483 
1484   /// Returns the conversion target in use by the legalizer.
1485   ConversionTarget &getTarget() { return target; }
1486 
1487 private:
1488   /// Attempt to legalize the given operation by folding it.
1489   LogicalResult legalizeWithFold(Operation *op,
1490                                  ConversionPatternRewriter &rewriter);
1491 
1492   /// Attempt to legalize the given operation by applying a pattern. Returns
1493   /// success if the operation was legalized, failure otherwise.
1494   LogicalResult legalizeWithPattern(Operation *op,
1495                                     ConversionPatternRewriter &rewriter);
1496 
1497   /// Return true if the given pattern may be applied to the given operation,
1498   /// false otherwise.
1499   bool canApplyPattern(Operation *op, const Pattern &pattern,
1500                        ConversionPatternRewriter &rewriter);
1501 
1502   /// Legalize the resultant IR after successfully applying the given pattern.
1503   LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1504                                       ConversionPatternRewriter &rewriter,
1505                                       RewriterState &curState);
1506 
1507   /// Legalizes the actions registered during the execution of a pattern.
1508   LogicalResult legalizePatternBlockActions(Operation *op,
1509                                             ConversionPatternRewriter &rewriter,
1510                                             ConversionPatternRewriterImpl &impl,
1511                                             RewriterState &state,
1512                                             RewriterState &newState);
1513   LogicalResult legalizePatternCreatedOperations(
1514       ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1515       RewriterState &state, RewriterState &newState);
1516   LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1517                                            ConversionPatternRewriterImpl &impl,
1518                                            RewriterState &state,
1519                                            RewriterState &newState);
1520 
1521   //===--------------------------------------------------------------------===//
1522   // Cost Model
1523   //===--------------------------------------------------------------------===//
1524 
1525   /// Build an optimistic legalization graph given the provided patterns. This
1526   /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1527   /// patterns for operations that are not directly legal, but may be
1528   /// transitively legal for the current target given the provided patterns.
1529   void buildLegalizationGraph(
1530       LegalizationPatterns &anyOpLegalizerPatterns,
1531       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1532 
1533   /// Compute the benefit of each node within the computed legalization graph.
1534   /// This orders the patterns within 'legalizerPatterns' based upon two
1535   /// criteria:
1536   ///  1) Prefer patterns that have the lowest legalization depth, i.e.
1537   ///     represent the more direct mapping to the target.
1538   ///  2) When comparing patterns with the same legalization depth, prefer the
1539   ///     pattern with the highest PatternBenefit. This allows for users to
1540   ///     prefer specific legalizations over others.
1541   void computeLegalizationGraphBenefit(
1542       LegalizationPatterns &anyOpLegalizerPatterns,
1543       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1544 
1545   /// Compute the legalization depth when legalizing an operation of the given
1546   /// type.
1547   unsigned computeOpLegalizationDepth(
1548       OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1549       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1550 
1551   /// Apply the conversion cost model to the given set of patterns, and return
1552   /// the smallest legalization depth of any of the patterns. See
1553   /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1554   unsigned applyCostModelToPatterns(
1555       LegalizationPatterns &patterns,
1556       DenseMap<OperationName, unsigned> &minOpPatternDepth,
1557       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1558 
1559   /// The current set of patterns that have been applied.
1560   SmallPtrSet<const Pattern *, 8> appliedPatterns;
1561 
1562   /// The legalization information provided by the target.
1563   ConversionTarget &target;
1564 
1565   /// The pattern applicator to use for conversions.
1566   PatternApplicator applicator;
1567 };
1568 } // namespace
1569 
1570 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
1571                                        const FrozenRewritePatternList &patterns)
1572     : target(targetInfo), applicator(patterns) {
1573   // The set of patterns that can be applied to illegal operations to transform
1574   // them into legal ones.
1575   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
1576   LegalizationPatterns anyOpLegalizerPatterns;
1577 
1578   buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1579   computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1580 }
1581 
1582 bool OperationLegalizer::isIllegal(Operation *op) const {
1583   // Check if the target explicitly marked this operation as illegal.
1584   return target.getOpAction(op->getName()) == LegalizationAction::Illegal;
1585 }
1586 
1587 LogicalResult
1588 OperationLegalizer::legalize(Operation *op,
1589                              ConversionPatternRewriter &rewriter) {
1590 #ifndef NDEBUG
1591   const char *logLineComment =
1592       "//===-------------------------------------------===//\n";
1593 
1594   auto &rewriterImpl = rewriter.getImpl();
1595 #endif
1596   LLVM_DEBUG({
1597     auto &os = rewriterImpl.logger;
1598     os.getOStream() << "\n";
1599     os.startLine() << logLineComment;
1600     os.startLine() << "Legalizing operation : '" << op->getName() << "'(" << op
1601                    << ") {\n";
1602     os.indent();
1603 
1604     // If the operation has no regions, just print it here.
1605     if (op->getNumRegions() == 0) {
1606       op->print(os.startLine(), OpPrintingFlags().printGenericOpForm());
1607       os.getOStream() << "\n\n";
1608     }
1609   });
1610 
1611   // Check if this operation is legal on the target.
1612   if (auto legalityInfo = target.isLegal(op)) {
1613     LLVM_DEBUG({
1614       logSuccess(
1615           rewriterImpl.logger, "operation marked legal by the target{0}",
1616           legalityInfo->isRecursivelyLegal
1617               ? "; NOTE: operation is recursively legal; skipping internals"
1618               : "");
1619       rewriterImpl.logger.startLine() << logLineComment;
1620     });
1621 
1622     // If this operation is recursively legal, mark its children as ignored so
1623     // that we don't consider them for legalization.
1624     if (legalityInfo->isRecursivelyLegal)
1625       rewriter.getImpl().markNestedOpsIgnored(op);
1626     return success();
1627   }
1628 
1629   // Check to see if the operation is ignored and doesn't need to be converted.
1630   if (rewriter.getImpl().isOpIgnored(op)) {
1631     LLVM_DEBUG({
1632       logSuccess(rewriterImpl.logger,
1633                  "operation marked 'ignored' during conversion");
1634       rewriterImpl.logger.startLine() << logLineComment;
1635     });
1636     return success();
1637   }
1638 
1639   // If the operation isn't legal, try to fold it in-place.
1640   // TODO: Should we always try to do this, even if the op is
1641   // already legal?
1642   if (succeeded(legalizeWithFold(op, rewriter))) {
1643     LLVM_DEBUG({
1644       logSuccess(rewriterImpl.logger, "operation was folded");
1645       rewriterImpl.logger.startLine() << logLineComment;
1646     });
1647     return success();
1648   }
1649 
1650   // Otherwise, we need to apply a legalization pattern to this operation.
1651   if (succeeded(legalizeWithPattern(op, rewriter))) {
1652     LLVM_DEBUG({
1653       logSuccess(rewriterImpl.logger, "");
1654       rewriterImpl.logger.startLine() << logLineComment;
1655     });
1656     return success();
1657   }
1658 
1659   LLVM_DEBUG({
1660     logFailure(rewriterImpl.logger, "no matched legalization pattern");
1661     rewriterImpl.logger.startLine() << logLineComment;
1662   });
1663   return failure();
1664 }
1665 
1666 LogicalResult
1667 OperationLegalizer::legalizeWithFold(Operation *op,
1668                                      ConversionPatternRewriter &rewriter) {
1669   auto &rewriterImpl = rewriter.getImpl();
1670   RewriterState curState = rewriterImpl.getCurrentState();
1671 
1672   LLVM_DEBUG({
1673     rewriterImpl.logger.startLine() << "* Fold {\n";
1674     rewriterImpl.logger.indent();
1675   });
1676 
1677   // Try to fold the operation.
1678   SmallVector<Value, 2> replacementValues;
1679   rewriter.setInsertionPoint(op);
1680   if (failed(rewriter.tryFold(op, replacementValues))) {
1681     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1682     return failure();
1683   }
1684 
1685   // Insert a replacement for 'op' with the folded replacement values.
1686   rewriter.replaceOp(op, replacementValues);
1687 
1688   // Recursively legalize any new constant operations.
1689   for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1690        i != e; ++i) {
1691     Operation *cstOp = rewriterImpl.createdOps[i];
1692     if (failed(legalize(cstOp, rewriter))) {
1693       LLVM_DEBUG(logFailure(rewriterImpl.logger,
1694                             "generated constant '{0}' was illegal",
1695                             cstOp->getName()));
1696       rewriterImpl.resetState(curState);
1697       return failure();
1698     }
1699   }
1700 
1701   LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1702   return success();
1703 }
1704 
1705 LogicalResult
1706 OperationLegalizer::legalizeWithPattern(Operation *op,
1707                                         ConversionPatternRewriter &rewriter) {
1708   auto &rewriterImpl = rewriter.getImpl();
1709 
1710   // Functor that returns if the given pattern may be applied.
1711   auto canApply = [&](const Pattern &pattern) {
1712     return canApplyPattern(op, pattern, rewriter);
1713   };
1714 
1715   // Functor that cleans up the rewriter state after a pattern failed to match.
1716   RewriterState curState = rewriterImpl.getCurrentState();
1717   auto onFailure = [&](const Pattern &pattern) {
1718     LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
1719     rewriterImpl.resetState(curState);
1720     appliedPatterns.erase(&pattern);
1721   };
1722 
1723   // Functor that performs additional legalization when a pattern is
1724   // successfully applied.
1725   auto onSuccess = [&](const Pattern &pattern) {
1726     auto result = legalizePatternResult(op, pattern, rewriter, curState);
1727     appliedPatterns.erase(&pattern);
1728     if (failed(result))
1729       rewriterImpl.resetState(curState);
1730     return result;
1731   };
1732 
1733   // Try to match and rewrite a pattern on this operation.
1734   return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1735                                     onSuccess);
1736 }
1737 
1738 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1739                                          ConversionPatternRewriter &rewriter) {
1740   LLVM_DEBUG({
1741     auto &os = rewriter.getImpl().logger;
1742     os.getOStream() << "\n";
1743     os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1744     llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs());
1745     os.getOStream() << ")' {\n";
1746     os.indent();
1747   });
1748 
1749   // Ensure that we don't cycle by not allowing the same pattern to be
1750   // applied twice in the same recursion stack if it is not known to be safe.
1751   if (!pattern.hasBoundedRewriteRecursion() &&
1752       !appliedPatterns.insert(&pattern).second) {
1753     LLVM_DEBUG(
1754         logFailure(rewriter.getImpl().logger, "pattern was already applied"));
1755     return false;
1756   }
1757   return true;
1758 }
1759 
1760 LogicalResult
1761 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
1762                                           ConversionPatternRewriter &rewriter,
1763                                           RewriterState &curState) {
1764   auto &impl = rewriter.getImpl();
1765 
1766 #ifndef NDEBUG
1767   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
1768 #endif
1769 
1770   // Check that the root was either replaced or updated in place.
1771   auto replacedRoot = [&] {
1772     return llvm::any_of(
1773         llvm::drop_begin(impl.replacements, curState.numReplacements),
1774         [op](auto &it) { return it.first == op; });
1775   };
1776   auto updatedRootInPlace = [&] {
1777     return llvm::any_of(
1778         llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
1779         [op](auto &state) { return state.getOperation() == op; });
1780   };
1781   (void)replacedRoot;
1782   (void)updatedRootInPlace;
1783   assert((replacedRoot() || updatedRootInPlace()) &&
1784          "expected pattern to replace the root operation");
1785 
1786   // Legalize each of the actions registered during application.
1787   RewriterState newState = impl.getCurrentState();
1788   if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
1789                                          newState)) ||
1790       failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
1791       failed(legalizePatternCreatedOperations(rewriter, impl, curState,
1792                                               newState))) {
1793     return failure();
1794   }
1795 
1796   LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
1797   return success();
1798 }
1799 
1800 LogicalResult OperationLegalizer::legalizePatternBlockActions(
1801     Operation *op, ConversionPatternRewriter &rewriter,
1802     ConversionPatternRewriterImpl &impl, RewriterState &state,
1803     RewriterState &newState) {
1804   SmallPtrSet<Operation *, 16> operationsToIgnore;
1805 
1806   // If the pattern moved or created any blocks, make sure the types of block
1807   // arguments get legalized.
1808   for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
1809        ++i) {
1810     auto &action = impl.blockActions[i];
1811     if (action.kind == BlockActionKind::TypeConversion ||
1812         action.kind == BlockActionKind::Erase)
1813       continue;
1814     // Only check blocks outside of the current operation.
1815     Operation *parentOp = action.block->getParentOp();
1816     if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
1817       continue;
1818 
1819     // If the region of the block has a type converter, try to convert the block
1820     // directly.
1821     if (auto *converter =
1822             impl.argConverter.getConverter(action.block->getParent())) {
1823       if (failed(impl.convertBlockSignature(action.block, *converter))) {
1824         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
1825                                            "block"));
1826         return failure();
1827       }
1828       continue;
1829     }
1830 
1831     // Otherwise, check that this operation isn't one generated by this pattern.
1832     // This is because we will attempt to legalize the parent operation, and
1833     // blocks in regions created by this pattern will already be legalized later
1834     // on. If we haven't built the set yet, build it now.
1835     if (operationsToIgnore.empty()) {
1836       auto createdOps = ArrayRef<Operation *>(impl.createdOps)
1837                             .drop_front(state.numCreatedOps);
1838       operationsToIgnore.insert(createdOps.begin(), createdOps.end());
1839     }
1840 
1841     // If this operation should be considered for re-legalization, try it.
1842     if (operationsToIgnore.insert(parentOp).second &&
1843         failed(legalize(parentOp, rewriter))) {
1844       LLVM_DEBUG(logFailure(
1845           impl.logger, "operation '{0}'({1}) became illegal after block action",
1846           parentOp->getName(), parentOp));
1847       return failure();
1848     }
1849   }
1850   return success();
1851 }
1852 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
1853     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1854     RewriterState &state, RewriterState &newState) {
1855   for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
1856     Operation *op = impl.createdOps[i];
1857     if (failed(legalize(op, rewriter))) {
1858       LLVM_DEBUG(logFailure(impl.logger,
1859                             "generated operation '{0}'({1}) was illegal",
1860                             op->getName(), op));
1861       return failure();
1862     }
1863   }
1864   return success();
1865 }
1866 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
1867     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1868     RewriterState &state, RewriterState &newState) {
1869   for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
1870     Operation *op = impl.rootUpdates[i].getOperation();
1871     if (failed(legalize(op, rewriter))) {
1872       LLVM_DEBUG(logFailure(impl.logger,
1873                             "operation updated in-place '{0}' was illegal",
1874                             op->getName()));
1875       return failure();
1876     }
1877   }
1878   return success();
1879 }
1880 
1881 //===----------------------------------------------------------------------===//
1882 // Cost Model
1883 
1884 void OperationLegalizer::buildLegalizationGraph(
1885     LegalizationPatterns &anyOpLegalizerPatterns,
1886     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1887   // A mapping between an operation and a set of operations that can be used to
1888   // generate it.
1889   DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
1890   // A mapping between an operation and any currently invalid patterns it has.
1891   DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
1892   // A worklist of patterns to consider for legality.
1893   llvm::SetVector<const Pattern *> patternWorklist;
1894 
1895   // Build the mapping from operations to the parent ops that may generate them.
1896   applicator.walkAllPatterns([&](const Pattern &pattern) {
1897     Optional<OperationName> root = pattern.getRootKind();
1898 
1899     // If the pattern has no specific root, we can't analyze the relationship
1900     // between the root op and generated operations. Given that, add all such
1901     // patterns to the legalization set.
1902     if (!root) {
1903       anyOpLegalizerPatterns.push_back(&pattern);
1904       return;
1905     }
1906 
1907     // Skip operations that are always known to be legal.
1908     if (target.getOpAction(*root) == LegalizationAction::Legal)
1909       return;
1910 
1911     // Add this pattern to the invalid set for the root op and record this root
1912     // as a parent for any generated operations.
1913     invalidPatterns[*root].insert(&pattern);
1914     for (auto op : pattern.getGeneratedOps())
1915       parentOps[op].insert(*root);
1916 
1917     // Add this pattern to the worklist.
1918     patternWorklist.insert(&pattern);
1919   });
1920 
1921   // If there are any patterns that don't have a specific root kind, we can't
1922   // make direct assumptions about what operations will never be legalized.
1923   // Note: Technically we could, but it would require an analysis that may
1924   // recurse into itself. It would be better to perform this kind of filtering
1925   // at a higher level than here anyways.
1926   if (!anyOpLegalizerPatterns.empty()) {
1927     for (const Pattern *pattern : patternWorklist)
1928       legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1929     return;
1930   }
1931 
1932   while (!patternWorklist.empty()) {
1933     auto *pattern = patternWorklist.pop_back_val();
1934 
1935     // Check to see if any of the generated operations are invalid.
1936     if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
1937           Optional<LegalizationAction> action = target.getOpAction(op);
1938           return !legalizerPatterns.count(op) &&
1939                  (!action || action == LegalizationAction::Illegal);
1940         }))
1941       continue;
1942 
1943     // Otherwise, if all of the generated operation are valid, this op is now
1944     // legal so add all of the child patterns to the worklist.
1945     legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1946     invalidPatterns[*pattern->getRootKind()].erase(pattern);
1947 
1948     // Add any invalid patterns of the parent operations to see if they have now
1949     // become legal.
1950     for (auto op : parentOps[*pattern->getRootKind()])
1951       patternWorklist.set_union(invalidPatterns[op]);
1952   }
1953 }
1954 
1955 void OperationLegalizer::computeLegalizationGraphBenefit(
1956     LegalizationPatterns &anyOpLegalizerPatterns,
1957     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1958   // The smallest pattern depth, when legalizing an operation.
1959   DenseMap<OperationName, unsigned> minOpPatternDepth;
1960 
1961   // For each operation that is transitively legal, compute a cost for it.
1962   for (auto &opIt : legalizerPatterns)
1963     if (!minOpPatternDepth.count(opIt.first))
1964       computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
1965                                  legalizerPatterns);
1966 
1967   // Apply the cost model to the patterns that can match any operation. Those
1968   // with a specific operation type are already resolved when computing the op
1969   // legalization depth.
1970   if (!anyOpLegalizerPatterns.empty())
1971     applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
1972                              legalizerPatterns);
1973 
1974   // Apply a cost model to the pattern applicator. We order patterns first by
1975   // depth then benefit. `legalizerPatterns` contains per-op patterns by
1976   // decreasing benefit.
1977   applicator.applyCostModel([&](const Pattern &pattern) {
1978     ArrayRef<const Pattern *> orderedPatternList;
1979     if (Optional<OperationName> rootName = pattern.getRootKind())
1980       orderedPatternList = legalizerPatterns[*rootName];
1981     else
1982       orderedPatternList = anyOpLegalizerPatterns;
1983 
1984     // If the pattern is not found, then it was removed and cannot be matched.
1985     auto it = llvm::find(orderedPatternList, &pattern);
1986     if (it == orderedPatternList.end())
1987       return PatternBenefit::impossibleToMatch();
1988 
1989     // Patterns found earlier in the list have higher benefit.
1990     return PatternBenefit(std::distance(it, orderedPatternList.end()));
1991   });
1992 }
1993 
1994 unsigned OperationLegalizer::computeOpLegalizationDepth(
1995     OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1996     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1997   // Check for existing depth.
1998   auto depthIt = minOpPatternDepth.find(op);
1999   if (depthIt != minOpPatternDepth.end())
2000     return depthIt->second;
2001 
2002   // If a mapping for this operation does not exist, then this operation
2003   // is always legal. Return 0 as the depth for a directly legal operation.
2004   auto opPatternsIt = legalizerPatterns.find(op);
2005   if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2006     return 0u;
2007 
2008   // Record this initial depth in case we encounter this op again when
2009   // recursively computing the depth.
2010   minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2011 
2012   // Apply the cost model to the operation patterns, and update the minimum
2013   // depth.
2014   unsigned minDepth = applyCostModelToPatterns(
2015       opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2016   minOpPatternDepth[op] = minDepth;
2017   return minDepth;
2018 }
2019 
2020 unsigned OperationLegalizer::applyCostModelToPatterns(
2021     LegalizationPatterns &patterns,
2022     DenseMap<OperationName, unsigned> &minOpPatternDepth,
2023     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2024   unsigned minDepth = std::numeric_limits<unsigned>::max();
2025 
2026   // Compute the depth for each pattern within the set.
2027   SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2028   patternsByDepth.reserve(patterns.size());
2029   for (const Pattern *pattern : patterns) {
2030     unsigned depth = 0;
2031     for (auto generatedOp : pattern->getGeneratedOps()) {
2032       unsigned generatedOpDepth = computeOpLegalizationDepth(
2033           generatedOp, minOpPatternDepth, legalizerPatterns);
2034       depth = std::max(depth, generatedOpDepth + 1);
2035     }
2036     patternsByDepth.emplace_back(pattern, depth);
2037 
2038     // Update the minimum depth of the pattern list.
2039     minDepth = std::min(minDepth, depth);
2040   }
2041 
2042   // If the operation only has one legalization pattern, there is no need to
2043   // sort them.
2044   if (patternsByDepth.size() == 1)
2045     return minDepth;
2046 
2047   // Sort the patterns by those likely to be the most beneficial.
2048   llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
2049                        [](const std::pair<const Pattern *, unsigned> *lhs,
2050                           const std::pair<const Pattern *, unsigned> *rhs) {
2051                          // First sort by the smaller pattern legalization
2052                          // depth.
2053                          if (lhs->second != rhs->second)
2054                            return llvm::array_pod_sort_comparator<unsigned>(
2055                                &lhs->second, &rhs->second);
2056 
2057                          // Then sort by the larger pattern benefit.
2058                          auto lhsBenefit = lhs->first->getBenefit();
2059                          auto rhsBenefit = rhs->first->getBenefit();
2060                          return llvm::array_pod_sort_comparator<PatternBenefit>(
2061                              &rhsBenefit, &lhsBenefit);
2062                        });
2063 
2064   // Update the legalization pattern to use the new sorted list.
2065   patterns.clear();
2066   for (auto &patternIt : patternsByDepth)
2067     patterns.push_back(patternIt.first);
2068   return minDepth;
2069 }
2070 
2071 //===----------------------------------------------------------------------===//
2072 // OperationConverter
2073 //===----------------------------------------------------------------------===//
2074 namespace {
2075 enum OpConversionMode {
2076   // In this mode, the conversion will ignore failed conversions to allow
2077   // illegal operations to co-exist in the IR.
2078   Partial,
2079 
2080   // In this mode, all operations must be legal for the given target for the
2081   // conversion to succeed.
2082   Full,
2083 
2084   // In this mode, operations are analyzed for legality. No actual rewrites are
2085   // applied to the operations on success.
2086   Analysis,
2087 };
2088 
2089 // This class converts operations to a given conversion target via a set of
2090 // rewrite patterns. The conversion behaves differently depending on the
2091 // conversion mode.
2092 struct OperationConverter {
2093   explicit OperationConverter(ConversionTarget &target,
2094                               const FrozenRewritePatternList &patterns,
2095                               OpConversionMode mode,
2096                               DenseSet<Operation *> *trackedOps = nullptr)
2097       : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2098 
2099   /// Converts the given operations to the conversion target.
2100   LogicalResult convertOperations(ArrayRef<Operation *> ops);
2101 
2102 private:
2103   /// Converts an operation with the given rewriter.
2104   LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2105 
2106   /// This method is called after the conversion process to legalize any
2107   /// remaining artifacts and complete the conversion.
2108   LogicalResult finalize(ConversionPatternRewriter &rewriter);
2109 
2110   /// Legalize the types of converted block arguments.
2111   LogicalResult
2112   legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2113                                  ConversionPatternRewriterImpl &rewriterImpl);
2114 
2115   /// Legalize an operation result that was marked as "erased".
2116   LogicalResult
2117   legalizeErasedResult(Operation *op, OpResult result,
2118                        ConversionPatternRewriterImpl &rewriterImpl);
2119 
2120   /// Legalize an operation result that was replaced with a value of a different
2121   /// type.
2122   LogicalResult
2123   legalizeChangedResultType(Operation *op, OpResult result, Value newValue,
2124                             TypeConverter *replConverter,
2125                             ConversionPatternRewriter &rewriter,
2126                             ConversionPatternRewriterImpl &rewriterImpl);
2127 
2128   /// The legalizer to use when converting operations.
2129   OperationLegalizer opLegalizer;
2130 
2131   /// The conversion mode to use when legalizing operations.
2132   OpConversionMode mode;
2133 
2134   /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2135   /// this is populated with ops found to be legalizable to the target.
2136   /// When mode == OpConversionMode::Partial, this is populated with ops found
2137   /// *not* to be legalizable to the target.
2138   DenseSet<Operation *> *trackedOps;
2139 };
2140 } // end anonymous namespace
2141 
2142 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2143                                           Operation *op) {
2144   // Legalize the given operation.
2145   if (failed(opLegalizer.legalize(op, rewriter))) {
2146     // Handle the case of a failed conversion for each of the different modes.
2147     // Full conversions expect all operations to be converted.
2148     if (mode == OpConversionMode::Full)
2149       return op->emitError()
2150              << "failed to legalize operation '" << op->getName() << "'";
2151     // Partial conversions allow conversions to fail iff the operation was not
2152     // explicitly marked as illegal. If the user provided a nonlegalizableOps
2153     // set, non-legalizable ops are included.
2154     if (mode == OpConversionMode::Partial) {
2155       if (opLegalizer.isIllegal(op))
2156         return op->emitError()
2157                << "failed to legalize operation '" << op->getName()
2158                << "' that was explicitly marked illegal";
2159       if (trackedOps)
2160         trackedOps->insert(op);
2161     }
2162   } else if (mode == OpConversionMode::Analysis) {
2163     // Analysis conversions don't fail if any operations fail to legalize,
2164     // they are only interested in the operations that were successfully
2165     // legalized.
2166     trackedOps->insert(op);
2167   }
2168   return success();
2169 }
2170 
2171 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2172   if (ops.empty())
2173     return success();
2174   ConversionTarget &target = opLegalizer.getTarget();
2175 
2176   // Compute the set of operations and blocks to convert.
2177   std::vector<Operation *> toConvert;
2178   for (auto *op : ops) {
2179     toConvert.emplace_back(op);
2180     for (auto &region : op->getRegions())
2181       if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
2182                                       toConvert, &target)))
2183         return failure();
2184   }
2185 
2186   // Convert each operation and discard rewrites on failure.
2187   ConversionPatternRewriter rewriter(ops.front()->getContext());
2188   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2189   for (auto *op : toConvert)
2190     if (failed(convert(rewriter, op)))
2191       return rewriterImpl.discardRewrites(), failure();
2192 
2193   // Now that all of the operations have been converted, finalize the conversion
2194   // process to ensure any lingering conversion artifacts are cleaned up and
2195   // legalized.
2196   if (failed(finalize(rewriter)))
2197     return rewriterImpl.discardRewrites(), failure();
2198 
2199   // After a successful conversion, apply rewrites if this is not an analysis
2200   // conversion.
2201   if (mode == OpConversionMode::Analysis)
2202     rewriterImpl.discardRewrites();
2203   else
2204     rewriterImpl.applyRewrites();
2205   return success();
2206 }
2207 
2208 LogicalResult
2209 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2210   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2211 
2212   // Legalize converted block arguments.
2213   if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2214     return failure();
2215 
2216   // Process requested operation replacements.
2217   for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
2218        i != e; ++i) {
2219     unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
2220     auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
2221     for (OpResult result : repl.first->getResults()) {
2222       Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2223 
2224       // If the operation result was replaced with null, all of the uses of this
2225       // value should be replaced.
2226       if (!newValue) {
2227         if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2228           return failure();
2229         continue;
2230       }
2231 
2232       // Otherwise, check to see if the type of the result changed.
2233       if (result.getType() == newValue.getType())
2234         continue;
2235 
2236       // Legalize this result.
2237       rewriter.setInsertionPoint(repl.first);
2238       if (failed(legalizeChangedResultType(repl.first, result, newValue,
2239                                            repl.second.converter, rewriter,
2240                                            rewriterImpl)))
2241         return failure();
2242 
2243       // Update the end iterator for this loop in the case it was updated
2244       // when legalizing generated conversion operations.
2245       e = rewriterImpl.operationsWithChangedResults.size();
2246     }
2247   }
2248   return success();
2249 }
2250 
2251 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2252     ConversionPatternRewriter &rewriter,
2253     ConversionPatternRewriterImpl &rewriterImpl) {
2254   // Functor used to check if all users of a value will be dead after
2255   // conversion.
2256   auto findLiveUser = [&](Value val) {
2257     auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2258       return rewriterImpl.isOpIgnored(user);
2259     });
2260     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2261   };
2262 
2263   // Materialize any necessary conversions for converted block arguments that
2264   // are still live.
2265   size_t numCreatedOps = rewriterImpl.createdOps.size();
2266   if (failed(rewriterImpl.argConverter.materializeLiveConversions(
2267           rewriterImpl.mapping, rewriter, findLiveUser)))
2268     return failure();
2269 
2270   // Legalize any newly created operations during argument materialization.
2271   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2272     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2273       return rewriterImpl.createdOps[i]->emitError()
2274              << "failed to legalize conversion operation generated for block "
2275                 "argument that remained live after conversion";
2276     }
2277   }
2278   return success();
2279 }
2280 
2281 LogicalResult OperationConverter::legalizeErasedResult(
2282     Operation *op, OpResult result,
2283     ConversionPatternRewriterImpl &rewriterImpl) {
2284   // If the operation result was replaced with null, all of the uses of this
2285   // value should be replaced.
2286   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2287     return rewriterImpl.isOpIgnored(user);
2288   });
2289   if (liveUserIt != result.user_end()) {
2290     InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2291                               << op->getName() << "' marked as erased";
2292     diag.attachNote(liveUserIt->getLoc())
2293         << "found live user of result #" << result.getResultNumber() << ": "
2294         << *liveUserIt;
2295     return failure();
2296   }
2297   return success();
2298 }
2299 
2300 LogicalResult OperationConverter::legalizeChangedResultType(
2301     Operation *op, OpResult result, Value newValue,
2302     TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2303     ConversionPatternRewriterImpl &rewriterImpl) {
2304   // Walk the users of this value to see if there are any live users that
2305   // weren't replaced during conversion.
2306   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2307     return rewriterImpl.isOpIgnored(user);
2308   });
2309   if (liveUserIt == result.user_end())
2310     return success();
2311 
2312   // If the replacement has a type converter, attempt to materialize a
2313   // conversion back to the original type.
2314   if (!replConverter) {
2315     // TODO: We should emit an error here, similarly to the case where the
2316     // result is replaced with null. Unfortunately a lot of existing
2317     // patterns rely on this behavior, so until those patterns are updated
2318     // we keep the legacy behavior here of just forwarding the new value.
2319     return success();
2320   }
2321 
2322   // Track the number of created operations so that new ones can be legalized.
2323   size_t numCreatedOps = rewriterImpl.createdOps.size();
2324 
2325   // Materialize a conversion for this live result value.
2326   Type resultType = result.getType();
2327   Value convertedValue = replConverter->materializeSourceConversion(
2328       rewriter, op->getLoc(), resultType, newValue);
2329   if (!convertedValue) {
2330     InFlightDiagnostic diag = op->emitError()
2331                               << "failed to materialize conversion for result #"
2332                               << result.getResultNumber() << " of operation '"
2333                               << op->getName()
2334                               << "' that remained live after conversion";
2335     diag.attachNote(liveUserIt->getLoc())
2336         << "see existing live user here: " << *liveUserIt;
2337     return failure();
2338   }
2339 
2340   // Legalize all of the newly created conversion operations.
2341   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2342     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2343       return op->emitError("failed to legalize conversion operation generated ")
2344              << "for result #" << result.getResultNumber() << " of operation '"
2345              << op->getName() << "' that remained live after conversion";
2346     }
2347   }
2348 
2349   rewriterImpl.mapping.map(result, convertedValue);
2350   return success();
2351 }
2352 
2353 //===----------------------------------------------------------------------===//
2354 // Type Conversion
2355 //===----------------------------------------------------------------------===//
2356 
2357 /// Remap an input of the original signature with a new set of types. The
2358 /// new types are appended to the new signature conversion.
2359 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
2360                                                    ArrayRef<Type> types) {
2361   assert(!types.empty() && "expected valid types");
2362   remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2363   addInputs(types);
2364 }
2365 
2366 /// Append new input types to the signature conversion, this should only be
2367 /// used if the new types are not intended to remap an existing input.
2368 void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
2369   assert(!types.empty() &&
2370          "1->0 type remappings don't need to be added explicitly");
2371   argTypes.append(types.begin(), types.end());
2372 }
2373 
2374 /// Remap an input of the original signature with a range of types in the
2375 /// new signature.
2376 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2377                                                     unsigned newInputNo,
2378                                                     unsigned newInputCount) {
2379   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2380   assert(newInputCount != 0 && "expected valid input count");
2381   remappedInputs[origInputNo] =
2382       InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2383 }
2384 
2385 /// Remap an input of the original signature to another `replacementValue`
2386 /// value. This would make the signature converter drop this argument.
2387 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2388                                                     Value replacementValue) {
2389   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2390   remappedInputs[origInputNo] =
2391       InputMapping{origInputNo, /*size=*/0, replacementValue};
2392 }
2393 
2394 /// This hooks allows for converting a type.
2395 LogicalResult TypeConverter::convertType(Type t,
2396                                          SmallVectorImpl<Type> &results) {
2397   auto existingIt = cachedDirectConversions.find(t);
2398   if (existingIt != cachedDirectConversions.end()) {
2399     if (existingIt->second)
2400       results.push_back(existingIt->second);
2401     return success(existingIt->second != nullptr);
2402   }
2403   auto multiIt = cachedMultiConversions.find(t);
2404   if (multiIt != cachedMultiConversions.end()) {
2405     results.append(multiIt->second.begin(), multiIt->second.end());
2406     return success();
2407   }
2408 
2409   // Walk the added converters in reverse order to apply the most recently
2410   // registered first.
2411   size_t currentCount = results.size();
2412   for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2413     if (Optional<LogicalResult> result = converter(t, results)) {
2414       if (!succeeded(*result)) {
2415         cachedDirectConversions.try_emplace(t, nullptr);
2416         return failure();
2417       }
2418       auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2419       if (newTypes.size() == 1)
2420         cachedDirectConversions.try_emplace(t, newTypes.front());
2421       else
2422         cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2423       return success();
2424     }
2425   }
2426   return failure();
2427 }
2428 
2429 /// This hook simplifies defining 1-1 type conversions. This function returns
2430 /// the type to convert to on success, and a null type on failure.
2431 Type TypeConverter::convertType(Type t) {
2432   // Use the multi-type result version to convert the type.
2433   SmallVector<Type, 1> results;
2434   if (failed(convertType(t, results)))
2435     return nullptr;
2436 
2437   // Check to ensure that only one type was produced.
2438   return results.size() == 1 ? results.front() : nullptr;
2439 }
2440 
2441 /// Convert the given set of types, filling 'results' as necessary. This
2442 /// returns failure if the conversion of any of the types fails, success
2443 /// otherwise.
2444 LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
2445                                           SmallVectorImpl<Type> &results) {
2446   for (auto type : types)
2447     if (failed(convertType(type, results)))
2448       return failure();
2449   return success();
2450 }
2451 
2452 /// Return true if the given type is legal for this type converter, i.e. the
2453 /// type converts to itself.
2454 bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
2455 /// Return true if the given operation has legal operand and result types.
2456 bool TypeConverter::isLegal(Operation *op) {
2457   return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2458 }
2459 
2460 /// Return true if the types of block arguments within the region are legal.
2461 bool TypeConverter::isLegal(Region *region) {
2462   return llvm::all_of(*region, [this](Block &block) {
2463     return isLegal(block.getArgumentTypes());
2464   });
2465 }
2466 
2467 /// Return true if the inputs and outputs of the given function type are
2468 /// legal.
2469 bool TypeConverter::isSignatureLegal(FunctionType ty) {
2470   return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2471 }
2472 
2473 /// This hook allows for converting a specific argument of a signature.
2474 LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
2475                                                  SignatureConversion &result) {
2476   // Try to convert the given input type.
2477   SmallVector<Type, 1> convertedTypes;
2478   if (failed(convertType(type, convertedTypes)))
2479     return failure();
2480 
2481   // If this argument is being dropped, there is nothing left to do.
2482   if (convertedTypes.empty())
2483     return success();
2484 
2485   // Otherwise, add the new inputs.
2486   result.addInputs(inputNo, convertedTypes);
2487   return success();
2488 }
2489 LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
2490                                                   SignatureConversion &result,
2491                                                   unsigned origInputOffset) {
2492   for (unsigned i = 0, e = types.size(); i != e; ++i)
2493     if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2494       return failure();
2495   return success();
2496 }
2497 
2498 Value TypeConverter::materializeConversion(
2499     MutableArrayRef<MaterializationCallbackFn> materializations,
2500     OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
2501   for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
2502     if (Optional<Value> result = fn(builder, resultType, inputs, loc))
2503       return result.getValue();
2504   return nullptr;
2505 }
2506 
2507 /// This function converts the type signature of the given block, by invoking
2508 /// 'convertSignatureArg' for each argument. This function should return a valid
2509 /// conversion for the signature on success, None otherwise.
2510 auto TypeConverter::convertBlockSignature(Block *block)
2511     -> Optional<SignatureConversion> {
2512   SignatureConversion conversion(block->getNumArguments());
2513   if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2514     return llvm::None;
2515   return conversion;
2516 }
2517 
2518 /// Create a default conversion pattern that rewrites the type signature of a
2519 /// FunctionLike op. This only supports FunctionLike ops which use FunctionType
2520 /// to represent their type.
2521 namespace {
2522 struct FunctionLikeSignatureConversion : public ConversionPattern {
2523   FunctionLikeSignatureConversion(StringRef functionLikeOpName,
2524                                   MLIRContext *ctx, TypeConverter &converter)
2525       : ConversionPattern(functionLikeOpName, /*benefit=*/1, converter, ctx) {}
2526 
2527   /// Hook to implement combined matching and rewriting for FunctionLike ops.
2528   LogicalResult
2529   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
2530                   ConversionPatternRewriter &rewriter) const override {
2531     FunctionType type = mlir::impl::getFunctionType(op);
2532 
2533     // Convert the original function types.
2534     TypeConverter::SignatureConversion result(type.getNumInputs());
2535     SmallVector<Type, 1> newResults;
2536     if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
2537         failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
2538         failed(rewriter.convertRegionTypes(&mlir::impl::getFunctionBody(op),
2539                                            *typeConverter, &result)))
2540       return failure();
2541 
2542     // Update the function signature in-place.
2543     auto newType = FunctionType::get(rewriter.getContext(),
2544                                      result.getConvertedTypes(), newResults);
2545 
2546     rewriter.updateRootInPlace(
2547         op, [&] { mlir::impl::setFunctionType(op, newType); });
2548 
2549     return success();
2550   }
2551 };
2552 } // end anonymous namespace
2553 
2554 void mlir::populateFunctionLikeTypeConversionPattern(
2555     StringRef functionLikeOpName, OwningRewritePatternList &patterns,
2556     MLIRContext *ctx, TypeConverter &converter) {
2557   patterns.insert<FunctionLikeSignatureConversion>(functionLikeOpName, ctx,
2558                                                    converter);
2559 }
2560 
2561 void mlir::populateFuncOpTypeConversionPattern(
2562     OwningRewritePatternList &patterns, MLIRContext *ctx,
2563     TypeConverter &converter) {
2564   populateFunctionLikeTypeConversionPattern<FuncOp>(patterns, ctx, converter);
2565 }
2566 
2567 //===----------------------------------------------------------------------===//
2568 // ConversionTarget
2569 //===----------------------------------------------------------------------===//
2570 
2571 /// Register a legality action for the given operation.
2572 void ConversionTarget::setOpAction(OperationName op,
2573                                    LegalizationAction action) {
2574   legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None};
2575 }
2576 
2577 /// Register a legality action for the given dialects.
2578 void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
2579                                         LegalizationAction action) {
2580   for (StringRef dialect : dialectNames)
2581     legalDialects[dialect] = action;
2582 }
2583 
2584 /// Get the legality action for the given operation.
2585 auto ConversionTarget::getOpAction(OperationName op) const
2586     -> Optional<LegalizationAction> {
2587   Optional<LegalizationInfo> info = getOpInfo(op);
2588   return info ? info->action : Optional<LegalizationAction>();
2589 }
2590 
2591 /// If the given operation instance is legal on this target, a structure
2592 /// containing legality information is returned. If the operation is not legal,
2593 /// None is returned.
2594 auto ConversionTarget::isLegal(Operation *op) const
2595     -> Optional<LegalOpDetails> {
2596   Optional<LegalizationInfo> info = getOpInfo(op->getName());
2597   if (!info)
2598     return llvm::None;
2599 
2600   // Returns true if this operation instance is known to be legal.
2601   auto isOpLegal = [&] {
2602     // Handle dynamic legality either with the provided legality function, or
2603     // the default hook on the derived instance.
2604     if (info->action == LegalizationAction::Dynamic)
2605       return info->legalityFn ? (*info->legalityFn)(op)
2606                               : isDynamicallyLegal(op);
2607 
2608     // Otherwise, the operation is only legal if it was marked 'Legal'.
2609     return info->action == LegalizationAction::Legal;
2610   };
2611   if (!isOpLegal())
2612     return llvm::None;
2613 
2614   // This operation is legal, compute any additional legality information.
2615   LegalOpDetails legalityDetails;
2616   if (info->isRecursivelyLegal) {
2617     auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
2618     if (legalityFnIt != opRecursiveLegalityFns.end())
2619       legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
2620     else
2621       legalityDetails.isRecursivelyLegal = true;
2622   }
2623   return legalityDetails;
2624 }
2625 
2626 /// Set the dynamic legality callback for the given operation.
2627 void ConversionTarget::setLegalityCallback(
2628     OperationName name, const DynamicLegalityCallbackFn &callback) {
2629   assert(callback && "expected valid legality callback");
2630   auto infoIt = legalOperations.find(name);
2631   assert(infoIt != legalOperations.end() &&
2632          infoIt->second.action == LegalizationAction::Dynamic &&
2633          "expected operation to already be marked as dynamically legal");
2634   infoIt->second.legalityFn = callback;
2635 }
2636 
2637 /// Set the recursive legality callback for the given operation and mark the
2638 /// operation as recursively legal.
2639 void ConversionTarget::markOpRecursivelyLegal(
2640     OperationName name, const DynamicLegalityCallbackFn &callback) {
2641   auto infoIt = legalOperations.find(name);
2642   assert(infoIt != legalOperations.end() &&
2643          infoIt->second.action != LegalizationAction::Illegal &&
2644          "expected operation to already be marked as legal");
2645   infoIt->second.isRecursivelyLegal = true;
2646   if (callback)
2647     opRecursiveLegalityFns[name] = callback;
2648   else
2649     opRecursiveLegalityFns.erase(name);
2650 }
2651 
2652 /// Set the dynamic legality callback for the given dialects.
2653 void ConversionTarget::setLegalityCallback(
2654     ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
2655   assert(callback && "expected valid legality callback");
2656   for (StringRef dialect : dialects)
2657     dialectLegalityFns[dialect] = callback;
2658 }
2659 
2660 /// Get the legalization information for the given operation.
2661 auto ConversionTarget::getOpInfo(OperationName op) const
2662     -> Optional<LegalizationInfo> {
2663   // Check for info for this specific operation.
2664   auto it = legalOperations.find(op);
2665   if (it != legalOperations.end())
2666     return it->second;
2667   // Check for info for the parent dialect.
2668   auto dialectIt = legalDialects.find(op.getDialect());
2669   if (dialectIt != legalDialects.end()) {
2670     Optional<DynamicLegalityCallbackFn> callback;
2671     auto dialectFn = dialectLegalityFns.find(op.getDialect());
2672     if (dialectFn != dialectLegalityFns.end())
2673       callback = dialectFn->second;
2674     return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
2675                             callback};
2676   }
2677   // Otherwise, check if we mark unknown operations as dynamic.
2678   if (unknownOpsDynamicallyLegal)
2679     return LegalizationInfo{LegalizationAction::Dynamic,
2680                             /*isRecursivelyLegal=*/false, unknownLegalityFn};
2681   return llvm::None;
2682 }
2683 
2684 //===----------------------------------------------------------------------===//
2685 // Op Conversion Entry Points
2686 //===----------------------------------------------------------------------===//
2687 
2688 /// Apply a partial conversion on the given operations and all nested
2689 /// operations. This method converts as many operations to the target as
2690 /// possible, ignoring operations that failed to legalize. This method only
2691 /// returns failure if there ops explicitly marked as illegal.
2692 /// If an `unconvertedOps` set is provided, all operations that are found not
2693 /// to be legalizable to the given `target` are placed within that set. (Note
2694 /// that if there is an op explicitly marked as illegal, the conversion
2695 /// terminates and the `unconvertedOps` set will not necessarily be complete.)
2696 LogicalResult
2697 mlir::applyPartialConversion(ArrayRef<Operation *> ops,
2698                              ConversionTarget &target,
2699                              const FrozenRewritePatternList &patterns,
2700                              DenseSet<Operation *> *unconvertedOps) {
2701   OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
2702                                  unconvertedOps);
2703   return opConverter.convertOperations(ops);
2704 }
2705 LogicalResult
2706 mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
2707                              const FrozenRewritePatternList &patterns,
2708                              DenseSet<Operation *> *unconvertedOps) {
2709   return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
2710                                 unconvertedOps);
2711 }
2712 
2713 /// Apply a complete conversion on the given operations, and all nested
2714 /// operations. This method will return failure if the conversion of any
2715 /// operation fails.
2716 LogicalResult
2717 mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
2718                           const FrozenRewritePatternList &patterns) {
2719   OperationConverter opConverter(target, patterns, OpConversionMode::Full);
2720   return opConverter.convertOperations(ops);
2721 }
2722 LogicalResult
2723 mlir::applyFullConversion(Operation *op, ConversionTarget &target,
2724                           const FrozenRewritePatternList &patterns) {
2725   return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
2726 }
2727 
2728 /// Apply an analysis conversion on the given operations, and all nested
2729 /// operations. This method analyzes which operations would be successfully
2730 /// converted to the target if a conversion was applied. All operations that
2731 /// were found to be legalizable to the given 'target' are placed within the
2732 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
2733 /// operations on success and only pre-existing operations are added to the set.
2734 LogicalResult
2735 mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
2736                               ConversionTarget &target,
2737                               const FrozenRewritePatternList &patterns,
2738                               DenseSet<Operation *> &convertedOps) {
2739   OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
2740                                  &convertedOps);
2741   return opConverter.convertOperations(ops);
2742 }
2743 LogicalResult
2744 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
2745                               const FrozenRewritePatternList &patterns,
2746                               DenseSet<Operation *> &convertedOps) {
2747   return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
2748                                  convertedOps);
2749 }
2750