1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "gtest/gtest.h"
12
13 using namespace mlir;
14 using namespace mlir::detail;
15
16 template <typename EltTy>
testSplat(Type eltType,const EltTy & splatElt)17 static void testSplat(Type eltType, const EltTy &splatElt) {
18 RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
19
20 // Check that the generated splat is the same for 1 element and N elements.
21 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
22 EXPECT_TRUE(splat.isSplat());
23
24 auto detectedSplat =
25 DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
26 EXPECT_EQ(detectedSplat, splat);
27
28 for (auto newValue : detectedSplat.template getValues<EltTy>())
29 EXPECT_TRUE(newValue == splatElt);
30 }
31
32 namespace {
TEST(DenseSplatTest,BoolSplat)33 TEST(DenseSplatTest, BoolSplat) {
34 MLIRContext context;
35 IntegerType boolTy = IntegerType::get(&context, 1);
36 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
37
38 // Check that splat is automatically detected for boolean values.
39 /// True.
40 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
41 EXPECT_TRUE(trueSplat.isSplat());
42 /// False.
43 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
44 EXPECT_TRUE(falseSplat.isSplat());
45 EXPECT_NE(falseSplat, trueSplat);
46
47 /// Detect and handle splat within 8 elements (bool values are bit-packed).
48 /// True.
49 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
50 EXPECT_EQ(detectedSplat, trueSplat);
51 /// False.
52 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
53 EXPECT_EQ(detectedSplat, falseSplat);
54 }
55
TEST(DenseSplatTest,LargeBoolSplat)56 TEST(DenseSplatTest, LargeBoolSplat) {
57 constexpr int64_t boolCount = 56;
58
59 MLIRContext context;
60 IntegerType boolTy = IntegerType::get(&context, 1);
61 RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
62
63 // Check that splat is automatically detected for boolean values.
64 /// True.
65 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
66 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
67 EXPECT_TRUE(trueSplat.isSplat());
68 EXPECT_TRUE(falseSplat.isSplat());
69
70 /// Detect that the large boolean arrays are properly splatted.
71 /// True.
72 SmallVector<bool, 64> trueValues(boolCount, true);
73 auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
74 EXPECT_EQ(detectedSplat, trueSplat);
75 /// False.
76 SmallVector<bool, 64> falseValues(boolCount, false);
77 detectedSplat = DenseElementsAttr::get(shape, falseValues);
78 EXPECT_EQ(detectedSplat, falseSplat);
79 }
80
TEST(DenseSplatTest,BoolNonSplat)81 TEST(DenseSplatTest, BoolNonSplat) {
82 MLIRContext context;
83 IntegerType boolTy = IntegerType::get(&context, 1);
84 RankedTensorType shape = RankedTensorType::get({6}, boolTy);
85
86 // Check that we properly handle non-splat values.
87 DenseElementsAttr nonSplat =
88 DenseElementsAttr::get(shape, {false, false, true, false, false, true});
89 EXPECT_FALSE(nonSplat.isSplat());
90 }
91
TEST(DenseSplatTest,OddIntSplat)92 TEST(DenseSplatTest, OddIntSplat) {
93 // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
94 MLIRContext context;
95 constexpr size_t intWidth = 19;
96 IntegerType intTy = IntegerType::get(&context, intWidth);
97 APInt value(intWidth, 10);
98
99 testSplat(intTy, value);
100 }
101
TEST(DenseSplatTest,Int32Splat)102 TEST(DenseSplatTest, Int32Splat) {
103 MLIRContext context;
104 IntegerType intTy = IntegerType::get(&context, 32);
105 int value = 64;
106
107 testSplat(intTy, value);
108 }
109
TEST(DenseSplatTest,IntAttrSplat)110 TEST(DenseSplatTest, IntAttrSplat) {
111 MLIRContext context;
112 IntegerType intTy = IntegerType::get(&context, 85);
113 Attribute value = IntegerAttr::get(intTy, 109);
114
115 testSplat(intTy, value);
116 }
117
TEST(DenseSplatTest,F32Splat)118 TEST(DenseSplatTest, F32Splat) {
119 MLIRContext context;
120 FloatType floatTy = FloatType::getF32(&context);
121 float value = 10.0;
122
123 testSplat(floatTy, value);
124 }
125
TEST(DenseSplatTest,F64Splat)126 TEST(DenseSplatTest, F64Splat) {
127 MLIRContext context;
128 FloatType floatTy = FloatType::getF64(&context);
129 double value = 10.0;
130
131 testSplat(floatTy, APFloat(value));
132 }
133
TEST(DenseSplatTest,FloatAttrSplat)134 TEST(DenseSplatTest, FloatAttrSplat) {
135 MLIRContext context;
136 FloatType floatTy = FloatType::getF32(&context);
137 Attribute value = FloatAttr::get(floatTy, 10.0);
138
139 testSplat(floatTy, value);
140 }
141
TEST(DenseSplatTest,BF16Splat)142 TEST(DenseSplatTest, BF16Splat) {
143 MLIRContext context;
144 FloatType floatTy = FloatType::getBF16(&context);
145 Attribute value = FloatAttr::get(floatTy, 10.0);
146
147 testSplat(floatTy, value);
148 }
149
TEST(DenseSplatTest,StringSplat)150 TEST(DenseSplatTest, StringSplat) {
151 MLIRContext context;
152 context.allowUnregisteredDialects();
153 Type stringType =
154 OpaqueType::get(StringAttr::get(&context, "test"), "string");
155 StringRef value = "test-string";
156 testSplat(stringType, value);
157 }
158
TEST(DenseSplatTest,StringAttrSplat)159 TEST(DenseSplatTest, StringAttrSplat) {
160 MLIRContext context;
161 context.allowUnregisteredDialects();
162 Type stringType =
163 OpaqueType::get(StringAttr::get(&context, "test"), "string");
164 Attribute stringAttr = StringAttr::get("test-string", stringType);
165 testSplat(stringType, stringAttr);
166 }
167
TEST(DenseComplexTest,ComplexFloatSplat)168 TEST(DenseComplexTest, ComplexFloatSplat) {
169 MLIRContext context;
170 ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
171 std::complex<float> value(10.0, 15.0);
172 testSplat(complexType, value);
173 }
174
TEST(DenseComplexTest,ComplexIntSplat)175 TEST(DenseComplexTest, ComplexIntSplat) {
176 MLIRContext context;
177 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
178 std::complex<int64_t> value(10, 15);
179 testSplat(complexType, value);
180 }
181
TEST(DenseComplexTest,ComplexAPFloatSplat)182 TEST(DenseComplexTest, ComplexAPFloatSplat) {
183 MLIRContext context;
184 ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
185 std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
186 testSplat(complexType, value);
187 }
188
TEST(DenseComplexTest,ComplexAPIntSplat)189 TEST(DenseComplexTest, ComplexAPIntSplat) {
190 MLIRContext context;
191 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
192 std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
193 testSplat(complexType, value);
194 }
195
TEST(DenseScalarTest,ExtractZeroRankElement)196 TEST(DenseScalarTest, ExtractZeroRankElement) {
197 MLIRContext context;
198 const int elementValue = 12;
199 IntegerType intTy = IntegerType::get(&context, 32);
200 Attribute value = IntegerAttr::get(intTy, elementValue);
201 RankedTensorType shape = RankedTensorType::get({}, intTy);
202
203 auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
204 EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
205 }
206
TEST(SparseElementsAttrTest,GetZero)207 TEST(SparseElementsAttrTest, GetZero) {
208 MLIRContext context;
209 context.allowUnregisteredDialects();
210
211 IntegerType intTy = IntegerType::get(&context, 32);
212 FloatType floatTy = FloatType::getF32(&context);
213 Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
214
215 ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
216 ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
217 ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
218
219 auto indicesType =
220 RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
221 auto indices =
222 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
223
224 RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
225 auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
226
227 RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
228 auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
229
230 RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
231 auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
232
233 auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
234 auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
235 auto sparseString =
236 SparseElementsAttr::get(tensorString, indices, stringValue);
237
238 // Only index (0, 0) contains an element, others are supposed to return
239 // the zero/empty value.
240 auto zeroIntValue = sparseInt.getValues<Attribute>()[{1, 1}];
241 EXPECT_EQ(zeroIntValue.cast<IntegerAttr>().getInt(), 0);
242 EXPECT_TRUE(zeroIntValue.getType() == intTy);
243
244 auto zeroFloatValue = sparseFloat.getValues<Attribute>()[{1, 1}];
245 EXPECT_EQ(zeroFloatValue.cast<FloatAttr>().getValueAsDouble(), 0.0f);
246 EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
247
248 auto zeroStringValue = sparseString.getValues<Attribute>()[{1, 1}];
249 EXPECT_TRUE(zeroStringValue.cast<StringAttr>().getValue().empty());
250 EXPECT_TRUE(zeroStringValue.getType() == stringTy);
251 }
252
253 } // namespace
254