1 //===- TestFunc.cpp - Pass to test helpers on function utilities ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinOps.h" 10 #include "mlir/Pass/Pass.h" 11 12 using namespace mlir; 13 14 namespace { 15 /// This is a test pass for verifying FuncOp's insertArgument method. 16 struct TestFuncInsertArg 17 : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> { 18 StringRef getArgument() const final { return "test-func-insert-arg"; } 19 StringRef getDescription() const final { return "Test inserting func args."; } 20 void runOnOperation() override { 21 auto module = getOperation(); 22 23 UnknownLoc unknownLoc = UnknownLoc::get(module.getContext()); 24 for (FuncOp func : module.getOps<FuncOp>()) { 25 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args"); 26 if (!inserts || inserts.empty()) 27 continue; 28 SmallVector<unsigned, 4> indicesToInsert; 29 SmallVector<Type, 4> typesToInsert; 30 SmallVector<DictionaryAttr, 4> attrsToInsert; 31 SmallVector<Location, 4> locsToInsert; 32 for (auto insert : inserts.getAsRange<ArrayAttr>()) { 33 indicesToInsert.push_back( 34 insert[0].cast<IntegerAttr>().getValue().getZExtValue()); 35 typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue()); 36 attrsToInsert.push_back(insert.size() > 2 37 ? insert[2].cast<DictionaryAttr>() 38 : DictionaryAttr::get(&getContext())); 39 locsToInsert.push_back(insert.size() > 3 40 ? Location(insert[3].cast<LocationAttr>()) 41 : unknownLoc); 42 } 43 func->removeAttr("test.insert_args"); 44 func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert, 45 locsToInsert); 46 } 47 } 48 }; 49 50 /// This is a test pass for verifying FuncOp's insertResult method. 51 struct TestFuncInsertResult 52 : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> { 53 StringRef getArgument() const final { return "test-func-insert-result"; } 54 StringRef getDescription() const final { 55 return "Test inserting func results."; 56 } 57 void runOnOperation() override { 58 auto module = getOperation(); 59 60 for (FuncOp func : module.getOps<FuncOp>()) { 61 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results"); 62 if (!inserts || inserts.empty()) 63 continue; 64 SmallVector<unsigned, 4> indicesToInsert; 65 SmallVector<Type, 4> typesToInsert; 66 SmallVector<DictionaryAttr, 4> attrsToInsert; 67 for (auto insert : inserts.getAsRange<ArrayAttr>()) { 68 indicesToInsert.push_back( 69 insert[0].cast<IntegerAttr>().getValue().getZExtValue()); 70 typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue()); 71 attrsToInsert.push_back(insert.size() > 2 72 ? insert[2].cast<DictionaryAttr>() 73 : DictionaryAttr::get(&getContext())); 74 } 75 func->removeAttr("test.insert_results"); 76 func.insertResults(indicesToInsert, typesToInsert, attrsToInsert); 77 } 78 } 79 }; 80 81 /// This is a test pass for verifying FuncOp's eraseArgument method. 82 struct TestFuncEraseArg 83 : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> { 84 StringRef getArgument() const final { return "test-func-erase-arg"; } 85 StringRef getDescription() const final { return "Test erasing func args."; } 86 void runOnOperation() override { 87 auto module = getOperation(); 88 89 for (FuncOp func : module.getOps<FuncOp>()) { 90 BitVector indicesToErase(func.getNumArguments()); 91 for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) 92 if (func.getArgAttr(argIndex, "test.erase_this_arg")) 93 indicesToErase.set(argIndex); 94 func.eraseArguments(indicesToErase); 95 } 96 } 97 }; 98 99 /// This is a test pass for verifying FuncOp's eraseResult method. 100 struct TestFuncEraseResult 101 : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> { 102 StringRef getArgument() const final { return "test-func-erase-result"; } 103 StringRef getDescription() const final { 104 return "Test erasing func results."; 105 } 106 void runOnOperation() override { 107 auto module = getOperation(); 108 109 for (FuncOp func : module.getOps<FuncOp>()) { 110 BitVector indicesToErase(func.getNumResults()); 111 for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) 112 if (func.getResultAttr(resultIndex, "test.erase_this_result")) 113 indicesToErase.set(resultIndex); 114 func.eraseResults(indicesToErase); 115 } 116 } 117 }; 118 119 /// This is a test pass for verifying FuncOp's setType method. 120 struct TestFuncSetType 121 : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> { 122 StringRef getArgument() const final { return "test-func-set-type"; } 123 StringRef getDescription() const final { return "Test FuncOp::setType."; } 124 void runOnOperation() override { 125 auto module = getOperation(); 126 SymbolTable symbolTable(module); 127 128 for (FuncOp func : module.getOps<FuncOp>()) { 129 auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from"); 130 if (!sym) 131 continue; 132 func.setType(symbolTable.lookup<FuncOp>(sym.getValue()).getType()); 133 } 134 } 135 }; 136 } // namespace 137 138 namespace mlir { 139 void registerTestFunc() { 140 PassRegistration<TestFuncInsertArg>(); 141 142 PassRegistration<TestFuncInsertResult>(); 143 144 PassRegistration<TestFuncEraseArg>(); 145 146 PassRegistration<TestFuncEraseResult>(); 147 148 PassRegistration<TestFuncSetType>(); 149 } 150 } // namespace mlir 151