1c484c7ddSChia-hung Duan //===- ReducePatternInterface.h - Collecting Reduce Patterns ----*- C++ -*-===//
2c484c7ddSChia-hung Duan //
3c484c7ddSChia-hung Duan // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c484c7ddSChia-hung Duan // See https://llvm.org/LICENSE.txt for license information.
5c484c7ddSChia-hung Duan // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c484c7ddSChia-hung Duan //
7c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
8c484c7ddSChia-hung Duan 
9c484c7ddSChia-hung Duan #ifndef MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
10c484c7ddSChia-hung Duan #define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
11c484c7ddSChia-hung Duan 
12c484c7ddSChia-hung Duan #include "mlir/IR/DialectInterface.h"
13c484c7ddSChia-hung Duan 
14c484c7ddSChia-hung Duan namespace mlir {
15c484c7ddSChia-hung Duan 
16c484c7ddSChia-hung Duan class RewritePatternSet;
17c484c7ddSChia-hung Duan 
18c484c7ddSChia-hung Duan /// This is used to report the reduction patterns for a Dialect. While using
19c484c7ddSChia-hung Duan /// mlir-reduce to reduce a module, we may want to transform certain cases into
20c484c7ddSChia-hung Duan /// simpler forms by applying certain rewrite patterns. Implement the
21c484c7ddSChia-hung Duan /// `populateReductionPatterns` to report those patterns by adding them to the
22c484c7ddSChia-hung Duan /// RewritePatternSet.
23c484c7ddSChia-hung Duan ///
24c484c7ddSChia-hung Duan /// Example:
25c484c7ddSChia-hung Duan ///   MyDialectReductionPattern::populateReductionPatterns(
26c484c7ddSChia-hung Duan ///       RewritePatternSet &patterns) {
27c484c7ddSChia-hung Duan ///       patterns.add<TensorOpReduction>(patterns.getContext());
28c484c7ddSChia-hung Duan ///   }
29c484c7ddSChia-hung Duan ///
30c484c7ddSChia-hung Duan /// For DRR, mlir-tblgen will generate a helper function
31c484c7ddSChia-hung Duan /// `populateWithGenerated` which has the same signature therefore you can
32c484c7ddSChia-hung Duan /// delegate to the helper function as well.
33c484c7ddSChia-hung Duan ///
34c484c7ddSChia-hung Duan /// Example:
35c484c7ddSChia-hung Duan ///   MyDialectReductionPattern::populateReductionPatterns(
36c484c7ddSChia-hung Duan ///       RewritePatternSet &patterns) {
37c484c7ddSChia-hung Duan ///       // Include the autogen file somewhere above.
38c484c7ddSChia-hung Duan ///       populateWithGenerated(patterns);
39c484c7ddSChia-hung Duan ///   }
40c484c7ddSChia-hung Duan class DialectReductionPatternInterface
41c484c7ddSChia-hung Duan     : public DialectInterface::Base<DialectReductionPatternInterface> {
42c484c7ddSChia-hung Duan public:
43c484c7ddSChia-hung Duan   /// Patterns provided here are intended to transform operations from a complex
44c484c7ddSChia-hung Duan   /// form to a simpler form, without breaking the semantics of the program
45c484c7ddSChia-hung Duan   /// being reduced. For example, you may want to replace the
46c484c7ddSChia-hung Duan   /// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
47c484c7ddSChia-hung Duan   /// replacing an operation with a constant.
48c484c7ddSChia-hung Duan   virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
49c484c7ddSChia-hung Duan 
50c484c7ddSChia-hung Duan protected:
DialectReductionPatternInterface(Dialect * dialect)51c484c7ddSChia-hung Duan   DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
52c484c7ddSChia-hung Duan };
53c484c7ddSChia-hung Duan 
54*be0a7e9fSMehdi Amini } // namespace mlir
55c484c7ddSChia-hung Duan 
56c484c7ddSChia-hung Duan #endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
57