1 //===- TestLinalgHoisting.cpp - Test Linalg hoisting functions ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg hoisting functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 16 #include "mlir/Pass/Pass.h" 17 18 using namespace mlir; 19 using namespace mlir::linalg; 20 21 namespace { 22 struct TestLinalgHoisting 23 : public PassWrapper<TestLinalgHoisting, FunctionPass> { 24 TestLinalgHoisting() = default; 25 TestLinalgHoisting(const TestLinalgHoisting &pass) : PassWrapper(pass) {} 26 void getDependentDialects(DialectRegistry ®istry) const override { 27 registry.insert<AffineDialect>(); 28 } 29 StringRef getArgument() const final { return "test-linalg-hoisting"; } 30 StringRef getDescription() const final { 31 return "Test Linalg hoisting functions."; 32 } 33 34 void runOnFunction() override; 35 36 Option<bool> testHoistRedundantTransfers{ 37 *this, "test-hoist-redundant-transfers", 38 llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"), 39 llvm::cl::init(false)}; 40 }; 41 } // namespace 42 43 void TestLinalgHoisting::runOnFunction() { 44 if (testHoistRedundantTransfers) { 45 hoistRedundantVectorTransfers(getFunction()); 46 hoistRedundantVectorTransfersOnTensor(getFunction()); 47 return; 48 } 49 } 50 51 namespace mlir { 52 namespace test { 53 void registerTestLinalgHoisting() { PassRegistration<TestLinalgHoisting>(); } 54 } // namespace test 55 } // namespace mlir 56