1 //===- TestFunc.cpp - Pass to test helpers on function utilities ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/IR/FunctionInterfaces.h"
11 #include "mlir/Pass/Pass.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 /// This is a test pass for verifying FunctionOpInterface's insertArgument
17 /// method.
18 struct TestFuncInsertArg
19     : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> {
20   StringRef getArgument() const final { return "test-func-insert-arg"; }
21   StringRef getDescription() const final { return "Test inserting func args."; }
22   void runOnOperation() override {
23     auto module = getOperation();
24 
25     UnknownLoc unknownLoc = UnknownLoc::get(module.getContext());
26     for (auto func : module.getOps<FunctionOpInterface>()) {
27       auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args");
28       if (!inserts || inserts.empty())
29         continue;
30       SmallVector<unsigned, 4> indicesToInsert;
31       SmallVector<Type, 4> typesToInsert;
32       SmallVector<DictionaryAttr, 4> attrsToInsert;
33       SmallVector<Location, 4> locsToInsert;
34       for (auto insert : inserts.getAsRange<ArrayAttr>()) {
35         indicesToInsert.push_back(
36             insert[0].cast<IntegerAttr>().getValue().getZExtValue());
37         typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
38         attrsToInsert.push_back(insert.size() > 2
39                                     ? insert[2].cast<DictionaryAttr>()
40                                     : DictionaryAttr::get(&getContext()));
41         locsToInsert.push_back(insert.size() > 3
42                                    ? Location(insert[3].cast<LocationAttr>())
43                                    : unknownLoc);
44       }
45       func->removeAttr("test.insert_args");
46       func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert,
47                            locsToInsert);
48     }
49   }
50 };
51 
52 /// This is a test pass for verifying FunctionOpInterface's insertResult method.
53 struct TestFuncInsertResult
54     : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> {
55   StringRef getArgument() const final { return "test-func-insert-result"; }
56   StringRef getDescription() const final {
57     return "Test inserting func results.";
58   }
59   void runOnOperation() override {
60     auto module = getOperation();
61 
62     for (auto func : module.getOps<FunctionOpInterface>()) {
63       auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results");
64       if (!inserts || inserts.empty())
65         continue;
66       SmallVector<unsigned, 4> indicesToInsert;
67       SmallVector<Type, 4> typesToInsert;
68       SmallVector<DictionaryAttr, 4> attrsToInsert;
69       for (auto insert : inserts.getAsRange<ArrayAttr>()) {
70         indicesToInsert.push_back(
71             insert[0].cast<IntegerAttr>().getValue().getZExtValue());
72         typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
73         attrsToInsert.push_back(insert.size() > 2
74                                     ? insert[2].cast<DictionaryAttr>()
75                                     : DictionaryAttr::get(&getContext()));
76       }
77       func->removeAttr("test.insert_results");
78       func.insertResults(indicesToInsert, typesToInsert, attrsToInsert);
79     }
80   }
81 };
82 
83 /// This is a test pass for verifying FunctionOpInterface's eraseArgument
84 /// method.
85 struct TestFuncEraseArg
86     : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
87   StringRef getArgument() const final { return "test-func-erase-arg"; }
88   StringRef getDescription() const final { return "Test erasing func args."; }
89   void runOnOperation() override {
90     auto module = getOperation();
91 
92     for (auto func : module.getOps<FunctionOpInterface>()) {
93       BitVector indicesToErase(func.getNumArguments());
94       for (auto argIndex : llvm::seq<int>(0, func.getNumArguments()))
95         if (func.getArgAttr(argIndex, "test.erase_this_arg"))
96           indicesToErase.set(argIndex);
97       func.eraseArguments(indicesToErase);
98     }
99   }
100 };
101 
102 /// This is a test pass for verifying FunctionOpInterface's eraseResult method.
103 struct TestFuncEraseResult
104     : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
105   StringRef getArgument() const final { return "test-func-erase-result"; }
106   StringRef getDescription() const final {
107     return "Test erasing func results.";
108   }
109   void runOnOperation() override {
110     auto module = getOperation();
111 
112     for (auto func : module.getOps<FunctionOpInterface>()) {
113       BitVector indicesToErase(func.getNumResults());
114       for (auto resultIndex : llvm::seq<int>(0, func.getNumResults()))
115         if (func.getResultAttr(resultIndex, "test.erase_this_result"))
116           indicesToErase.set(resultIndex);
117       func.eraseResults(indicesToErase);
118     }
119   }
120 };
121 
122 /// This is a test pass for verifying FunctionOpInterface's setType method.
123 struct TestFuncSetType
124     : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
125   StringRef getArgument() const final { return "test-func-set-type"; }
126   StringRef getDescription() const final {
127     return "Test FunctionOpInterface::setType.";
128   }
129   void runOnOperation() override {
130     auto module = getOperation();
131     SymbolTable symbolTable(module);
132 
133     for (auto func : module.getOps<FunctionOpInterface>()) {
134       auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
135       if (!sym)
136         continue;
137       func.setType(
138           symbolTable.lookup<FunctionOpInterface>(sym.getValue()).getType());
139     }
140   }
141 };
142 } // namespace
143 
144 namespace mlir {
145 void registerTestFunc() {
146   PassRegistration<TestFuncInsertArg>();
147 
148   PassRegistration<TestFuncInsertResult>();
149 
150   PassRegistration<TestFuncEraseArg>();
151 
152   PassRegistration<TestFuncEraseResult>();
153 
154   PassRegistration<TestFuncSetType>();
155 }
156 } // namespace mlir
157