1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a generic pass for converting between MLIR dialects.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15
16 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/StringMap.h"
19
20 namespace mlir {
21
22 // Forward declarations.
23 class Block;
24 class ConversionPatternRewriter;
25 class MLIRContext;
26 class Operation;
27 class Type;
28 class Value;
29
30 //===----------------------------------------------------------------------===//
31 // Type Conversion
32 //===----------------------------------------------------------------------===//
33
34 /// Type conversion class. Specific conversions and materializations can be
35 /// registered using addConversion and addMaterialization, respectively.
36 class TypeConverter {
37 public:
38 /// This class provides all of the information necessary to convert a type
39 /// signature.
40 class SignatureConversion {
41 public:
SignatureConversion(unsigned numOrigInputs)42 SignatureConversion(unsigned numOrigInputs)
43 : remappedInputs(numOrigInputs) {}
44
45 /// This struct represents a range of new types or a single value that
46 /// remaps an existing signature input.
47 struct InputMapping {
48 size_t inputNo, size;
49 Value replacementValue;
50 };
51
52 /// Return the argument types for the new signature.
getConvertedTypes()53 ArrayRef<Type> getConvertedTypes() const { return argTypes; }
54
55 /// Get the input mapping for the given argument.
getInputMapping(unsigned input)56 Optional<InputMapping> getInputMapping(unsigned input) const {
57 return remappedInputs[input];
58 }
59
60 //===------------------------------------------------------------------===//
61 // Conversion Hooks
62 //===------------------------------------------------------------------===//
63
64 /// Remap an input of the original signature with a new set of types. The
65 /// new types are appended to the new signature conversion.
66 void addInputs(unsigned origInputNo, ArrayRef<Type> types);
67
68 /// Append new input types to the signature conversion, this should only be
69 /// used if the new types are not intended to remap an existing input.
70 void addInputs(ArrayRef<Type> types);
71
72 /// Remap an input of the original signature to another `replacement`
73 /// value. This drops the original argument.
74 void remapInput(unsigned origInputNo, Value replacement);
75
76 private:
77 /// Remap an input of the original signature with a range of types in the
78 /// new signature.
79 void remapInput(unsigned origInputNo, unsigned newInputNo,
80 unsigned newInputCount = 1);
81
82 /// The remapping information for each of the original arguments.
83 SmallVector<Optional<InputMapping>, 4> remappedInputs;
84
85 /// The set of new argument types.
86 SmallVector<Type, 4> argTypes;
87 };
88
89 /// Register a conversion function. A conversion function must be convertible
90 /// to any of the following forms(where `T` is a class derived from `Type`:
91 /// * Optional<Type>(T)
92 /// - This form represents a 1-1 type conversion. It should return nullptr
93 /// or `llvm::None` to signify failure. If `llvm::None` is returned, the
94 /// converter is allowed to try another conversion function to perform
95 /// the conversion.
96 /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
97 /// - This form represents a 1-N type conversion. It should return
98 /// `failure` or `llvm::None` to signify a failed conversion. If the new
99 /// set of types is empty, the type is removed and any usages of the
100 /// existing value are expected to be removed during conversion. If
101 /// `llvm::None` is returned, the converter is allowed to try another
102 /// conversion function to perform the conversion.
103 /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
104 /// - This form represents a 1-N type conversion supporting recursive
105 /// types. The first two arguments and the return value are the same as
106 /// for the regular 1-N form. The third argument is contains is the
107 /// "call stack" of the recursive conversion: it contains the list of
108 /// types currently being converted, with the current type being the
109 /// last one. If it is present more than once in the list, the
110 /// conversion concerns a recursive type.
111 /// Note: When attempting to convert a type, e.g. via 'convertType', the
112 /// mostly recently added conversions will be invoked first.
113 template <typename FnT, typename T = typename llvm::function_traits<
114 std::decay_t<FnT>>::template arg_t<0>>
addConversion(FnT && callback)115 void addConversion(FnT &&callback) {
116 registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
117 }
118
119 /// Register a materialization function, which must be convertible to the
120 /// following form:
121 /// `Optional<Value>(OpBuilder &, T, ValueRange, Location)`,
122 /// where `T` is any subclass of `Type`. This function is responsible for
123 /// creating an operation, using the OpBuilder and Location provided, that
124 /// "casts" a range of values into a single value of the given type `T`. It
125 /// must return a Value of the converted type on success, an `llvm::None` if
126 /// it failed but other materialization can be attempted, and `nullptr` on
127 /// unrecoverable failure. It will only be called for (sub)types of `T`.
128 /// Materialization functions must be provided when a type conversion may
129 /// persist after the conversion has finished.
130 ///
131 /// This method registers a materialization that will be called when
132 /// converting an illegal block argument type, to a legal type.
133 template <typename FnT, typename T = typename llvm::function_traits<
134 std::decay_t<FnT>>::template arg_t<1>>
addArgumentMaterialization(FnT && callback)135 void addArgumentMaterialization(FnT &&callback) {
136 argumentMaterializations.emplace_back(
137 wrapMaterialization<T>(std::forward<FnT>(callback)));
138 }
139 /// This method registers a materialization that will be called when
140 /// converting a legal type to an illegal source type. This is used when
141 /// conversions to an illegal type must persist beyond the main conversion.
142 template <typename FnT, typename T = typename llvm::function_traits<
143 std::decay_t<FnT>>::template arg_t<1>>
addSourceMaterialization(FnT && callback)144 void addSourceMaterialization(FnT &&callback) {
145 sourceMaterializations.emplace_back(
146 wrapMaterialization<T>(std::forward<FnT>(callback)));
147 }
148 /// This method registers a materialization that will be called when
149 /// converting type from an illegal, or source, type to a legal type.
150 template <typename FnT, typename T = typename llvm::function_traits<
151 std::decay_t<FnT>>::template arg_t<1>>
addTargetMaterialization(FnT && callback)152 void addTargetMaterialization(FnT &&callback) {
153 targetMaterializations.emplace_back(
154 wrapMaterialization<T>(std::forward<FnT>(callback)));
155 }
156
157 /// Convert the given type. This function should return failure if no valid
158 /// conversion exists, success otherwise. If the new set of types is empty,
159 /// the type is removed and any usages of the existing value are expected to
160 /// be removed during conversion.
161 LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
162
163 /// This hook simplifies defining 1-1 type conversions. This function returns
164 /// the type to convert to on success, and a null type on failure.
165 Type convertType(Type t);
166
167 /// Convert the given set of types, filling 'results' as necessary. This
168 /// returns failure if the conversion of any of the types fails, success
169 /// otherwise.
170 LogicalResult convertTypes(TypeRange types, SmallVectorImpl<Type> &results);
171
172 /// Return true if the given type is legal for this type converter, i.e. the
173 /// type converts to itself.
174 bool isLegal(Type type);
175 /// Return true if all of the given types are legal for this type converter.
176 template <typename RangeT>
177 std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
178 !std::is_convertible<RangeT, Operation *>::value,
179 bool>
isLegal(RangeT && range)180 isLegal(RangeT &&range) {
181 return llvm::all_of(range, [this](Type type) { return isLegal(type); });
182 }
183 /// Return true if the given operation has legal operand and result types.
184 bool isLegal(Operation *op);
185
186 /// Return true if the types of block arguments within the region are legal.
187 bool isLegal(Region *region);
188
189 /// Return true if the inputs and outputs of the given function type are
190 /// legal.
191 bool isSignatureLegal(FunctionType ty);
192
193 /// This method allows for converting a specific argument of a signature. It
194 /// takes as inputs the original argument input number, type.
195 /// On success, it populates 'result' with any new mappings.
196 LogicalResult convertSignatureArg(unsigned inputNo, Type type,
197 SignatureConversion &result);
198 LogicalResult convertSignatureArgs(TypeRange types,
199 SignatureConversion &result,
200 unsigned origInputOffset = 0);
201
202 /// This function converts the type signature of the given block, by invoking
203 /// 'convertSignatureArg' for each argument. This function should return a
204 /// valid conversion for the signature on success, None otherwise.
205 Optional<SignatureConversion> convertBlockSignature(Block *block);
206
207 /// Materialize a conversion from a set of types into one result type by
208 /// generating a cast sequence of some kind. See the respective
209 /// `add*Materialization` for more information on the context for these
210 /// methods.
materializeArgumentConversion(OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)211 Value materializeArgumentConversion(OpBuilder &builder, Location loc,
212 Type resultType, ValueRange inputs) {
213 return materializeConversion(argumentMaterializations, builder, loc,
214 resultType, inputs);
215 }
materializeSourceConversion(OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)216 Value materializeSourceConversion(OpBuilder &builder, Location loc,
217 Type resultType, ValueRange inputs) {
218 return materializeConversion(sourceMaterializations, builder, loc,
219 resultType, inputs);
220 }
materializeTargetConversion(OpBuilder & builder,Location loc,Type resultType,ValueRange inputs)221 Value materializeTargetConversion(OpBuilder &builder, Location loc,
222 Type resultType, ValueRange inputs) {
223 return materializeConversion(targetMaterializations, builder, loc,
224 resultType, inputs);
225 }
226
227 private:
228 /// The signature of the callback used to convert a type. If the new set of
229 /// types is empty, the type is removed and any usages of the existing value
230 /// are expected to be removed during conversion.
231 using ConversionCallbackFn = std::function<Optional<LogicalResult>(
232 Type, SmallVectorImpl<Type> &, ArrayRef<Type>)>;
233
234 /// The signature of the callback used to materialize a conversion.
235 using MaterializationCallbackFn =
236 std::function<Optional<Value>(OpBuilder &, Type, ValueRange, Location)>;
237
238 /// Attempt to materialize a conversion using one of the provided
239 /// materialization functions.
240 Value materializeConversion(
241 MutableArrayRef<MaterializationCallbackFn> materializations,
242 OpBuilder &builder, Location loc, Type resultType, ValueRange inputs);
243
244 /// Generate a wrapper for the given callback. This allows for accepting
245 /// different callback forms, that all compose into a single version.
246 /// With callback of form: `Optional<Type>(T)`
247 template <typename T, typename FnT>
248 std::enable_if_t<llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
wrapCallback(FnT && callback)249 wrapCallback(FnT &&callback) {
250 return wrapCallback<T>(
251 [callback = std::forward<FnT>(callback)](
252 T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
253 if (Optional<Type> resultOpt = callback(type)) {
254 bool wasSuccess = static_cast<bool>(resultOpt.value());
255 if (wasSuccess)
256 results.push_back(resultOpt.value());
257 return Optional<LogicalResult>(success(wasSuccess));
258 }
259 return Optional<LogicalResult>();
260 });
261 }
262 /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
263 /// &)`
264 template <typename T, typename FnT>
265 std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &>::value,
266 ConversionCallbackFn>
wrapCallback(FnT && callback)267 wrapCallback(FnT &&callback) {
268 return wrapCallback<T>(
269 [callback = std::forward<FnT>(callback)](
270 T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
271 return callback(type, results);
272 });
273 }
274 /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
275 /// &, ArrayRef<Type>)`.
276 template <typename T, typename FnT>
277 std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &,
278 ArrayRef<Type>>::value,
279 ConversionCallbackFn>
wrapCallback(FnT && callback)280 wrapCallback(FnT &&callback) {
281 return [callback = std::forward<FnT>(callback)](
282 Type type, SmallVectorImpl<Type> &results,
283 ArrayRef<Type> callStack) -> Optional<LogicalResult> {
284 T derivedType = type.dyn_cast<T>();
285 if (!derivedType)
286 return llvm::None;
287 return callback(derivedType, results, callStack);
288 };
289 }
290
291 /// Register a type conversion.
registerConversion(ConversionCallbackFn callback)292 void registerConversion(ConversionCallbackFn callback) {
293 conversions.emplace_back(std::move(callback));
294 cachedDirectConversions.clear();
295 cachedMultiConversions.clear();
296 }
297
298 /// Generate a wrapper for the given materialization callback. The callback
299 /// may take any subclass of `Type` and the wrapper will check for the target
300 /// type to be of the expected class before calling the callback.
301 template <typename T, typename FnT>
wrapMaterialization(FnT && callback)302 MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
303 return [callback = std::forward<FnT>(callback)](
304 OpBuilder &builder, Type resultType, ValueRange inputs,
305 Location loc) -> Optional<Value> {
306 if (T derivedType = resultType.dyn_cast<T>())
307 return callback(builder, derivedType, inputs, loc);
308 return llvm::None;
309 };
310 }
311
312 /// The set of registered conversion functions.
313 SmallVector<ConversionCallbackFn, 4> conversions;
314
315 /// The list of registered materialization functions.
316 SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
317 SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
318 SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
319
320 /// A set of cached conversions to avoid recomputing in the common case.
321 /// Direct 1-1 conversions are the most common, so this cache stores the
322 /// successful 1-1 conversions as well as all failed conversions.
323 DenseMap<Type, Type> cachedDirectConversions;
324 /// This cache stores the successful 1->N conversions, where N != 1.
325 DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
326
327 /// Stores the types that are being converted in the case when convertType
328 /// is being called recursively to convert nested types.
329 SmallVector<Type, 2> conversionCallStack;
330 };
331
332 //===----------------------------------------------------------------------===//
333 // Conversion Patterns
334 //===----------------------------------------------------------------------===//
335
336 /// Base class for the conversion patterns. This pattern class enables type
337 /// conversions, and other uses specific to the conversion framework. As such,
338 /// patterns of this type can only be used with the 'apply*' methods below.
339 class ConversionPattern : public RewritePattern {
340 public:
341 /// Hook for derived classes to implement rewriting. `op` is the (first)
342 /// operation matched by the pattern, `operands` is a list of the rewritten
343 /// operand values that are passed to `op`, `rewriter` can be used to emit the
344 /// new operations. This function should not fail. If some specific cases of
345 /// the operation are not supported, these cases should not be matched.
rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)346 virtual void rewrite(Operation *op, ArrayRef<Value> operands,
347 ConversionPatternRewriter &rewriter) const {
348 llvm_unreachable("unimplemented rewrite");
349 }
350
351 /// Hook for derived classes to implement combined matching and rewriting.
352 virtual LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)353 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
354 ConversionPatternRewriter &rewriter) const {
355 if (failed(match(op)))
356 return failure();
357 rewrite(op, operands, rewriter);
358 return success();
359 }
360
361 /// Attempt to match and rewrite the IR root at the specified operation.
362 LogicalResult matchAndRewrite(Operation *op,
363 PatternRewriter &rewriter) const final;
364
365 /// Return the type converter held by this pattern, or nullptr if the pattern
366 /// does not require type conversion.
getTypeConverter()367 TypeConverter *getTypeConverter() const { return typeConverter; }
368
369 template <typename ConverterTy>
370 std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
371 ConverterTy *>
getTypeConverter()372 getTypeConverter() const {
373 return static_cast<ConverterTy *>(typeConverter);
374 }
375
376 protected:
377 /// See `RewritePattern::RewritePattern` for information on the other
378 /// available constructors.
379 using RewritePattern::RewritePattern;
380 /// Construct a conversion pattern with the given converter, and forward the
381 /// remaining arguments to RewritePattern.
382 template <typename... Args>
ConversionPattern(TypeConverter & typeConverter,Args &&...args)383 ConversionPattern(TypeConverter &typeConverter, Args &&...args)
384 : RewritePattern(std::forward<Args>(args)...),
385 typeConverter(&typeConverter) {}
386
387 protected:
388 /// An optional type converter for use by this pattern.
389 TypeConverter *typeConverter = nullptr;
390
391 private:
392 using RewritePattern::rewrite;
393 };
394
395 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
396 /// matching and rewriting against an instance of a derived operation class as
397 /// opposed to a raw Operation.
398 template <typename SourceOp>
399 class OpConversionPattern : public ConversionPattern {
400 public:
401 using OpAdaptor = typename SourceOp::Adaptor;
402
403 OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
ConversionPattern(SourceOp::getOperationName (),benefit,context)404 : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
405 OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
406 PatternBenefit benefit = 1)
ConversionPattern(typeConverter,SourceOp::getOperationName (),benefit,context)407 : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
408 context) {}
409
410 /// Wrappers around the ConversionPattern methods that pass the derived op
411 /// type.
match(Operation * op)412 LogicalResult match(Operation *op) const final {
413 return match(cast<SourceOp>(op));
414 }
rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)415 void rewrite(Operation *op, ArrayRef<Value> operands,
416 ConversionPatternRewriter &rewriter) const final {
417 rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
418 rewriter);
419 }
420 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)421 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
422 ConversionPatternRewriter &rewriter) const final {
423 return matchAndRewrite(cast<SourceOp>(op),
424 OpAdaptor(operands, op->getAttrDictionary()),
425 rewriter);
426 }
427
428 /// Rewrite and Match methods that operate on the SourceOp type. These must be
429 /// overridden by the derived pattern class.
match(SourceOp op)430 virtual LogicalResult match(SourceOp op) const {
431 llvm_unreachable("must override match or matchAndRewrite");
432 }
rewrite(SourceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter)433 virtual void rewrite(SourceOp op, OpAdaptor adaptor,
434 ConversionPatternRewriter &rewriter) const {
435 llvm_unreachable("must override matchAndRewrite or a rewrite method");
436 }
437 virtual LogicalResult
matchAndRewrite(SourceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter)438 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
439 ConversionPatternRewriter &rewriter) const {
440 if (failed(match(op)))
441 return failure();
442 rewrite(op, adaptor, rewriter);
443 return success();
444 }
445
446 private:
447 using ConversionPattern::matchAndRewrite;
448 };
449
450 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
451 /// allows for matching and rewriting against an instance of an OpInterface
452 /// class as opposed to a raw Operation.
453 template <typename SourceOp>
454 class OpInterfaceConversionPattern : public ConversionPattern {
455 public:
456 OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
ConversionPattern(Pattern::MatchInterfaceOpTypeTag (),SourceOp::getInterfaceID (),benefit,context)457 : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
458 SourceOp::getInterfaceID(), benefit, context) {}
459 OpInterfaceConversionPattern(TypeConverter &typeConverter,
460 MLIRContext *context, PatternBenefit benefit = 1)
ConversionPattern(typeConverter,Pattern::MatchInterfaceOpTypeTag (),SourceOp::getInterfaceID (),benefit,context)461 : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
462 SourceOp::getInterfaceID(), benefit, context) {}
463
464 /// Wrappers around the ConversionPattern methods that pass the derived op
465 /// type.
rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)466 void rewrite(Operation *op, ArrayRef<Value> operands,
467 ConversionPatternRewriter &rewriter) const final {
468 rewrite(cast<SourceOp>(op), operands, rewriter);
469 }
470 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)471 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
472 ConversionPatternRewriter &rewriter) const final {
473 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
474 }
475
476 /// Rewrite and Match methods that operate on the SourceOp type. These must be
477 /// overridden by the derived pattern class.
rewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)478 virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
479 ConversionPatternRewriter &rewriter) const {
480 llvm_unreachable("must override matchAndRewrite or a rewrite method");
481 }
482 virtual LogicalResult
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)483 matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
484 ConversionPatternRewriter &rewriter) const {
485 if (failed(match(op)))
486 return failure();
487 rewrite(op, operands, rewriter);
488 return success();
489 }
490
491 private:
492 using ConversionPattern::matchAndRewrite;
493 };
494
495 /// Add a pattern to the given pattern list to convert the signature of a
496 /// FunctionOpInterface op with the given type converter. This only supports
497 /// ops which use FunctionType to represent their type.
498 void populateFunctionOpInterfaceTypeConversionPattern(
499 StringRef functionLikeOpName, RewritePatternSet &patterns,
500 TypeConverter &converter);
501
502 template <typename FuncOpT>
populateFunctionOpInterfaceTypeConversionPattern(RewritePatternSet & patterns,TypeConverter & converter)503 void populateFunctionOpInterfaceTypeConversionPattern(
504 RewritePatternSet &patterns, TypeConverter &converter) {
505 populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(),
506 patterns, converter);
507 }
508
509 //===----------------------------------------------------------------------===//
510 // Conversion PatternRewriter
511 //===----------------------------------------------------------------------===//
512
513 namespace detail {
514 struct ConversionPatternRewriterImpl;
515 } // namespace detail
516
517 /// This class implements a pattern rewriter for use with ConversionPatterns. It
518 /// extends the base PatternRewriter and provides special conversion specific
519 /// hooks.
520 class ConversionPatternRewriter final : public PatternRewriter {
521 public:
522 explicit ConversionPatternRewriter(MLIRContext *ctx);
523 ~ConversionPatternRewriter() override;
524
525 /// Apply a signature conversion to the entry block of the given region. This
526 /// replaces the entry block with a new block containing the updated
527 /// signature. The new entry block to the region is returned for convenience.
528 ///
529 /// If provided, `converter` will be used for any materializations.
530 Block *
531 applySignatureConversion(Region *region,
532 TypeConverter::SignatureConversion &conversion,
533 TypeConverter *converter = nullptr);
534
535 /// Convert the types of block arguments within the given region. This
536 /// replaces each block with a new block containing the updated signature. The
537 /// entry block may have a special conversion if `entryConversion` is
538 /// provided. On success, the new entry block to the region is returned for
539 /// convenience. Otherwise, failure is returned.
540 FailureOr<Block *> convertRegionTypes(
541 Region *region, TypeConverter &converter,
542 TypeConverter::SignatureConversion *entryConversion = nullptr);
543
544 /// Convert the types of block arguments within the given region except for
545 /// the entry region. This replaces each non-entry block with a new block
546 /// containing the updated signature.
547 ///
548 /// If special conversion behavior is needed for the non-entry blocks (for
549 /// example, we need to convert only a subset of a BB arguments), such
550 /// behavior can be specified in blockConversions.
551 LogicalResult convertNonEntryRegionTypes(
552 Region *region, TypeConverter &converter,
553 ArrayRef<TypeConverter::SignatureConversion> blockConversions);
554
555 /// Replace all the uses of the block argument `from` with value `to`.
556 void replaceUsesOfBlockArgument(BlockArgument from, Value to);
557
558 /// Return the converted value of 'key' with a type defined by the type
559 /// converter of the currently executing pattern. Return nullptr in the case
560 /// of failure, the remapped value otherwise.
561 Value getRemappedValue(Value key);
562
563 /// Return the converted values that replace 'keys' with types defined by the
564 /// type converter of the currently executing pattern. Returns failure if the
565 /// remap failed, success otherwise.
566 LogicalResult getRemappedValues(ValueRange keys,
567 SmallVectorImpl<Value> &results);
568
569 //===--------------------------------------------------------------------===//
570 // PatternRewriter Hooks
571 //===--------------------------------------------------------------------===//
572
573 /// PatternRewriter hook for replacing the results of an operation when the
574 /// given functor returns true.
575 void replaceOpWithIf(
576 Operation *op, ValueRange newValues, bool *allUsesReplaced,
577 llvm::unique_function<bool(OpOperand &) const> functor) override;
578
579 /// PatternRewriter hook for replacing the results of an operation.
580 void replaceOp(Operation *op, ValueRange newValues) override;
581 using PatternRewriter::replaceOp;
582
583 /// PatternRewriter hook for erasing a dead operation. The uses of this
584 /// operation *must* be made dead by the end of the conversion process,
585 /// otherwise an assert will be issued.
586 void eraseOp(Operation *op) override;
587
588 /// PatternRewriter hook for erase all operations in a block. This is not yet
589 /// implemented for dialect conversion.
590 void eraseBlock(Block *block) override;
591
592 /// PatternRewriter hook creating a new block.
593 void notifyBlockCreated(Block *block) override;
594
595 /// PatternRewriter hook for splitting a block into two parts.
596 Block *splitBlock(Block *block, Block::iterator before) override;
597
598 /// PatternRewriter hook for merging a block into another.
599 void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override;
600
601 /// PatternRewriter hook for moving blocks out of a region.
602 void inlineRegionBefore(Region ®ion, Region &parent,
603 Region::iterator before) override;
604 using PatternRewriter::inlineRegionBefore;
605
606 /// PatternRewriter hook for cloning blocks of one region into another. The
607 /// given region to clone *must* not have been modified as part of conversion
608 /// yet, i.e. it must be within an operation that is either in the process of
609 /// conversion, or has not yet been converted.
610 void cloneRegionBefore(Region ®ion, Region &parent,
611 Region::iterator before,
612 BlockAndValueMapping &mapping) override;
613 using PatternRewriter::cloneRegionBefore;
614
615 /// PatternRewriter hook for inserting a new operation.
616 void notifyOperationInserted(Operation *op) override;
617
618 /// PatternRewriter hook for updating the root operation in-place.
619 /// Note: These methods only track updates to the top-level operation itself,
620 /// and not nested regions. Updates to regions will still require notification
621 /// through other more specific hooks above.
622 void startRootUpdate(Operation *op) override;
623
624 /// PatternRewriter hook for updating the root operation in-place.
625 void finalizeRootUpdate(Operation *op) override;
626
627 /// PatternRewriter hook for updating the root operation in-place.
628 void cancelRootUpdate(Operation *op) override;
629
630 /// PatternRewriter hook for notifying match failure reasons.
631 LogicalResult
632 notifyMatchFailure(Location loc,
633 function_ref<void(Diagnostic &)> reasonCallback) override;
634 using PatternRewriter::notifyMatchFailure;
635
636 /// Return a reference to the internal implementation.
637 detail::ConversionPatternRewriterImpl &getImpl();
638
639 private:
640 std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
641 };
642
643 //===----------------------------------------------------------------------===//
644 // ConversionTarget
645 //===----------------------------------------------------------------------===//
646
647 /// This class describes a specific conversion target.
648 class ConversionTarget {
649 public:
650 /// This enumeration corresponds to the specific action to take when
651 /// considering an operation legal for this conversion target.
652 enum class LegalizationAction {
653 /// The target supports this operation.
654 Legal,
655
656 /// This operation has dynamic legalization constraints that must be checked
657 /// by the target.
658 Dynamic,
659
660 /// The target explicitly does not support this operation.
661 Illegal,
662 };
663
664 /// A structure containing additional information describing a specific legal
665 /// operation instance.
666 struct LegalOpDetails {
667 /// A flag that indicates if this operation is 'recursively' legal. This
668 /// means that if an operation is legal, either statically or dynamically,
669 /// all of the operations nested within are also considered legal.
670 bool isRecursivelyLegal = false;
671 };
672
673 /// The signature of the callback used to determine if an operation is
674 /// dynamically legal on the target.
675 using DynamicLegalityCallbackFn = std::function<Optional<bool>(Operation *)>;
676
ConversionTarget(MLIRContext & ctx)677 ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
678 virtual ~ConversionTarget() = default;
679
680 //===--------------------------------------------------------------------===//
681 // Legality Registration
682 //===--------------------------------------------------------------------===//
683
684 /// Register a legality action for the given operation.
685 void setOpAction(OperationName op, LegalizationAction action);
686 template <typename OpT>
setOpAction(LegalizationAction action)687 void setOpAction(LegalizationAction action) {
688 setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
689 }
690
691 /// Register the given operations as legal.
addLegalOp(OperationName op)692 void addLegalOp(OperationName op) {
693 setOpAction(op, LegalizationAction::Legal);
694 }
695 template <typename OpT>
addLegalOp()696 void addLegalOp() {
697 addLegalOp(OperationName(OpT::getOperationName(), &ctx));
698 }
699 template <typename OpT, typename OpT2, typename... OpTs>
addLegalOp()700 void addLegalOp() {
701 addLegalOp<OpT>();
702 addLegalOp<OpT2, OpTs...>();
703 }
704
705 /// Register the given operation as dynamically legal and set the dynamic
706 /// legalization callback to the one provided.
addDynamicallyLegalOp(OperationName op,const DynamicLegalityCallbackFn & callback)707 void addDynamicallyLegalOp(OperationName op,
708 const DynamicLegalityCallbackFn &callback) {
709 setOpAction(op, LegalizationAction::Dynamic);
710 setLegalityCallback(op, callback);
711 }
712 template <typename OpT>
addDynamicallyLegalOp(const DynamicLegalityCallbackFn & callback)713 void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
714 addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
715 callback);
716 }
717 template <typename OpT, typename OpT2, typename... OpTs>
addDynamicallyLegalOp(const DynamicLegalityCallbackFn & callback)718 void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
719 addDynamicallyLegalOp<OpT>(callback);
720 addDynamicallyLegalOp<OpT2, OpTs...>(callback);
721 }
722 template <typename OpT, class Callable>
723 typename std::enable_if<
724 !llvm::is_invocable<Callable, Operation *>::value>::type
addDynamicallyLegalOp(Callable && callback)725 addDynamicallyLegalOp(Callable &&callback) {
726 addDynamicallyLegalOp<OpT>(
727 [=](Operation *op) { return callback(cast<OpT>(op)); });
728 }
729
730 /// Register the given operation as illegal, i.e. this operation is known to
731 /// not be supported by this target.
addIllegalOp(OperationName op)732 void addIllegalOp(OperationName op) {
733 setOpAction(op, LegalizationAction::Illegal);
734 }
735 template <typename OpT>
addIllegalOp()736 void addIllegalOp() {
737 addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
738 }
739 template <typename OpT, typename OpT2, typename... OpTs>
addIllegalOp()740 void addIllegalOp() {
741 addIllegalOp<OpT>();
742 addIllegalOp<OpT2, OpTs...>();
743 }
744
745 /// Mark an operation, that *must* have either been set as `Legal` or
746 /// `DynamicallyLegal`, as being recursively legal. This means that in
747 /// addition to the operation itself, all of the operations nested within are
748 /// also considered legal. An optional dynamic legality callback may be
749 /// provided to mark subsets of legal instances as recursively legal.
750 void markOpRecursivelyLegal(OperationName name,
751 const DynamicLegalityCallbackFn &callback);
752 template <typename OpT>
753 void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
754 OperationName opName(OpT::getOperationName(), &ctx);
755 markOpRecursivelyLegal(opName, callback);
756 }
757 template <typename OpT, typename OpT2, typename... OpTs>
758 void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
759 markOpRecursivelyLegal<OpT>(callback);
760 markOpRecursivelyLegal<OpT2, OpTs...>(callback);
761 }
762 template <typename OpT, class Callable>
763 typename std::enable_if<
764 !llvm::is_invocable<Callable, Operation *>::value>::type
markOpRecursivelyLegal(Callable && callback)765 markOpRecursivelyLegal(Callable &&callback) {
766 markOpRecursivelyLegal<OpT>(
767 [=](Operation *op) { return callback(cast<OpT>(op)); });
768 }
769
770 /// Register a legality action for the given dialects.
771 void setDialectAction(ArrayRef<StringRef> dialectNames,
772 LegalizationAction action);
773
774 /// Register the operations of the given dialects as legal.
775 template <typename... Names>
addLegalDialect(StringRef name,Names...names)776 void addLegalDialect(StringRef name, Names... names) {
777 SmallVector<StringRef, 2> dialectNames({name, names...});
778 setDialectAction(dialectNames, LegalizationAction::Legal);
779 }
780 template <typename... Args>
addLegalDialect()781 void addLegalDialect() {
782 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
783 setDialectAction(dialectNames, LegalizationAction::Legal);
784 }
785
786 /// Register the operations of the given dialects as dynamically legal, i.e.
787 /// requiring custom handling by the callback.
788 template <typename... Names>
addDynamicallyLegalDialect(const DynamicLegalityCallbackFn & callback,StringRef name,Names...names)789 void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback,
790 StringRef name, Names... names) {
791 SmallVector<StringRef, 2> dialectNames({name, names...});
792 setDialectAction(dialectNames, LegalizationAction::Dynamic);
793 setLegalityCallback(dialectNames, callback);
794 }
795 template <typename... Args>
addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback)796 void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
797 addDynamicallyLegalDialect(std::move(callback),
798 Args::getDialectNamespace()...);
799 }
800
801 /// Register unknown operations as dynamically legal. For operations(and
802 /// dialects) that do not have a set legalization action, treat them as
803 /// dynamically legal and invoke the given callback.
markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn & fn)804 void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) {
805 setLegalityCallback(fn);
806 }
807
808 /// Register the operations of the given dialects as illegal, i.e.
809 /// operations of this dialect are not supported by the target.
810 template <typename... Names>
addIllegalDialect(StringRef name,Names...names)811 void addIllegalDialect(StringRef name, Names... names) {
812 SmallVector<StringRef, 2> dialectNames({name, names...});
813 setDialectAction(dialectNames, LegalizationAction::Illegal);
814 }
815 template <typename... Args>
addIllegalDialect()816 void addIllegalDialect() {
817 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
818 setDialectAction(dialectNames, LegalizationAction::Illegal);
819 }
820
821 //===--------------------------------------------------------------------===//
822 // Legality Querying
823 //===--------------------------------------------------------------------===//
824
825 /// Get the legality action for the given operation.
826 Optional<LegalizationAction> getOpAction(OperationName op) const;
827
828 /// If the given operation instance is legal on this target, a structure
829 /// containing legality information is returned. If the operation is not
830 /// legal, None is returned. Also returns None is operation legality wasn't
831 /// registered by user or dynamic legality callbacks returned None.
832 ///
833 /// Note: Legality is actually a 4-state: Legal(recursive=true),
834 /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
835 /// either as Legal or Illegal depending on context.
836 Optional<LegalOpDetails> isLegal(Operation *op) const;
837
838 /// Returns true is operation instance is illegal on this target. Returns
839 /// false if operation is legal, operation legality wasn't registered by user
840 /// or dynamic legality callbacks returned None.
841 bool isIllegal(Operation *op) const;
842
843 private:
844 /// Set the dynamic legality callback for the given operation.
845 void setLegalityCallback(OperationName name,
846 const DynamicLegalityCallbackFn &callback);
847
848 /// Set the dynamic legality callback for the given dialects.
849 void setLegalityCallback(ArrayRef<StringRef> dialects,
850 const DynamicLegalityCallbackFn &callback);
851
852 /// Set the dynamic legality callback for the unknown ops.
853 void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
854
855 /// The set of information that configures the legalization of an operation.
856 struct LegalizationInfo {
857 /// The legality action this operation was given.
858 LegalizationAction action = LegalizationAction::Illegal;
859
860 /// If some legal instances of this operation may also be recursively legal.
861 bool isRecursivelyLegal = false;
862
863 /// The legality callback if this operation is dynamically legal.
864 DynamicLegalityCallbackFn legalityFn;
865 };
866
867 /// Get the legalization information for the given operation.
868 Optional<LegalizationInfo> getOpInfo(OperationName op) const;
869
870 /// A deterministic mapping of operation name and its respective legality
871 /// information.
872 llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
873
874 /// A set of legality callbacks for given operation names that are used to
875 /// check if an operation instance is recursively legal.
876 DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
877
878 /// A deterministic mapping of dialect name to the specific legality action to
879 /// take.
880 llvm::StringMap<LegalizationAction> legalDialects;
881
882 /// A set of dynamic legality callbacks for given dialect names.
883 llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
884
885 /// An optional legality callback for unknown operations.
886 DynamicLegalityCallbackFn unknownLegalityFn;
887
888 /// The current context this target applies to.
889 MLIRContext &ctx;
890 };
891
892 //===----------------------------------------------------------------------===//
893 // Op Conversion Entry Points
894 //===----------------------------------------------------------------------===//
895
896 /// Below we define several entry points for operation conversion. It is
897 /// important to note that the patterns provided to the conversion framework may
898 /// have additional constraints. See the `PatternRewriter Hooks` section of the
899 /// ConversionPatternRewriter, to see what additional constraints are imposed on
900 /// the use of the PatternRewriter.
901
902 /// Apply a partial conversion on the given operations and all nested
903 /// operations. This method converts as many operations to the target as
904 /// possible, ignoring operations that failed to legalize. This method only
905 /// returns failure if there ops explicitly marked as illegal. If an
906 /// `unconvertedOps` set is provided, all operations that are found not to be
907 /// legalizable to the given `target` are placed within that set. (Note that if
908 /// there is an op explicitly marked as illegal, the conversion terminates and
909 /// the `unconvertedOps` set will not necessarily be complete.)
910 LogicalResult
911 applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
912 const FrozenRewritePatternSet &patterns,
913 DenseSet<Operation *> *unconvertedOps = nullptr);
914 LogicalResult
915 applyPartialConversion(Operation *op, ConversionTarget &target,
916 const FrozenRewritePatternSet &patterns,
917 DenseSet<Operation *> *unconvertedOps = nullptr);
918
919 /// Apply a complete conversion on the given operations, and all nested
920 /// operations. This method returns failure if the conversion of any operation
921 /// fails, or if there are unreachable blocks in any of the regions nested
922 /// within 'ops'.
923 LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
924 ConversionTarget &target,
925 const FrozenRewritePatternSet &patterns);
926 LogicalResult applyFullConversion(Operation *op, ConversionTarget &target,
927 const FrozenRewritePatternSet &patterns);
928
929 /// Apply an analysis conversion on the given operations, and all nested
930 /// operations. This method analyzes which operations would be successfully
931 /// converted to the target if a conversion was applied. All operations that
932 /// were found to be legalizable to the given 'target' are placed within the
933 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
934 /// operations on success and only pre-existing operations are added to the set.
935 /// This method only returns failure if there are unreachable blocks in any of
936 /// the regions nested within 'ops'. There's an additional argument
937 /// `notifyCallback` which is used for collecting match failure diagnostics
938 /// generated during the conversion. Diagnostics are only reported to this
939 /// callback may only be available in debug mode.
940 LogicalResult applyAnalysisConversion(
941 ArrayRef<Operation *> ops, ConversionTarget &target,
942 const FrozenRewritePatternSet &patterns,
943 DenseSet<Operation *> &convertedOps,
944 function_ref<void(Diagnostic &)> notifyCallback = nullptr);
945 LogicalResult applyAnalysisConversion(
946 Operation *op, ConversionTarget &target,
947 const FrozenRewritePatternSet &patterns,
948 DenseSet<Operation *> &convertedOps,
949 function_ref<void(Diagnostic &)> notifyCallback = nullptr);
950 } // namespace mlir
951
952 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
953