1 //===- TestLinalgHoisting.cpp - Test Linalg hoisting functions ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg hoisting functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Pass/Pass.h" 18 19 using namespace mlir; 20 using namespace mlir::linalg; 21 22 namespace { 23 struct TestLinalgHoisting 24 : public PassWrapper<TestLinalgHoisting, OperationPass<func::FuncOp>> { 25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgHoisting) 26 27 TestLinalgHoisting() = default; TestLinalgHoisting__anon933c194f0111::TestLinalgHoisting28 TestLinalgHoisting(const TestLinalgHoisting &pass) : PassWrapper(pass) {} getDependentDialects__anon933c194f0111::TestLinalgHoisting29 void getDependentDialects(DialectRegistry ®istry) const override { 30 registry.insert<AffineDialect>(); 31 } getArgument__anon933c194f0111::TestLinalgHoisting32 StringRef getArgument() const final { return "test-linalg-hoisting"; } getDescription__anon933c194f0111::TestLinalgHoisting33 StringRef getDescription() const final { 34 return "Test Linalg hoisting functions."; 35 } 36 37 void runOnOperation() override; 38 39 Option<bool> testHoistRedundantTransfers{ 40 *this, "test-hoist-redundant-transfers", 41 llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"), 42 llvm::cl::init(false)}; 43 }; 44 } // namespace 45 runOnOperation()46void TestLinalgHoisting::runOnOperation() { 47 if (testHoistRedundantTransfers) { 48 hoistRedundantVectorTransfers(getOperation()); 49 hoistRedundantVectorTransfersOnTensor(getOperation()); 50 return; 51 } 52 } 53 54 namespace mlir { 55 namespace test { registerTestLinalgHoisting()56void registerTestLinalgHoisting() { PassRegistration<TestLinalgHoisting>(); } 57 } // namespace test 58 } // namespace mlir 59