1 //===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass for testing tiling operations using 10 // `TilingInterface`. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/SCF/TileUsingInterface.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Interfaces/TilingInterface.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 28 using namespace mlir; 29 30 namespace { 31 32 /// Construct a generic pattern applied to all TilingInterface ops that verify 33 /// `filter`. 34 struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp { 35 TestTileUsingSCFForOpWithFilter(MLIRContext *context, 36 scf::SCFTilingOptions options, 37 linalg::LinalgTransformationFilter filter = 38 linalg::LinalgTransformationFilter(), 39 PatternBenefit benefit = 1) 40 : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {} 41 42 /// Construct a generic pattern applied to `opName`. 43 TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context, 44 scf::SCFTilingOptions options, 45 linalg::LinalgTransformationFilter filter = 46 linalg::LinalgTransformationFilter(), 47 PatternBenefit benefit = 1) 48 : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {} 49 50 LogicalResult matchAndRewrite(TilingInterface op, 51 PatternRewriter &rewriter) const override { 52 if (failed(filter.checkAndNotify(rewriter, op))) 53 return failure(); 54 55 FailureOr<scf::SCFTilingResult> tilingResult = 56 returningMatchAndRewrite(op, rewriter); 57 if (failed(tilingResult)) { 58 return failure(); 59 } 60 filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp); 61 return success(); 62 } 63 64 private: 65 linalg::LinalgTransformationFilter filter; 66 }; 67 68 struct TestTilingInterfacePass 69 : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> { 70 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass) 71 72 TestTilingInterfacePass() = default; 73 TestTilingInterfacePass(const TestTilingInterfacePass &pass) 74 : PassWrapper(pass) {} 75 void getDependentDialects(DialectRegistry ®istry) const override { 76 registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect, 77 tensor::TensorDialect>(); 78 linalg::registerTilingInterfaceExternalModels(registry); 79 } 80 StringRef getArgument() const final { return "test-tiling-interface"; } 81 StringRef getDescription() const final { 82 return "Test tiling using TilingInterface"; 83 } 84 85 void runOnOperation() override; 86 }; 87 } // namespace 88 89 static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) { 90 auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes, 91 StringRef filterName) { 92 scf::SCFTilingOptions tilingOptions; 93 tilingOptions.setTileSizes(tileSizes); 94 linalg::LinalgTransformationFilter filter( 95 StringAttr::get(context, filterName), 96 StringAttr::get(context, "tiled")); 97 patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions, 98 filter); 99 }; 100 // 1. Tiling M and N dims of `linalg.matmul` on tensors. 101 addPatternForTiling({10, 20}, "simple_gemm"); 102 // 2. Tiling M, N and K of `linalg.matmul` on buffers. 103 addPatternForTiling({10, 20, 30}, "simple_gemm_memref"); 104 // 3. Tiling 3D parallel generic op which implements a transpose 105 addPatternForTiling({10, 0, 20}, "parallel_generic_transpose"); 106 // 4. Tiling 2D conv op. 107 addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv"); 108 } 109 110 void TestTilingInterfacePass::runOnOperation() { 111 MLIRContext *context = &getContext(); 112 113 RewritePatternSet tilingPatterns(context); 114 addTestPatterns(context, tilingPatterns); 115 if (failed(applyPatternsAndFoldGreedily(getOperation(), 116 std::move(tilingPatterns)))) 117 return signalPassFailure(); 118 } 119 120 namespace mlir { 121 namespace test { 122 void registerTestTilingInterface() { 123 PassRegistration<TestTilingInterfacePass>(); 124 } 125 } // namespace test 126 } // namespace mlir 127