1 //===- TestFunc.cpp - Pass to test helpers on function utilities ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinOps.h" 10 #include "mlir/IR/FunctionInterfaces.h" 11 #include "mlir/Pass/Pass.h" 12 13 using namespace mlir; 14 15 namespace { 16 /// This is a test pass for verifying FunctionOpInterface's insertArgument 17 /// method. 18 struct TestFuncInsertArg 19 : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> { 20 StringRef getArgument() const final { return "test-func-insert-arg"; } 21 StringRef getDescription() const final { return "Test inserting func args."; } 22 void runOnOperation() override { 23 auto module = getOperation(); 24 25 UnknownLoc unknownLoc = UnknownLoc::get(module.getContext()); 26 for (auto func : module.getOps<FunctionOpInterface>()) { 27 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args"); 28 if (!inserts || inserts.empty()) 29 continue; 30 SmallVector<unsigned, 4> indicesToInsert; 31 SmallVector<Type, 4> typesToInsert; 32 SmallVector<DictionaryAttr, 4> attrsToInsert; 33 SmallVector<Location, 4> locsToInsert; 34 for (auto insert : inserts.getAsRange<ArrayAttr>()) { 35 indicesToInsert.push_back( 36 insert[0].cast<IntegerAttr>().getValue().getZExtValue()); 37 typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue()); 38 attrsToInsert.push_back(insert.size() > 2 39 ? insert[2].cast<DictionaryAttr>() 40 : DictionaryAttr::get(&getContext())); 41 locsToInsert.push_back(insert.size() > 3 42 ? Location(insert[3].cast<LocationAttr>()) 43 : unknownLoc); 44 } 45 func->removeAttr("test.insert_args"); 46 func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert, 47 locsToInsert); 48 } 49 } 50 }; 51 52 /// This is a test pass for verifying FunctionOpInterface's insertResult method. 53 struct TestFuncInsertResult 54 : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> { 55 StringRef getArgument() const final { return "test-func-insert-result"; } 56 StringRef getDescription() const final { 57 return "Test inserting func results."; 58 } 59 void runOnOperation() override { 60 auto module = getOperation(); 61 62 for (auto func : module.getOps<FunctionOpInterface>()) { 63 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results"); 64 if (!inserts || inserts.empty()) 65 continue; 66 SmallVector<unsigned, 4> indicesToInsert; 67 SmallVector<Type, 4> typesToInsert; 68 SmallVector<DictionaryAttr, 4> attrsToInsert; 69 for (auto insert : inserts.getAsRange<ArrayAttr>()) { 70 indicesToInsert.push_back( 71 insert[0].cast<IntegerAttr>().getValue().getZExtValue()); 72 typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue()); 73 attrsToInsert.push_back(insert.size() > 2 74 ? insert[2].cast<DictionaryAttr>() 75 : DictionaryAttr::get(&getContext())); 76 } 77 func->removeAttr("test.insert_results"); 78 func.insertResults(indicesToInsert, typesToInsert, attrsToInsert); 79 } 80 } 81 }; 82 83 /// This is a test pass for verifying FunctionOpInterface's eraseArgument 84 /// method. 85 struct TestFuncEraseArg 86 : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> { 87 StringRef getArgument() const final { return "test-func-erase-arg"; } 88 StringRef getDescription() const final { return "Test erasing func args."; } 89 void runOnOperation() override { 90 auto module = getOperation(); 91 92 for (auto func : module.getOps<FunctionOpInterface>()) { 93 BitVector indicesToErase(func.getNumArguments()); 94 for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) 95 if (func.getArgAttr(argIndex, "test.erase_this_arg")) 96 indicesToErase.set(argIndex); 97 func.eraseArguments(indicesToErase); 98 } 99 } 100 }; 101 102 /// This is a test pass for verifying FunctionOpInterface's eraseResult method. 103 struct TestFuncEraseResult 104 : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> { 105 StringRef getArgument() const final { return "test-func-erase-result"; } 106 StringRef getDescription() const final { 107 return "Test erasing func results."; 108 } 109 void runOnOperation() override { 110 auto module = getOperation(); 111 112 for (auto func : module.getOps<FunctionOpInterface>()) { 113 BitVector indicesToErase(func.getNumResults()); 114 for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) 115 if (func.getResultAttr(resultIndex, "test.erase_this_result")) 116 indicesToErase.set(resultIndex); 117 func.eraseResults(indicesToErase); 118 } 119 } 120 }; 121 122 /// This is a test pass for verifying FunctionOpInterface's setType method. 123 struct TestFuncSetType 124 : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> { 125 StringRef getArgument() const final { return "test-func-set-type"; } 126 StringRef getDescription() const final { 127 return "Test FunctionOpInterface::setType."; 128 } 129 void runOnOperation() override { 130 auto module = getOperation(); 131 SymbolTable symbolTable(module); 132 133 for (auto func : module.getOps<FunctionOpInterface>()) { 134 auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from"); 135 if (!sym) 136 continue; 137 func.setType( 138 symbolTable.lookup<FunctionOpInterface>(sym.getValue()).getType()); 139 } 140 } 141 }; 142 } // namespace 143 144 namespace mlir { 145 void registerTestFunc() { 146 PassRegistration<TestFuncInsertArg>(); 147 148 PassRegistration<TestFuncInsertResult>(); 149 150 PassRegistration<TestFuncEraseArg>(); 151 152 PassRegistration<TestFuncEraseResult>(); 153 154 PassRegistration<TestFuncSetType>(); 155 } 156 } // namespace mlir 157