1 //===- TestFunc.cpp - Pass to test helpers on function utilities ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinOps.h" 10 #include "mlir/IR/FunctionInterfaces.h" 11 #include "mlir/Pass/Pass.h" 12 13 using namespace mlir; 14 15 namespace { 16 /// This is a test pass for verifying FunctionOpInterface's insertArgument 17 /// method. 18 struct TestFuncInsertArg 19 : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> { 20 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertArg) 21 22 StringRef getArgument() const final { return "test-func-insert-arg"; } 23 StringRef getDescription() const final { return "Test inserting func args."; } 24 void runOnOperation() override { 25 auto module = getOperation(); 26 27 UnknownLoc unknownLoc = UnknownLoc::get(module.getContext()); 28 for (auto func : module.getOps<FunctionOpInterface>()) { 29 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args"); 30 if (!inserts || inserts.empty()) 31 continue; 32 SmallVector<unsigned, 4> indicesToInsert; 33 SmallVector<Type, 4> typesToInsert; 34 SmallVector<DictionaryAttr, 4> attrsToInsert; 35 SmallVector<Location, 4> locsToInsert; 36 for (auto insert : inserts.getAsRange<ArrayAttr>()) { 37 indicesToInsert.push_back( 38 insert[0].cast<IntegerAttr>().getValue().getZExtValue()); 39 typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue()); 40 attrsToInsert.push_back(insert.size() > 2 41 ? insert[2].cast<DictionaryAttr>() 42 : DictionaryAttr::get(&getContext())); 43 locsToInsert.push_back(insert.size() > 3 44 ? Location(insert[3].cast<LocationAttr>()) 45 : unknownLoc); 46 } 47 func->removeAttr("test.insert_args"); 48 func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert, 49 locsToInsert); 50 } 51 } 52 }; 53 54 /// This is a test pass for verifying FunctionOpInterface's insertResult method. 55 struct TestFuncInsertResult 56 : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> { 57 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertResult) 58 59 StringRef getArgument() const final { return "test-func-insert-result"; } 60 StringRef getDescription() const final { 61 return "Test inserting func results."; 62 } 63 void runOnOperation() override { 64 auto module = getOperation(); 65 66 for (auto func : module.getOps<FunctionOpInterface>()) { 67 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results"); 68 if (!inserts || inserts.empty()) 69 continue; 70 SmallVector<unsigned, 4> indicesToInsert; 71 SmallVector<Type, 4> typesToInsert; 72 SmallVector<DictionaryAttr, 4> attrsToInsert; 73 for (auto insert : inserts.getAsRange<ArrayAttr>()) { 74 indicesToInsert.push_back( 75 insert[0].cast<IntegerAttr>().getValue().getZExtValue()); 76 typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue()); 77 attrsToInsert.push_back(insert.size() > 2 78 ? insert[2].cast<DictionaryAttr>() 79 : DictionaryAttr::get(&getContext())); 80 } 81 func->removeAttr("test.insert_results"); 82 func.insertResults(indicesToInsert, typesToInsert, attrsToInsert); 83 } 84 } 85 }; 86 87 /// This is a test pass for verifying FunctionOpInterface's eraseArgument 88 /// method. 89 struct TestFuncEraseArg 90 : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> { 91 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseArg) 92 93 StringRef getArgument() const final { return "test-func-erase-arg"; } 94 StringRef getDescription() const final { return "Test erasing func args."; } 95 void runOnOperation() override { 96 auto module = getOperation(); 97 98 for (auto func : module.getOps<FunctionOpInterface>()) { 99 BitVector indicesToErase(func.getNumArguments()); 100 for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) 101 if (func.getArgAttr(argIndex, "test.erase_this_arg")) 102 indicesToErase.set(argIndex); 103 func.eraseArguments(indicesToErase); 104 } 105 } 106 }; 107 108 /// This is a test pass for verifying FunctionOpInterface's eraseResult method. 109 struct TestFuncEraseResult 110 : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> { 111 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseResult) 112 113 StringRef getArgument() const final { return "test-func-erase-result"; } 114 StringRef getDescription() const final { 115 return "Test erasing func results."; 116 } 117 void runOnOperation() override { 118 auto module = getOperation(); 119 120 for (auto func : module.getOps<FunctionOpInterface>()) { 121 BitVector indicesToErase(func.getNumResults()); 122 for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) 123 if (func.getResultAttr(resultIndex, "test.erase_this_result")) 124 indicesToErase.set(resultIndex); 125 func.eraseResults(indicesToErase); 126 } 127 } 128 }; 129 130 /// This is a test pass for verifying FunctionOpInterface's setType method. 131 struct TestFuncSetType 132 : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> { 133 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncSetType) 134 135 StringRef getArgument() const final { return "test-func-set-type"; } 136 StringRef getDescription() const final { 137 return "Test FunctionOpInterface::setType."; 138 } 139 void runOnOperation() override { 140 auto module = getOperation(); 141 SymbolTable symbolTable(module); 142 143 for (auto func : module.getOps<FunctionOpInterface>()) { 144 auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from"); 145 if (!sym) 146 continue; 147 func.setType(symbolTable.lookup<FunctionOpInterface>(sym.getValue()) 148 .getFunctionType()); 149 } 150 } 151 }; 152 } // namespace 153 154 namespace mlir { 155 void registerTestFunc() { 156 PassRegistration<TestFuncInsertArg>(); 157 158 PassRegistration<TestFuncInsertResult>(); 159 160 PassRegistration<TestFuncEraseArg>(); 161 162 PassRegistration<TestFuncEraseResult>(); 163 164 PassRegistration<TestFuncSetType>(); 165 } 166 } // namespace mlir 167