1cf6a7c19SMahesh Ravishankar //===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===// 2cf6a7c19SMahesh Ravishankar // 3cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information. 5cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6cf6a7c19SMahesh Ravishankar // 7cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 8cf6a7c19SMahesh Ravishankar // 9cf6a7c19SMahesh Ravishankar // This file implements a pass for testing tiling operations using 10cf6a7c19SMahesh Ravishankar // `TilingInterface`. 11cf6a7c19SMahesh Ravishankar // 12cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===// 13cf6a7c19SMahesh Ravishankar 14cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 15cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h" 16cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" 17cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/MemRef/IR/MemRef.h" 19*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 20*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 21cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h" 22cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h" 23cf6a7c19SMahesh Ravishankar #include "mlir/Pass/Pass.h" 24cf6a7c19SMahesh Ravishankar #include "mlir/Pass/PassManager.h" 25cf6a7c19SMahesh Ravishankar #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26cf6a7c19SMahesh Ravishankar #include "llvm/ADT/TypeSwitch.h" 27cf6a7c19SMahesh Ravishankar 28cf6a7c19SMahesh Ravishankar using namespace mlir; 29cf6a7c19SMahesh Ravishankar 30cf6a7c19SMahesh Ravishankar namespace { 31cf6a7c19SMahesh Ravishankar 32cf6a7c19SMahesh Ravishankar /// Construct a generic pattern applied to all TilingInterface ops that verify 33cf6a7c19SMahesh Ravishankar /// `filter`. 34cf6a7c19SMahesh Ravishankar struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp { 35cf6a7c19SMahesh Ravishankar TestTileUsingSCFForOpWithFilter(MLIRContext *context, 36cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions options, 37cf6a7c19SMahesh Ravishankar linalg::LinalgTransformationFilter filter = 38cf6a7c19SMahesh Ravishankar linalg::LinalgTransformationFilter(), 39cf6a7c19SMahesh Ravishankar PatternBenefit benefit = 1) 40cf6a7c19SMahesh Ravishankar : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {} 41cf6a7c19SMahesh Ravishankar 42cf6a7c19SMahesh Ravishankar /// Construct a generic pattern applied to `opName`. 43cf6a7c19SMahesh Ravishankar TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context, 44cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions options, 45cf6a7c19SMahesh Ravishankar linalg::LinalgTransformationFilter filter = 46cf6a7c19SMahesh Ravishankar linalg::LinalgTransformationFilter(), 47cf6a7c19SMahesh Ravishankar PatternBenefit benefit = 1) 48cf6a7c19SMahesh Ravishankar : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {} 49cf6a7c19SMahesh Ravishankar 50cf6a7c19SMahesh Ravishankar LogicalResult matchAndRewrite(TilingInterface op, 51cf6a7c19SMahesh Ravishankar PatternRewriter &rewriter) const override { 52cf6a7c19SMahesh Ravishankar if (failed(filter.checkAndNotify(rewriter, op))) 53cf6a7c19SMahesh Ravishankar return failure(); 54cf6a7c19SMahesh Ravishankar 55cf6a7c19SMahesh Ravishankar FailureOr<scf::SCFTilingResult> tilingResult = 56cf6a7c19SMahesh Ravishankar returningMatchAndRewrite(op, rewriter); 57cf6a7c19SMahesh Ravishankar if (failed(tilingResult)) { 58cf6a7c19SMahesh Ravishankar return failure(); 59cf6a7c19SMahesh Ravishankar } 60cf6a7c19SMahesh Ravishankar filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp); 61cf6a7c19SMahesh Ravishankar return success(); 62cf6a7c19SMahesh Ravishankar } 63cf6a7c19SMahesh Ravishankar 64cf6a7c19SMahesh Ravishankar private: 65cf6a7c19SMahesh Ravishankar linalg::LinalgTransformationFilter filter; 66cf6a7c19SMahesh Ravishankar }; 67cf6a7c19SMahesh Ravishankar 68cf6a7c19SMahesh Ravishankar struct TestTilingInterfacePass 69cf6a7c19SMahesh Ravishankar : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> { 70cf6a7c19SMahesh Ravishankar MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass) 71cf6a7c19SMahesh Ravishankar 72cf6a7c19SMahesh Ravishankar TestTilingInterfacePass() = default; 73cf6a7c19SMahesh Ravishankar TestTilingInterfacePass(const TestTilingInterfacePass &pass) 74cf6a7c19SMahesh Ravishankar : PassWrapper(pass) {} 75cf6a7c19SMahesh Ravishankar void getDependentDialects(DialectRegistry ®istry) const override { 76cf6a7c19SMahesh Ravishankar registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect, 77cf6a7c19SMahesh Ravishankar tensor::TensorDialect>(); 78cf6a7c19SMahesh Ravishankar linalg::registerTilingInterfaceExternalModels(registry); 79cf6a7c19SMahesh Ravishankar } 80cf6a7c19SMahesh Ravishankar StringRef getArgument() const final { return "test-tiling-interface"; } 81cf6a7c19SMahesh Ravishankar StringRef getDescription() const final { 82cf6a7c19SMahesh Ravishankar return "Test tiling using TilingInterface"; 83cf6a7c19SMahesh Ravishankar } 84cf6a7c19SMahesh Ravishankar 85cf6a7c19SMahesh Ravishankar void runOnOperation() override; 86cf6a7c19SMahesh Ravishankar }; 87cf6a7c19SMahesh Ravishankar } // namespace 88cf6a7c19SMahesh Ravishankar 89cf6a7c19SMahesh Ravishankar static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) { 90cf6a7c19SMahesh Ravishankar auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes, 91cf6a7c19SMahesh Ravishankar StringRef filterName) { 92cf6a7c19SMahesh Ravishankar scf::SCFTilingOptions tilingOptions; 93cf6a7c19SMahesh Ravishankar tilingOptions.setTileSizes(tileSizes); 94cf6a7c19SMahesh Ravishankar linalg::LinalgTransformationFilter filter( 95cf6a7c19SMahesh Ravishankar StringAttr::get(context, filterName), 96cf6a7c19SMahesh Ravishankar StringAttr::get(context, "tiled")); 97cf6a7c19SMahesh Ravishankar patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions, 98cf6a7c19SMahesh Ravishankar filter); 99cf6a7c19SMahesh Ravishankar }; 100cf6a7c19SMahesh Ravishankar // 1. Tiling M and N dims of `linalg.matmul` on tensors. 101cf6a7c19SMahesh Ravishankar addPatternForTiling({10, 20}, "simple_gemm"); 102cf6a7c19SMahesh Ravishankar // 2. Tiling M, N and K of `linalg.matmul` on buffers. 103cf6a7c19SMahesh Ravishankar addPatternForTiling({10, 20, 30}, "simple_gemm_memref"); 104cf6a7c19SMahesh Ravishankar // 3. Tiling 3D parallel generic op which implements a transpose 105cf6a7c19SMahesh Ravishankar addPatternForTiling({10, 0, 20}, "parallel_generic_transpose"); 106cf6a7c19SMahesh Ravishankar // 4. Tiling 2D conv op. 107cf6a7c19SMahesh Ravishankar addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv"); 108cf6a7c19SMahesh Ravishankar } 109cf6a7c19SMahesh Ravishankar 110cf6a7c19SMahesh Ravishankar void TestTilingInterfacePass::runOnOperation() { 111cf6a7c19SMahesh Ravishankar MLIRContext *context = &getContext(); 112cf6a7c19SMahesh Ravishankar 113cf6a7c19SMahesh Ravishankar RewritePatternSet tilingPatterns(context); 114cf6a7c19SMahesh Ravishankar addTestPatterns(context, tilingPatterns); 115cf6a7c19SMahesh Ravishankar if (failed(applyPatternsAndFoldGreedily(getOperation(), 116cf6a7c19SMahesh Ravishankar std::move(tilingPatterns)))) 117cf6a7c19SMahesh Ravishankar return signalPassFailure(); 118cf6a7c19SMahesh Ravishankar } 119cf6a7c19SMahesh Ravishankar 120cf6a7c19SMahesh Ravishankar namespace mlir { 121cf6a7c19SMahesh Ravishankar namespace test { 122cf6a7c19SMahesh Ravishankar void registerTestTilingInterface() { 123cf6a7c19SMahesh Ravishankar PassRegistration<TestTilingInterfacePass>(); 124cf6a7c19SMahesh Ravishankar } 125cf6a7c19SMahesh Ravishankar } // namespace test 126cf6a7c19SMahesh Ravishankar } // namespace mlir 127