1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15 
16 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/StringMap.h"
19 
20 namespace mlir {
21 
22 // Forward declarations.
23 class Block;
24 class ConversionPatternRewriter;
25 class MLIRContext;
26 class Operation;
27 class Type;
28 class Value;
29 
30 //===----------------------------------------------------------------------===//
31 // Type Conversion
32 //===----------------------------------------------------------------------===//
33 
34 /// Type conversion class. Specific conversions and materializations can be
35 /// registered using addConversion and addMaterialization, respectively.
36 class TypeConverter {
37 public:
38   /// This class provides all of the information necessary to convert a type
39   /// signature.
40   class SignatureConversion {
41   public:
SignatureConversion(unsigned numOrigInputs)42     SignatureConversion(unsigned numOrigInputs)
43         : remappedInputs(numOrigInputs) {}
44 
45     /// This struct represents a range of new types or a single value that
46     /// remaps an existing signature input.
47     struct InputMapping {
48       size_t inputNo, size;
49       Value replacementValue;
50     };
51 
52     /// Return the argument types for the new signature.
getConvertedTypes()53     ArrayRef<Type> getConvertedTypes() const { return argTypes; }
54 
55     /// Get the input mapping for the given argument.
getInputMapping(unsigned input)56     Optional<InputMapping> getInputMapping(unsigned input) const {
57       return remappedInputs[input];
58     }
59 
60     //===------------------------------------------------------------------===//
61     // Conversion Hooks
62     //===------------------------------------------------------------------===//
63 
64     /// Remap an input of the original signature with a new set of types. The
65     /// new types are appended to the new signature conversion.
66     void addInputs(unsigned origInputNo, ArrayRef<Type> types);
67 
68     /// Append new input types to the signature conversion, this should only be
69     /// used if the new types are not intended to remap an existing input.
70     void addInputs(ArrayRef<Type> types);
71 
72     /// Remap an input of the original signature to another `replacement`
73     /// value. This drops the original argument.
74     void remapInput(unsigned origInputNo, Value replacement);
75 
76   private:
77     /// Remap an input of the original signature with a range of types in the
78     /// new signature.
79     void remapInput(unsigned origInputNo, unsigned newInputNo,
80                     unsigned newInputCount = 1);
81 
82     /// The remapping information for each of the original arguments.
83     SmallVector<Optional<InputMapping>, 4> remappedInputs;
84 
85     /// The set of new argument types.
86     SmallVector<Type, 4> argTypes;
87   };
88 
89   /// Register a conversion function. A conversion function must be convertible
90   /// to any of the following forms(where `T` is a class derived from `Type`:
91   ///   * Optional<Type>(T)
92   ///     - This form represents a 1-1 type conversion. It should return nullptr
93   ///       or `llvm::None` to signify failure. If `llvm::None` is returned, the
94   ///       converter is allowed to try another conversion function to perform
95   ///       the conversion.
96   ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
97   ///     - This form represents a 1-N type conversion. It should return
98   ///       `failure` or `llvm::None` to signify a failed conversion. If the new
99   ///       set of types is empty, the type is removed and any usages of the
100   ///       existing value are expected to be removed during conversion. If
101   ///       `llvm::None` is returned, the converter is allowed to try another
102   ///       conversion function to perform the conversion.
103   ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
104   ///     - This form represents a 1-N type conversion supporting recursive
105   ///       types. The first two arguments and the return value are the same as
106   ///       for the regular 1-N form. The third argument is contains is the
107   ///       "call stack" of the recursive conversion: it contains the list of
108   ///       types currently being converted, with the current type being the
109   ///       last one. If it is present more than once in the list, the
110   ///       conversion concerns a recursive type.
111   /// Note: When attempting to convert a type, e.g. via 'convertType', the
112   ///       mostly recently added conversions will be invoked first.
113   template <typename FnT, typename T = typename llvm::function_traits<
114                               std::decay_t<FnT>>::template arg_t<0>>
addConversion(FnT && callback)115   void addConversion(FnT &&callback) {
116     registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
117   }
118 
119   /// Register a materialization function, which must be convertible to the
120   /// following form:
121   ///   `Optional<Value>(OpBuilder &, T, ValueRange, Location)`,
122   /// where `T` is any subclass of `Type`. This function is responsible for
123   /// creating an operation, using the OpBuilder and Location provided, that
124   /// "casts" a range of values into a single value of the given type `T`. It
125   /// must return a Value of the converted type on success, an `llvm::None` if
126   /// it failed but other materialization can be attempted, and `nullptr` on
127   /// unrecoverable failure. It will only be called for (sub)types of `T`.
128   /// Materialization functions must be provided when a type conversion may
129   /// persist after the conversion has finished.
130   ///
131   /// This method registers a materialization that will be called when
132   /// converting an illegal block argument type, to a legal type.
133   template <typename FnT, typename T = typename llvm::function_traits<
134                               std::decay_t<FnT>>::template arg_t<1>>
addArgumentMaterialization(FnT && callback)135   void addArgumentMaterialization(FnT &&callback) {
136     argumentMaterializations.emplace_back(
137         wrapMaterialization<T>(std::forward<FnT>(callback)));
138   }
139   /// This method registers a materialization that will be called when
140   /// converting a legal type to an illegal source type. This is used when
141   /// conversions to an illegal type must persist beyond the main conversion.
142   template <typename FnT, typename T = typename llvm::function_traits<
143                               std::decay_t<FnT>>::template arg_t<1>>
addSourceMaterialization(FnT && callback)144   void addSourceMaterialization(FnT &&callback) {
145     sourceMaterializations.emplace_back(
146         wrapMaterialization<T>(std::forward<FnT>(callback)));
147   }
148   /// This method registers a materialization that will be called when
149   /// converting type from an illegal, or source, type to a legal type.
150   template <typename FnT, typename T = typename llvm::function_traits<
151                               std::decay_t<FnT>>::template arg_t<1>>
addTargetMaterialization(FnT && callback)152   void addTargetMaterialization(FnT &&callback) {
153     targetMaterializations.emplace_back(
154         wrapMaterialization<T>(std::forward<FnT>(callback)));
155   }
156 
157   /// Convert the given type. This function should return failure if no valid
158   /// conversion exists, success otherwise. If the new set of types is empty,
159   /// the type is removed and any usages of the existing value are expected to
160   /// be removed during conversion.
161   LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
162 
163   /// This hook simplifies defining 1-1 type conversions. This function returns
164   /// the type to convert to on success, and a null type on failure.
165   Type convertType(Type t);
166 
167   /// Convert the given set of types, filling 'results' as necessary. This
168   /// returns failure if the conversion of any of the types fails, success
169   /// otherwise.
170   LogicalResult convertTypes(TypeRange types, SmallVectorImpl<Type> &results);
171 
172   /// Return true if the given type is legal for this type converter, i.e. the
173   /// type converts to itself.
174   bool isLegal(Type type);
175   /// Return true if all of the given types are legal for this type converter.
176   template <typename RangeT>
177   std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
178                        !std::is_convertible<RangeT, Operation *>::value,
179                    bool>
isLegal(RangeT && range)180   isLegal(RangeT &&range) {
181     return llvm::all_of(range, [this](Type type) { return isLegal(type); });
182   }
183   /// Return true if the given operation has legal operand and result types.
184   bool isLegal(Operation *op);
185 
186   /// Return true if the types of block arguments within the region are legal.
187   bool isLegal(Region *region);
188 
189   /// Return true if the inputs and outputs of the given function type are
190   /// legal.
191   bool isSignatureLegal(FunctionType ty);
192 
193   /// This method allows for converting a specific argument of a signature. It
194   /// takes as inputs the original argument input number, type.
195   /// On success, it populates 'result' with any new mappings.
196   LogicalResult convertSignatureArg(unsigned inputNo, Type type,
197                                     SignatureConversion &result);
198   LogicalResult convertSignatureArgs(TypeRange types,
199                                      SignatureConversion &result,
200                                      unsigned origInputOffset = 0);
201 
202   /// This function converts the type signature of the given block, by invoking
203   /// 'convertSignatureArg' for each argument. This function should return a
204   /// valid conversion for the signature on success, None otherwise.
205   Optional<SignatureConversion> convertBlockSignature(Block *block);
206 
207   /// Materialize a conversion from a set of types into one result type by
208   /// generating a cast sequence of some kind. See the respective
209   /// `add*Materialization` for more information on the context for these
210   /// methods.
materializeArgumentConversion(OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)211   Value materializeArgumentConversion(OpBuilder &builder, Location loc,
212                                       Type resultType, ValueRange inputs) {
213     return materializeConversion(argumentMaterializations, builder, loc,
214                                  resultType, inputs);
215   }
materializeSourceConversion(OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)216   Value materializeSourceConversion(OpBuilder &builder, Location loc,
217                                     Type resultType, ValueRange inputs) {
218     return materializeConversion(sourceMaterializations, builder, loc,
219                                  resultType, inputs);
220   }
materializeTargetConversion(OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)221   Value materializeTargetConversion(OpBuilder &builder, Location loc,
222                                     Type resultType, ValueRange inputs) {
223     return materializeConversion(targetMaterializations, builder, loc,
224                                  resultType, inputs);
225   }
226 
227 private:
228   /// The signature of the callback used to convert a type. If the new set of
229   /// types is empty, the type is removed and any usages of the existing value
230   /// are expected to be removed during conversion.
231   using ConversionCallbackFn = std::function<Optional<LogicalResult>(
232       Type, SmallVectorImpl<Type> &, ArrayRef<Type>)>;
233 
234   /// The signature of the callback used to materialize a conversion.
235   using MaterializationCallbackFn =
236       std::function<Optional<Value>(OpBuilder &, Type, ValueRange, Location)>;
237 
238   /// Attempt to materialize a conversion using one of the provided
239   /// materialization functions.
240   Value materializeConversion(
241       MutableArrayRef<MaterializationCallbackFn> materializations,
242       OpBuilder &builder, Location loc, Type resultType, ValueRange inputs);
243 
244   /// Generate a wrapper for the given callback. This allows for accepting
245   /// different callback forms, that all compose into a single version.
246   /// With callback of form: `Optional<Type>(T)`
247   template <typename T, typename FnT>
248   std::enable_if_t<llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
wrapCallback(FnT && callback)249   wrapCallback(FnT &&callback) {
250     return wrapCallback<T>(
251         [callback = std::forward<FnT>(callback)](
252             T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
253           if (Optional<Type> resultOpt = callback(type)) {
254             bool wasSuccess = static_cast<bool>(resultOpt.value());
255             if (wasSuccess)
256               results.push_back(resultOpt.value());
257             return Optional<LogicalResult>(success(wasSuccess));
258           }
259           return Optional<LogicalResult>();
260         });
261   }
262   /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
263   /// &)`
264   template <typename T, typename FnT>
265   std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &>::value,
266                    ConversionCallbackFn>
wrapCallback(FnT && callback)267   wrapCallback(FnT &&callback) {
268     return wrapCallback<T>(
269         [callback = std::forward<FnT>(callback)](
270             T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
271           return callback(type, results);
272         });
273   }
274   /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
275   /// &, ArrayRef<Type>)`.
276   template <typename T, typename FnT>
277   std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &,
278                                       ArrayRef<Type>>::value,
279                    ConversionCallbackFn>
wrapCallback(FnT && callback)280   wrapCallback(FnT &&callback) {
281     return [callback = std::forward<FnT>(callback)](
282                Type type, SmallVectorImpl<Type> &results,
283                ArrayRef<Type> callStack) -> Optional<LogicalResult> {
284       T derivedType = type.dyn_cast<T>();
285       if (!derivedType)
286         return llvm::None;
287       return callback(derivedType, results, callStack);
288     };
289   }
290 
291   /// Register a type conversion.
registerConversion(ConversionCallbackFn callback)292   void registerConversion(ConversionCallbackFn callback) {
293     conversions.emplace_back(std::move(callback));
294     cachedDirectConversions.clear();
295     cachedMultiConversions.clear();
296   }
297 
298   /// Generate a wrapper for the given materialization callback. The callback
299   /// may take any subclass of `Type` and the wrapper will check for the target
300   /// type to be of the expected class before calling the callback.
301   template <typename T, typename FnT>
wrapMaterialization(FnT && callback)302   MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
303     return [callback = std::forward<FnT>(callback)](
304                OpBuilder &builder, Type resultType, ValueRange inputs,
305                Location loc) -> Optional<Value> {
306       if (T derivedType = resultType.dyn_cast<T>())
307         return callback(builder, derivedType, inputs, loc);
308       return llvm::None;
309     };
310   }
311 
312   /// The set of registered conversion functions.
313   SmallVector<ConversionCallbackFn, 4> conversions;
314 
315   /// The list of registered materialization functions.
316   SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
317   SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
318   SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
319 
320   /// A set of cached conversions to avoid recomputing in the common case.
321   /// Direct 1-1 conversions are the most common, so this cache stores the
322   /// successful 1-1 conversions as well as all failed conversions.
323   DenseMap<Type, Type> cachedDirectConversions;
324   /// This cache stores the successful 1->N conversions, where N != 1.
325   DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
326 
327   /// Stores the types that are being converted in the case when convertType
328   /// is being called recursively to convert nested types.
329   SmallVector<Type, 2> conversionCallStack;
330 };
331 
332 //===----------------------------------------------------------------------===//
333 // Conversion Patterns
334 //===----------------------------------------------------------------------===//
335 
336 /// Base class for the conversion patterns. This pattern class enables type
337 /// conversions, and other uses specific to the conversion framework. As such,
338 /// patterns of this type can only be used with the 'apply*' methods below.
339 class ConversionPattern : public RewritePattern {
340 public:
341   /// Hook for derived classes to implement rewriting. `op` is the (first)
342   /// operation matched by the pattern, `operands` is a list of the rewritten
343   /// operand values that are passed to `op`, `rewriter` can be used to emit the
344   /// new operations. This function should not fail. If some specific cases of
345   /// the operation are not supported, these cases should not be matched.
rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)346   virtual void rewrite(Operation *op, ArrayRef<Value> operands,
347                        ConversionPatternRewriter &rewriter) const {
348     llvm_unreachable("unimplemented rewrite");
349   }
350 
351   /// Hook for derived classes to implement combined matching and rewriting.
352   virtual LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)353   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
354                   ConversionPatternRewriter &rewriter) const {
355     if (failed(match(op)))
356       return failure();
357     rewrite(op, operands, rewriter);
358     return success();
359   }
360 
361   /// Attempt to match and rewrite the IR root at the specified operation.
362   LogicalResult matchAndRewrite(Operation *op,
363                                 PatternRewriter &rewriter) const final;
364 
365   /// Return the type converter held by this pattern, or nullptr if the pattern
366   /// does not require type conversion.
getTypeConverter()367   TypeConverter *getTypeConverter() const { return typeConverter; }
368 
369   template <typename ConverterTy>
370   std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
371                    ConverterTy *>
getTypeConverter()372   getTypeConverter() const {
373     return static_cast<ConverterTy *>(typeConverter);
374   }
375 
376 protected:
377   /// See `RewritePattern::RewritePattern` for information on the other
378   /// available constructors.
379   using RewritePattern::RewritePattern;
380   /// Construct a conversion pattern with the given converter, and forward the
381   /// remaining arguments to RewritePattern.
382   template <typename... Args>
ConversionPattern(TypeConverter & typeConverter,Args &&...args)383   ConversionPattern(TypeConverter &typeConverter, Args &&...args)
384       : RewritePattern(std::forward<Args>(args)...),
385         typeConverter(&typeConverter) {}
386 
387 protected:
388   /// An optional type converter for use by this pattern.
389   TypeConverter *typeConverter = nullptr;
390 
391 private:
392   using RewritePattern::rewrite;
393 };
394 
395 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
396 /// matching and rewriting against an instance of a derived operation class as
397 /// opposed to a raw Operation.
398 template <typename SourceOp>
399 class OpConversionPattern : public ConversionPattern {
400 public:
401   using OpAdaptor = typename SourceOp::Adaptor;
402 
403   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
ConversionPattern(SourceOp::getOperationName (),benefit,context)404       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
405   OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
406                       PatternBenefit benefit = 1)
ConversionPattern(typeConverter,SourceOp::getOperationName (),benefit,context)407       : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
408                           context) {}
409 
410   /// Wrappers around the ConversionPattern methods that pass the derived op
411   /// type.
match(Operation * op)412   LogicalResult match(Operation *op) const final {
413     return match(cast<SourceOp>(op));
414   }
rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)415   void rewrite(Operation *op, ArrayRef<Value> operands,
416                ConversionPatternRewriter &rewriter) const final {
417     rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
418             rewriter);
419   }
420   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)421   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
422                   ConversionPatternRewriter &rewriter) const final {
423     return matchAndRewrite(cast<SourceOp>(op),
424                            OpAdaptor(operands, op->getAttrDictionary()),
425                            rewriter);
426   }
427 
428   /// Rewrite and Match methods that operate on the SourceOp type. These must be
429   /// overridden by the derived pattern class.
match(SourceOp op)430   virtual LogicalResult match(SourceOp op) const {
431     llvm_unreachable("must override match or matchAndRewrite");
432   }
rewrite(SourceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter)433   virtual void rewrite(SourceOp op, OpAdaptor adaptor,
434                        ConversionPatternRewriter &rewriter) const {
435     llvm_unreachable("must override matchAndRewrite or a rewrite method");
436   }
437   virtual LogicalResult
matchAndRewrite(SourceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter)438   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
439                   ConversionPatternRewriter &rewriter) const {
440     if (failed(match(op)))
441       return failure();
442     rewrite(op, adaptor, rewriter);
443     return success();
444   }
445 
446 private:
447   using ConversionPattern::matchAndRewrite;
448 };
449 
450 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
451 /// allows for matching and rewriting against an instance of an OpInterface
452 /// class as opposed to a raw Operation.
453 template <typename SourceOp>
454 class OpInterfaceConversionPattern : public ConversionPattern {
455 public:
456   OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
ConversionPattern(Pattern::MatchInterfaceOpTypeTag (),SourceOp::getInterfaceID (),benefit,context)457       : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
458                           SourceOp::getInterfaceID(), benefit, context) {}
459   OpInterfaceConversionPattern(TypeConverter &typeConverter,
460                                MLIRContext *context, PatternBenefit benefit = 1)
ConversionPattern(typeConverter,Pattern::MatchInterfaceOpTypeTag (),SourceOp::getInterfaceID (),benefit,context)461       : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
462                           SourceOp::getInterfaceID(), benefit, context) {}
463 
464   /// Wrappers around the ConversionPattern methods that pass the derived op
465   /// type.
rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)466   void rewrite(Operation *op, ArrayRef<Value> operands,
467                ConversionPatternRewriter &rewriter) const final {
468     rewrite(cast<SourceOp>(op), operands, rewriter);
469   }
470   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)471   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
472                   ConversionPatternRewriter &rewriter) const final {
473     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
474   }
475 
476   /// Rewrite and Match methods that operate on the SourceOp type. These must be
477   /// overridden by the derived pattern class.
rewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)478   virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
479                        ConversionPatternRewriter &rewriter) const {
480     llvm_unreachable("must override matchAndRewrite or a rewrite method");
481   }
482   virtual LogicalResult
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)483   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
484                   ConversionPatternRewriter &rewriter) const {
485     if (failed(match(op)))
486       return failure();
487     rewrite(op, operands, rewriter);
488     return success();
489   }
490 
491 private:
492   using ConversionPattern::matchAndRewrite;
493 };
494 
495 /// Add a pattern to the given pattern list to convert the signature of a
496 /// FunctionOpInterface op with the given type converter. This only supports
497 /// ops which use FunctionType to represent their type.
498 void populateFunctionOpInterfaceTypeConversionPattern(
499     StringRef functionLikeOpName, RewritePatternSet &patterns,
500     TypeConverter &converter);
501 
502 template <typename FuncOpT>
populateFunctionOpInterfaceTypeConversionPattern(RewritePatternSet & patterns,TypeConverter & converter)503 void populateFunctionOpInterfaceTypeConversionPattern(
504     RewritePatternSet &patterns, TypeConverter &converter) {
505   populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
506                                                    patterns, converter);
507 }
508 
509 //===----------------------------------------------------------------------===//
510 // Conversion PatternRewriter
511 //===----------------------------------------------------------------------===//
512 
513 namespace detail {
514 struct ConversionPatternRewriterImpl;
515 } // namespace detail
516 
517 /// This class implements a pattern rewriter for use with ConversionPatterns. It
518 /// extends the base PatternRewriter and provides special conversion specific
519 /// hooks.
520 class ConversionPatternRewriter final : public PatternRewriter {
521 public:
522   explicit ConversionPatternRewriter(MLIRContext *ctx);
523   ~ConversionPatternRewriter() override;
524 
525   /// Apply a signature conversion to the entry block of the given region. This
526   /// replaces the entry block with a new block containing the updated
527   /// signature. The new entry block to the region is returned for convenience.
528   ///
529   /// If provided, `converter` will be used for any materializations.
530   Block *
531   applySignatureConversion(Region *region,
532                            TypeConverter::SignatureConversion &conversion,
533                            TypeConverter *converter = nullptr);
534 
535   /// Convert the types of block arguments within the given region. This
536   /// replaces each block with a new block containing the updated signature. The
537   /// entry block may have a special conversion if `entryConversion` is
538   /// provided. On success, the new entry block to the region is returned for
539   /// convenience. Otherwise, failure is returned.
540   FailureOr<Block *> convertRegionTypes(
541       Region *region, TypeConverter &converter,
542       TypeConverter::SignatureConversion *entryConversion = nullptr);
543 
544   /// Convert the types of block arguments within the given region except for
545   /// the entry region. This replaces each non-entry block with a new block
546   /// containing the updated signature.
547   ///
548   /// If special conversion behavior is needed for the non-entry blocks (for
549   /// example, we need to convert only a subset of a BB arguments), such
550   /// behavior can be specified in blockConversions.
551   LogicalResult convertNonEntryRegionTypes(
552       Region *region, TypeConverter &converter,
553       ArrayRef<TypeConverter::SignatureConversion> blockConversions);
554 
555   /// Replace all the uses of the block argument `from` with value `to`.
556   void replaceUsesOfBlockArgument(BlockArgument from, Value to);
557 
558   /// Return the converted value of 'key' with a type defined by the type
559   /// converter of the currently executing pattern. Return nullptr in the case
560   /// of failure, the remapped value otherwise.
561   Value getRemappedValue(Value key);
562 
563   /// Return the converted values that replace 'keys' with types defined by the
564   /// type converter of the currently executing pattern. Returns failure if the
565   /// remap failed, success otherwise.
566   LogicalResult getRemappedValues(ValueRange keys,
567                                   SmallVectorImpl<Value> &results);
568 
569   //===--------------------------------------------------------------------===//
570   // PatternRewriter Hooks
571   //===--------------------------------------------------------------------===//
572 
573   /// PatternRewriter hook for replacing the results of an operation when the
574   /// given functor returns true.
575   void replaceOpWithIf(
576       Operation *op, ValueRange newValues, bool *allUsesReplaced,
577       llvm::unique_function<bool(OpOperand &) const> functor) override;
578 
579   /// PatternRewriter hook for replacing the results of an operation.
580   void replaceOp(Operation *op, ValueRange newValues) override;
581   using PatternRewriter::replaceOp;
582 
583   /// PatternRewriter hook for erasing a dead operation. The uses of this
584   /// operation *must* be made dead by the end of the conversion process,
585   /// otherwise an assert will be issued.
586   void eraseOp(Operation *op) override;
587 
588   /// PatternRewriter hook for erase all operations in a block. This is not yet
589   /// implemented for dialect conversion.
590   void eraseBlock(Block *block) override;
591 
592   /// PatternRewriter hook creating a new block.
593   void notifyBlockCreated(Block *block) override;
594 
595   /// PatternRewriter hook for splitting a block into two parts.
596   Block *splitBlock(Block *block, Block::iterator before) override;
597 
598   /// PatternRewriter hook for merging a block into another.
599   void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override;
600 
601   /// PatternRewriter hook for moving blocks out of a region.
602   void inlineRegionBefore(Region &region, Region &parent,
603                           Region::iterator before) override;
604   using PatternRewriter::inlineRegionBefore;
605 
606   /// PatternRewriter hook for cloning blocks of one region into another. The
607   /// given region to clone *must* not have been modified as part of conversion
608   /// yet, i.e. it must be within an operation that is either in the process of
609   /// conversion, or has not yet been converted.
610   void cloneRegionBefore(Region &region, Region &parent,
611                          Region::iterator before,
612                          BlockAndValueMapping &mapping) override;
613   using PatternRewriter::cloneRegionBefore;
614 
615   /// PatternRewriter hook for inserting a new operation.
616   void notifyOperationInserted(Operation *op) override;
617 
618   /// PatternRewriter hook for updating the root operation in-place.
619   /// Note: These methods only track updates to the top-level operation itself,
620   /// and not nested regions. Updates to regions will still require notification
621   /// through other more specific hooks above.
622   void startRootUpdate(Operation *op) override;
623 
624   /// PatternRewriter hook for updating the root operation in-place.
625   void finalizeRootUpdate(Operation *op) override;
626 
627   /// PatternRewriter hook for updating the root operation in-place.
628   void cancelRootUpdate(Operation *op) override;
629 
630   /// PatternRewriter hook for notifying match failure reasons.
631   LogicalResult
632   notifyMatchFailure(Location loc,
633                      function_ref<void(Diagnostic &)> reasonCallback) override;
634   using PatternRewriter::notifyMatchFailure;
635 
636   /// Return a reference to the internal implementation.
637   detail::ConversionPatternRewriterImpl &getImpl();
638 
639 private:
640   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
641 };
642 
643 //===----------------------------------------------------------------------===//
644 // ConversionTarget
645 //===----------------------------------------------------------------------===//
646 
647 /// This class describes a specific conversion target.
648 class ConversionTarget {
649 public:
650   /// This enumeration corresponds to the specific action to take when
651   /// considering an operation legal for this conversion target.
652   enum class LegalizationAction {
653     /// The target supports this operation.
654     Legal,
655 
656     /// This operation has dynamic legalization constraints that must be checked
657     /// by the target.
658     Dynamic,
659 
660     /// The target explicitly does not support this operation.
661     Illegal,
662   };
663 
664   /// A structure containing additional information describing a specific legal
665   /// operation instance.
666   struct LegalOpDetails {
667     /// A flag that indicates if this operation is 'recursively' legal. This
668     /// means that if an operation is legal, either statically or dynamically,
669     /// all of the operations nested within are also considered legal.
670     bool isRecursivelyLegal = false;
671   };
672 
673   /// The signature of the callback used to determine if an operation is
674   /// dynamically legal on the target.
675   using DynamicLegalityCallbackFn = std::function<Optional<bool>(Operation *)>;
676 
ConversionTarget(MLIRContext & ctx)677   ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
678   virtual ~ConversionTarget() = default;
679 
680   //===--------------------------------------------------------------------===//
681   // Legality Registration
682   //===--------------------------------------------------------------------===//
683 
684   /// Register a legality action for the given operation.
685   void setOpAction(OperationName op, LegalizationAction action);
686   template <typename OpT>
setOpAction(LegalizationAction action)687   void setOpAction(LegalizationAction action) {
688     setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
689   }
690 
691   /// Register the given operations as legal.
addLegalOp(OperationName op)692   void addLegalOp(OperationName op) {
693     setOpAction(op, LegalizationAction::Legal);
694   }
695   template <typename OpT>
addLegalOp()696   void addLegalOp() {
697     addLegalOp(OperationName(OpT::getOperationName(), &ctx));
698   }
699   template <typename OpT, typename OpT2, typename... OpTs>
addLegalOp()700   void addLegalOp() {
701     addLegalOp<OpT>();
702     addLegalOp<OpT2, OpTs...>();
703   }
704 
705   /// Register the given operation as dynamically legal and set the dynamic
706   /// legalization callback to the one provided.
addDynamicallyLegalOp(OperationName op,const DynamicLegalityCallbackFn & callback)707   void addDynamicallyLegalOp(OperationName op,
708                              const DynamicLegalityCallbackFn &callback) {
709     setOpAction(op, LegalizationAction::Dynamic);
710     setLegalityCallback(op, callback);
711   }
712   template <typename OpT>
addDynamicallyLegalOp(const DynamicLegalityCallbackFn & callback)713   void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
714     addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
715                           callback);
716   }
717   template <typename OpT, typename OpT2, typename... OpTs>
addDynamicallyLegalOp(const DynamicLegalityCallbackFn & callback)718   void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
719     addDynamicallyLegalOp<OpT>(callback);
720     addDynamicallyLegalOp<OpT2, OpTs...>(callback);
721   }
722   template <typename OpT, class Callable>
723   typename std::enable_if<
724       !llvm::is_invocable<Callable, Operation *>::value>::type
addDynamicallyLegalOp(Callable && callback)725   addDynamicallyLegalOp(Callable &&callback) {
726     addDynamicallyLegalOp<OpT>(
727         [=](Operation *op) { return callback(cast<OpT>(op)); });
728   }
729 
730   /// Register the given operation as illegal, i.e. this operation is known to
731   /// not be supported by this target.
addIllegalOp(OperationName op)732   void addIllegalOp(OperationName op) {
733     setOpAction(op, LegalizationAction::Illegal);
734   }
735   template <typename OpT>
addIllegalOp()736   void addIllegalOp() {
737     addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
738   }
739   template <typename OpT, typename OpT2, typename... OpTs>
addIllegalOp()740   void addIllegalOp() {
741     addIllegalOp<OpT>();
742     addIllegalOp<OpT2, OpTs...>();
743   }
744 
745   /// Mark an operation, that *must* have either been set as `Legal` or
746   /// `DynamicallyLegal`, as being recursively legal. This means that in
747   /// addition to the operation itself, all of the operations nested within are
748   /// also considered legal. An optional dynamic legality callback may be
749   /// provided to mark subsets of legal instances as recursively legal.
750   void markOpRecursivelyLegal(OperationName name,
751                               const DynamicLegalityCallbackFn &callback);
752   template <typename OpT>
753   void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
754     OperationName opName(OpT::getOperationName(), &ctx);
755     markOpRecursivelyLegal(opName, callback);
756   }
757   template <typename OpT, typename OpT2, typename... OpTs>
758   void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
759     markOpRecursivelyLegal<OpT>(callback);
760     markOpRecursivelyLegal<OpT2, OpTs...>(callback);
761   }
762   template <typename OpT, class Callable>
763   typename std::enable_if<
764       !llvm::is_invocable<Callable, Operation *>::value>::type
markOpRecursivelyLegal(Callable && callback)765   markOpRecursivelyLegal(Callable &&callback) {
766     markOpRecursivelyLegal<OpT>(
767         [=](Operation *op) { return callback(cast<OpT>(op)); });
768   }
769 
770   /// Register a legality action for the given dialects.
771   void setDialectAction(ArrayRef<StringRef> dialectNames,
772                         LegalizationAction action);
773 
774   /// Register the operations of the given dialects as legal.
775   template <typename... Names>
addLegalDialect(StringRef name,Names...names)776   void addLegalDialect(StringRef name, Names... names) {
777     SmallVector<StringRef, 2> dialectNames({name, names...});
778     setDialectAction(dialectNames, LegalizationAction::Legal);
779   }
780   template <typename... Args>
addLegalDialect()781   void addLegalDialect() {
782     SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
783     setDialectAction(dialectNames, LegalizationAction::Legal);
784   }
785 
786   /// Register the operations of the given dialects as dynamically legal, i.e.
787   /// requiring custom handling by the callback.
788   template <typename... Names>
addDynamicallyLegalDialect(const DynamicLegalityCallbackFn & callback,StringRef name,Names...names)789   void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback,
790                                   StringRef name, Names... names) {
791     SmallVector<StringRef, 2> dialectNames({name, names...});
792     setDialectAction(dialectNames, LegalizationAction::Dynamic);
793     setLegalityCallback(dialectNames, callback);
794   }
795   template <typename... Args>
addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)796   void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
797     addDynamicallyLegalDialect(std::move(callback),
798                                Args::getDialectNamespace()...);
799   }
800 
801   /// Register unknown operations as dynamically legal. For operations(and
802   /// dialects) that do not have a set legalization action, treat them as
803   /// dynamically legal and invoke the given callback.
markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn & fn)804   void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) {
805     setLegalityCallback(fn);
806   }
807 
808   /// Register the operations of the given dialects as illegal, i.e.
809   /// operations of this dialect are not supported by the target.
810   template <typename... Names>
addIllegalDialect(StringRef name,Names...names)811   void addIllegalDialect(StringRef name, Names... names) {
812     SmallVector<StringRef, 2> dialectNames({name, names...});
813     setDialectAction(dialectNames, LegalizationAction::Illegal);
814   }
815   template <typename... Args>
addIllegalDialect()816   void addIllegalDialect() {
817     SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
818     setDialectAction(dialectNames, LegalizationAction::Illegal);
819   }
820 
821   //===--------------------------------------------------------------------===//
822   // Legality Querying
823   //===--------------------------------------------------------------------===//
824 
825   /// Get the legality action for the given operation.
826   Optional<LegalizationAction> getOpAction(OperationName op) const;
827 
828   /// If the given operation instance is legal on this target, a structure
829   /// containing legality information is returned. If the operation is not
830   /// legal, None is returned. Also returns None is operation legality wasn't
831   /// registered by user or dynamic legality callbacks returned None.
832   ///
833   /// Note: Legality is actually a 4-state: Legal(recursive=true),
834   /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
835   /// either as Legal or Illegal depending on context.
836   Optional<LegalOpDetails> isLegal(Operation *op) const;
837 
838   /// Returns true is operation instance is illegal on this target. Returns
839   /// false if operation is legal, operation legality wasn't registered by user
840   /// or dynamic legality callbacks returned None.
841   bool isIllegal(Operation *op) const;
842 
843 private:
844   /// Set the dynamic legality callback for the given operation.
845   void setLegalityCallback(OperationName name,
846                            const DynamicLegalityCallbackFn &callback);
847 
848   /// Set the dynamic legality callback for the given dialects.
849   void setLegalityCallback(ArrayRef<StringRef> dialects,
850                            const DynamicLegalityCallbackFn &callback);
851 
852   /// Set the dynamic legality callback for the unknown ops.
853   void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
854 
855   /// The set of information that configures the legalization of an operation.
856   struct LegalizationInfo {
857     /// The legality action this operation was given.
858     LegalizationAction action = LegalizationAction::Illegal;
859 
860     /// If some legal instances of this operation may also be recursively legal.
861     bool isRecursivelyLegal = false;
862 
863     /// The legality callback if this operation is dynamically legal.
864     DynamicLegalityCallbackFn legalityFn;
865   };
866 
867   /// Get the legalization information for the given operation.
868   Optional<LegalizationInfo> getOpInfo(OperationName op) const;
869 
870   /// A deterministic mapping of operation name and its respective legality
871   /// information.
872   llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
873 
874   /// A set of legality callbacks for given operation names that are used to
875   /// check if an operation instance is recursively legal.
876   DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
877 
878   /// A deterministic mapping of dialect name to the specific legality action to
879   /// take.
880   llvm::StringMap<LegalizationAction> legalDialects;
881 
882   /// A set of dynamic legality callbacks for given dialect names.
883   llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
884 
885   /// An optional legality callback for unknown operations.
886   DynamicLegalityCallbackFn unknownLegalityFn;
887 
888   /// The current context this target applies to.
889   MLIRContext &ctx;
890 };
891 
892 //===----------------------------------------------------------------------===//
893 // Op Conversion Entry Points
894 //===----------------------------------------------------------------------===//
895 
896 /// Below we define several entry points for operation conversion. It is
897 /// important to note that the patterns provided to the conversion framework may
898 /// have additional constraints. See the `PatternRewriter Hooks` section of the
899 /// ConversionPatternRewriter, to see what additional constraints are imposed on
900 /// the use of the PatternRewriter.
901 
902 /// Apply a partial conversion on the given operations and all nested
903 /// operations. This method converts as many operations to the target as
904 /// possible, ignoring operations that failed to legalize. This method only
905 /// returns failure if there ops explicitly marked as illegal. If an
906 /// `unconvertedOps` set is provided, all operations that are found not to be
907 /// legalizable to the given `target` are placed within that set. (Note that if
908 /// there is an op explicitly marked as illegal, the conversion terminates and
909 /// the `unconvertedOps` set will not necessarily be complete.)
910 LogicalResult
911 applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
912                        const FrozenRewritePatternSet &patterns,
913                        DenseSet<Operation *> *unconvertedOps = nullptr);
914 LogicalResult
915 applyPartialConversion(Operation *op, ConversionTarget &target,
916                        const FrozenRewritePatternSet &patterns,
917                        DenseSet<Operation *> *unconvertedOps = nullptr);
918 
919 /// Apply a complete conversion on the given operations, and all nested
920 /// operations. This method returns failure if the conversion of any operation
921 /// fails, or if there are unreachable blocks in any of the regions nested
922 /// within 'ops'.
923 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
924                                   ConversionTarget &target,
925                                   const FrozenRewritePatternSet &patterns);
926 LogicalResult applyFullConversion(Operation *op, ConversionTarget &target,
927                                   const FrozenRewritePatternSet &patterns);
928 
929 /// Apply an analysis conversion on the given operations, and all nested
930 /// operations. This method analyzes which operations would be successfully
931 /// converted to the target if a conversion was applied. All operations that
932 /// were found to be legalizable to the given 'target' are placed within the
933 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
934 /// operations on success and only pre-existing operations are added to the set.
935 /// This method only returns failure if there are unreachable blocks in any of
936 /// the regions nested within 'ops'. There's an additional argument
937 /// `notifyCallback` which is used for collecting match failure diagnostics
938 /// generated during the conversion. Diagnostics are only reported to this
939 /// callback may only be available in debug mode.
940 LogicalResult applyAnalysisConversion(
941     ArrayRef<Operation *> ops, ConversionTarget &target,
942     const FrozenRewritePatternSet &patterns,
943     DenseSet<Operation *> &convertedOps,
944     function_ref<void(Diagnostic &)> notifyCallback = nullptr);
945 LogicalResult applyAnalysisConversion(
946     Operation *op, ConversionTarget &target,
947     const FrozenRewritePatternSet &patterns,
948     DenseSet<Operation *> &convertedOps,
949     function_ref<void(Diagnostic &)> notifyCallback = nullptr);
950 } // namespace mlir
951 
952 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
953