1cf6a7c19SMahesh Ravishankar //===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===//
2cf6a7c19SMahesh Ravishankar //
3cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information.
5cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6cf6a7c19SMahesh Ravishankar //
7cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
8cf6a7c19SMahesh Ravishankar //
9cf6a7c19SMahesh Ravishankar // This file implements a pass for testing tiling operations using
10cf6a7c19SMahesh Ravishankar // `TilingInterface`.
11cf6a7c19SMahesh Ravishankar //
12cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
13cf6a7c19SMahesh Ravishankar 
14cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
15cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h"
16cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
17cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/MemRef/IR/MemRef.h"
19*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
20*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
21cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h"
22cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h"
23cf6a7c19SMahesh Ravishankar #include "mlir/Pass/Pass.h"
24cf6a7c19SMahesh Ravishankar #include "mlir/Pass/PassManager.h"
25cf6a7c19SMahesh Ravishankar #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26cf6a7c19SMahesh Ravishankar #include "llvm/ADT/TypeSwitch.h"
27cf6a7c19SMahesh Ravishankar 
28cf6a7c19SMahesh Ravishankar using namespace mlir;
29cf6a7c19SMahesh Ravishankar 
30cf6a7c19SMahesh Ravishankar namespace {
31cf6a7c19SMahesh Ravishankar 
32cf6a7c19SMahesh Ravishankar /// Construct a generic pattern applied to all TilingInterface ops that verify
33cf6a7c19SMahesh Ravishankar /// `filter`.
34cf6a7c19SMahesh Ravishankar struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
35cf6a7c19SMahesh Ravishankar   TestTileUsingSCFForOpWithFilter(MLIRContext *context,
36cf6a7c19SMahesh Ravishankar                                   scf::SCFTilingOptions options,
37cf6a7c19SMahesh Ravishankar                                   linalg::LinalgTransformationFilter filter =
38cf6a7c19SMahesh Ravishankar                                       linalg::LinalgTransformationFilter(),
39cf6a7c19SMahesh Ravishankar                                   PatternBenefit benefit = 1)
40cf6a7c19SMahesh Ravishankar       : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
41cf6a7c19SMahesh Ravishankar 
42cf6a7c19SMahesh Ravishankar   /// Construct a generic pattern applied to `opName`.
43cf6a7c19SMahesh Ravishankar   TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context,
44cf6a7c19SMahesh Ravishankar                                   scf::SCFTilingOptions options,
45cf6a7c19SMahesh Ravishankar                                   linalg::LinalgTransformationFilter filter =
46cf6a7c19SMahesh Ravishankar                                       linalg::LinalgTransformationFilter(),
47cf6a7c19SMahesh Ravishankar                                   PatternBenefit benefit = 1)
48cf6a7c19SMahesh Ravishankar       : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
49cf6a7c19SMahesh Ravishankar 
50cf6a7c19SMahesh Ravishankar   LogicalResult matchAndRewrite(TilingInterface op,
51cf6a7c19SMahesh Ravishankar                                 PatternRewriter &rewriter) const override {
52cf6a7c19SMahesh Ravishankar     if (failed(filter.checkAndNotify(rewriter, op)))
53cf6a7c19SMahesh Ravishankar       return failure();
54cf6a7c19SMahesh Ravishankar 
55cf6a7c19SMahesh Ravishankar     FailureOr<scf::SCFTilingResult> tilingResult =
56cf6a7c19SMahesh Ravishankar         returningMatchAndRewrite(op, rewriter);
57cf6a7c19SMahesh Ravishankar     if (failed(tilingResult)) {
58cf6a7c19SMahesh Ravishankar       return failure();
59cf6a7c19SMahesh Ravishankar     }
60cf6a7c19SMahesh Ravishankar     filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
61cf6a7c19SMahesh Ravishankar     return success();
62cf6a7c19SMahesh Ravishankar   }
63cf6a7c19SMahesh Ravishankar 
64cf6a7c19SMahesh Ravishankar private:
65cf6a7c19SMahesh Ravishankar   linalg::LinalgTransformationFilter filter;
66cf6a7c19SMahesh Ravishankar };
67cf6a7c19SMahesh Ravishankar 
68cf6a7c19SMahesh Ravishankar struct TestTilingInterfacePass
69cf6a7c19SMahesh Ravishankar     : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
70cf6a7c19SMahesh Ravishankar   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
71cf6a7c19SMahesh Ravishankar 
72cf6a7c19SMahesh Ravishankar   TestTilingInterfacePass() = default;
73cf6a7c19SMahesh Ravishankar   TestTilingInterfacePass(const TestTilingInterfacePass &pass)
74cf6a7c19SMahesh Ravishankar       : PassWrapper(pass) {}
75cf6a7c19SMahesh Ravishankar   void getDependentDialects(DialectRegistry &registry) const override {
76cf6a7c19SMahesh Ravishankar     registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
77cf6a7c19SMahesh Ravishankar                     tensor::TensorDialect>();
78cf6a7c19SMahesh Ravishankar     linalg::registerTilingInterfaceExternalModels(registry);
79cf6a7c19SMahesh Ravishankar   }
80cf6a7c19SMahesh Ravishankar   StringRef getArgument() const final { return "test-tiling-interface"; }
81cf6a7c19SMahesh Ravishankar   StringRef getDescription() const final {
82cf6a7c19SMahesh Ravishankar     return "Test tiling using TilingInterface";
83cf6a7c19SMahesh Ravishankar   }
84cf6a7c19SMahesh Ravishankar 
85cf6a7c19SMahesh Ravishankar   void runOnOperation() override;
86cf6a7c19SMahesh Ravishankar };
87cf6a7c19SMahesh Ravishankar } // namespace
88cf6a7c19SMahesh Ravishankar 
89cf6a7c19SMahesh Ravishankar static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) {
90cf6a7c19SMahesh Ravishankar   auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes,
91cf6a7c19SMahesh Ravishankar                                  StringRef filterName) {
92cf6a7c19SMahesh Ravishankar     scf::SCFTilingOptions tilingOptions;
93cf6a7c19SMahesh Ravishankar     tilingOptions.setTileSizes(tileSizes);
94cf6a7c19SMahesh Ravishankar     linalg::LinalgTransformationFilter filter(
95cf6a7c19SMahesh Ravishankar         StringAttr::get(context, filterName),
96cf6a7c19SMahesh Ravishankar         StringAttr::get(context, "tiled"));
97cf6a7c19SMahesh Ravishankar     patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions,
98cf6a7c19SMahesh Ravishankar                                                   filter);
99cf6a7c19SMahesh Ravishankar   };
100cf6a7c19SMahesh Ravishankar   // 1. Tiling M and N dims of `linalg.matmul` on tensors.
101cf6a7c19SMahesh Ravishankar   addPatternForTiling({10, 20}, "simple_gemm");
102cf6a7c19SMahesh Ravishankar   // 2. Tiling M, N and K of `linalg.matmul` on buffers.
103cf6a7c19SMahesh Ravishankar   addPatternForTiling({10, 20, 30}, "simple_gemm_memref");
104cf6a7c19SMahesh Ravishankar   // 3. Tiling 3D parallel generic op which implements a transpose
105cf6a7c19SMahesh Ravishankar   addPatternForTiling({10, 0, 20}, "parallel_generic_transpose");
106cf6a7c19SMahesh Ravishankar   // 4. Tiling 2D conv op.
107cf6a7c19SMahesh Ravishankar   addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv");
108cf6a7c19SMahesh Ravishankar }
109cf6a7c19SMahesh Ravishankar 
110cf6a7c19SMahesh Ravishankar void TestTilingInterfacePass::runOnOperation() {
111cf6a7c19SMahesh Ravishankar   MLIRContext *context = &getContext();
112cf6a7c19SMahesh Ravishankar 
113cf6a7c19SMahesh Ravishankar   RewritePatternSet tilingPatterns(context);
114cf6a7c19SMahesh Ravishankar   addTestPatterns(context, tilingPatterns);
115cf6a7c19SMahesh Ravishankar   if (failed(applyPatternsAndFoldGreedily(getOperation(),
116cf6a7c19SMahesh Ravishankar                                           std::move(tilingPatterns))))
117cf6a7c19SMahesh Ravishankar     return signalPassFailure();
118cf6a7c19SMahesh Ravishankar }
119cf6a7c19SMahesh Ravishankar 
120cf6a7c19SMahesh Ravishankar namespace mlir {
121cf6a7c19SMahesh Ravishankar namespace test {
122cf6a7c19SMahesh Ravishankar void registerTestTilingInterface() {
123cf6a7c19SMahesh Ravishankar   PassRegistration<TestTilingInterfacePass>();
124cf6a7c19SMahesh Ravishankar }
125cf6a7c19SMahesh Ravishankar } // namespace test
126cf6a7c19SMahesh Ravishankar } // namespace mlir
127