1 //===- VectorizerTestPass.cpp - VectorizerTestPass Pass Impl --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a simple testing pass for vectorization functionality.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
15 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Affine/LoopUtils.h"
18 #include "mlir/Dialect/Affine/Utils.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/Passes.h"
27
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31
32 #define DEBUG_TYPE "affine-super-vectorizer-test"
33
34 using namespace mlir;
35
36 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
37
38 static llvm::cl::list<int> clTestVectorShapeRatio(
39 "vector-shape-ratio",
40 llvm::cl::desc("Specify the HW vector size for vectorization"),
41 llvm::cl::cat(clOptionsCategory));
42 static llvm::cl::opt<bool> clTestForwardSlicingAnalysis(
43 "forward-slicing",
44 llvm::cl::desc("Enable testing forward static slicing and topological sort "
45 "functionalities"),
46 llvm::cl::cat(clOptionsCategory));
47 static llvm::cl::opt<bool> clTestBackwardSlicingAnalysis(
48 "backward-slicing",
49 llvm::cl::desc("Enable testing backward static slicing and "
50 "topological sort functionalities"),
51 llvm::cl::cat(clOptionsCategory));
52 static llvm::cl::opt<bool> clTestSlicingAnalysis(
53 "slicing",
54 llvm::cl::desc("Enable testing static slicing and topological sort "
55 "functionalities"),
56 llvm::cl::cat(clOptionsCategory));
57 static llvm::cl::opt<bool> clTestComposeMaps(
58 "compose-maps",
59 llvm::cl::desc(
60 "Enable testing the composition of AffineMap where each "
61 "AffineMap in the composition is specified as the affine_map attribute "
62 "in a constant op."),
63 llvm::cl::cat(clOptionsCategory));
64 static llvm::cl::opt<bool> clTestVecAffineLoopNest(
65 "vectorize-affine-loop-nest",
66 llvm::cl::desc(
67 "Enable testing for the 'vectorizeAffineLoopNest' utility by "
68 "vectorizing the outermost loops found"),
69 llvm::cl::cat(clOptionsCategory));
70
71 namespace {
72 struct VectorizerTestPass
73 : public PassWrapper<VectorizerTestPass, OperationPass<func::FuncOp>> {
74 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorizerTestPass)
75
76 static constexpr auto kTestAffineMapOpName = "test_affine_map";
77 static constexpr auto kTestAffineMapAttrName = "affine_map";
getDependentDialects__anona61fa8470111::VectorizerTestPass78 void getDependentDialects(DialectRegistry ®istry) const override {
79 registry.insert<vector::VectorDialect>();
80 }
getArgument__anona61fa8470111::VectorizerTestPass81 StringRef getArgument() const final { return "affine-super-vectorizer-test"; }
getDescription__anona61fa8470111::VectorizerTestPass82 StringRef getDescription() const final {
83 return "Tests vectorizer standalone functionality.";
84 }
85
86 void runOnOperation() override;
87 void testVectorShapeRatio(llvm::raw_ostream &outs);
88 void testForwardSlicing(llvm::raw_ostream &outs);
89 void testBackwardSlicing(llvm::raw_ostream &outs);
90 void testSlicing(llvm::raw_ostream &outs);
91 void testComposeMaps(llvm::raw_ostream &outs);
92
93 /// Test for 'vectorizeAffineLoopNest' utility.
94 void testVecAffineLoopNest();
95 };
96
97 } // namespace
98
testVectorShapeRatio(llvm::raw_ostream & outs)99 void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) {
100 auto f = getOperation();
101 using matcher::Op;
102 SmallVector<int64_t, 8> shape(clTestVectorShapeRatio.begin(),
103 clTestVectorShapeRatio.end());
104 auto subVectorType =
105 VectorType::get(shape, FloatType::getF32(f.getContext()));
106 // Only filter operations that operate on a strict super-vector and have one
107 // return. This makes testing easier.
108 auto filter = [&](Operation &op) {
109 assert(subVectorType.getElementType().isF32() &&
110 "Only f32 supported for now");
111 if (!matcher::operatesOnSuperVectorsOf(op, subVectorType)) {
112 return false;
113 }
114 if (op.getNumResults() != 1) {
115 return false;
116 }
117 return true;
118 };
119 auto pat = Op(filter);
120 SmallVector<NestedMatch, 8> matches;
121 pat.match(f, &matches);
122 for (auto m : matches) {
123 auto *opInst = m.getMatchedOperation();
124 // This is a unit test that only checks and prints shape ratio.
125 // As a consequence we write only Ops with a single return type for the
126 // purpose of this test. If we need to test more intricate behavior in the
127 // future we can always extend.
128 auto superVectorType = opInst->getResult(0).getType().cast<VectorType>();
129 auto ratio = shapeRatio(superVectorType, subVectorType);
130 if (!ratio) {
131 opInst->emitRemark("NOT MATCHED");
132 } else {
133 outs << "\nmatched: " << *opInst << " with shape ratio: ";
134 llvm::interleaveComma(MutableArrayRef<int64_t>(*ratio), outs);
135 }
136 }
137 }
138
patternTestSlicingOps()139 static NestedPattern patternTestSlicingOps() {
140 using matcher::Op;
141 // Match all operations with the kTestSlicingOpName name.
142 auto filter = [](Operation &op) {
143 // Just use a custom op name for this test, it makes life easier.
144 return op.getName().getStringRef() == "slicing-test-op";
145 };
146 return Op(filter);
147 }
148
testBackwardSlicing(llvm::raw_ostream & outs)149 void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) {
150 auto f = getOperation();
151 outs << "\n" << f.getName();
152
153 SmallVector<NestedMatch, 8> matches;
154 patternTestSlicingOps().match(f, &matches);
155 for (auto m : matches) {
156 SetVector<Operation *> backwardSlice;
157 getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
158 outs << "\nmatched: " << *m.getMatchedOperation()
159 << " backward static slice: ";
160 for (auto *op : backwardSlice)
161 outs << "\n" << *op;
162 }
163 }
164
testForwardSlicing(llvm::raw_ostream & outs)165 void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) {
166 auto f = getOperation();
167 outs << "\n" << f.getName();
168
169 SmallVector<NestedMatch, 8> matches;
170 patternTestSlicingOps().match(f, &matches);
171 for (auto m : matches) {
172 SetVector<Operation *> forwardSlice;
173 getForwardSlice(m.getMatchedOperation(), &forwardSlice);
174 outs << "\nmatched: " << *m.getMatchedOperation()
175 << " forward static slice: ";
176 for (auto *op : forwardSlice)
177 outs << "\n" << *op;
178 }
179 }
180
testSlicing(llvm::raw_ostream & outs)181 void VectorizerTestPass::testSlicing(llvm::raw_ostream &outs) {
182 auto f = getOperation();
183 outs << "\n" << f.getName();
184
185 SmallVector<NestedMatch, 8> matches;
186 patternTestSlicingOps().match(f, &matches);
187 for (auto m : matches) {
188 SetVector<Operation *> staticSlice = getSlice(m.getMatchedOperation());
189 outs << "\nmatched: " << *m.getMatchedOperation() << " static slice: ";
190 for (auto *op : staticSlice)
191 outs << "\n" << *op;
192 }
193 }
194
customOpWithAffineMapAttribute(Operation & op)195 static bool customOpWithAffineMapAttribute(Operation &op) {
196 return op.getName().getStringRef() ==
197 VectorizerTestPass::kTestAffineMapOpName;
198 }
199
testComposeMaps(llvm::raw_ostream & outs)200 void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) {
201 auto f = getOperation();
202
203 using matcher::Op;
204 auto pattern = Op(customOpWithAffineMapAttribute);
205 SmallVector<NestedMatch, 8> matches;
206 pattern.match(f, &matches);
207 SmallVector<AffineMap, 4> maps;
208 maps.reserve(matches.size());
209 for (auto m : llvm::reverse(matches)) {
210 auto *opInst = m.getMatchedOperation();
211 auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
212 .cast<AffineMapAttr>()
213 .getValue();
214 maps.push_back(map);
215 }
216 AffineMap res;
217 for (auto m : maps) {
218 res = res ? res.compose(m) : m;
219 }
220 simplifyAffineMap(res).print(outs << "\nComposed map: ");
221 }
222
223 /// Test for 'vectorizeAffineLoopNest' utility.
testVecAffineLoopNest()224 void VectorizerTestPass::testVecAffineLoopNest() {
225 std::vector<SmallVector<AffineForOp, 2>> loops;
226 gatherLoops(getOperation(), loops);
227
228 // Expected only one loop nest.
229 if (loops.empty() || loops[0].size() != 1)
230 return;
231
232 // We vectorize the outermost loop found with VF=4.
233 AffineForOp outermostLoop = loops[0][0];
234 VectorizationStrategy strategy;
235 strategy.vectorSizes.push_back(4 /*vectorization factor*/);
236 strategy.loopToVectorDim[outermostLoop] = 0;
237 std::vector<SmallVector<AffineForOp, 2>> loopsToVectorize;
238 loopsToVectorize.push_back({outermostLoop});
239 (void)vectorizeAffineLoopNest(loopsToVectorize, strategy);
240 }
241
runOnOperation()242 void VectorizerTestPass::runOnOperation() {
243 // Only support single block functions at this point.
244 func::FuncOp f = getOperation();
245 if (!llvm::hasSingleElement(f))
246 return;
247
248 std::string str;
249 llvm::raw_string_ostream outs(str);
250
251 { // Tests that expect a NestedPatternContext to be allocated externally.
252 NestedPatternContext mlContext;
253
254 if (!clTestVectorShapeRatio.empty())
255 testVectorShapeRatio(outs);
256
257 if (clTestForwardSlicingAnalysis)
258 testForwardSlicing(outs);
259
260 if (clTestBackwardSlicingAnalysis)
261 testBackwardSlicing(outs);
262
263 if (clTestSlicingAnalysis)
264 testSlicing(outs);
265
266 if (clTestComposeMaps)
267 testComposeMaps(outs);
268 }
269
270 if (clTestVecAffineLoopNest)
271 testVecAffineLoopNest();
272
273 if (!outs.str().empty()) {
274 emitRemark(UnknownLoc::get(&getContext()), outs.str());
275 }
276 }
277
278 namespace mlir {
registerVectorizerTestPass()279 void registerVectorizerTestPass() { PassRegistration<VectorizerTestPass>(); }
280 } // namespace mlir
281