1 //===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass for testing tiling operations using 10 // `TilingInterface`. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Interfaces/TilingInterface.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 28 using namespace mlir; 29 30 namespace { 31 32 /// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using 33 /// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while 34 /// using a `filter` to avoid recursive application. 35 struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp { 36 TestTileUsingSCFForOpWithFilter(MLIRContext *context, 37 scf::SCFTilingOptions options, 38 linalg::LinalgTransformationFilter filter = 39 linalg::LinalgTransformationFilter(), 40 PatternBenefit benefit = 1) 41 : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {} 42 43 /// Construct a generic pattern applied to `opName`. 44 TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context, 45 scf::SCFTilingOptions options, 46 linalg::LinalgTransformationFilter filter = 47 linalg::LinalgTransformationFilter(), 48 PatternBenefit benefit = 1) 49 : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {} 50 51 LogicalResult matchAndRewrite(TilingInterface op, 52 PatternRewriter &rewriter) const override { 53 if (failed(filter.checkAndNotify(rewriter, op))) 54 return failure(); 55 56 auto tilingResult = returningMatchAndRewrite(op, rewriter); 57 if (failed(tilingResult)) { 58 return failure(); 59 } 60 filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp); 61 return success(); 62 } 63 64 private: 65 linalg::LinalgTransformationFilter filter; 66 }; 67 68 /// Pattern for testing `TileConsumerAndFUseProducersUsingSCFForOp` pattern 69 /// (that tiles and fuses operations using the `TilingInterface` with `scf.for` 70 /// ops for iterating over the tiles) while using a `filter` to avoid recursive 71 /// application. 72 struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter 73 : public scf::TileConsumerAndFuseProducersUsingSCFForOp { 74 TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter( 75 MLIRContext *context, scf::SCFTilingOptions options, 76 linalg::LinalgTransformationFilter filter = 77 linalg::LinalgTransformationFilter(), 78 PatternBenefit benefit = 1) 79 : scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options, 80 benefit), 81 filter(filter) {} 82 83 /// Construct a generic pattern applied to `opName`. 84 TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter( 85 StringRef opName, MLIRContext *context, scf::SCFTilingOptions options, 86 linalg::LinalgTransformationFilter filter = 87 linalg::LinalgTransformationFilter(), 88 PatternBenefit benefit = 1) 89 : scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options, 90 benefit), 91 filter(filter) {} 92 93 LogicalResult matchAndRewrite(TilingInterface op, 94 PatternRewriter &rewriter) const override { 95 if (failed(filter.checkAndNotify(rewriter, op))) 96 return failure(); 97 98 auto tileAndFuseResult = returningMatchAndRewrite(op, rewriter); 99 if (failed(tileAndFuseResult)) { 100 return failure(); 101 } 102 filter.replaceLinalgTransformationFilter( 103 rewriter, tileAndFuseResult->tiledAndFusedOps.front()); 104 return success(); 105 } 106 107 private: 108 linalg::LinalgTransformationFilter filter; 109 }; 110 111 /// Test pass for testing the use of `TilingInterface`. 112 struct TestTilingInterfacePass 113 : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> { 114 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass) 115 116 TestTilingInterfacePass() = default; 117 TestTilingInterfacePass(const TestTilingInterfacePass &pass) 118 : PassWrapper(pass) {} 119 void getDependentDialects(DialectRegistry ®istry) const override { 120 registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect, 121 tensor::TensorDialect>(); 122 linalg::registerTilingInterfaceExternalModels(registry); 123 } 124 StringRef getArgument() const final { return "test-tiling-interface"; } 125 StringRef getDescription() const final { 126 return "Test tiling using TilingInterface"; 127 } 128 129 Option<bool> testTiling{ 130 *this, "tile-using-scf-for", 131 llvm::cl::desc( 132 "Test tiling using TilingInterface with scf.for operations"), 133 llvm::cl::init(false)}; 134 135 Option<bool> testTileConsumerAndFuseProducer{ 136 *this, "tile-consumer-and-fuse-producer-using-scf-for", 137 llvm::cl::desc("Test tile and fuse transformation using TilingInterface " 138 "with scf.for operations"), 139 llvm::cl::init(false)}; 140 141 void runOnOperation() override; 142 143 private: 144 void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns); 145 }; 146 } // namespace 147 148 template <class Pattern> 149 static void 150 addPatternForTiling(MLIRContext *context, RewritePatternSet &patterns, 151 StringRef filterName, ArrayRef<int64_t> tileSizes, 152 ArrayRef<unsigned> interchange = {}) { 153 scf::SCFTilingOptions tilingOptions; 154 tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); 155 linalg::LinalgTransformationFilter filter( 156 StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); 157 patterns.add<Pattern>(context, tilingOptions, filter); 158 } 159 160 void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, 161 RewritePatternSet &patterns) { 162 if (testTiling) { 163 // 1. Tiling M and N dims of `linalg.matmul` on tensors. 164 addPatternForTiling<TestTileUsingSCFForOpWithFilter>( 165 context, patterns, "simple_gemm", {10, 20}); 166 // 2. Tiling M, N and K of `linalg.matmul` on buffers. 167 addPatternForTiling<TestTileUsingSCFForOpWithFilter>( 168 context, patterns, "simple_gemm_memref", {10, 20, 30}); 169 // 3. Tiling 3D parallel generic op which implements a transpose 170 addPatternForTiling<TestTileUsingSCFForOpWithFilter>( 171 context, patterns, "parallel_generic_transpose", {10, 0, 20}); 172 // 4. Tiling 2D conv op. 173 addPatternForTiling<TestTileUsingSCFForOpWithFilter>( 174 context, patterns, "simple_conv", {0, 0, 0, 0, 10, 20, 30}); 175 // 5. Tiling a simple op with `linalg.index` inside. 176 addPatternForTiling<TestTileUsingSCFForOpWithFilter>( 177 context, patterns, "indexed_semantics", {10, 20}); 178 // 6. Tiling + interchange of an operation 179 addPatternForTiling<TestTileUsingSCFForOpWithFilter>( 180 context, patterns, "gemm_interchange", {10, 20, 30}, {1, 2, 0}); 181 return; 182 } 183 if (testTileConsumerAndFuseProducer) { 184 // 1. Tile and fuse of gemm with bias-add operation. 185 addPatternForTiling< 186 TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( 187 context, patterns, "fusion", {10, 20}); 188 addPatternForTiling< 189 TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( 190 context, patterns, "gemm_fusion", {10}); 191 addPatternForTiling< 192 TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( 193 context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0}); 194 addPatternForTiling< 195 TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( 196 context, patterns, "gemm_plus_gemm_fusion", {10, 20}); 197 addPatternForTiling< 198 TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( 199 context, patterns, "gemm_sequence_fusion", {10}); 200 return; 201 } 202 } 203 204 void TestTilingInterfacePass::runOnOperation() { 205 MLIRContext *context = &getContext(); 206 207 RewritePatternSet tilingPatterns(context); 208 addTestPatterns(context, tilingPatterns); 209 if (failed(applyPatternsAndFoldGreedily(getOperation(), 210 std::move(tilingPatterns)))) 211 return signalPassFailure(); 212 } 213 214 namespace mlir { 215 namespace test { 216 void registerTestTilingInterface() { 217 PassRegistration<TestTilingInterfacePass>(); 218 } 219 } // namespace test 220 } // namespace mlir 221