1 //===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing tiling operations using
10 // `TilingInterface`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/SCF/TileUsingInterface.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Interfaces/TilingInterface.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 /// Construct a generic pattern applied to all TilingInterface ops that verify
33 /// `filter`.
34 struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
35   TestTileUsingSCFForOpWithFilter(MLIRContext *context,
36                                   scf::SCFTilingOptions options,
37                                   linalg::LinalgTransformationFilter filter =
38                                       linalg::LinalgTransformationFilter(),
39                                   PatternBenefit benefit = 1)
40       : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
41 
42   /// Construct a generic pattern applied to `opName`.
43   TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context,
44                                   scf::SCFTilingOptions options,
45                                   linalg::LinalgTransformationFilter filter =
46                                       linalg::LinalgTransformationFilter(),
47                                   PatternBenefit benefit = 1)
48       : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
49 
50   LogicalResult matchAndRewrite(TilingInterface op,
51                                 PatternRewriter &rewriter) const override {
52     if (failed(filter.checkAndNotify(rewriter, op)))
53       return failure();
54 
55     FailureOr<scf::SCFTilingResult> tilingResult =
56         returningMatchAndRewrite(op, rewriter);
57     if (failed(tilingResult)) {
58       return failure();
59     }
60     filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
61     return success();
62   }
63 
64 private:
65   linalg::LinalgTransformationFilter filter;
66 };
67 
68 struct TestTilingInterfacePass
69     : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
70   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
71 
72   TestTilingInterfacePass() = default;
73   TestTilingInterfacePass(const TestTilingInterfacePass &pass)
74       : PassWrapper(pass) {}
75   void getDependentDialects(DialectRegistry &registry) const override {
76     registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
77                     tensor::TensorDialect>();
78     linalg::registerTilingInterfaceExternalModels(registry);
79   }
80   StringRef getArgument() const final { return "test-tiling-interface"; }
81   StringRef getDescription() const final {
82     return "Test tiling using TilingInterface";
83   }
84 
85   void runOnOperation() override;
86 };
87 } // namespace
88 
89 static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) {
90   auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes,
91                                  StringRef filterName) {
92     scf::SCFTilingOptions tilingOptions;
93     tilingOptions.setTileSizes(tileSizes);
94     linalg::LinalgTransformationFilter filter(
95         StringAttr::get(context, filterName),
96         StringAttr::get(context, "tiled"));
97     patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions,
98                                                   filter);
99   };
100   // 1. Tiling M and N dims of `linalg.matmul` on tensors.
101   addPatternForTiling({10, 20}, "simple_gemm");
102   // 2. Tiling M, N and K of `linalg.matmul` on buffers.
103   addPatternForTiling({10, 20, 30}, "simple_gemm_memref");
104   // 3. Tiling 3D parallel generic op which implements a transpose
105   addPatternForTiling({10, 0, 20}, "parallel_generic_transpose");
106   // 4. Tiling 2D conv op.
107   addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv");
108 }
109 
110 void TestTilingInterfacePass::runOnOperation() {
111   MLIRContext *context = &getContext();
112 
113   RewritePatternSet tilingPatterns(context);
114   addTestPatterns(context, tilingPatterns);
115   if (failed(applyPatternsAndFoldGreedily(getOperation(),
116                                           std::move(tilingPatterns))))
117     return signalPassFailure();
118 }
119 
120 namespace mlir {
121 namespace test {
122 void registerTestTilingInterface() {
123   PassRegistration<TestTilingInterfacePass>();
124 }
125 } // namespace test
126 } // namespace mlir
127