1 //===- StructuredOpsUtilsTest.cpp - StructuredOpsUtils unit tests ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "gmock/gmock.h"
13 #include "gtest/gtest.h"
14
15 using namespace mlir;
16 using testing::Not;
17 using testing::Truly;
18
19 namespace {
20
TEST(isRowMajorMatmul,Simple)21 TEST(isRowMajorMatmul, Simple) {
22 MLIRContext context;
23
24 AffineExpr m, n, k;
25 bindDims(&context, m, n, k);
26 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
27 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
28 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
29 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
30
31 EXPECT_THAT(maps, Truly(isRowMajorMatmul));
32 }
33
TEST(isRowMajorMatmul,BindingShifted)34 TEST(isRowMajorMatmul, BindingShifted) {
35 MLIRContext context;
36
37 AffineExpr m, n, k;
38 bindDims(&context, k, m, n); // bind in different order
39 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
40 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
41 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
42 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
43
44 EXPECT_THAT(maps, Truly(isRowMajorMatmul));
45 }
46
TEST(isRowMajorMatmul,BindingSwapped)47 TEST(isRowMajorMatmul, BindingSwapped) {
48 MLIRContext context;
49
50 AffineExpr m, n, k;
51 bindDims(&context, k, n, m); // bind in different order
52 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
53 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
54 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
55 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
56
57 EXPECT_THAT(maps, Truly(isRowMajorMatmul));
58 }
59
TEST(isRowMajorMatmul,ColumnMajor)60 TEST(isRowMajorMatmul, ColumnMajor) {
61 MLIRContext context;
62
63 AffineExpr m, n, k;
64 bindDims(&context, m, n, k);
65 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
66 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
67 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
68 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
69
70 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
71 }
72
TEST(isRowMajorMatmul,FirstInputSwapped)73 TEST(isRowMajorMatmul, FirstInputSwapped) {
74 MLIRContext context;
75
76 AffineExpr m, n, k;
77 bindDims(&context, m, n, k);
78 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context));
79 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
80 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
81 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
82
83 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
84 }
85
TEST(isRowMajorMatmul,TooFewMaps)86 TEST(isRowMajorMatmul, TooFewMaps) {
87 MLIRContext context;
88
89 AffineExpr m, n, k;
90 bindDims(&context, m, n, k);
91 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
92 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
93 auto maps = ArrayAttr::get(&context, {mapA, mapB});
94
95 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
96 }
97
TEST(isRowMajorMatmul,TooManyMaps)98 TEST(isRowMajorMatmul, TooManyMaps) {
99 MLIRContext context;
100
101 AffineExpr m, n, k;
102 bindDims(&context, m, n, k);
103 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
104 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
105 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
106 auto mapD = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context));
107
108 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC, mapD});
109
110 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
111 }
112
TEST(isRowMajorMatmul,TooFewOutputs)113 TEST(isRowMajorMatmul, TooFewOutputs) {
114 MLIRContext context;
115
116 AffineExpr m, n, k;
117 bindDims(&context, m, n, k);
118 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m}, &context));
119 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
120 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
121 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
122
123 EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul)));
124 }
125
TEST(isColumnMajorMatmul,Simple)126 TEST(isColumnMajorMatmul, Simple) {
127 MLIRContext context;
128
129 AffineExpr m, n, k;
130 bindDims(&context, m, n, k);
131 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
132 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
133 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
134 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
135
136 EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
137 }
138
TEST(isColumnMajorMatmul,BindingShifted)139 TEST(isColumnMajorMatmul, BindingShifted) {
140 MLIRContext context;
141
142 AffineExpr m, n, k;
143 bindDims(&context, k, m, n); // bind in different order
144 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
145 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
146 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
147 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
148
149 EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
150 }
151
TEST(isColumnMajorMatmul,BindingSwapped)152 TEST(isColumnMajorMatmul, BindingSwapped) {
153 MLIRContext context;
154
155 AffineExpr m, n, k;
156 bindDims(&context, k, n, m); // bind in different order
157 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
158 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
159 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
160 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
161
162 EXPECT_THAT(maps, Truly(isColumnMajorMatmul));
163 }
164
TEST(isColumnMajorMatmul,RowMajor)165 TEST(isColumnMajorMatmul, RowMajor) {
166 MLIRContext context;
167
168 AffineExpr m, n, k;
169 bindDims(&context, m, n, k);
170 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
171 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context));
172 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
173 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
174
175 EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul)));
176 }
177
TEST(isColumnMajorMatmul,FirstInputSwapped)178 TEST(isColumnMajorMatmul, FirstInputSwapped) {
179 MLIRContext context;
180
181 AffineExpr m, n, k;
182 bindDims(&context, m, n, k);
183 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, &context));
184 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context));
185 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context));
186 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
187
188 EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul)));
189 }
190
TEST(isRowMajorBatchMatmul,Simple)191 TEST(isRowMajorBatchMatmul, Simple) {
192 MLIRContext context;
193
194 AffineExpr batch, m, n, k;
195 bindDims(&context, batch, m, n, k);
196 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
197 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
198 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
199 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
200
201 EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
202 }
203
TEST(isRowMajorBatchMatmul,BindingShifted)204 TEST(isRowMajorBatchMatmul, BindingShifted) {
205 MLIRContext context;
206
207 AffineExpr batch, m, n, k;
208 bindDims(&context, k, batch, m, n); // bind in different order
209 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
210 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
211 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
212 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
213
214 EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
215 }
216
TEST(isRowMajorBatchMatmul,BindingSwapped)217 TEST(isRowMajorBatchMatmul, BindingSwapped) {
218 MLIRContext context;
219
220 AffineExpr batch, m, n, k;
221 bindDims(&context, batch, k, n, m); // bind in different order
222 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context));
223 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
224 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
225 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
226
227 EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul));
228 }
229
TEST(isRowMajorBatchMatmul,FirstInputSwapped)230 TEST(isRowMajorBatchMatmul, FirstInputSwapped) {
231 MLIRContext context;
232
233 AffineExpr batch, m, n, k;
234 bindDims(&context, batch, m, n, k);
235 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, m}, &context));
236 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context));
237 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context));
238 auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC});
239
240 EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul)));
241 }
242
243 } // namespace
244