1 //===- TestBuiltinAttributeInterfaces.cpp ---------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestAttributes.h" 10 #include "mlir/IR/BuiltinOps.h" 11 #include "mlir/Pass/Pass.h" 12 #include "llvm/Support/FormatVariadic.h" 13 14 using namespace mlir; 15 using namespace test; 16 17 // Helper to print one scalar value, force int8_t to print as integer instead of 18 // char. 19 template <typename T> 20 static void printOneElement(InFlightDiagnostic &os, T value) { 21 os << llvm::formatv("{0}", value).str(); 22 } 23 template <> 24 void printOneElement<int8_t>(InFlightDiagnostic &os, int8_t value) { 25 os << llvm::formatv("{0}", static_cast<int64_t>(value)).str(); 26 } 27 28 namespace { 29 struct TestElementsAttrInterface 30 : public PassWrapper<TestElementsAttrInterface, OperationPass<ModuleOp>> { 31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestElementsAttrInterface) 32 33 StringRef getArgument() const final { return "test-elements-attr-interface"; } 34 StringRef getDescription() const final { 35 return "Test ElementsAttr interface support."; 36 } 37 void runOnOperation() override { 38 getOperation().walk([&](Operation *op) { 39 for (NamedAttribute attr : op->getAttrs()) { 40 auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>(); 41 if (!elementsAttr) 42 continue; 43 if (auto concreteAttr = 44 attr.getValue().dyn_cast<DenseArrayBaseAttr>()) { 45 switch (concreteAttr.getElementType()) { 46 case DenseArrayBaseAttr::EltType::I8: 47 testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t"); 48 break; 49 case DenseArrayBaseAttr::EltType::I16: 50 testElementsAttrIteration<int16_t>(op, elementsAttr, "int16_t"); 51 break; 52 case DenseArrayBaseAttr::EltType::I32: 53 testElementsAttrIteration<int32_t>(op, elementsAttr, "int32_t"); 54 break; 55 case DenseArrayBaseAttr::EltType::I64: 56 testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t"); 57 break; 58 case DenseArrayBaseAttr::EltType::F32: 59 testElementsAttrIteration<float>(op, elementsAttr, "float"); 60 break; 61 case DenseArrayBaseAttr::EltType::F64: 62 testElementsAttrIteration<double>(op, elementsAttr, "double"); 63 break; 64 } 65 continue; 66 } 67 testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t"); 68 testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t"); 69 testElementsAttrIteration<APInt>(op, elementsAttr, "APInt"); 70 testElementsAttrIteration<IntegerAttr>(op, elementsAttr, "IntegerAttr"); 71 } 72 }); 73 } 74 75 template <typename T> 76 void testElementsAttrIteration(Operation *op, ElementsAttr attr, 77 StringRef type) { 78 InFlightDiagnostic diag = op->emitError() 79 << "Test iterating `" << type << "`: "; 80 81 auto values = attr.tryGetValues<T>(); 82 if (!values) { 83 diag << "unable to iterate type"; 84 return; 85 } 86 87 llvm::interleaveComma(*values, diag, 88 [&](T value) { printOneElement(diag, value); }); 89 } 90 }; 91 } // namespace 92 93 namespace mlir { 94 namespace test { 95 void registerTestBuiltinAttributeInterfaces() { 96 PassRegistration<TestElementsAttrInterface>(); 97 } 98 } // namespace test 99 } // namespace mlir 100