1 //===- TestFunctionLike.cpp - Pass to test helpers on FunctionLike --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/Pass/Pass.h"
11 
12 using namespace mlir;
13 
14 namespace {
15 /// This is a test pass for verifying FuncOp's eraseArgument method.
16 struct TestFuncEraseArg
17     : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
18   StringRef getArgument() const final { return "test-func-erase-arg"; }
19   StringRef getDescription() const final { return "Test erasing func args."; }
20   void runOnOperation() override {
21     auto module = getOperation();
22 
23     for (FuncOp func : module.getOps<FuncOp>()) {
24       SmallVector<unsigned, 4> indicesToErase;
25       for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) {
26         if (func.getArgAttr(argIndex, "test.erase_this_arg")) {
27           // Push back twice to test that duplicate arg indices are handled
28           // correctly.
29           indicesToErase.push_back(argIndex);
30           indicesToErase.push_back(argIndex);
31         }
32       }
33       // Reverse the order to test that unsorted index lists are handled
34       // correctly.
35       std::reverse(indicesToErase.begin(), indicesToErase.end());
36       func.eraseArguments(indicesToErase);
37     }
38   }
39 };
40 
41 /// This is a test pass for verifying FuncOp's eraseResult method.
42 struct TestFuncEraseResult
43     : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
44   StringRef getArgument() const final { return "test-func-erase-result"; }
45   StringRef getDescription() const final {
46     return "Test erasing func results.";
47   }
48   void runOnOperation() override {
49     auto module = getOperation();
50 
51     for (FuncOp func : module.getOps<FuncOp>()) {
52       SmallVector<unsigned, 4> indicesToErase;
53       for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
54         if (func.getResultAttr(resultIndex, "test.erase_this_"
55                                             "result")) {
56           // Push back twice to test
57           // that duplicate indices
58           // are handled correctly.
59           indicesToErase.push_back(resultIndex);
60           indicesToErase.push_back(resultIndex);
61         }
62       }
63       // Reverse the order to test
64       // that unsorted index lists are
65       // handled correctly.
66       std::reverse(indicesToErase.begin(), indicesToErase.end());
67       func.eraseResults(indicesToErase);
68     }
69   }
70 };
71 
72 /// This is a test pass for verifying FuncOp's setType method.
73 struct TestFuncSetType
74     : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
75   StringRef getArgument() const final { return "test-func-set-type"; }
76   StringRef getDescription() const final { return "Test FuncOp::setType."; }
77   void runOnOperation() override {
78     auto module = getOperation();
79     SymbolTable symbolTable(module);
80 
81     for (FuncOp func : module.getOps<FuncOp>()) {
82       auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
83       if (!sym)
84         continue;
85       func.setType(symbolTable.lookup<FuncOp>(sym.getValue()).getType());
86     }
87   }
88 };
89 } // end anonymous namespace
90 
91 namespace mlir {
92 void registerTestFunc() {
93   PassRegistration<TestFuncEraseArg>();
94 
95   PassRegistration<TestFuncEraseResult>();
96 
97   PassRegistration<TestFuncSetType>();
98 }
99 } // namespace mlir
100