1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/DialectConversion.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/BlockAndValueMapping.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/FunctionSupport.h"
15 #include "mlir/Rewrite/PatternApplicator.h"
16 #include "mlir/Transforms/Utils.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/Support/SaveAndRestore.h"
22 #include "llvm/Support/ScopedPrinter.h"
23 
24 using namespace mlir;
25 using namespace mlir::detail;
26 
27 #define DEBUG_TYPE "dialect-conversion"
28 
29 /// Recursively collect all of the operations to convert from within 'region'.
30 /// If 'target' is nonnull, operations that are recursively legal have their
31 /// regions pre-filtered to avoid considering them for legalization.
32 static LogicalResult
33 computeConversionSet(iterator_range<Region::iterator> region,
34                      Location regionLoc, std::vector<Operation *> &toConvert,
35                      ConversionTarget *target = nullptr) {
36   if (llvm::empty(region))
37     return success();
38 
39   // Traverse starting from the entry block.
40   SmallVector<Block *, 16> worklist(1, &*region.begin());
41   DenseSet<Block *> visitedBlocks;
42   visitedBlocks.insert(worklist.front());
43   while (!worklist.empty()) {
44     Block *block = worklist.pop_back_val();
45 
46     // Compute the conversion set of each of the nested operations.
47     for (Operation &op : *block) {
48       toConvert.emplace_back(&op);
49 
50       // Don't check this operation's children for conversion if the operation
51       // is recursively legal.
52       auto legalityInfo = target ? target->isLegal(&op)
53                                  : Optional<ConversionTarget::LegalOpDetails>();
54       if (legalityInfo && legalityInfo->isRecursivelyLegal)
55         continue;
56       for (auto &region : op.getRegions()) {
57         if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
58                                         toConvert, target)))
59           return failure();
60       }
61     }
62 
63     // Recurse to children that haven't been visited.
64     for (Block *succ : block->getSuccessors())
65       if (visitedBlocks.insert(succ).second)
66         worklist.push_back(succ);
67   }
68 
69   // Check that all blocks in the region were visited.
70   if (llvm::any_of(llvm::drop_begin(region, 1),
71                    [&](Block &block) { return !visitedBlocks.count(&block); }))
72     return emitError(regionLoc, "unreachable blocks were not converted");
73   return success();
74 }
75 
76 /// A utility function to log a successful result for the given reason.
77 template <typename... Args>
78 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
79   LLVM_DEBUG({
80     os.unindent();
81     os.startLine() << "} -> SUCCESS";
82     if (!fmt.empty())
83       os.getOStream() << " : "
84                       << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
85     os.getOStream() << "\n";
86   });
87 }
88 
89 /// A utility function to log a failure result for the given reason.
90 template <typename... Args>
91 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
92   LLVM_DEBUG({
93     os.unindent();
94     os.startLine() << "} -> FAILURE : "
95                    << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
96                    << "\n";
97   });
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // ConversionValueMapping
102 //===----------------------------------------------------------------------===//
103 
104 namespace {
105 /// This class wraps a BlockAndValueMapping to provide recursive lookup
106 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
107 struct ConversionValueMapping {
108   /// Lookup a mapped value within the map. If a mapping for the provided value
109   /// does not exist then return the provided value. If `desiredType` is
110   /// non-null, returns the most recently mapped value with that type. If an
111   /// operand of that type does not exist, defaults to normal behavior.
112   Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
113 
114   /// Lookup a mapped value within the map, or return null if a mapping does not
115   /// exist. If a mapping exists, this follows the same behavior of
116   /// `lookupOrDefault`.
117   Value lookupOrNull(Value from) const;
118 
119   /// Map a value to the one provided.
120   void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); }
121 
122   /// Drop the last mapping for the given value.
123   void erase(Value value) { mapping.erase(value); }
124 
125 private:
126   /// Current value mappings.
127   BlockAndValueMapping mapping;
128 };
129 } // end anonymous namespace
130 
131 Value ConversionValueMapping::lookupOrDefault(Value from,
132                                               Type desiredType) const {
133   // If there was no desired type, simply find the leaf value.
134   if (!desiredType) {
135     // If this value had a valid mapping, unmap that value as well in the case
136     // that it was also replaced.
137     while (auto mappedValue = mapping.lookupOrNull(from))
138       from = mappedValue;
139     return from;
140   }
141 
142   // Otherwise, try to find the deepest value that has the desired type.
143   Value desiredValue;
144   do {
145     if (from.getType() == desiredType)
146       desiredValue = from;
147 
148     Value mappedValue = mapping.lookupOrNull(from);
149     if (!mappedValue)
150       break;
151     from = mappedValue;
152   } while (true);
153 
154   // If the desired value was found use it, otherwise default to the leaf value.
155   return desiredValue ? desiredValue : from;
156 }
157 
158 Value ConversionValueMapping::lookupOrNull(Value from) const {
159   Value result = lookupOrDefault(from);
160   return result == from ? nullptr : result;
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // ArgConverter
165 //===----------------------------------------------------------------------===//
166 namespace {
167 /// This class provides a simple interface for converting the types of block
168 /// arguments. This is done by creating a new block that contains the new legal
169 /// types and extracting the block that contains the old illegal types to allow
170 /// for undoing pending rewrites in the case of failure.
171 struct ArgConverter {
172   ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {}
173 
174   /// This structure contains the information pertaining to an argument that has
175   /// been converted.
176   struct ConvertedArgInfo {
177     ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
178                      Value castValue = nullptr)
179         : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
180 
181     /// The start index of in the new argument list that contains arguments that
182     /// replace the original.
183     unsigned newArgIdx;
184 
185     /// The number of arguments that replaced the original argument.
186     unsigned newArgSize;
187 
188     /// The cast value that was created to cast from the new arguments to the
189     /// old. This only used if 'newArgSize' > 1.
190     Value castValue;
191   };
192 
193   /// This structure contains information pertaining to a block that has had its
194   /// signature converted.
195   struct ConvertedBlockInfo {
196     ConvertedBlockInfo(Block *origBlock, TypeConverter &converter)
197         : origBlock(origBlock), converter(&converter) {}
198 
199     /// The original block that was requested to have its signature converted.
200     Block *origBlock;
201 
202     /// The conversion information for each of the arguments. The information is
203     /// None if the argument was dropped during conversion.
204     SmallVector<Optional<ConvertedArgInfo>, 1> argInfo;
205 
206     /// The type converter used to convert the arguments.
207     TypeConverter *converter;
208   };
209 
210   /// Return if the signature of the given block has already been converted.
211   bool hasBeenConverted(Block *block) const {
212     return conversionInfo.count(block) || convertedBlocks.count(block);
213   }
214 
215   /// Set the type converter to use for the given region.
216   void setConverter(Region *region, TypeConverter *typeConverter) {
217     assert(typeConverter && "expected valid type converter");
218     regionToConverter[region] = typeConverter;
219   }
220 
221   /// Return the type converter to use for the given region, or null if there
222   /// isn't one.
223   TypeConverter *getConverter(Region *region) {
224     return regionToConverter.lookup(region);
225   }
226 
227   //===--------------------------------------------------------------------===//
228   // Rewrite Application
229   //===--------------------------------------------------------------------===//
230 
231   /// Erase any rewrites registered for the blocks within the given operation
232   /// which is about to be removed. This merely drops the rewrites without
233   /// undoing them.
234   void notifyOpRemoved(Operation *op);
235 
236   /// Cleanup and undo any generated conversions for the arguments of block.
237   /// This method replaces the new block with the original, reverting the IR to
238   /// its original state.
239   void discardRewrites(Block *block);
240 
241   /// Fully replace uses of the old arguments with the new.
242   void applyRewrites(ConversionValueMapping &mapping);
243 
244   /// Materialize any necessary conversions for converted arguments that have
245   /// live users, using the provided `findLiveUser` to search for a user that
246   /// survives the conversion process.
247   LogicalResult
248   materializeLiveConversions(ConversionValueMapping &mapping,
249                              OpBuilder &builder,
250                              function_ref<Operation *(Value)> findLiveUser);
251 
252   //===--------------------------------------------------------------------===//
253   // Conversion
254   //===--------------------------------------------------------------------===//
255 
256   /// Attempt to convert the signature of the given block, if successful a new
257   /// block is returned containing the new arguments. Returns `block` if it did
258   /// not require conversion.
259   FailureOr<Block *>
260   convertSignature(Block *block, TypeConverter &converter,
261                    ConversionValueMapping &mapping,
262                    SmallVectorImpl<BlockArgument> &argReplacements);
263 
264   /// Apply the given signature conversion on the given block. The new block
265   /// containing the updated signature is returned. If no conversions were
266   /// necessary, e.g. if the block has no arguments, `block` is returned.
267   /// `converter` is used to generate any necessary cast operations that
268   /// translate between the origin argument types and those specified in the
269   /// signature conversion.
270   Block *applySignatureConversion(
271       Block *block, TypeConverter &converter,
272       TypeConverter::SignatureConversion &signatureConversion,
273       ConversionValueMapping &mapping,
274       SmallVectorImpl<BlockArgument> &argReplacements);
275 
276   /// Insert a new conversion into the cache.
277   void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
278 
279   /// A collection of blocks that have had their arguments converted. This is a
280   /// map from the new replacement block, back to the original block.
281   llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
282 
283   /// The set of original blocks that were converted.
284   DenseSet<Block *> convertedBlocks;
285 
286   /// A mapping from valid regions, to those containing the original blocks of a
287   /// conversion.
288   DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
289 
290   /// A mapping of regions to type converters that should be used when
291   /// converting the arguments of blocks within that region.
292   DenseMap<Region *, TypeConverter *> regionToConverter;
293 
294   /// The pattern rewriter to use when materializing conversions.
295   PatternRewriter &rewriter;
296 };
297 } // end anonymous namespace
298 
299 //===----------------------------------------------------------------------===//
300 // Rewrite Application
301 
302 void ArgConverter::notifyOpRemoved(Operation *op) {
303   if (conversionInfo.empty())
304     return;
305 
306   for (Region &region : op->getRegions()) {
307     for (Block &block : region) {
308       // Drop any rewrites from within.
309       for (Operation &nestedOp : block)
310         if (nestedOp.getNumRegions())
311           notifyOpRemoved(&nestedOp);
312 
313       // Check if this block was converted.
314       auto it = conversionInfo.find(&block);
315       if (it == conversionInfo.end())
316         continue;
317 
318       // Drop all uses of the original arguments and delete the original block.
319       Block *origBlock = it->second.origBlock;
320       for (BlockArgument arg : origBlock->getArguments())
321         arg.dropAllUses();
322       conversionInfo.erase(it);
323     }
324   }
325 }
326 
327 void ArgConverter::discardRewrites(Block *block) {
328   auto it = conversionInfo.find(block);
329   if (it == conversionInfo.end())
330     return;
331   Block *origBlock = it->second.origBlock;
332 
333   // Drop all uses of the new block arguments and replace uses of the new block.
334   for (int i = block->getNumArguments() - 1; i >= 0; --i)
335     block->getArgument(i).dropAllUses();
336   block->replaceAllUsesWith(origBlock);
337 
338   // Move the operations back the original block and the delete the new block.
339   origBlock->getOperations().splice(origBlock->end(), block->getOperations());
340   origBlock->moveBefore(block);
341   block->erase();
342 
343   convertedBlocks.erase(origBlock);
344   conversionInfo.erase(it);
345 }
346 
347 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
348   for (auto &info : conversionInfo) {
349     ConvertedBlockInfo &blockInfo = info.second;
350     Block *origBlock = blockInfo.origBlock;
351 
352     // Process the remapping for each of the original arguments.
353     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
354       Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
355       BlockArgument origArg = origBlock->getArgument(i);
356 
357       // Handle the case of a 1->0 value mapping.
358       if (!argInfo) {
359         if (Value newArg = mapping.lookupOrNull(origArg))
360           origArg.replaceAllUsesWith(newArg);
361         continue;
362       }
363 
364       // Otherwise this is a 1->1+ value mapping.
365       Value castValue = argInfo->castValue;
366       assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
367 
368       // If the argument is still used, replace it with the generated cast.
369       if (!origArg.use_empty())
370         origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue));
371     }
372   }
373 }
374 
375 LogicalResult ArgConverter::materializeLiveConversions(
376     ConversionValueMapping &mapping, OpBuilder &builder,
377     function_ref<Operation *(Value)> findLiveUser) {
378   for (auto &info : conversionInfo) {
379     Block *newBlock = info.first;
380     ConvertedBlockInfo &blockInfo = info.second;
381     Block *origBlock = blockInfo.origBlock;
382 
383     // Process the remapping for each of the original arguments.
384     for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
385       // FIXME: We should run the below checks even if the type conversion was
386       // 1->N, but a lot of existing lowering rely on the block argument being
387       // blindly replaced. Those usages should be updated, and this if should be
388       // removed.
389       if (blockInfo.argInfo[i])
390         continue;
391 
392       // If the type of this argument changed and the argument is still live, we
393       // need to materialize a conversion.
394       BlockArgument origArg = origBlock->getArgument(i);
395       auto argReplacementValue = mapping.lookupOrDefault(origArg);
396       bool isDroppedArg = argReplacementValue == origArg;
397       if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg)
398         continue;
399       Operation *liveUser = findLiveUser(origArg);
400       if (!liveUser)
401         continue;
402 
403       if (OpResult result = argReplacementValue.dyn_cast<OpResult>())
404         rewriter.setInsertionPointAfter(result.getOwner());
405       else
406         rewriter.setInsertionPointToStart(newBlock);
407       Value newArg = blockInfo.converter->materializeSourceConversion(
408           rewriter, origArg.getLoc(), origArg.getType(),
409           isDroppedArg ? ValueRange() : ValueRange(argReplacementValue));
410       if (!newArg) {
411         InFlightDiagnostic diag =
412             emitError(origArg.getLoc())
413             << "failed to materialize conversion for block argument #" << i
414             << " that remained live after conversion, type was "
415             << origArg.getType();
416         if (!isDroppedArg)
417           diag << ", with target type " << argReplacementValue.getType();
418         diag.attachNote(liveUser->getLoc())
419             << "see existing live user here: " << *liveUser;
420         return failure();
421       }
422       mapping.map(origArg, newArg);
423     }
424   }
425   return success();
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // Conversion
430 
431 FailureOr<Block *> ArgConverter::convertSignature(
432     Block *block, TypeConverter &converter, ConversionValueMapping &mapping,
433     SmallVectorImpl<BlockArgument> &argReplacements) {
434   // Check if the block was already converted. If the block is detached,
435   // conservatively assume it is going to be deleted.
436   if (hasBeenConverted(block) || !block->getParent())
437     return block;
438 
439   // Try to convert the signature for the block with the provided converter.
440   if (auto conversion = converter.convertBlockSignature(block))
441     return applySignatureConversion(block, converter, *conversion, mapping,
442                                     argReplacements);
443   return failure();
444 }
445 
446 Block *ArgConverter::applySignatureConversion(
447     Block *block, TypeConverter &converter,
448     TypeConverter::SignatureConversion &signatureConversion,
449     ConversionValueMapping &mapping,
450     SmallVectorImpl<BlockArgument> &argReplacements) {
451   // If no arguments are being changed or added, there is nothing to do.
452   unsigned origArgCount = block->getNumArguments();
453   auto convertedTypes = signatureConversion.getConvertedTypes();
454   if (origArgCount == 0 && convertedTypes.empty())
455     return block;
456 
457   // Split the block at the beginning to get a new block to use for the updated
458   // signature.
459   Block *newBlock = block->splitBlock(block->begin());
460   block->replaceAllUsesWith(newBlock);
461 
462   SmallVector<Value, 4> newArgRange(newBlock->addArguments(convertedTypes));
463   ArrayRef<Value> newArgs(newArgRange);
464 
465   // Remap each of the original arguments as determined by the signature
466   // conversion.
467   ConvertedBlockInfo info(block, converter);
468   info.argInfo.resize(origArgCount);
469 
470   OpBuilder::InsertionGuard guard(rewriter);
471   rewriter.setInsertionPointToStart(newBlock);
472   for (unsigned i = 0; i != origArgCount; ++i) {
473     auto inputMap = signatureConversion.getInputMapping(i);
474     if (!inputMap)
475       continue;
476     BlockArgument origArg = block->getArgument(i);
477 
478     // If inputMap->replacementValue is not nullptr, then the argument is
479     // dropped and a replacement value is provided to be the remappedValue.
480     if (inputMap->replacementValue) {
481       assert(inputMap->size == 0 &&
482              "invalid to provide a replacement value when the argument isn't "
483              "dropped");
484       mapping.map(origArg, inputMap->replacementValue);
485       argReplacements.push_back(origArg);
486       continue;
487     }
488 
489     // Otherwise, this is a 1->1+ mapping. Call into the provided type converter
490     // to pack the new values. For 1->1 mappings, if there is no materialization
491     // provided, use the argument directly instead.
492     auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
493     Value newArg = converter.materializeArgumentConversion(
494         rewriter, origArg.getLoc(), origArg.getType(), replArgs);
495     if (!newArg) {
496       assert(replArgs.size() == 1 &&
497              "couldn't materialize the result of 1->N conversion");
498       newArg = replArgs.front();
499     }
500     mapping.map(origArg, newArg);
501     argReplacements.push_back(origArg);
502     info.argInfo[i] =
503         ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
504   }
505 
506   // Remove the original block from the region and return the new one.
507   insertConversion(newBlock, std::move(info));
508   return newBlock;
509 }
510 
511 void ArgConverter::insertConversion(Block *newBlock,
512                                     ConvertedBlockInfo &&info) {
513   // Get a region to insert the old block.
514   Region *region = newBlock->getParent();
515   std::unique_ptr<Region> &mappedRegion = regionMapping[region];
516   if (!mappedRegion)
517     mappedRegion = std::make_unique<Region>(region->getParentOp());
518 
519   // Move the original block to the mapped region and emplace the conversion.
520   mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
521                                    info.origBlock->getIterator());
522   convertedBlocks.insert(info.origBlock);
523   conversionInfo.insert({newBlock, std::move(info)});
524 }
525 
526 //===----------------------------------------------------------------------===//
527 // Rewriter and Translation State
528 //===----------------------------------------------------------------------===//
529 namespace {
530 /// This class contains a snapshot of the current conversion rewriter state.
531 /// This is useful when saving and undoing a set of rewrites.
532 struct RewriterState {
533   RewriterState(unsigned numCreatedOps, unsigned numReplacements,
534                 unsigned numArgReplacements, unsigned numBlockActions,
535                 unsigned numIgnoredOperations, unsigned numRootUpdates)
536       : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
537         numArgReplacements(numArgReplacements),
538         numBlockActions(numBlockActions),
539         numIgnoredOperations(numIgnoredOperations),
540         numRootUpdates(numRootUpdates) {}
541 
542   /// The current number of created operations.
543   unsigned numCreatedOps;
544 
545   /// The current number of replacements queued.
546   unsigned numReplacements;
547 
548   /// The current number of argument replacements queued.
549   unsigned numArgReplacements;
550 
551   /// The current number of block actions performed.
552   unsigned numBlockActions;
553 
554   /// The current number of ignored operations.
555   unsigned numIgnoredOperations;
556 
557   /// The current number of operations that were updated in place.
558   unsigned numRootUpdates;
559 };
560 
561 /// The state of an operation that was updated by a pattern in-place. This
562 /// contains all of the necessary information to reconstruct an operation that
563 /// was updated in place.
564 class OperationTransactionState {
565 public:
566   OperationTransactionState() = default;
567   OperationTransactionState(Operation *op)
568       : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
569         operands(op->operand_begin(), op->operand_end()),
570         successors(op->successor_begin(), op->successor_end()) {}
571 
572   /// Discard the transaction state and reset the state of the original
573   /// operation.
574   void resetOperation() const {
575     op->setLoc(loc);
576     op->setAttrs(attrs);
577     op->setOperands(operands);
578     for (auto it : llvm::enumerate(successors))
579       op->setSuccessor(it.value(), it.index());
580   }
581 
582   /// Return the original operation of this state.
583   Operation *getOperation() const { return op; }
584 
585 private:
586   Operation *op;
587   LocationAttr loc;
588   DictionaryAttr attrs;
589   SmallVector<Value, 8> operands;
590   SmallVector<Block *, 2> successors;
591 };
592 
593 /// This class represents one requested operation replacement via 'replaceOp' or
594 /// 'eraseOp`.
595 struct OpReplacement {
596   OpReplacement() = default;
597   OpReplacement(TypeConverter *converter) : converter(converter) {}
598 
599   /// An optional type converter that can be used to materialize conversions
600   /// between the new and old values if necessary.
601   TypeConverter *converter = nullptr;
602 };
603 
604 /// The kind of the block action performed during the rewrite.  Actions can be
605 /// undone if the conversion fails.
606 enum class BlockActionKind {
607   Create,
608   Erase,
609   Merge,
610   Move,
611   Split,
612   TypeConversion
613 };
614 
615 /// Original position of the given block in its parent region. During undo
616 /// actions, the block needs to be placed after `insertAfterBlock`.
617 struct BlockPosition {
618   Region *region;
619   Block *insertAfterBlock;
620 };
621 
622 /// Information needed to undo the merge actions.
623 /// - the source block, and
624 /// - the Operation that was the last operation in the dest block before the
625 ///   merge (could be null if the dest block was empty).
626 struct MergeInfo {
627   Block *sourceBlock;
628   Operation *destBlockLastInst;
629 };
630 
631 /// The storage class for an undoable block action (one of BlockActionKind),
632 /// contains the information necessary to undo this action.
633 struct BlockAction {
634   static BlockAction getCreate(Block *block) {
635     return {BlockActionKind::Create, block, {}};
636   }
637   static BlockAction getErase(Block *block, BlockPosition originalPosition) {
638     return {BlockActionKind::Erase, block, {originalPosition}};
639   }
640   static BlockAction getMerge(Block *block, Block *sourceBlock) {
641     BlockAction action{BlockActionKind::Merge, block, {}};
642     action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
643     return action;
644   }
645   static BlockAction getMove(Block *block, BlockPosition originalPosition) {
646     return {BlockActionKind::Move, block, {originalPosition}};
647   }
648   static BlockAction getSplit(Block *block, Block *originalBlock) {
649     BlockAction action{BlockActionKind::Split, block, {}};
650     action.originalBlock = originalBlock;
651     return action;
652   }
653   static BlockAction getTypeConversion(Block *block) {
654     return BlockAction{BlockActionKind::TypeConversion, block, {}};
655   }
656 
657   // The action kind.
658   BlockActionKind kind;
659 
660   // A pointer to the block that was created by the action.
661   Block *block;
662 
663   union {
664     // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
665     // contains a pointer to the region that originally contained the block as
666     // well as the position of the block in that region.
667     BlockPosition originalPosition;
668     // In use if kind == BlockActionKind::Split and contains a pointer to the
669     // block that was split into two parts.
670     Block *originalBlock;
671     // In use if kind == BlockActionKind::Merge, and contains the information
672     // needed to undo the merge.
673     MergeInfo mergeInfo;
674   };
675 };
676 } // end anonymous namespace
677 
678 //===----------------------------------------------------------------------===//
679 // ConversionPatternRewriterImpl
680 //===----------------------------------------------------------------------===//
681 namespace mlir {
682 namespace detail {
683 struct ConversionPatternRewriterImpl {
684   ConversionPatternRewriterImpl(PatternRewriter &rewriter)
685       : argConverter(rewriter) {}
686 
687   /// Cleanup and destroy any generated rewrite operations. This method is
688   /// invoked when the conversion process fails.
689   void discardRewrites();
690 
691   /// Apply all requested operation rewrites. This method is invoked when the
692   /// conversion process succeeds.
693   void applyRewrites();
694 
695   //===--------------------------------------------------------------------===//
696   // State Management
697   //===--------------------------------------------------------------------===//
698 
699   /// Return the current state of the rewriter.
700   RewriterState getCurrentState();
701 
702   /// Reset the state of the rewriter to a previously saved point.
703   void resetState(RewriterState state);
704 
705   /// Erase any blocks that were unlinked from their regions and stored in block
706   /// actions.
707   void eraseDanglingBlocks();
708 
709   /// Undo the block actions (motions, splits) one by one in reverse order until
710   /// "numActionsToKeep" actions remains.
711   void undoBlockActions(unsigned numActionsToKeep = 0);
712 
713   /// Remap the given operands to those with potentially different types. The
714   /// provided type converter is used to ensure that the remapped types are
715   /// legal. Returns success if the operands could be remapped, failure
716   /// otherwise.
717   LogicalResult remapValues(Location loc, PatternRewriter &rewriter,
718                             TypeConverter *converter,
719                             Operation::operand_range operands,
720                             SmallVectorImpl<Value> &remapped);
721 
722   /// Returns true if the given operation is ignored, and does not need to be
723   /// converted.
724   bool isOpIgnored(Operation *op) const;
725 
726   /// Recursively marks the nested operations under 'op' as ignored. This
727   /// removes them from being considered for legalization.
728   void markNestedOpsIgnored(Operation *op);
729 
730   //===--------------------------------------------------------------------===//
731   // Type Conversion
732   //===--------------------------------------------------------------------===//
733 
734   /// Convert the signature of the given block.
735   FailureOr<Block *> convertBlockSignature(
736       Block *block, TypeConverter &converter,
737       TypeConverter::SignatureConversion *conversion = nullptr);
738 
739   /// Apply a signature conversion on the given region.
740   Block *
741   applySignatureConversion(Region *region,
742                            TypeConverter::SignatureConversion &conversion);
743 
744   /// Convert the types of block arguments within the given region.
745   FailureOr<Block *>
746   convertRegionTypes(Region *region, TypeConverter &converter,
747                      TypeConverter::SignatureConversion *entryConversion);
748 
749   //===--------------------------------------------------------------------===//
750   // Rewriter Notification Hooks
751   //===--------------------------------------------------------------------===//
752 
753   /// PatternRewriter hook for replacing the results of an operation.
754   void notifyOpReplaced(Operation *op, ValueRange newValues);
755 
756   /// Notifies that a block is about to be erased.
757   void notifyBlockIsBeingErased(Block *block);
758 
759   /// Notifies that a block was created.
760   void notifyCreatedBlock(Block *block);
761 
762   /// Notifies that a block was split.
763   void notifySplitBlock(Block *block, Block *continuation);
764 
765   /// Notifies that `block` is being merged with `srcBlock`.
766   void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
767 
768   /// Notifies that the blocks of a region are about to be moved.
769   void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
770                                         Region::iterator before);
771 
772   /// Notifies that the blocks of a region were cloned into another.
773   void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
774                                    Location origRegionLoc);
775 
776   /// Notifies that a pattern match failed for the given reason.
777   LogicalResult
778   notifyMatchFailure(Location loc,
779                      function_ref<void(Diagnostic &)> reasonCallback);
780 
781   //===--------------------------------------------------------------------===//
782   // State
783   //===--------------------------------------------------------------------===//
784 
785   // Mapping between replaced values that differ in type. This happens when
786   // replacing a value with one of a different type.
787   ConversionValueMapping mapping;
788 
789   /// Utility used to convert block arguments.
790   ArgConverter argConverter;
791 
792   /// Ordered vector of all of the newly created operations during conversion.
793   std::vector<Operation *> createdOps;
794 
795   /// Ordered map of requested operation replacements.
796   llvm::MapVector<Operation *, OpReplacement> replacements;
797 
798   /// Ordered vector of any requested block argument replacements.
799   SmallVector<BlockArgument, 4> argReplacements;
800 
801   /// Ordered list of block operations (creations, splits, motions).
802   SmallVector<BlockAction, 4> blockActions;
803 
804   /// A set of operations that should no longer be considered for legalization,
805   /// but were not directly replace/erased/etc. by a pattern. These are
806   /// generally child operations of other operations who were
807   /// replaced/erased/etc. This is not meant to be an exhaustive list of all
808   /// operations, but the minimal set that can be used to detect if a given
809   /// operation should be `ignored`. For example, we may add the operations that
810   /// define non-empty regions to the set, but not any of the others. This
811   /// simplifies the amount of memory needed as we can query if the parent
812   /// operation was ignored.
813   llvm::SetVector<Operation *> ignoredOps;
814 
815   /// A transaction state for each of operations that were updated in-place.
816   SmallVector<OperationTransactionState, 4> rootUpdates;
817 
818   /// A vector of indices into `replacements` of operations that were replaced
819   /// with values with different result types than the original operation, e.g.
820   /// 1->N conversion of some kind.
821   SmallVector<unsigned, 4> operationsWithChangedResults;
822 
823   /// A default type converter, used when block conversions do not have one
824   /// explicitly provided.
825   TypeConverter defaultTypeConverter;
826 
827   /// The current conversion pattern that is being rewritten, or nullptr if
828   /// called from outside of a conversion pattern rewrite.
829   const ConversionPattern *currentConversionPattern = nullptr;
830 
831 #ifndef NDEBUG
832   /// A set of operations that have pending updates. This tracking isn't
833   /// strictly necessary, and is thus only active during debug builds for extra
834   /// verification.
835   SmallPtrSet<Operation *, 1> pendingRootUpdates;
836 
837   /// A logger used to emit diagnostics during the conversion process.
838   llvm::ScopedPrinter logger{llvm::dbgs()};
839 #endif
840 };
841 } // end namespace detail
842 } // end namespace mlir
843 
844 /// Detach any operations nested in the given operation from their parent
845 /// blocks, and erase the given operation. This can be used when the nested
846 /// operations are scheduled for erasure themselves, so deleting the regions of
847 /// the given operation together with their content would result in double-free.
848 /// This happens, for example, when rolling back op creation in the reverse
849 /// order and if the nested ops were created before the parent op. This function
850 /// does not need to collect nested ops recursively because it is expected to
851 /// also be called for each nested op when it is about to be deleted.
852 static void detachNestedAndErase(Operation *op) {
853   for (Region &region : op->getRegions()) {
854     for (Block &block : region.getBlocks()) {
855       while (!block.getOperations().empty())
856         block.getOperations().remove(block.getOperations().begin());
857       block.dropAllDefinedValueUses();
858     }
859   }
860   op->dropAllUses();
861   op->erase();
862 }
863 
864 void ConversionPatternRewriterImpl::discardRewrites() {
865   // Reset any operations that were updated in place.
866   for (auto &state : rootUpdates)
867     state.resetOperation();
868 
869   undoBlockActions();
870 
871   // Remove any newly created ops.
872   for (auto *op : llvm::reverse(createdOps))
873     detachNestedAndErase(op);
874 }
875 
876 void ConversionPatternRewriterImpl::applyRewrites() {
877   // Apply all of the rewrites replacements requested during conversion.
878   for (auto &repl : replacements) {
879     for (OpResult result : repl.first->getResults())
880       if (Value newValue = mapping.lookupOrNull(result))
881         result.replaceAllUsesWith(newValue);
882 
883     // If this operation defines any regions, drop any pending argument
884     // rewrites.
885     if (repl.first->getNumRegions())
886       argConverter.notifyOpRemoved(repl.first);
887   }
888 
889   // Apply all of the requested argument replacements.
890   for (BlockArgument arg : argReplacements) {
891     Value repl = mapping.lookupOrDefault(arg);
892     if (repl.isa<BlockArgument>()) {
893       arg.replaceAllUsesWith(repl);
894       continue;
895     }
896 
897     // If the replacement value is an operation, we check to make sure that we
898     // don't replace uses that are within the parent operation of the
899     // replacement value.
900     Operation *replOp = repl.cast<OpResult>().getOwner();
901     Block *replBlock = replOp->getBlock();
902     arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
903       Operation *user = operand.getOwner();
904       return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
905     });
906   }
907 
908   // In a second pass, erase all of the replaced operations in reverse. This
909   // allows processing nested operations before their parent region is
910   // destroyed.
911   for (auto &repl : llvm::reverse(replacements))
912     repl.first->erase();
913 
914   argConverter.applyRewrites(mapping);
915 
916   // Now that the ops have been erased, also erase dangling blocks.
917   eraseDanglingBlocks();
918 }
919 
920 //===----------------------------------------------------------------------===//
921 // State Management
922 
923 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
924   return RewriterState(createdOps.size(), replacements.size(),
925                        argReplacements.size(), blockActions.size(),
926                        ignoredOps.size(), rootUpdates.size());
927 }
928 
929 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
930   // Reset any operations that were updated in place.
931   for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
932     rootUpdates[i].resetOperation();
933   rootUpdates.resize(state.numRootUpdates);
934 
935   // Reset any replaced arguments.
936   for (BlockArgument replacedArg :
937        llvm::drop_begin(argReplacements, state.numArgReplacements))
938     mapping.erase(replacedArg);
939   argReplacements.resize(state.numArgReplacements);
940 
941   // Undo any block actions.
942   undoBlockActions(state.numBlockActions);
943 
944   // Reset any replaced operations and undo any saved mappings.
945   for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
946     for (auto result : repl.first->getResults())
947       mapping.erase(result);
948   while (replacements.size() != state.numReplacements)
949     replacements.pop_back();
950 
951   // Pop all of the newly created operations.
952   while (createdOps.size() != state.numCreatedOps) {
953     detachNestedAndErase(createdOps.back());
954     createdOps.pop_back();
955   }
956 
957   // Pop all of the recorded ignored operations that are no longer valid.
958   while (ignoredOps.size() != state.numIgnoredOperations)
959     ignoredOps.pop_back();
960 
961   // Reset operations with changed results.
962   while (!operationsWithChangedResults.empty() &&
963          operationsWithChangedResults.back() >= state.numReplacements)
964     operationsWithChangedResults.pop_back();
965 }
966 
967 void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
968   for (auto &action : blockActions)
969     if (action.kind == BlockActionKind::Erase)
970       delete action.block;
971 }
972 
973 void ConversionPatternRewriterImpl::undoBlockActions(
974     unsigned numActionsToKeep) {
975   for (auto &action :
976        llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
977     switch (action.kind) {
978     // Delete the created block.
979     case BlockActionKind::Create: {
980       // Unlink all of the operations within this block, they will be deleted
981       // separately.
982       auto &blockOps = action.block->getOperations();
983       while (!blockOps.empty())
984         blockOps.remove(blockOps.begin());
985       action.block->dropAllDefinedValueUses();
986       action.block->erase();
987       break;
988     }
989     // Put the block (owned by action) back into its original position.
990     case BlockActionKind::Erase: {
991       auto &blockList = action.originalPosition.region->getBlocks();
992       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
993       blockList.insert((insertAfterBlock
994                             ? std::next(Region::iterator(insertAfterBlock))
995                             : blockList.begin()),
996                        action.block);
997       break;
998     }
999     // Split the block at the position which was originally the end of the
1000     // destination block (owned by action), and put the instructions back into
1001     // the block used before the merge.
1002     case BlockActionKind::Merge: {
1003       Block *sourceBlock = action.mergeInfo.sourceBlock;
1004       Block::iterator splitPoint =
1005           (action.mergeInfo.destBlockLastInst
1006                ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
1007                : action.block->begin());
1008       sourceBlock->getOperations().splice(sourceBlock->begin(),
1009                                           action.block->getOperations(),
1010                                           splitPoint, action.block->end());
1011       break;
1012     }
1013     // Move the block back to its original position.
1014     case BlockActionKind::Move: {
1015       Region *originalRegion = action.originalPosition.region;
1016       Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1017       originalRegion->getBlocks().splice(
1018           (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
1019                             : originalRegion->end()),
1020           action.block->getParent()->getBlocks(), action.block);
1021       break;
1022     }
1023     // Merge back the block that was split out.
1024     case BlockActionKind::Split: {
1025       action.originalBlock->getOperations().splice(
1026           action.originalBlock->end(), action.block->getOperations());
1027       action.block->dropAllDefinedValueUses();
1028       action.block->erase();
1029       break;
1030     }
1031     // Undo the type conversion.
1032     case BlockActionKind::TypeConversion: {
1033       argConverter.discardRewrites(action.block);
1034       break;
1035     }
1036     }
1037   }
1038   blockActions.resize(numActionsToKeep);
1039 }
1040 
1041 LogicalResult ConversionPatternRewriterImpl::remapValues(
1042     Location loc, PatternRewriter &rewriter, TypeConverter *converter,
1043     Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
1044   remapped.reserve(llvm::size(operands));
1045 
1046   SmallVector<Type, 1> legalTypes;
1047   for (auto it : llvm::enumerate(operands)) {
1048     Value operand = it.value();
1049     Type origType = operand.getType();
1050 
1051     // If a converter was provided, get the desired legal types for this
1052     // operand.
1053     Type desiredType;
1054     if (converter) {
1055       // If there is no legal conversion, fail to match this pattern.
1056       legalTypes.clear();
1057       if (failed(converter->convertType(origType, legalTypes))) {
1058         return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1059           diag << "unable to convert type for operand #" << it.index()
1060                << ", type was " << origType;
1061         });
1062       }
1063       // TODO: There currently isn't any mechanism to do 1->N type conversion
1064       // via the PatternRewriter replacement API, so for now we just ignore it.
1065       if (legalTypes.size() == 1)
1066         desiredType = legalTypes.front();
1067     } else {
1068       // TODO: What we should do here is just set `desiredType` to `origType`
1069       // and then handle the necessary type conversions after the conversion
1070       // process has finished. Unfortunately a lot of patterns currently rely on
1071       // receiving the new operands even if the types change, so we keep the
1072       // original behavior here for now until all of the patterns relying on
1073       // this get updated.
1074     }
1075     Value newOperand = mapping.lookupOrDefault(operand, desiredType);
1076 
1077     // Handle the case where the conversion was 1->1 and the new operand type
1078     // isn't legal.
1079     Type newOperandType = newOperand.getType();
1080     if (converter && desiredType && newOperandType != desiredType) {
1081       // Attempt to materialize a conversion for this new value.
1082       newOperand = converter->materializeTargetConversion(
1083           rewriter, loc, desiredType, newOperand);
1084       if (!newOperand) {
1085         return notifyMatchFailure(loc, [=](Diagnostic &diag) {
1086           diag << "unable to materialize a conversion for "
1087                   "operand #"
1088                << it.index() << ", from " << newOperandType << " to "
1089                << desiredType;
1090         });
1091       }
1092     }
1093     remapped.push_back(newOperand);
1094   }
1095   return success();
1096 }
1097 
1098 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
1099   // Check to see if this operation was replaced or its parent ignored.
1100   return replacements.count(op) || ignoredOps.count(op->getParentOp());
1101 }
1102 
1103 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
1104   // Walk this operation and collect nested operations that define non-empty
1105   // regions. We mark such operations as 'ignored' so that we know we don't have
1106   // to convert them, or their nested ops.
1107   if (op->getNumRegions() == 0)
1108     return;
1109   op->walk([&](Operation *op) {
1110     if (llvm::any_of(op->getRegions(),
1111                      [](Region &region) { return !region.empty(); }))
1112       ignoredOps.insert(op);
1113   });
1114 }
1115 
1116 //===----------------------------------------------------------------------===//
1117 // Type Conversion
1118 
1119 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1120     Block *block, TypeConverter &converter,
1121     TypeConverter::SignatureConversion *conversion) {
1122   FailureOr<Block *> result =
1123       conversion ? argConverter.applySignatureConversion(
1124                        block, converter, *conversion, mapping, argReplacements)
1125                  : argConverter.convertSignature(block, converter, mapping,
1126                                                  argReplacements);
1127   if (Block *newBlock = result.getValue()) {
1128     if (newBlock != block)
1129       blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1130   }
1131   return result;
1132 }
1133 
1134 Block *ConversionPatternRewriterImpl::applySignatureConversion(
1135     Region *region, TypeConverter::SignatureConversion &conversion) {
1136   if (!region->empty()) {
1137     return *convertBlockSignature(&region->front(), defaultTypeConverter,
1138                                   &conversion);
1139   }
1140   return nullptr;
1141 }
1142 
1143 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
1144     Region *region, TypeConverter &converter,
1145     TypeConverter::SignatureConversion *entryConversion) {
1146   argConverter.setConverter(region, &converter);
1147   if (region->empty())
1148     return nullptr;
1149 
1150   // Convert the arguments of each block within the region.
1151   FailureOr<Block *> newEntry =
1152       convertBlockSignature(&region->front(), converter, entryConversion);
1153   for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
1154     if (failed(convertBlockSignature(&block, converter)))
1155       return failure();
1156   return newEntry;
1157 }
1158 
1159 //===----------------------------------------------------------------------===//
1160 // Rewriter Notification Hooks
1161 
1162 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1163                                                      ValueRange newValues) {
1164   assert(newValues.size() == op->getNumResults());
1165   assert(!replacements.count(op) && "operation was already replaced");
1166 
1167   // Track if any of the results changed, e.g. erased and replaced with null.
1168   bool resultChanged = false;
1169 
1170   // Create mappings for each of the new result values.
1171   Value newValue, result;
1172   for (auto it : llvm::zip(newValues, op->getResults())) {
1173     std::tie(newValue, result) = it;
1174     if (!newValue) {
1175       resultChanged = true;
1176       continue;
1177     }
1178     // Remap, and check for any result type changes.
1179     mapping.map(result, newValue);
1180     resultChanged |= (newValue.getType() != result.getType());
1181   }
1182   if (resultChanged)
1183     operationsWithChangedResults.push_back(replacements.size());
1184 
1185   // Record the requested operation replacement.
1186   TypeConverter *converter = nullptr;
1187   if (currentConversionPattern)
1188     converter = currentConversionPattern->getTypeConverter();
1189   replacements.insert(std::make_pair(op, OpReplacement(converter)));
1190 
1191   // Mark this operation as recursively ignored so that we don't need to
1192   // convert any nested operations.
1193   markNestedOpsIgnored(op);
1194 }
1195 
1196 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
1197   Region *region = block->getParent();
1198   Block *origPrevBlock = block->getPrevNode();
1199   blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1200 }
1201 
1202 void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
1203   blockActions.push_back(BlockAction::getCreate(block));
1204 }
1205 
1206 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
1207                                                      Block *continuation) {
1208   blockActions.push_back(BlockAction::getSplit(continuation, block));
1209 }
1210 
1211 void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block,
1212                                                             Block *srcBlock) {
1213   blockActions.push_back(BlockAction::getMerge(block, srcBlock));
1214 }
1215 
1216 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
1217     Region &region, Region &parent, Region::iterator before) {
1218   if (region.empty())
1219     return;
1220   Block *laterBlock = &region.back();
1221   for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1222     blockActions.push_back(
1223         BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
1224     laterBlock = &earlierBlock;
1225   }
1226   blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
1227 }
1228 
1229 void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
1230     iterator_range<Region::iterator> &blocks, Location origRegionLoc) {
1231   for (Block &block : blocks)
1232     blockActions.push_back(BlockAction::getCreate(&block));
1233 
1234   // Compute the conversion set for the inlined region.
1235   auto result = computeConversionSet(blocks, origRegionLoc, createdOps);
1236 
1237   // This original region has already had its conversion set computed, so there
1238   // shouldn't be any new failures.
1239   (void)result;
1240   assert(succeeded(result) && "expected region to have no unreachable blocks");
1241 }
1242 
1243 LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
1244     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1245   LLVM_DEBUG({
1246     Diagnostic diag(loc, DiagnosticSeverity::Remark);
1247     reasonCallback(diag);
1248     logger.startLine() << "** Failure : " << diag.str() << "\n";
1249   });
1250   return failure();
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // ConversionPatternRewriter
1255 //===----------------------------------------------------------------------===//
1256 
1257 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
1258     : PatternRewriter(ctx),
1259       impl(new detail::ConversionPatternRewriterImpl(*this)) {}
1260 ConversionPatternRewriter::~ConversionPatternRewriter() {}
1261 
1262 /// PatternRewriter hook for replacing the results of an operation when the
1263 /// given functor returns true.
1264 void ConversionPatternRewriter::replaceOpWithIf(
1265     Operation *op, ValueRange newValues, bool *allUsesReplaced,
1266     llvm::unique_function<bool(OpOperand &) const> functor) {
1267   // TODO: To support this we will need to rework a bit of how replacements are
1268   // tracked, given that this isn't guranteed to replace all of the uses of an
1269   // operation. The main change is that now an operation can be replaced
1270   // multiple times, in parts. The current "set" based tracking is mainly useful
1271   // for tracking if a replaced operation should be ignored, i.e. if all of the
1272   // uses will be replaced.
1273   llvm_unreachable(
1274       "replaceOpWithIf is currently not supported by DialectConversion");
1275 }
1276 
1277 /// PatternRewriter hook for replacing the results of an operation.
1278 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
1279   LLVM_DEBUG({
1280     impl->logger.startLine()
1281         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
1282   });
1283   impl->notifyOpReplaced(op, newValues);
1284 }
1285 
1286 /// PatternRewriter hook for erasing a dead operation. The uses of this
1287 /// operation *must* be made dead by the end of the conversion process,
1288 /// otherwise an assert will be issued.
1289 void ConversionPatternRewriter::eraseOp(Operation *op) {
1290   LLVM_DEBUG({
1291     impl->logger.startLine()
1292         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
1293   });
1294   SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
1295   impl->notifyOpReplaced(op, nullRepls);
1296 }
1297 
1298 void ConversionPatternRewriter::eraseBlock(Block *block) {
1299   impl->notifyBlockIsBeingErased(block);
1300 
1301   // Mark all ops for erasure.
1302   for (Operation &op : *block)
1303     eraseOp(&op);
1304 
1305   // Unlink the block from its parent region. The block is kept in the block
1306   // action and will be actually destroyed when rewrites are applied. This
1307   // allows us to keep the operations in the block live and undo the removal by
1308   // re-inserting the block.
1309   block->getParent()->getBlocks().remove(block);
1310 }
1311 
1312 Block *ConversionPatternRewriter::applySignatureConversion(
1313     Region *region, TypeConverter::SignatureConversion &conversion) {
1314   return impl->applySignatureConversion(region, conversion);
1315 }
1316 
1317 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
1318     Region *region, TypeConverter &converter,
1319     TypeConverter::SignatureConversion *entryConversion) {
1320   return impl->convertRegionTypes(region, converter, entryConversion);
1321 }
1322 
1323 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
1324                                                            Value to) {
1325   LLVM_DEBUG({
1326     Operation *parentOp = from.getOwner()->getParentOp();
1327     impl->logger.startLine() << "** Replace Argument : '" << from
1328                              << "'(in region of '" << parentOp->getName()
1329                              << "'(" << from.getOwner()->getParentOp() << ")\n";
1330   });
1331   impl->argReplacements.push_back(from);
1332   impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
1333 }
1334 
1335 /// Return the converted value that replaces 'key'. Return 'key' if there is
1336 /// no such a converted value.
1337 Value ConversionPatternRewriter::getRemappedValue(Value key) {
1338   return impl->mapping.lookupOrDefault(key);
1339 }
1340 
1341 /// PatternRewriter hook for creating a new block with the given arguments.
1342 void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
1343   impl->notifyCreatedBlock(block);
1344 }
1345 
1346 /// PatternRewriter hook for splitting a block into two parts.
1347 Block *ConversionPatternRewriter::splitBlock(Block *block,
1348                                              Block::iterator before) {
1349   auto *continuation = PatternRewriter::splitBlock(block, before);
1350   impl->notifySplitBlock(block, continuation);
1351   return continuation;
1352 }
1353 
1354 /// PatternRewriter hook for merging a block into another.
1355 void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
1356                                             ValueRange argValues) {
1357   impl->notifyBlocksBeingMerged(dest, source);
1358   assert(llvm::all_of(source->getPredecessors(),
1359                       [dest](Block *succ) { return succ == dest; }) &&
1360          "expected 'source' to have no predecessors or only 'dest'");
1361   assert(argValues.size() == source->getNumArguments() &&
1362          "incorrect # of argument replacement values");
1363   for (auto it : llvm::zip(source->getArguments(), argValues))
1364     replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1365   dest->getOperations().splice(dest->end(), source->getOperations());
1366   eraseBlock(source);
1367 }
1368 
1369 /// PatternRewriter hook for moving blocks out of a region.
1370 void ConversionPatternRewriter::inlineRegionBefore(Region &region,
1371                                                    Region &parent,
1372                                                    Region::iterator before) {
1373   impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1374   PatternRewriter::inlineRegionBefore(region, parent, before);
1375 }
1376 
1377 /// PatternRewriter hook for cloning blocks of one region into another.
1378 void ConversionPatternRewriter::cloneRegionBefore(
1379     Region &region, Region &parent, Region::iterator before,
1380     BlockAndValueMapping &mapping) {
1381   if (region.empty())
1382     return;
1383   PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1384 
1385   // Collect the range of the cloned blocks.
1386   auto clonedBeginIt = mapping.lookup(&region.front())->getIterator();
1387   auto clonedBlocks = llvm::make_range(clonedBeginIt, before);
1388   impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
1389 }
1390 
1391 /// PatternRewriter hook for creating a new operation.
1392 void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
1393   LLVM_DEBUG({
1394     impl->logger.startLine()
1395         << "** Insert  : '" << op->getName() << "'(" << op << ")\n";
1396   });
1397   impl->createdOps.push_back(op);
1398 }
1399 
1400 /// PatternRewriter hook for updating the root operation in-place.
1401 void ConversionPatternRewriter::startRootUpdate(Operation *op) {
1402 #ifndef NDEBUG
1403   impl->pendingRootUpdates.insert(op);
1404 #endif
1405   impl->rootUpdates.emplace_back(op);
1406 }
1407 
1408 /// PatternRewriter hook for updating the root operation in-place.
1409 void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
1410   // There is nothing to do here, we only need to track the operation at the
1411   // start of the update.
1412 #ifndef NDEBUG
1413   assert(impl->pendingRootUpdates.erase(op) &&
1414          "operation did not have a pending in-place update");
1415 #endif
1416 }
1417 
1418 /// PatternRewriter hook for updating the root operation in-place.
1419 void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
1420 #ifndef NDEBUG
1421   assert(impl->pendingRootUpdates.erase(op) &&
1422          "operation did not have a pending in-place update");
1423 #endif
1424   // Erase the last update for this operation.
1425   auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
1426   auto &rootUpdates = impl->rootUpdates;
1427   auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1428   rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
1429 }
1430 
1431 /// PatternRewriter hook for notifying match failure reasons.
1432 LogicalResult ConversionPatternRewriter::notifyMatchFailure(
1433     Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
1434   return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
1435 }
1436 
1437 /// Return a reference to the internal implementation.
1438 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
1439   return *impl;
1440 }
1441 
1442 //===----------------------------------------------------------------------===//
1443 // ConversionPattern
1444 //===----------------------------------------------------------------------===//
1445 
1446 /// Attempt to match and rewrite the IR root at the specified operation.
1447 LogicalResult
1448 ConversionPattern::matchAndRewrite(Operation *op,
1449                                    PatternRewriter &rewriter) const {
1450   auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1451   auto &rewriterImpl = dialectRewriter.getImpl();
1452 
1453   // Track the current conversion pattern in the rewriter.
1454   assert(!rewriterImpl.currentConversionPattern &&
1455          "already inside of a pattern rewrite");
1456   llvm::SaveAndRestore<const ConversionPattern *> currentPatternGuard(
1457       rewriterImpl.currentConversionPattern, this);
1458 
1459   // Remap the operands of the operation.
1460   SmallVector<Value, 4> operands;
1461   if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter,
1462                                       getTypeConverter(), op->getOperands(),
1463                                       operands))) {
1464     return failure();
1465   }
1466   return matchAndRewrite(op, operands, dialectRewriter);
1467 }
1468 
1469 //===----------------------------------------------------------------------===//
1470 // OperationLegalizer
1471 //===----------------------------------------------------------------------===//
1472 
1473 namespace {
1474 /// A set of rewrite patterns that can be used to legalize a given operation.
1475 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
1476 
1477 /// This class defines a recursive operation legalizer.
1478 class OperationLegalizer {
1479 public:
1480   using LegalizationAction = ConversionTarget::LegalizationAction;
1481 
1482   OperationLegalizer(ConversionTarget &targetInfo,
1483                      const FrozenRewritePatternList &patterns);
1484 
1485   /// Returns true if the given operation is known to be illegal on the target.
1486   bool isIllegal(Operation *op) const;
1487 
1488   /// Attempt to legalize the given operation. Returns success if the operation
1489   /// was legalized, failure otherwise.
1490   LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
1491 
1492   /// Returns the conversion target in use by the legalizer.
1493   ConversionTarget &getTarget() { return target; }
1494 
1495 private:
1496   /// Attempt to legalize the given operation by folding it.
1497   LogicalResult legalizeWithFold(Operation *op,
1498                                  ConversionPatternRewriter &rewriter);
1499 
1500   /// Attempt to legalize the given operation by applying a pattern. Returns
1501   /// success if the operation was legalized, failure otherwise.
1502   LogicalResult legalizeWithPattern(Operation *op,
1503                                     ConversionPatternRewriter &rewriter);
1504 
1505   /// Return true if the given pattern may be applied to the given operation,
1506   /// false otherwise.
1507   bool canApplyPattern(Operation *op, const Pattern &pattern,
1508                        ConversionPatternRewriter &rewriter);
1509 
1510   /// Legalize the resultant IR after successfully applying the given pattern.
1511   LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
1512                                       ConversionPatternRewriter &rewriter,
1513                                       RewriterState &curState);
1514 
1515   /// Legalizes the actions registered during the execution of a pattern.
1516   LogicalResult legalizePatternBlockActions(Operation *op,
1517                                             ConversionPatternRewriter &rewriter,
1518                                             ConversionPatternRewriterImpl &impl,
1519                                             RewriterState &state,
1520                                             RewriterState &newState);
1521   LogicalResult legalizePatternCreatedOperations(
1522       ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1523       RewriterState &state, RewriterState &newState);
1524   LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1525                                            ConversionPatternRewriterImpl &impl,
1526                                            RewriterState &state,
1527                                            RewriterState &newState);
1528 
1529   //===--------------------------------------------------------------------===//
1530   // Cost Model
1531   //===--------------------------------------------------------------------===//
1532 
1533   /// Build an optimistic legalization graph given the provided patterns. This
1534   /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
1535   /// patterns for operations that are not directly legal, but may be
1536   /// transitively legal for the current target given the provided patterns.
1537   void buildLegalizationGraph(
1538       LegalizationPatterns &anyOpLegalizerPatterns,
1539       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1540 
1541   /// Compute the benefit of each node within the computed legalization graph.
1542   /// This orders the patterns within 'legalizerPatterns' based upon two
1543   /// criteria:
1544   ///  1) Prefer patterns that have the lowest legalization depth, i.e.
1545   ///     represent the more direct mapping to the target.
1546   ///  2) When comparing patterns with the same legalization depth, prefer the
1547   ///     pattern with the highest PatternBenefit. This allows for users to
1548   ///     prefer specific legalizations over others.
1549   void computeLegalizationGraphBenefit(
1550       LegalizationPatterns &anyOpLegalizerPatterns,
1551       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1552 
1553   /// Compute the legalization depth when legalizing an operation of the given
1554   /// type.
1555   unsigned computeOpLegalizationDepth(
1556       OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
1557       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1558 
1559   /// Apply the conversion cost model to the given set of patterns, and return
1560   /// the smallest legalization depth of any of the patterns. See
1561   /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
1562   unsigned applyCostModelToPatterns(
1563       LegalizationPatterns &patterns,
1564       DenseMap<OperationName, unsigned> &minOpPatternDepth,
1565       DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
1566 
1567   /// The current set of patterns that have been applied.
1568   SmallPtrSet<const Pattern *, 8> appliedPatterns;
1569 
1570   /// The legalization information provided by the target.
1571   ConversionTarget &target;
1572 
1573   /// The pattern applicator to use for conversions.
1574   PatternApplicator applicator;
1575 };
1576 } // namespace
1577 
1578 OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
1579                                        const FrozenRewritePatternList &patterns)
1580     : target(targetInfo), applicator(patterns) {
1581   // The set of patterns that can be applied to illegal operations to transform
1582   // them into legal ones.
1583   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
1584   LegalizationPatterns anyOpLegalizerPatterns;
1585 
1586   buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1587   computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1588 }
1589 
1590 bool OperationLegalizer::isIllegal(Operation *op) const {
1591   // Check if the target explicitly marked this operation as illegal.
1592   return target.getOpAction(op->getName()) == LegalizationAction::Illegal;
1593 }
1594 
1595 LogicalResult
1596 OperationLegalizer::legalize(Operation *op,
1597                              ConversionPatternRewriter &rewriter) {
1598 #ifndef NDEBUG
1599   const char *logLineComment =
1600       "//===-------------------------------------------===//\n";
1601 
1602   auto &rewriterImpl = rewriter.getImpl();
1603 #endif
1604   LLVM_DEBUG({
1605     auto &os = rewriterImpl.logger;
1606     os.getOStream() << "\n";
1607     os.startLine() << logLineComment;
1608     os.startLine() << "Legalizing operation : '" << op->getName() << "'(" << op
1609                    << ") {\n";
1610     os.indent();
1611 
1612     // If the operation has no regions, just print it here.
1613     if (op->getNumRegions() == 0) {
1614       op->print(os.startLine(), OpPrintingFlags().printGenericOpForm());
1615       os.getOStream() << "\n\n";
1616     }
1617   });
1618 
1619   // Check if this operation is legal on the target.
1620   if (auto legalityInfo = target.isLegal(op)) {
1621     LLVM_DEBUG({
1622       logSuccess(
1623           rewriterImpl.logger, "operation marked legal by the target{0}",
1624           legalityInfo->isRecursivelyLegal
1625               ? "; NOTE: operation is recursively legal; skipping internals"
1626               : "");
1627       rewriterImpl.logger.startLine() << logLineComment;
1628     });
1629 
1630     // If this operation is recursively legal, mark its children as ignored so
1631     // that we don't consider them for legalization.
1632     if (legalityInfo->isRecursivelyLegal)
1633       rewriter.getImpl().markNestedOpsIgnored(op);
1634     return success();
1635   }
1636 
1637   // Check to see if the operation is ignored and doesn't need to be converted.
1638   if (rewriter.getImpl().isOpIgnored(op)) {
1639     LLVM_DEBUG({
1640       logSuccess(rewriterImpl.logger,
1641                  "operation marked 'ignored' during conversion");
1642       rewriterImpl.logger.startLine() << logLineComment;
1643     });
1644     return success();
1645   }
1646 
1647   // If the operation isn't legal, try to fold it in-place.
1648   // TODO: Should we always try to do this, even if the op is
1649   // already legal?
1650   if (succeeded(legalizeWithFold(op, rewriter))) {
1651     LLVM_DEBUG({
1652       logSuccess(rewriterImpl.logger, "operation was folded");
1653       rewriterImpl.logger.startLine() << logLineComment;
1654     });
1655     return success();
1656   }
1657 
1658   // Otherwise, we need to apply a legalization pattern to this operation.
1659   if (succeeded(legalizeWithPattern(op, rewriter))) {
1660     LLVM_DEBUG({
1661       logSuccess(rewriterImpl.logger, "");
1662       rewriterImpl.logger.startLine() << logLineComment;
1663     });
1664     return success();
1665   }
1666 
1667   LLVM_DEBUG({
1668     logFailure(rewriterImpl.logger, "no matched legalization pattern");
1669     rewriterImpl.logger.startLine() << logLineComment;
1670   });
1671   return failure();
1672 }
1673 
1674 LogicalResult
1675 OperationLegalizer::legalizeWithFold(Operation *op,
1676                                      ConversionPatternRewriter &rewriter) {
1677   auto &rewriterImpl = rewriter.getImpl();
1678   RewriterState curState = rewriterImpl.getCurrentState();
1679 
1680   LLVM_DEBUG({
1681     rewriterImpl.logger.startLine() << "* Fold {\n";
1682     rewriterImpl.logger.indent();
1683   });
1684 
1685   // Try to fold the operation.
1686   SmallVector<Value, 2> replacementValues;
1687   rewriter.setInsertionPoint(op);
1688   if (failed(rewriter.tryFold(op, replacementValues))) {
1689     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
1690     return failure();
1691   }
1692 
1693   // Insert a replacement for 'op' with the folded replacement values.
1694   rewriter.replaceOp(op, replacementValues);
1695 
1696   // Recursively legalize any new constant operations.
1697   for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1698        i != e; ++i) {
1699     Operation *cstOp = rewriterImpl.createdOps[i];
1700     if (failed(legalize(cstOp, rewriter))) {
1701       LLVM_DEBUG(logFailure(rewriterImpl.logger,
1702                             "generated constant '{0}' was illegal",
1703                             cstOp->getName()));
1704       rewriterImpl.resetState(curState);
1705       return failure();
1706     }
1707   }
1708 
1709   LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
1710   return success();
1711 }
1712 
1713 LogicalResult
1714 OperationLegalizer::legalizeWithPattern(Operation *op,
1715                                         ConversionPatternRewriter &rewriter) {
1716   auto &rewriterImpl = rewriter.getImpl();
1717 
1718   // Functor that returns if the given pattern may be applied.
1719   auto canApply = [&](const Pattern &pattern) {
1720     return canApplyPattern(op, pattern, rewriter);
1721   };
1722 
1723   // Functor that cleans up the rewriter state after a pattern failed to match.
1724   RewriterState curState = rewriterImpl.getCurrentState();
1725   auto onFailure = [&](const Pattern &pattern) {
1726     LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
1727     rewriterImpl.resetState(curState);
1728     appliedPatterns.erase(&pattern);
1729   };
1730 
1731   // Functor that performs additional legalization when a pattern is
1732   // successfully applied.
1733   auto onSuccess = [&](const Pattern &pattern) {
1734     auto result = legalizePatternResult(op, pattern, rewriter, curState);
1735     appliedPatterns.erase(&pattern);
1736     if (failed(result))
1737       rewriterImpl.resetState(curState);
1738     return result;
1739   };
1740 
1741   // Try to match and rewrite a pattern on this operation.
1742   return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1743                                     onSuccess);
1744 }
1745 
1746 bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
1747                                          ConversionPatternRewriter &rewriter) {
1748   LLVM_DEBUG({
1749     auto &os = rewriter.getImpl().logger;
1750     os.getOStream() << "\n";
1751     os.startLine() << "* Pattern : '" << op->getName() << " -> (";
1752     llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs());
1753     os.getOStream() << ")' {\n";
1754     os.indent();
1755   });
1756 
1757   // Ensure that we don't cycle by not allowing the same pattern to be
1758   // applied twice in the same recursion stack if it is not known to be safe.
1759   if (!pattern.hasBoundedRewriteRecursion() &&
1760       !appliedPatterns.insert(&pattern).second) {
1761     LLVM_DEBUG(
1762         logFailure(rewriter.getImpl().logger, "pattern was already applied"));
1763     return false;
1764   }
1765   return true;
1766 }
1767 
1768 LogicalResult
1769 OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
1770                                           ConversionPatternRewriter &rewriter,
1771                                           RewriterState &curState) {
1772   auto &impl = rewriter.getImpl();
1773 
1774 #ifndef NDEBUG
1775   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
1776 #endif
1777 
1778   // Check that the root was either replaced or updated in place.
1779   auto replacedRoot = [&] {
1780     return llvm::any_of(
1781         llvm::drop_begin(impl.replacements, curState.numReplacements),
1782         [op](auto &it) { return it.first == op; });
1783   };
1784   auto updatedRootInPlace = [&] {
1785     return llvm::any_of(
1786         llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
1787         [op](auto &state) { return state.getOperation() == op; });
1788   };
1789   (void)replacedRoot;
1790   (void)updatedRootInPlace;
1791   assert((replacedRoot() || updatedRootInPlace()) &&
1792          "expected pattern to replace the root operation");
1793 
1794   // Legalize each of the actions registered during application.
1795   RewriterState newState = impl.getCurrentState();
1796   if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
1797                                          newState)) ||
1798       failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
1799       failed(legalizePatternCreatedOperations(rewriter, impl, curState,
1800                                               newState))) {
1801     return failure();
1802   }
1803 
1804   LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
1805   return success();
1806 }
1807 
1808 LogicalResult OperationLegalizer::legalizePatternBlockActions(
1809     Operation *op, ConversionPatternRewriter &rewriter,
1810     ConversionPatternRewriterImpl &impl, RewriterState &state,
1811     RewriterState &newState) {
1812   SmallPtrSet<Operation *, 16> operationsToIgnore;
1813 
1814   // If the pattern moved or created any blocks, make sure the types of block
1815   // arguments get legalized.
1816   for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
1817        ++i) {
1818     auto &action = impl.blockActions[i];
1819     if (action.kind == BlockActionKind::TypeConversion ||
1820         action.kind == BlockActionKind::Erase)
1821       continue;
1822     // Only check blocks outside of the current operation.
1823     Operation *parentOp = action.block->getParentOp();
1824     if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
1825       continue;
1826 
1827     // If the region of the block has a type converter, try to convert the block
1828     // directly.
1829     if (auto *converter =
1830             impl.argConverter.getConverter(action.block->getParent())) {
1831       if (failed(impl.convertBlockSignature(action.block, *converter))) {
1832         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
1833                                            "block"));
1834         return failure();
1835       }
1836       continue;
1837     }
1838 
1839     // Otherwise, check that this operation isn't one generated by this pattern.
1840     // This is because we will attempt to legalize the parent operation, and
1841     // blocks in regions created by this pattern will already be legalized later
1842     // on. If we haven't built the set yet, build it now.
1843     if (operationsToIgnore.empty()) {
1844       auto createdOps = ArrayRef<Operation *>(impl.createdOps)
1845                             .drop_front(state.numCreatedOps);
1846       operationsToIgnore.insert(createdOps.begin(), createdOps.end());
1847     }
1848 
1849     // If this operation should be considered for re-legalization, try it.
1850     if (operationsToIgnore.insert(parentOp).second &&
1851         failed(legalize(parentOp, rewriter))) {
1852       LLVM_DEBUG(logFailure(
1853           impl.logger, "operation '{0}'({1}) became illegal after block action",
1854           parentOp->getName(), parentOp));
1855       return failure();
1856     }
1857   }
1858   return success();
1859 }
1860 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
1861     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1862     RewriterState &state, RewriterState &newState) {
1863   for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
1864     Operation *op = impl.createdOps[i];
1865     if (failed(legalize(op, rewriter))) {
1866       LLVM_DEBUG(logFailure(impl.logger,
1867                             "generated operation '{0}'({1}) was illegal",
1868                             op->getName(), op));
1869       return failure();
1870     }
1871   }
1872   return success();
1873 }
1874 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
1875     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1876     RewriterState &state, RewriterState &newState) {
1877   for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
1878     Operation *op = impl.rootUpdates[i].getOperation();
1879     if (failed(legalize(op, rewriter))) {
1880       LLVM_DEBUG(logFailure(impl.logger,
1881                             "operation updated in-place '{0}' was illegal",
1882                             op->getName()));
1883       return failure();
1884     }
1885   }
1886   return success();
1887 }
1888 
1889 //===----------------------------------------------------------------------===//
1890 // Cost Model
1891 
1892 void OperationLegalizer::buildLegalizationGraph(
1893     LegalizationPatterns &anyOpLegalizerPatterns,
1894     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1895   // A mapping between an operation and a set of operations that can be used to
1896   // generate it.
1897   DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
1898   // A mapping between an operation and any currently invalid patterns it has.
1899   DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
1900   // A worklist of patterns to consider for legality.
1901   llvm::SetVector<const Pattern *> patternWorklist;
1902 
1903   // Build the mapping from operations to the parent ops that may generate them.
1904   applicator.walkAllPatterns([&](const Pattern &pattern) {
1905     Optional<OperationName> root = pattern.getRootKind();
1906 
1907     // If the pattern has no specific root, we can't analyze the relationship
1908     // between the root op and generated operations. Given that, add all such
1909     // patterns to the legalization set.
1910     if (!root) {
1911       anyOpLegalizerPatterns.push_back(&pattern);
1912       return;
1913     }
1914 
1915     // Skip operations that are always known to be legal.
1916     if (target.getOpAction(*root) == LegalizationAction::Legal)
1917       return;
1918 
1919     // Add this pattern to the invalid set for the root op and record this root
1920     // as a parent for any generated operations.
1921     invalidPatterns[*root].insert(&pattern);
1922     for (auto op : pattern.getGeneratedOps())
1923       parentOps[op].insert(*root);
1924 
1925     // Add this pattern to the worklist.
1926     patternWorklist.insert(&pattern);
1927   });
1928 
1929   // If there are any patterns that don't have a specific root kind, we can't
1930   // make direct assumptions about what operations will never be legalized.
1931   // Note: Technically we could, but it would require an analysis that may
1932   // recurse into itself. It would be better to perform this kind of filtering
1933   // at a higher level than here anyways.
1934   if (!anyOpLegalizerPatterns.empty()) {
1935     for (const Pattern *pattern : patternWorklist)
1936       legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1937     return;
1938   }
1939 
1940   while (!patternWorklist.empty()) {
1941     auto *pattern = patternWorklist.pop_back_val();
1942 
1943     // Check to see if any of the generated operations are invalid.
1944     if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
1945           Optional<LegalizationAction> action = target.getOpAction(op);
1946           return !legalizerPatterns.count(op) &&
1947                  (!action || action == LegalizationAction::Illegal);
1948         }))
1949       continue;
1950 
1951     // Otherwise, if all of the generated operation are valid, this op is now
1952     // legal so add all of the child patterns to the worklist.
1953     legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
1954     invalidPatterns[*pattern->getRootKind()].erase(pattern);
1955 
1956     // Add any invalid patterns of the parent operations to see if they have now
1957     // become legal.
1958     for (auto op : parentOps[*pattern->getRootKind()])
1959       patternWorklist.set_union(invalidPatterns[op]);
1960   }
1961 }
1962 
1963 void OperationLegalizer::computeLegalizationGraphBenefit(
1964     LegalizationPatterns &anyOpLegalizerPatterns,
1965     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
1966   // The smallest pattern depth, when legalizing an operation.
1967   DenseMap<OperationName, unsigned> minOpPatternDepth;
1968 
1969   // For each operation that is transitively legal, compute a cost for it.
1970   for (auto &opIt : legalizerPatterns)
1971     if (!minOpPatternDepth.count(opIt.first))
1972       computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
1973                                  legalizerPatterns);
1974 
1975   // Apply the cost model to the patterns that can match any operation. Those
1976   // with a specific operation type are already resolved when computing the op
1977   // legalization depth.
1978   if (!anyOpLegalizerPatterns.empty())
1979     applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
1980                              legalizerPatterns);
1981 
1982   // Apply a cost model to the pattern applicator. We order patterns first by
1983   // depth then benefit. `legalizerPatterns` contains per-op patterns by
1984   // decreasing benefit.
1985   applicator.applyCostModel([&](const Pattern &pattern) {
1986     ArrayRef<const Pattern *> orderedPatternList;
1987     if (Optional<OperationName> rootName = pattern.getRootKind())
1988       orderedPatternList = legalizerPatterns[*rootName];
1989     else
1990       orderedPatternList = anyOpLegalizerPatterns;
1991 
1992     // If the pattern is not found, then it was removed and cannot be matched.
1993     auto it = llvm::find(orderedPatternList, &pattern);
1994     if (it == orderedPatternList.end())
1995       return PatternBenefit::impossibleToMatch();
1996 
1997     // Patterns found earlier in the list have higher benefit.
1998     return PatternBenefit(std::distance(it, orderedPatternList.end()));
1999   });
2000 }
2001 
2002 unsigned OperationLegalizer::computeOpLegalizationDepth(
2003     OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2004     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2005   // Check for existing depth.
2006   auto depthIt = minOpPatternDepth.find(op);
2007   if (depthIt != minOpPatternDepth.end())
2008     return depthIt->second;
2009 
2010   // If a mapping for this operation does not exist, then this operation
2011   // is always legal. Return 0 as the depth for a directly legal operation.
2012   auto opPatternsIt = legalizerPatterns.find(op);
2013   if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2014     return 0u;
2015 
2016   // Record this initial depth in case we encounter this op again when
2017   // recursively computing the depth.
2018   minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
2019 
2020   // Apply the cost model to the operation patterns, and update the minimum
2021   // depth.
2022   unsigned minDepth = applyCostModelToPatterns(
2023       opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2024   minOpPatternDepth[op] = minDepth;
2025   return minDepth;
2026 }
2027 
2028 unsigned OperationLegalizer::applyCostModelToPatterns(
2029     LegalizationPatterns &patterns,
2030     DenseMap<OperationName, unsigned> &minOpPatternDepth,
2031     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
2032   unsigned minDepth = std::numeric_limits<unsigned>::max();
2033 
2034   // Compute the depth for each pattern within the set.
2035   SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2036   patternsByDepth.reserve(patterns.size());
2037   for (const Pattern *pattern : patterns) {
2038     unsigned depth = 0;
2039     for (auto generatedOp : pattern->getGeneratedOps()) {
2040       unsigned generatedOpDepth = computeOpLegalizationDepth(
2041           generatedOp, minOpPatternDepth, legalizerPatterns);
2042       depth = std::max(depth, generatedOpDepth + 1);
2043     }
2044     patternsByDepth.emplace_back(pattern, depth);
2045 
2046     // Update the minimum depth of the pattern list.
2047     minDepth = std::min(minDepth, depth);
2048   }
2049 
2050   // If the operation only has one legalization pattern, there is no need to
2051   // sort them.
2052   if (patternsByDepth.size() == 1)
2053     return minDepth;
2054 
2055   // Sort the patterns by those likely to be the most beneficial.
2056   llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
2057                        [](const std::pair<const Pattern *, unsigned> *lhs,
2058                           const std::pair<const Pattern *, unsigned> *rhs) {
2059                          // First sort by the smaller pattern legalization
2060                          // depth.
2061                          if (lhs->second != rhs->second)
2062                            return llvm::array_pod_sort_comparator<unsigned>(
2063                                &lhs->second, &rhs->second);
2064 
2065                          // Then sort by the larger pattern benefit.
2066                          auto lhsBenefit = lhs->first->getBenefit();
2067                          auto rhsBenefit = rhs->first->getBenefit();
2068                          return llvm::array_pod_sort_comparator<PatternBenefit>(
2069                              &rhsBenefit, &lhsBenefit);
2070                        });
2071 
2072   // Update the legalization pattern to use the new sorted list.
2073   patterns.clear();
2074   for (auto &patternIt : patternsByDepth)
2075     patterns.push_back(patternIt.first);
2076   return minDepth;
2077 }
2078 
2079 //===----------------------------------------------------------------------===//
2080 // OperationConverter
2081 //===----------------------------------------------------------------------===//
2082 namespace {
2083 enum OpConversionMode {
2084   // In this mode, the conversion will ignore failed conversions to allow
2085   // illegal operations to co-exist in the IR.
2086   Partial,
2087 
2088   // In this mode, all operations must be legal for the given target for the
2089   // conversion to succeed.
2090   Full,
2091 
2092   // In this mode, operations are analyzed for legality. No actual rewrites are
2093   // applied to the operations on success.
2094   Analysis,
2095 };
2096 
2097 // This class converts operations to a given conversion target via a set of
2098 // rewrite patterns. The conversion behaves differently depending on the
2099 // conversion mode.
2100 struct OperationConverter {
2101   explicit OperationConverter(ConversionTarget &target,
2102                               const FrozenRewritePatternList &patterns,
2103                               OpConversionMode mode,
2104                               DenseSet<Operation *> *trackedOps = nullptr)
2105       : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2106 
2107   /// Converts the given operations to the conversion target.
2108   LogicalResult convertOperations(ArrayRef<Operation *> ops);
2109 
2110 private:
2111   /// Converts an operation with the given rewriter.
2112   LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
2113 
2114   /// This method is called after the conversion process to legalize any
2115   /// remaining artifacts and complete the conversion.
2116   LogicalResult finalize(ConversionPatternRewriter &rewriter);
2117 
2118   /// Legalize the types of converted block arguments.
2119   LogicalResult
2120   legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2121                                  ConversionPatternRewriterImpl &rewriterImpl);
2122 
2123   /// Legalize an operation result that was marked as "erased".
2124   LogicalResult
2125   legalizeErasedResult(Operation *op, OpResult result,
2126                        ConversionPatternRewriterImpl &rewriterImpl);
2127 
2128   /// Legalize an operation result that was replaced with a value of a different
2129   /// type.
2130   LogicalResult
2131   legalizeChangedResultType(Operation *op, OpResult result, Value newValue,
2132                             TypeConverter *replConverter,
2133                             ConversionPatternRewriter &rewriter,
2134                             ConversionPatternRewriterImpl &rewriterImpl);
2135 
2136   /// The legalizer to use when converting operations.
2137   OperationLegalizer opLegalizer;
2138 
2139   /// The conversion mode to use when legalizing operations.
2140   OpConversionMode mode;
2141 
2142   /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2143   /// this is populated with ops found to be legalizable to the target.
2144   /// When mode == OpConversionMode::Partial, this is populated with ops found
2145   /// *not* to be legalizable to the target.
2146   DenseSet<Operation *> *trackedOps;
2147 };
2148 } // end anonymous namespace
2149 
2150 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2151                                           Operation *op) {
2152   // Legalize the given operation.
2153   if (failed(opLegalizer.legalize(op, rewriter))) {
2154     // Handle the case of a failed conversion for each of the different modes.
2155     // Full conversions expect all operations to be converted.
2156     if (mode == OpConversionMode::Full)
2157       return op->emitError()
2158              << "failed to legalize operation '" << op->getName() << "'";
2159     // Partial conversions allow conversions to fail iff the operation was not
2160     // explicitly marked as illegal. If the user provided a nonlegalizableOps
2161     // set, non-legalizable ops are included.
2162     if (mode == OpConversionMode::Partial) {
2163       if (opLegalizer.isIllegal(op))
2164         return op->emitError()
2165                << "failed to legalize operation '" << op->getName()
2166                << "' that was explicitly marked illegal";
2167       if (trackedOps)
2168         trackedOps->insert(op);
2169     }
2170   } else if (mode == OpConversionMode::Analysis) {
2171     // Analysis conversions don't fail if any operations fail to legalize,
2172     // they are only interested in the operations that were successfully
2173     // legalized.
2174     trackedOps->insert(op);
2175   }
2176   return success();
2177 }
2178 
2179 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2180   if (ops.empty())
2181     return success();
2182   ConversionTarget &target = opLegalizer.getTarget();
2183 
2184   // Compute the set of operations and blocks to convert.
2185   std::vector<Operation *> toConvert;
2186   for (auto *op : ops) {
2187     toConvert.emplace_back(op);
2188     for (auto &region : op->getRegions())
2189       if (failed(computeConversionSet(region.getBlocks(), region.getLoc(),
2190                                       toConvert, &target)))
2191         return failure();
2192   }
2193 
2194   // Convert each operation and discard rewrites on failure.
2195   ConversionPatternRewriter rewriter(ops.front()->getContext());
2196   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2197   for (auto *op : toConvert)
2198     if (failed(convert(rewriter, op)))
2199       return rewriterImpl.discardRewrites(), failure();
2200 
2201   // Now that all of the operations have been converted, finalize the conversion
2202   // process to ensure any lingering conversion artifacts are cleaned up and
2203   // legalized.
2204   if (failed(finalize(rewriter)))
2205     return rewriterImpl.discardRewrites(), failure();
2206 
2207   // After a successful conversion, apply rewrites if this is not an analysis
2208   // conversion.
2209   if (mode == OpConversionMode::Analysis)
2210     rewriterImpl.discardRewrites();
2211   else
2212     rewriterImpl.applyRewrites();
2213   return success();
2214 }
2215 
2216 LogicalResult
2217 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2218   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2219 
2220   // Legalize converted block arguments.
2221   if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2222     return failure();
2223 
2224   // Process requested operation replacements.
2225   for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
2226        i != e; ++i) {
2227     unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
2228     auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
2229     for (OpResult result : repl.first->getResults()) {
2230       Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2231 
2232       // If the operation result was replaced with null, all of the uses of this
2233       // value should be replaced.
2234       if (!newValue) {
2235         if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2236           return failure();
2237         continue;
2238       }
2239 
2240       // Otherwise, check to see if the type of the result changed.
2241       if (result.getType() == newValue.getType())
2242         continue;
2243 
2244       // Legalize this result.
2245       rewriter.setInsertionPoint(repl.first);
2246       if (failed(legalizeChangedResultType(repl.first, result, newValue,
2247                                            repl.second.converter, rewriter,
2248                                            rewriterImpl)))
2249         return failure();
2250 
2251       // Update the end iterator for this loop in the case it was updated
2252       // when legalizing generated conversion operations.
2253       e = rewriterImpl.operationsWithChangedResults.size();
2254     }
2255   }
2256   return success();
2257 }
2258 
2259 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2260     ConversionPatternRewriter &rewriter,
2261     ConversionPatternRewriterImpl &rewriterImpl) {
2262   // Functor used to check if all users of a value will be dead after
2263   // conversion.
2264   auto findLiveUser = [&](Value val) {
2265     auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2266       return rewriterImpl.isOpIgnored(user);
2267     });
2268     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2269   };
2270 
2271   // Materialize any necessary conversions for converted block arguments that
2272   // are still live.
2273   size_t numCreatedOps = rewriterImpl.createdOps.size();
2274   if (failed(rewriterImpl.argConverter.materializeLiveConversions(
2275           rewriterImpl.mapping, rewriter, findLiveUser)))
2276     return failure();
2277 
2278   // Legalize any newly created operations during argument materialization.
2279   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2280     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2281       return rewriterImpl.createdOps[i]->emitError()
2282              << "failed to legalize conversion operation generated for block "
2283                 "argument that remained live after conversion";
2284     }
2285   }
2286   return success();
2287 }
2288 
2289 LogicalResult OperationConverter::legalizeErasedResult(
2290     Operation *op, OpResult result,
2291     ConversionPatternRewriterImpl &rewriterImpl) {
2292   // If the operation result was replaced with null, all of the uses of this
2293   // value should be replaced.
2294   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2295     return rewriterImpl.isOpIgnored(user);
2296   });
2297   if (liveUserIt != result.user_end()) {
2298     InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
2299                               << op->getName() << "' marked as erased";
2300     diag.attachNote(liveUserIt->getLoc())
2301         << "found live user of result #" << result.getResultNumber() << ": "
2302         << *liveUserIt;
2303     return failure();
2304   }
2305   return success();
2306 }
2307 
2308 LogicalResult OperationConverter::legalizeChangedResultType(
2309     Operation *op, OpResult result, Value newValue,
2310     TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
2311     ConversionPatternRewriterImpl &rewriterImpl) {
2312   // Walk the users of this value to see if there are any live users that
2313   // weren't replaced during conversion.
2314   auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
2315     return rewriterImpl.isOpIgnored(user);
2316   });
2317   if (liveUserIt == result.user_end())
2318     return success();
2319 
2320   // If the replacement has a type converter, attempt to materialize a
2321   // conversion back to the original type.
2322   if (!replConverter) {
2323     // TODO: We should emit an error here, similarly to the case where the
2324     // result is replaced with null. Unfortunately a lot of existing
2325     // patterns rely on this behavior, so until those patterns are updated
2326     // we keep the legacy behavior here of just forwarding the new value.
2327     return success();
2328   }
2329 
2330   // Track the number of created operations so that new ones can be legalized.
2331   size_t numCreatedOps = rewriterImpl.createdOps.size();
2332 
2333   // Materialize a conversion for this live result value.
2334   Type resultType = result.getType();
2335   Value convertedValue = replConverter->materializeSourceConversion(
2336       rewriter, op->getLoc(), resultType, newValue);
2337   if (!convertedValue) {
2338     InFlightDiagnostic diag = op->emitError()
2339                               << "failed to materialize conversion for result #"
2340                               << result.getResultNumber() << " of operation '"
2341                               << op->getName()
2342                               << "' that remained live after conversion";
2343     diag.attachNote(liveUserIt->getLoc())
2344         << "see existing live user here: " << *liveUserIt;
2345     return failure();
2346   }
2347 
2348   // Legalize all of the newly created conversion operations.
2349   for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
2350     if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
2351       return op->emitError("failed to legalize conversion operation generated ")
2352              << "for result #" << result.getResultNumber() << " of operation '"
2353              << op->getName() << "' that remained live after conversion";
2354     }
2355   }
2356 
2357   rewriterImpl.mapping.map(result, convertedValue);
2358   return success();
2359 }
2360 
2361 //===----------------------------------------------------------------------===//
2362 // Type Conversion
2363 //===----------------------------------------------------------------------===//
2364 
2365 /// Remap an input of the original signature with a new set of types. The
2366 /// new types are appended to the new signature conversion.
2367 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
2368                                                    ArrayRef<Type> types) {
2369   assert(!types.empty() && "expected valid types");
2370   remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
2371   addInputs(types);
2372 }
2373 
2374 /// Append new input types to the signature conversion, this should only be
2375 /// used if the new types are not intended to remap an existing input.
2376 void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
2377   assert(!types.empty() &&
2378          "1->0 type remappings don't need to be added explicitly");
2379   argTypes.append(types.begin(), types.end());
2380 }
2381 
2382 /// Remap an input of the original signature with a range of types in the
2383 /// new signature.
2384 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2385                                                     unsigned newInputNo,
2386                                                     unsigned newInputCount) {
2387   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2388   assert(newInputCount != 0 && "expected valid input count");
2389   remappedInputs[origInputNo] =
2390       InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2391 }
2392 
2393 /// Remap an input of the original signature to another `replacementValue`
2394 /// value. This would make the signature converter drop this argument.
2395 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2396                                                     Value replacementValue) {
2397   assert(!remappedInputs[origInputNo] && "input has already been remapped");
2398   remappedInputs[origInputNo] =
2399       InputMapping{origInputNo, /*size=*/0, replacementValue};
2400 }
2401 
2402 /// This hooks allows for converting a type.
2403 LogicalResult TypeConverter::convertType(Type t,
2404                                          SmallVectorImpl<Type> &results) {
2405   auto existingIt = cachedDirectConversions.find(t);
2406   if (existingIt != cachedDirectConversions.end()) {
2407     if (existingIt->second)
2408       results.push_back(existingIt->second);
2409     return success(existingIt->second != nullptr);
2410   }
2411   auto multiIt = cachedMultiConversions.find(t);
2412   if (multiIt != cachedMultiConversions.end()) {
2413     results.append(multiIt->second.begin(), multiIt->second.end());
2414     return success();
2415   }
2416 
2417   // Walk the added converters in reverse order to apply the most recently
2418   // registered first.
2419   size_t currentCount = results.size();
2420   for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2421     if (Optional<LogicalResult> result = converter(t, results)) {
2422       if (!succeeded(*result)) {
2423         cachedDirectConversions.try_emplace(t, nullptr);
2424         return failure();
2425       }
2426       auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
2427       if (newTypes.size() == 1)
2428         cachedDirectConversions.try_emplace(t, newTypes.front());
2429       else
2430         cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2431       return success();
2432     }
2433   }
2434   return failure();
2435 }
2436 
2437 /// This hook simplifies defining 1-1 type conversions. This function returns
2438 /// the type to convert to on success, and a null type on failure.
2439 Type TypeConverter::convertType(Type t) {
2440   // Use the multi-type result version to convert the type.
2441   SmallVector<Type, 1> results;
2442   if (failed(convertType(t, results)))
2443     return nullptr;
2444 
2445   // Check to ensure that only one type was produced.
2446   return results.size() == 1 ? results.front() : nullptr;
2447 }
2448 
2449 /// Convert the given set of types, filling 'results' as necessary. This
2450 /// returns failure if the conversion of any of the types fails, success
2451 /// otherwise.
2452 LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
2453                                           SmallVectorImpl<Type> &results) {
2454   for (auto type : types)
2455     if (failed(convertType(type, results)))
2456       return failure();
2457   return success();
2458 }
2459 
2460 /// Return true if the given type is legal for this type converter, i.e. the
2461 /// type converts to itself.
2462 bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
2463 /// Return true if the given operation has legal operand and result types.
2464 bool TypeConverter::isLegal(Operation *op) {
2465   return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
2466 }
2467 
2468 /// Return true if the types of block arguments within the region are legal.
2469 bool TypeConverter::isLegal(Region *region) {
2470   return llvm::all_of(*region, [this](Block &block) {
2471     return isLegal(block.getArgumentTypes());
2472   });
2473 }
2474 
2475 /// Return true if the inputs and outputs of the given function type are
2476 /// legal.
2477 bool TypeConverter::isSignatureLegal(FunctionType ty) {
2478   return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2479 }
2480 
2481 /// This hook allows for converting a specific argument of a signature.
2482 LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
2483                                                  SignatureConversion &result) {
2484   // Try to convert the given input type.
2485   SmallVector<Type, 1> convertedTypes;
2486   if (failed(convertType(type, convertedTypes)))
2487     return failure();
2488 
2489   // If this argument is being dropped, there is nothing left to do.
2490   if (convertedTypes.empty())
2491     return success();
2492 
2493   // Otherwise, add the new inputs.
2494   result.addInputs(inputNo, convertedTypes);
2495   return success();
2496 }
2497 LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
2498                                                   SignatureConversion &result,
2499                                                   unsigned origInputOffset) {
2500   for (unsigned i = 0, e = types.size(); i != e; ++i)
2501     if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
2502       return failure();
2503   return success();
2504 }
2505 
2506 Value TypeConverter::materializeConversion(
2507     MutableArrayRef<MaterializationCallbackFn> materializations,
2508     OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
2509   for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
2510     if (Optional<Value> result = fn(builder, resultType, inputs, loc))
2511       return result.getValue();
2512   return nullptr;
2513 }
2514 
2515 /// This function converts the type signature of the given block, by invoking
2516 /// 'convertSignatureArg' for each argument. This function should return a valid
2517 /// conversion for the signature on success, None otherwise.
2518 auto TypeConverter::convertBlockSignature(Block *block)
2519     -> Optional<SignatureConversion> {
2520   SignatureConversion conversion(block->getNumArguments());
2521   if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
2522     return llvm::None;
2523   return conversion;
2524 }
2525 
2526 /// Create a default conversion pattern that rewrites the type signature of a
2527 /// FunctionLike op. This only supports FunctionLike ops which use FunctionType
2528 /// to represent their type.
2529 namespace {
2530 struct FunctionLikeSignatureConversion : public ConversionPattern {
2531   FunctionLikeSignatureConversion(StringRef functionLikeOpName,
2532                                   MLIRContext *ctx, TypeConverter &converter)
2533       : ConversionPattern(functionLikeOpName, /*benefit=*/1, converter, ctx) {}
2534 
2535   /// Hook to implement combined matching and rewriting for FunctionLike ops.
2536   LogicalResult
2537   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
2538                   ConversionPatternRewriter &rewriter) const override {
2539     FunctionType type = mlir::impl::getFunctionType(op);
2540 
2541     // Convert the original function types.
2542     TypeConverter::SignatureConversion result(type.getNumInputs());
2543     SmallVector<Type, 1> newResults;
2544     if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
2545         failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
2546         failed(rewriter.convertRegionTypes(&mlir::impl::getFunctionBody(op),
2547                                            *typeConverter, &result)))
2548       return failure();
2549 
2550     // Update the function signature in-place.
2551     auto newType = FunctionType::get(rewriter.getContext(),
2552                                      result.getConvertedTypes(), newResults);
2553 
2554     rewriter.updateRootInPlace(
2555         op, [&] { mlir::impl::setFunctionType(op, newType); });
2556 
2557     return success();
2558   }
2559 };
2560 } // end anonymous namespace
2561 
2562 void mlir::populateFunctionLikeTypeConversionPattern(
2563     StringRef functionLikeOpName, OwningRewritePatternList &patterns,
2564     MLIRContext *ctx, TypeConverter &converter) {
2565   patterns.insert<FunctionLikeSignatureConversion>(functionLikeOpName, ctx,
2566                                                    converter);
2567 }
2568 
2569 void mlir::populateFuncOpTypeConversionPattern(
2570     OwningRewritePatternList &patterns, MLIRContext *ctx,
2571     TypeConverter &converter) {
2572   populateFunctionLikeTypeConversionPattern<FuncOp>(patterns, ctx, converter);
2573 }
2574 
2575 //===----------------------------------------------------------------------===//
2576 // ConversionTarget
2577 //===----------------------------------------------------------------------===//
2578 
2579 /// Register a legality action for the given operation.
2580 void ConversionTarget::setOpAction(OperationName op,
2581                                    LegalizationAction action) {
2582   legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None};
2583 }
2584 
2585 /// Register a legality action for the given dialects.
2586 void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
2587                                         LegalizationAction action) {
2588   for (StringRef dialect : dialectNames)
2589     legalDialects[dialect] = action;
2590 }
2591 
2592 /// Get the legality action for the given operation.
2593 auto ConversionTarget::getOpAction(OperationName op) const
2594     -> Optional<LegalizationAction> {
2595   Optional<LegalizationInfo> info = getOpInfo(op);
2596   return info ? info->action : Optional<LegalizationAction>();
2597 }
2598 
2599 /// If the given operation instance is legal on this target, a structure
2600 /// containing legality information is returned. If the operation is not legal,
2601 /// None is returned.
2602 auto ConversionTarget::isLegal(Operation *op) const
2603     -> Optional<LegalOpDetails> {
2604   Optional<LegalizationInfo> info = getOpInfo(op->getName());
2605   if (!info)
2606     return llvm::None;
2607 
2608   // Returns true if this operation instance is known to be legal.
2609   auto isOpLegal = [&] {
2610     // Handle dynamic legality either with the provided legality function, or
2611     // the default hook on the derived instance.
2612     if (info->action == LegalizationAction::Dynamic)
2613       return info->legalityFn ? (*info->legalityFn)(op)
2614                               : isDynamicallyLegal(op);
2615 
2616     // Otherwise, the operation is only legal if it was marked 'Legal'.
2617     return info->action == LegalizationAction::Legal;
2618   };
2619   if (!isOpLegal())
2620     return llvm::None;
2621 
2622   // This operation is legal, compute any additional legality information.
2623   LegalOpDetails legalityDetails;
2624   if (info->isRecursivelyLegal) {
2625     auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
2626     if (legalityFnIt != opRecursiveLegalityFns.end())
2627       legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
2628     else
2629       legalityDetails.isRecursivelyLegal = true;
2630   }
2631   return legalityDetails;
2632 }
2633 
2634 /// Set the dynamic legality callback for the given operation.
2635 void ConversionTarget::setLegalityCallback(
2636     OperationName name, const DynamicLegalityCallbackFn &callback) {
2637   assert(callback && "expected valid legality callback");
2638   auto infoIt = legalOperations.find(name);
2639   assert(infoIt != legalOperations.end() &&
2640          infoIt->second.action == LegalizationAction::Dynamic &&
2641          "expected operation to already be marked as dynamically legal");
2642   infoIt->second.legalityFn = callback;
2643 }
2644 
2645 /// Set the recursive legality callback for the given operation and mark the
2646 /// operation as recursively legal.
2647 void ConversionTarget::markOpRecursivelyLegal(
2648     OperationName name, const DynamicLegalityCallbackFn &callback) {
2649   auto infoIt = legalOperations.find(name);
2650   assert(infoIt != legalOperations.end() &&
2651          infoIt->second.action != LegalizationAction::Illegal &&
2652          "expected operation to already be marked as legal");
2653   infoIt->second.isRecursivelyLegal = true;
2654   if (callback)
2655     opRecursiveLegalityFns[name] = callback;
2656   else
2657     opRecursiveLegalityFns.erase(name);
2658 }
2659 
2660 /// Set the dynamic legality callback for the given dialects.
2661 void ConversionTarget::setLegalityCallback(
2662     ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
2663   assert(callback && "expected valid legality callback");
2664   for (StringRef dialect : dialects)
2665     dialectLegalityFns[dialect] = callback;
2666 }
2667 
2668 /// Get the legalization information for the given operation.
2669 auto ConversionTarget::getOpInfo(OperationName op) const
2670     -> Optional<LegalizationInfo> {
2671   // Check for info for this specific operation.
2672   auto it = legalOperations.find(op);
2673   if (it != legalOperations.end())
2674     return it->second;
2675   // Check for info for the parent dialect.
2676   auto dialectIt = legalDialects.find(op.getDialect());
2677   if (dialectIt != legalDialects.end()) {
2678     Optional<DynamicLegalityCallbackFn> callback;
2679     auto dialectFn = dialectLegalityFns.find(op.getDialect());
2680     if (dialectFn != dialectLegalityFns.end())
2681       callback = dialectFn->second;
2682     return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
2683                             callback};
2684   }
2685   // Otherwise, check if we mark unknown operations as dynamic.
2686   if (unknownOpsDynamicallyLegal)
2687     return LegalizationInfo{LegalizationAction::Dynamic,
2688                             /*isRecursivelyLegal=*/false, unknownLegalityFn};
2689   return llvm::None;
2690 }
2691 
2692 //===----------------------------------------------------------------------===//
2693 // Op Conversion Entry Points
2694 //===----------------------------------------------------------------------===//
2695 
2696 /// Apply a partial conversion on the given operations and all nested
2697 /// operations. This method converts as many operations to the target as
2698 /// possible, ignoring operations that failed to legalize. This method only
2699 /// returns failure if there ops explicitly marked as illegal.
2700 /// If an `unconvertedOps` set is provided, all operations that are found not
2701 /// to be legalizable to the given `target` are placed within that set. (Note
2702 /// that if there is an op explicitly marked as illegal, the conversion
2703 /// terminates and the `unconvertedOps` set will not necessarily be complete.)
2704 LogicalResult
2705 mlir::applyPartialConversion(ArrayRef<Operation *> ops,
2706                              ConversionTarget &target,
2707                              const FrozenRewritePatternList &patterns,
2708                              DenseSet<Operation *> *unconvertedOps) {
2709   OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
2710                                  unconvertedOps);
2711   return opConverter.convertOperations(ops);
2712 }
2713 LogicalResult
2714 mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
2715                              const FrozenRewritePatternList &patterns,
2716                              DenseSet<Operation *> *unconvertedOps) {
2717   return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
2718                                 unconvertedOps);
2719 }
2720 
2721 /// Apply a complete conversion on the given operations, and all nested
2722 /// operations. This method will return failure if the conversion of any
2723 /// operation fails.
2724 LogicalResult
2725 mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
2726                           const FrozenRewritePatternList &patterns) {
2727   OperationConverter opConverter(target, patterns, OpConversionMode::Full);
2728   return opConverter.convertOperations(ops);
2729 }
2730 LogicalResult
2731 mlir::applyFullConversion(Operation *op, ConversionTarget &target,
2732                           const FrozenRewritePatternList &patterns) {
2733   return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
2734 }
2735 
2736 /// Apply an analysis conversion on the given operations, and all nested
2737 /// operations. This method analyzes which operations would be successfully
2738 /// converted to the target if a conversion was applied. All operations that
2739 /// were found to be legalizable to the given 'target' are placed within the
2740 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
2741 /// operations on success and only pre-existing operations are added to the set.
2742 LogicalResult
2743 mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
2744                               ConversionTarget &target,
2745                               const FrozenRewritePatternList &patterns,
2746                               DenseSet<Operation *> &convertedOps) {
2747   OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
2748                                  &convertedOps);
2749   return opConverter.convertOperations(ops);
2750 }
2751 LogicalResult
2752 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
2753                               const FrozenRewritePatternList &patterns,
2754                               DenseSet<Operation *> &convertedOps) {
2755   return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
2756                                  convertedOps);
2757 }
2758