1 //===- TestFunctionLike.cpp - Pass to test helpers on FunctionLike --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinOps.h" 10 #include "mlir/Pass/Pass.h" 11 12 using namespace mlir; 13 14 namespace { 15 /// This is a test pass for verifying FuncOp's insertArgument method. 16 struct TestFuncInsertArg 17 : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> { 18 StringRef getArgument() const final { return "test-func-insert-arg"; } 19 StringRef getDescription() const final { return "Test inserting func args."; } 20 void runOnOperation() override { 21 auto module = getOperation(); 22 23 for (FuncOp func : module.getOps<FuncOp>()) { 24 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args"); 25 if (!inserts || inserts.empty()) 26 continue; 27 SmallVector<unsigned, 4> indicesToInsert; 28 SmallVector<Type, 4> typesToInsert; 29 SmallVector<DictionaryAttr, 4> attrsToInsert; 30 SmallVector<Optional<Location>, 4> locsToInsert; 31 for (auto insert : inserts.getAsRange<ArrayAttr>()) { 32 indicesToInsert.push_back( 33 insert[0].cast<IntegerAttr>().getValue().getZExtValue()); 34 typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue()); 35 attrsToInsert.push_back(insert.size() > 2 36 ? insert[2].cast<DictionaryAttr>() 37 : DictionaryAttr::get(&getContext())); 38 locsToInsert.push_back( 39 insert.size() > 3 40 ? Optional<Location>(insert[3].cast<LocationAttr>()) 41 : Optional<Location>{}); 42 } 43 func->removeAttr("test.insert_args"); 44 func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert, 45 locsToInsert); 46 } 47 } 48 }; 49 50 /// This is a test pass for verifying FuncOp's insertResult method. 51 struct TestFuncInsertResult 52 : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> { 53 StringRef getArgument() const final { return "test-func-insert-result"; } 54 StringRef getDescription() const final { 55 return "Test inserting func results."; 56 } 57 void runOnOperation() override { 58 auto module = getOperation(); 59 60 for (FuncOp func : module.getOps<FuncOp>()) { 61 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results"); 62 if (!inserts || inserts.empty()) 63 continue; 64 SmallVector<unsigned, 4> indicesToInsert; 65 SmallVector<Type, 4> typesToInsert; 66 SmallVector<DictionaryAttr, 4> attrsToInsert; 67 for (auto insert : inserts.getAsRange<ArrayAttr>()) { 68 indicesToInsert.push_back( 69 insert[0].cast<IntegerAttr>().getValue().getZExtValue()); 70 typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue()); 71 attrsToInsert.push_back(insert.size() > 2 72 ? insert[2].cast<DictionaryAttr>() 73 : DictionaryAttr::get(&getContext())); 74 } 75 func->removeAttr("test.insert_results"); 76 func.insertResults(indicesToInsert, typesToInsert, attrsToInsert); 77 } 78 } 79 }; 80 81 /// This is a test pass for verifying FuncOp's eraseArgument method. 82 struct TestFuncEraseArg 83 : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> { 84 StringRef getArgument() const final { return "test-func-erase-arg"; } 85 StringRef getDescription() const final { return "Test erasing func args."; } 86 void runOnOperation() override { 87 auto module = getOperation(); 88 89 for (FuncOp func : module.getOps<FuncOp>()) { 90 SmallVector<unsigned, 4> indicesToErase; 91 for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) { 92 if (func.getArgAttr(argIndex, "test.erase_this_arg")) { 93 // Push back twice to test that duplicate arg indices are handled 94 // correctly. 95 indicesToErase.push_back(argIndex); 96 indicesToErase.push_back(argIndex); 97 } 98 } 99 // Reverse the order to test that unsorted index lists are handled 100 // correctly. 101 std::reverse(indicesToErase.begin(), indicesToErase.end()); 102 func.eraseArguments(indicesToErase); 103 } 104 } 105 }; 106 107 /// This is a test pass for verifying FuncOp's eraseResult method. 108 struct TestFuncEraseResult 109 : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> { 110 StringRef getArgument() const final { return "test-func-erase-result"; } 111 StringRef getDescription() const final { 112 return "Test erasing func results."; 113 } 114 void runOnOperation() override { 115 auto module = getOperation(); 116 117 for (FuncOp func : module.getOps<FuncOp>()) { 118 SmallVector<unsigned, 4> indicesToErase; 119 for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) { 120 if (func.getResultAttr(resultIndex, "test.erase_this_result")) { 121 // Push back twice to test that duplicate indices are handled 122 // correctly. 123 indicesToErase.push_back(resultIndex); 124 indicesToErase.push_back(resultIndex); 125 } 126 } 127 // Reverse the order to test that unsorted index lists are handled 128 // correctly. 129 std::reverse(indicesToErase.begin(), indicesToErase.end()); 130 func.eraseResults(indicesToErase); 131 } 132 } 133 }; 134 135 /// This is a test pass for verifying FuncOp's setType method. 136 struct TestFuncSetType 137 : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> { 138 StringRef getArgument() const final { return "test-func-set-type"; } 139 StringRef getDescription() const final { return "Test FuncOp::setType."; } 140 void runOnOperation() override { 141 auto module = getOperation(); 142 SymbolTable symbolTable(module); 143 144 for (FuncOp func : module.getOps<FuncOp>()) { 145 auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from"); 146 if (!sym) 147 continue; 148 func.setType(symbolTable.lookup<FuncOp>(sym.getValue()).getType()); 149 } 150 } 151 }; 152 } // end anonymous namespace 153 154 namespace mlir { 155 void registerTestFunc() { 156 PassRegistration<TestFuncInsertArg>(); 157 158 PassRegistration<TestFuncInsertResult>(); 159 160 PassRegistration<TestFuncEraseArg>(); 161 162 PassRegistration<TestFuncEraseResult>(); 163 164 PassRegistration<TestFuncSetType>(); 165 } 166 } // namespace mlir 167