1 //===- TestLinalgHoisting.cpp - Test Linalg hoisting functions ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg hoisting functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 using namespace mlir::linalg;
20 
21 namespace {
22 struct TestLinalgHoisting
23     : public PassWrapper<TestLinalgHoisting, FunctionPass> {
24   TestLinalgHoisting() = default;
25   TestLinalgHoisting(const TestLinalgHoisting &pass) : PassWrapper(pass) {}
26   void getDependentDialects(DialectRegistry &registry) const override {
27     registry.insert<AffineDialect>();
28   }
29   StringRef getArgument() const final { return "test-linalg-hoisting"; }
30   StringRef getDescription() const final {
31     return "Test Linalg hoisting functions.";
32   }
33 
34   void runOnFunction() override;
35 
36   Option<bool> testHoistRedundantTransfers{
37       *this, "test-hoist-redundant-transfers",
38       llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"),
39       llvm::cl::init(false)};
40 };
41 } // namespace
42 
43 void TestLinalgHoisting::runOnFunction() {
44   if (testHoistRedundantTransfers) {
45     hoistRedundantVectorTransfers(getFunction());
46     hoistRedundantVectorTransfersOnTensor(getFunction());
47     return;
48   }
49 }
50 
51 namespace mlir {
52 namespace test {
53 void registerTestLinalgHoisting() { PassRegistration<TestLinalgHoisting>(); }
54 } // namespace test
55 } // namespace mlir
56