1 //===- TestBuiltinAttributeInterfaces.cpp ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "TestAttributes.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/Pass/Pass.h"
12 #include "llvm/Support/FormatVariadic.h"
13
14 using namespace mlir;
15 using namespace test;
16
17 // Helper to print one scalar value, force int8_t to print as integer instead of
18 // char.
19 template <typename T>
printOneElement(InFlightDiagnostic & os,T value)20 static void printOneElement(InFlightDiagnostic &os, T value) {
21 os << llvm::formatv("{0}", value).str();
22 }
23 template <>
printOneElement(InFlightDiagnostic & os,int8_t value)24 void printOneElement<int8_t>(InFlightDiagnostic &os, int8_t value) {
25 os << llvm::formatv("{0}", static_cast<int64_t>(value)).str();
26 }
27
28 namespace {
29 struct TestElementsAttrInterface
30 : public PassWrapper<TestElementsAttrInterface, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon25b185410111::TestElementsAttrInterface31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestElementsAttrInterface)
32
33 StringRef getArgument() const final { return "test-elements-attr-interface"; }
getDescription__anon25b185410111::TestElementsAttrInterface34 StringRef getDescription() const final {
35 return "Test ElementsAttr interface support.";
36 }
runOnOperation__anon25b185410111::TestElementsAttrInterface37 void runOnOperation() override {
38 getOperation().walk([&](Operation *op) {
39 for (NamedAttribute attr : op->getAttrs()) {
40 auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
41 if (!elementsAttr)
42 continue;
43 if (auto concreteAttr =
44 attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
45 switch (concreteAttr.getElementType()) {
46 case DenseArrayBaseAttr::EltType::I8:
47 testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t");
48 break;
49 case DenseArrayBaseAttr::EltType::I16:
50 testElementsAttrIteration<int16_t>(op, elementsAttr, "int16_t");
51 break;
52 case DenseArrayBaseAttr::EltType::I32:
53 testElementsAttrIteration<int32_t>(op, elementsAttr, "int32_t");
54 break;
55 case DenseArrayBaseAttr::EltType::I64:
56 testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
57 break;
58 case DenseArrayBaseAttr::EltType::F32:
59 testElementsAttrIteration<float>(op, elementsAttr, "float");
60 break;
61 case DenseArrayBaseAttr::EltType::F64:
62 testElementsAttrIteration<double>(op, elementsAttr, "double");
63 break;
64 }
65 continue;
66 }
67 testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
68 testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
69 testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");
70 testElementsAttrIteration<IntegerAttr>(op, elementsAttr, "IntegerAttr");
71 }
72 });
73 }
74
75 template <typename T>
testElementsAttrIteration__anon25b185410111::TestElementsAttrInterface76 void testElementsAttrIteration(Operation *op, ElementsAttr attr,
77 StringRef type) {
78 InFlightDiagnostic diag = op->emitError()
79 << "Test iterating `" << type << "`: ";
80
81 auto values = attr.tryGetValues<T>();
82 if (!values) {
83 diag << "unable to iterate type";
84 return;
85 }
86
87 llvm::interleaveComma(*values, diag,
88 [&](T value) { printOneElement(diag, value); });
89 }
90 };
91 } // namespace
92
93 namespace mlir {
94 namespace test {
registerTestBuiltinAttributeInterfaces()95 void registerTestBuiltinAttributeInterfaces() {
96 PassRegistration<TestElementsAttrInterface>();
97 }
98 } // namespace test
99 } // namespace mlir
100