1*cf6a7c19SMahesh Ravishankar //===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===// 2*cf6a7c19SMahesh Ravishankar // 3*cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information. 5*cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*cf6a7c19SMahesh Ravishankar // 7*cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 8*cf6a7c19SMahesh Ravishankar // 9*cf6a7c19SMahesh Ravishankar // This file implements a pass for testing tiling operations using 10*cf6a7c19SMahesh Ravishankar // `TilingInterface`. 11*cf6a7c19SMahesh Ravishankar // 12*cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 13*cf6a7c19SMahesh Ravishankar 14*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 15*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h" 16*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" 17*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/MemRef/IR/MemRef.h" 19*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/SCF.h" 20*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/TileUsingInterface.h" 21*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h" 22*cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h" 23*cf6a7c19SMahesh Ravishankar #include "mlir/Pass/Pass.h" 24*cf6a7c19SMahesh Ravishankar #include "mlir/Pass/PassManager.h" 25*cf6a7c19SMahesh Ravishankar #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26*cf6a7c19SMahesh Ravishankar #include "llvm/ADT/TypeSwitch.h" 27*cf6a7c19SMahesh Ravishankar 28*cf6a7c19SMahesh Ravishankar using namespace mlir; 29*cf6a7c19SMahesh Ravishankar 30*cf6a7c19SMahesh Ravishankar namespace { 31*cf6a7c19SMahesh Ravishankar 32*cf6a7c19SMahesh Ravishankar /// Construct a generic pattern applied to all TilingInterface ops that verify 33*cf6a7c19SMahesh Ravishankar /// `filter`. 34*cf6a7c19SMahesh Ravishankar struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp { 35*cf6a7c19SMahesh Ravishankar TestTileUsingSCFForOpWithFilter(MLIRContext *context, 36*cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions options, 37*cf6a7c19SMahesh Ravishankar linalg::LinalgTransformationFilter filter = 38*cf6a7c19SMahesh Ravishankar linalg::LinalgTransformationFilter(), 39*cf6a7c19SMahesh Ravishankar PatternBenefit benefit = 1) 40*cf6a7c19SMahesh Ravishankar : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {} 41*cf6a7c19SMahesh Ravishankar 42*cf6a7c19SMahesh Ravishankar /// Construct a generic pattern applied to `opName`. 43*cf6a7c19SMahesh Ravishankar TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context, 44*cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions options, 45*cf6a7c19SMahesh Ravishankar linalg::LinalgTransformationFilter filter = 46*cf6a7c19SMahesh Ravishankar linalg::LinalgTransformationFilter(), 47*cf6a7c19SMahesh Ravishankar PatternBenefit benefit = 1) 48*cf6a7c19SMahesh Ravishankar : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {} 49*cf6a7c19SMahesh Ravishankar 50*cf6a7c19SMahesh Ravishankar LogicalResult matchAndRewrite(TilingInterface op, 51*cf6a7c19SMahesh Ravishankar PatternRewriter &rewriter) const override { 52*cf6a7c19SMahesh Ravishankar if (failed(filter.checkAndNotify(rewriter, op))) 53*cf6a7c19SMahesh Ravishankar return failure(); 54*cf6a7c19SMahesh Ravishankar 55*cf6a7c19SMahesh Ravishankar FailureOr<scf::SCFTilingResult> tilingResult = 56*cf6a7c19SMahesh Ravishankar returningMatchAndRewrite(op, rewriter); 57*cf6a7c19SMahesh Ravishankar if (failed(tilingResult)) { 58*cf6a7c19SMahesh Ravishankar return failure(); 59*cf6a7c19SMahesh Ravishankar } 60*cf6a7c19SMahesh Ravishankar filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp); 61*cf6a7c19SMahesh Ravishankar return success(); 62*cf6a7c19SMahesh Ravishankar } 63*cf6a7c19SMahesh Ravishankar 64*cf6a7c19SMahesh Ravishankar private: 65*cf6a7c19SMahesh Ravishankar linalg::LinalgTransformationFilter filter; 66*cf6a7c19SMahesh Ravishankar }; 67*cf6a7c19SMahesh Ravishankar 68*cf6a7c19SMahesh Ravishankar struct TestTilingInterfacePass 69*cf6a7c19SMahesh Ravishankar : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> { 70*cf6a7c19SMahesh Ravishankar MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass) 71*cf6a7c19SMahesh Ravishankar 72*cf6a7c19SMahesh Ravishankar TestTilingInterfacePass() = default; 73*cf6a7c19SMahesh Ravishankar TestTilingInterfacePass(const TestTilingInterfacePass &pass) 74*cf6a7c19SMahesh Ravishankar : PassWrapper(pass) {} 75*cf6a7c19SMahesh Ravishankar void getDependentDialects(DialectRegistry ®istry) const override { 76*cf6a7c19SMahesh Ravishankar registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect, 77*cf6a7c19SMahesh Ravishankar tensor::TensorDialect>(); 78*cf6a7c19SMahesh Ravishankar linalg::registerTilingInterfaceExternalModels(registry); 79*cf6a7c19SMahesh Ravishankar } 80*cf6a7c19SMahesh Ravishankar StringRef getArgument() const final { return "test-tiling-interface"; } 81*cf6a7c19SMahesh Ravishankar StringRef getDescription() const final { 82*cf6a7c19SMahesh Ravishankar return "Test tiling using TilingInterface"; 83*cf6a7c19SMahesh Ravishankar } 84*cf6a7c19SMahesh Ravishankar 85*cf6a7c19SMahesh Ravishankar void runOnOperation() override; 86*cf6a7c19SMahesh Ravishankar }; 87*cf6a7c19SMahesh Ravishankar } // namespace 88*cf6a7c19SMahesh Ravishankar 89*cf6a7c19SMahesh Ravishankar static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) { 90*cf6a7c19SMahesh Ravishankar auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes, 91*cf6a7c19SMahesh Ravishankar StringRef filterName) { 92*cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions tilingOptions; 93*cf6a7c19SMahesh Ravishankar tilingOptions.setTileSizes(tileSizes); 94*cf6a7c19SMahesh Ravishankar linalg::LinalgTransformationFilter filter( 95*cf6a7c19SMahesh Ravishankar StringAttr::get(context, filterName), 96*cf6a7c19SMahesh Ravishankar StringAttr::get(context, "tiled")); 97*cf6a7c19SMahesh Ravishankar patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions, 98*cf6a7c19SMahesh Ravishankar filter); 99*cf6a7c19SMahesh Ravishankar }; 100*cf6a7c19SMahesh Ravishankar // 1. Tiling M and N dims of `linalg.matmul` on tensors. 101*cf6a7c19SMahesh Ravishankar addPatternForTiling({10, 20}, "simple_gemm"); 102*cf6a7c19SMahesh Ravishankar // 2. Tiling M, N and K of `linalg.matmul` on buffers. 103*cf6a7c19SMahesh Ravishankar addPatternForTiling({10, 20, 30}, "simple_gemm_memref"); 104*cf6a7c19SMahesh Ravishankar // 3. Tiling 3D parallel generic op which implements a transpose 105*cf6a7c19SMahesh Ravishankar addPatternForTiling({10, 0, 20}, "parallel_generic_transpose"); 106*cf6a7c19SMahesh Ravishankar // 4. Tiling 2D conv op. 107*cf6a7c19SMahesh Ravishankar addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv"); 108*cf6a7c19SMahesh Ravishankar } 109*cf6a7c19SMahesh Ravishankar 110*cf6a7c19SMahesh Ravishankar void TestTilingInterfacePass::runOnOperation() { 111*cf6a7c19SMahesh Ravishankar MLIRContext *context = &getContext(); 112*cf6a7c19SMahesh Ravishankar 113*cf6a7c19SMahesh Ravishankar RewritePatternSet tilingPatterns(context); 114*cf6a7c19SMahesh Ravishankar addTestPatterns(context, tilingPatterns); 115*cf6a7c19SMahesh Ravishankar if (failed(applyPatternsAndFoldGreedily(getOperation(), 116*cf6a7c19SMahesh Ravishankar std::move(tilingPatterns)))) 117*cf6a7c19SMahesh Ravishankar return signalPassFailure(); 118*cf6a7c19SMahesh Ravishankar } 119*cf6a7c19SMahesh Ravishankar 120*cf6a7c19SMahesh Ravishankar namespace mlir { 121*cf6a7c19SMahesh Ravishankar namespace test { 122*cf6a7c19SMahesh Ravishankar void registerTestTilingInterface() { 123*cf6a7c19SMahesh Ravishankar PassRegistration<TestTilingInterfacePass>(); 124*cf6a7c19SMahesh Ravishankar } 125*cf6a7c19SMahesh Ravishankar } // namespace test 126*cf6a7c19SMahesh Ravishankar } // namespace mlir 127