1*cf6a7c19SMahesh Ravishankar //===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===//
2*cf6a7c19SMahesh Ravishankar //
3*cf6a7c19SMahesh Ravishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*cf6a7c19SMahesh Ravishankar // See https://llvm.org/LICENSE.txt for license information.
5*cf6a7c19SMahesh Ravishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*cf6a7c19SMahesh Ravishankar //
7*cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
8*cf6a7c19SMahesh Ravishankar //
9*cf6a7c19SMahesh Ravishankar // This file implements a pass for testing tiling operations using
10*cf6a7c19SMahesh Ravishankar // `TilingInterface`.
11*cf6a7c19SMahesh Ravishankar //
12*cf6a7c19SMahesh Ravishankar //===----------------------------------------------------------------------===//
13*cf6a7c19SMahesh Ravishankar 
14*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
15*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Func/IR/FuncOps.h"
16*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
17*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/MemRef/IR/MemRef.h"
19*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/SCF.h"
20*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/SCF/TileUsingInterface.h"
21*cf6a7c19SMahesh Ravishankar #include "mlir/Dialect/Tensor/IR/Tensor.h"
22*cf6a7c19SMahesh Ravishankar #include "mlir/Interfaces/TilingInterface.h"
23*cf6a7c19SMahesh Ravishankar #include "mlir/Pass/Pass.h"
24*cf6a7c19SMahesh Ravishankar #include "mlir/Pass/PassManager.h"
25*cf6a7c19SMahesh Ravishankar #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26*cf6a7c19SMahesh Ravishankar #include "llvm/ADT/TypeSwitch.h"
27*cf6a7c19SMahesh Ravishankar 
28*cf6a7c19SMahesh Ravishankar using namespace mlir;
29*cf6a7c19SMahesh Ravishankar 
30*cf6a7c19SMahesh Ravishankar namespace {
31*cf6a7c19SMahesh Ravishankar 
32*cf6a7c19SMahesh Ravishankar /// Construct a generic pattern applied to all TilingInterface ops that verify
33*cf6a7c19SMahesh Ravishankar /// `filter`.
34*cf6a7c19SMahesh Ravishankar struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
35*cf6a7c19SMahesh Ravishankar   TestTileUsingSCFForOpWithFilter(MLIRContext *context,
36*cf6a7c19SMahesh Ravishankar                                   scf::SCFTilingOptions options,
37*cf6a7c19SMahesh Ravishankar                                   linalg::LinalgTransformationFilter filter =
38*cf6a7c19SMahesh Ravishankar                                       linalg::LinalgTransformationFilter(),
39*cf6a7c19SMahesh Ravishankar                                   PatternBenefit benefit = 1)
40*cf6a7c19SMahesh Ravishankar       : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
41*cf6a7c19SMahesh Ravishankar 
42*cf6a7c19SMahesh Ravishankar   /// Construct a generic pattern applied to `opName`.
43*cf6a7c19SMahesh Ravishankar   TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context,
44*cf6a7c19SMahesh Ravishankar                                   scf::SCFTilingOptions options,
45*cf6a7c19SMahesh Ravishankar                                   linalg::LinalgTransformationFilter filter =
46*cf6a7c19SMahesh Ravishankar                                       linalg::LinalgTransformationFilter(),
47*cf6a7c19SMahesh Ravishankar                                   PatternBenefit benefit = 1)
48*cf6a7c19SMahesh Ravishankar       : scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
49*cf6a7c19SMahesh Ravishankar 
50*cf6a7c19SMahesh Ravishankar   LogicalResult matchAndRewrite(TilingInterface op,
51*cf6a7c19SMahesh Ravishankar                                 PatternRewriter &rewriter) const override {
52*cf6a7c19SMahesh Ravishankar     if (failed(filter.checkAndNotify(rewriter, op)))
53*cf6a7c19SMahesh Ravishankar       return failure();
54*cf6a7c19SMahesh Ravishankar 
55*cf6a7c19SMahesh Ravishankar     FailureOr<scf::SCFTilingResult> tilingResult =
56*cf6a7c19SMahesh Ravishankar         returningMatchAndRewrite(op, rewriter);
57*cf6a7c19SMahesh Ravishankar     if (failed(tilingResult)) {
58*cf6a7c19SMahesh Ravishankar       return failure();
59*cf6a7c19SMahesh Ravishankar     }
60*cf6a7c19SMahesh Ravishankar     filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
61*cf6a7c19SMahesh Ravishankar     return success();
62*cf6a7c19SMahesh Ravishankar   }
63*cf6a7c19SMahesh Ravishankar 
64*cf6a7c19SMahesh Ravishankar private:
65*cf6a7c19SMahesh Ravishankar   linalg::LinalgTransformationFilter filter;
66*cf6a7c19SMahesh Ravishankar };
67*cf6a7c19SMahesh Ravishankar 
68*cf6a7c19SMahesh Ravishankar struct TestTilingInterfacePass
69*cf6a7c19SMahesh Ravishankar     : public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
70*cf6a7c19SMahesh Ravishankar   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
71*cf6a7c19SMahesh Ravishankar 
72*cf6a7c19SMahesh Ravishankar   TestTilingInterfacePass() = default;
73*cf6a7c19SMahesh Ravishankar   TestTilingInterfacePass(const TestTilingInterfacePass &pass)
74*cf6a7c19SMahesh Ravishankar       : PassWrapper(pass) {}
75*cf6a7c19SMahesh Ravishankar   void getDependentDialects(DialectRegistry &registry) const override {
76*cf6a7c19SMahesh Ravishankar     registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
77*cf6a7c19SMahesh Ravishankar                     tensor::TensorDialect>();
78*cf6a7c19SMahesh Ravishankar     linalg::registerTilingInterfaceExternalModels(registry);
79*cf6a7c19SMahesh Ravishankar   }
80*cf6a7c19SMahesh Ravishankar   StringRef getArgument() const final { return "test-tiling-interface"; }
81*cf6a7c19SMahesh Ravishankar   StringRef getDescription() const final {
82*cf6a7c19SMahesh Ravishankar     return "Test tiling using TilingInterface";
83*cf6a7c19SMahesh Ravishankar   }
84*cf6a7c19SMahesh Ravishankar 
85*cf6a7c19SMahesh Ravishankar   void runOnOperation() override;
86*cf6a7c19SMahesh Ravishankar };
87*cf6a7c19SMahesh Ravishankar } // namespace
88*cf6a7c19SMahesh Ravishankar 
89*cf6a7c19SMahesh Ravishankar static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) {
90*cf6a7c19SMahesh Ravishankar   auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes,
91*cf6a7c19SMahesh Ravishankar                                  StringRef filterName) {
92*cf6a7c19SMahesh Ravishankar     scf::SCFTilingOptions tilingOptions;
93*cf6a7c19SMahesh Ravishankar     tilingOptions.setTileSizes(tileSizes);
94*cf6a7c19SMahesh Ravishankar     linalg::LinalgTransformationFilter filter(
95*cf6a7c19SMahesh Ravishankar         StringAttr::get(context, filterName),
96*cf6a7c19SMahesh Ravishankar         StringAttr::get(context, "tiled"));
97*cf6a7c19SMahesh Ravishankar     patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions,
98*cf6a7c19SMahesh Ravishankar                                                   filter);
99*cf6a7c19SMahesh Ravishankar   };
100*cf6a7c19SMahesh Ravishankar   // 1. Tiling M and N dims of `linalg.matmul` on tensors.
101*cf6a7c19SMahesh Ravishankar   addPatternForTiling({10, 20}, "simple_gemm");
102*cf6a7c19SMahesh Ravishankar   // 2. Tiling M, N and K of `linalg.matmul` on buffers.
103*cf6a7c19SMahesh Ravishankar   addPatternForTiling({10, 20, 30}, "simple_gemm_memref");
104*cf6a7c19SMahesh Ravishankar   // 3. Tiling 3D parallel generic op which implements a transpose
105*cf6a7c19SMahesh Ravishankar   addPatternForTiling({10, 0, 20}, "parallel_generic_transpose");
106*cf6a7c19SMahesh Ravishankar   // 4. Tiling 2D conv op.
107*cf6a7c19SMahesh Ravishankar   addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv");
108*cf6a7c19SMahesh Ravishankar }
109*cf6a7c19SMahesh Ravishankar 
110*cf6a7c19SMahesh Ravishankar void TestTilingInterfacePass::runOnOperation() {
111*cf6a7c19SMahesh Ravishankar   MLIRContext *context = &getContext();
112*cf6a7c19SMahesh Ravishankar 
113*cf6a7c19SMahesh Ravishankar   RewritePatternSet tilingPatterns(context);
114*cf6a7c19SMahesh Ravishankar   addTestPatterns(context, tilingPatterns);
115*cf6a7c19SMahesh Ravishankar   if (failed(applyPatternsAndFoldGreedily(getOperation(),
116*cf6a7c19SMahesh Ravishankar                                           std::move(tilingPatterns))))
117*cf6a7c19SMahesh Ravishankar     return signalPassFailure();
118*cf6a7c19SMahesh Ravishankar }
119*cf6a7c19SMahesh Ravishankar 
120*cf6a7c19SMahesh Ravishankar namespace mlir {
121*cf6a7c19SMahesh Ravishankar namespace test {
122*cf6a7c19SMahesh Ravishankar void registerTestTilingInterface() {
123*cf6a7c19SMahesh Ravishankar   PassRegistration<TestTilingInterfacePass>();
124*cf6a7c19SMahesh Ravishankar }
125*cf6a7c19SMahesh Ravishankar } // namespace test
126*cf6a7c19SMahesh Ravishankar } // namespace mlir
127