1b8737614SUday Bondhugula //===- VectorizerTestPass.cpp - VectorizerTestPass Pass Impl --------------===//
2b8737614SUday Bondhugula //
3b8737614SUday Bondhugula // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b8737614SUday Bondhugula // See https://llvm.org/LICENSE.txt for license information.
5b8737614SUday Bondhugula // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b8737614SUday Bondhugula //
7b8737614SUday Bondhugula //===----------------------------------------------------------------------===//
8b8737614SUday Bondhugula //
9b8737614SUday Bondhugula // This file implements a simple testing pass for vectorization functionality.
10b8737614SUday Bondhugula //
11b8737614SUday Bondhugula //===----------------------------------------------------------------------===//
12b8737614SUday Bondhugula 
13b8737614SUday Bondhugula #include "mlir/Analysis/SliceAnalysis.h"
14755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
15755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
16b8737614SUday Bondhugula #include "mlir/Dialect/Affine/IR/AffineOps.h"
17a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopUtils.h"
1893936da9SDiego Caballero #include "mlir/Dialect/Affine/Utils.h"
1936550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
2099ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
2199ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
22b8737614SUday Bondhugula #include "mlir/IR/Builders.h"
2309f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
24b8737614SUday Bondhugula #include "mlir/IR/Diagnostics.h"
25b8737614SUday Bondhugula #include "mlir/Pass/Pass.h"
26b8737614SUday Bondhugula #include "mlir/Transforms/Passes.h"
27b8737614SUday Bondhugula 
28b8737614SUday Bondhugula #include "llvm/ADT/STLExtras.h"
29b8737614SUday Bondhugula #include "llvm/Support/CommandLine.h"
30b8737614SUday Bondhugula #include "llvm/Support/Debug.h"
31b8737614SUday Bondhugula 
32b8737614SUday Bondhugula #define DEBUG_TYPE "affine-super-vectorizer-test"
33b8737614SUday Bondhugula 
34b8737614SUday Bondhugula using namespace mlir;
35b8737614SUday Bondhugula 
36b8737614SUday Bondhugula static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
37b8737614SUday Bondhugula 
38b8737614SUday Bondhugula static llvm::cl::list<int> clTestVectorShapeRatio(
39b8737614SUday Bondhugula     "vector-shape-ratio",
40b8737614SUday Bondhugula     llvm::cl::desc("Specify the HW vector size for vectorization"),
41d86a206fSFangrui Song     llvm::cl::cat(clOptionsCategory));
42b8737614SUday Bondhugula static llvm::cl::opt<bool> clTestForwardSlicingAnalysis(
43b8737614SUday Bondhugula     "forward-slicing",
44b8737614SUday Bondhugula     llvm::cl::desc("Enable testing forward static slicing and topological sort "
45b8737614SUday Bondhugula                    "functionalities"),
46b8737614SUday Bondhugula     llvm::cl::cat(clOptionsCategory));
47b8737614SUday Bondhugula static llvm::cl::opt<bool> clTestBackwardSlicingAnalysis(
48b8737614SUday Bondhugula     "backward-slicing",
49b8737614SUday Bondhugula     llvm::cl::desc("Enable testing backward static slicing and "
50b8737614SUday Bondhugula                    "topological sort functionalities"),
51b8737614SUday Bondhugula     llvm::cl::cat(clOptionsCategory));
52b8737614SUday Bondhugula static llvm::cl::opt<bool> clTestSlicingAnalysis(
53b8737614SUday Bondhugula     "slicing",
54b8737614SUday Bondhugula     llvm::cl::desc("Enable testing static slicing and topological sort "
55b8737614SUday Bondhugula                    "functionalities"),
56b8737614SUday Bondhugula     llvm::cl::cat(clOptionsCategory));
57b8737614SUday Bondhugula static llvm::cl::opt<bool> clTestComposeMaps(
58b8737614SUday Bondhugula     "compose-maps",
59b8737614SUday Bondhugula     llvm::cl::desc(
60b8737614SUday Bondhugula         "Enable testing the composition of AffineMap where each "
61b8737614SUday Bondhugula         "AffineMap in the composition is specified as the affine_map attribute "
62b8737614SUday Bondhugula         "in a constant op."),
63b8737614SUday Bondhugula     llvm::cl::cat(clOptionsCategory));
6493936da9SDiego Caballero static llvm::cl::opt<bool> clTestVecAffineLoopNest(
6593936da9SDiego Caballero     "vectorize-affine-loop-nest",
6693936da9SDiego Caballero     llvm::cl::desc(
6793936da9SDiego Caballero         "Enable testing for the 'vectorizeAffineLoopNest' utility by "
6893936da9SDiego Caballero         "vectorizing the outermost loops found"),
6993936da9SDiego Caballero     llvm::cl::cat(clOptionsCategory));
70b8737614SUday Bondhugula 
71b8737614SUday Bondhugula namespace {
7280aca1eaSRiver Riddle struct VectorizerTestPass
7358ceae95SRiver Riddle     : public PassWrapper<VectorizerTestPass, OperationPass<func::FuncOp>> {
745e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorizerTestPass)
755e50dd04SRiver Riddle 
76b8737614SUday Bondhugula   static constexpr auto kTestAffineMapOpName = "test_affine_map";
77b8737614SUday Bondhugula   static constexpr auto kTestAffineMapAttrName = "affine_map";
getDependentDialects__anona61fa8470111::VectorizerTestPass78f9dc2b70SMehdi Amini   void getDependentDialects(DialectRegistry &registry) const override {
79f9dc2b70SMehdi Amini     registry.insert<vector::VectorDialect>();
80f9dc2b70SMehdi Amini   }
getArgument__anona61fa8470111::VectorizerTestPass81b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "affine-super-vectorizer-test"; }
getDescription__anona61fa8470111::VectorizerTestPass82b5e22e6dSMehdi Amini   StringRef getDescription() const final {
83b5e22e6dSMehdi Amini     return "Tests vectorizer standalone functionality.";
84b5e22e6dSMehdi Amini   }
85b8737614SUday Bondhugula 
8641574554SRiver Riddle   void runOnOperation() override;
87b8737614SUday Bondhugula   void testVectorShapeRatio(llvm::raw_ostream &outs);
88b8737614SUday Bondhugula   void testForwardSlicing(llvm::raw_ostream &outs);
89b8737614SUday Bondhugula   void testBackwardSlicing(llvm::raw_ostream &outs);
90b8737614SUday Bondhugula   void testSlicing(llvm::raw_ostream &outs);
91b8737614SUday Bondhugula   void testComposeMaps(llvm::raw_ostream &outs);
9293936da9SDiego Caballero 
9393936da9SDiego Caballero   /// Test for 'vectorizeAffineLoopNest' utility.
9493936da9SDiego Caballero   void testVecAffineLoopNest();
95b8737614SUday Bondhugula };
96b8737614SUday Bondhugula 
97be0a7e9fSMehdi Amini } // namespace
98b8737614SUday Bondhugula 
testVectorShapeRatio(llvm::raw_ostream & outs)99b8737614SUday Bondhugula void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) {
10041574554SRiver Riddle   auto f = getOperation();
101b8737614SUday Bondhugula   using matcher::Op;
102b8737614SUday Bondhugula   SmallVector<int64_t, 8> shape(clTestVectorShapeRatio.begin(),
103b8737614SUday Bondhugula                                 clTestVectorShapeRatio.end());
104b8737614SUday Bondhugula   auto subVectorType =
105b8737614SUday Bondhugula       VectorType::get(shape, FloatType::getF32(f.getContext()));
106b8737614SUday Bondhugula   // Only filter operations that operate on a strict super-vector and have one
107b8737614SUday Bondhugula   // return. This makes testing easier.
108b8737614SUday Bondhugula   auto filter = [&](Operation &op) {
109b8737614SUday Bondhugula     assert(subVectorType.getElementType().isF32() &&
110b8737614SUday Bondhugula            "Only f32 supported for now");
111b8737614SUday Bondhugula     if (!matcher::operatesOnSuperVectorsOf(op, subVectorType)) {
112b8737614SUday Bondhugula       return false;
113b8737614SUday Bondhugula     }
114b8737614SUday Bondhugula     if (op.getNumResults() != 1) {
115b8737614SUday Bondhugula       return false;
116b8737614SUday Bondhugula     }
117b8737614SUday Bondhugula     return true;
118b8737614SUday Bondhugula   };
119b8737614SUday Bondhugula   auto pat = Op(filter);
120b8737614SUday Bondhugula   SmallVector<NestedMatch, 8> matches;
121b8737614SUday Bondhugula   pat.match(f, &matches);
122b8737614SUday Bondhugula   for (auto m : matches) {
123b8737614SUday Bondhugula     auto *opInst = m.getMatchedOperation();
124b8737614SUday Bondhugula     // This is a unit test that only checks and prints shape ratio.
125b8737614SUday Bondhugula     // As a consequence we write only Ops with a single return type for the
126b8737614SUday Bondhugula     // purpose of this test. If we need to test more intricate behavior in the
127b8737614SUday Bondhugula     // future we can always extend.
128b8737614SUday Bondhugula     auto superVectorType = opInst->getResult(0).getType().cast<VectorType>();
129b8737614SUday Bondhugula     auto ratio = shapeRatio(superVectorType, subVectorType);
130*037f0995SKazu Hirata     if (!ratio) {
131b8737614SUday Bondhugula       opInst->emitRemark("NOT MATCHED");
132b8737614SUday Bondhugula     } else {
133b8737614SUday Bondhugula       outs << "\nmatched: " << *opInst << " with shape ratio: ";
1342f21a579SRiver Riddle       llvm::interleaveComma(MutableArrayRef<int64_t>(*ratio), outs);
135b8737614SUday Bondhugula     }
136b8737614SUday Bondhugula   }
137b8737614SUday Bondhugula }
138b8737614SUday Bondhugula 
patternTestSlicingOps()139b8737614SUday Bondhugula static NestedPattern patternTestSlicingOps() {
140b8737614SUday Bondhugula   using matcher::Op;
141b8737614SUday Bondhugula   // Match all operations with the kTestSlicingOpName name.
142b8737614SUday Bondhugula   auto filter = [](Operation &op) {
143b8737614SUday Bondhugula     // Just use a custom op name for this test, it makes life easier.
144b8737614SUday Bondhugula     return op.getName().getStringRef() == "slicing-test-op";
145b8737614SUday Bondhugula   };
146b8737614SUday Bondhugula   return Op(filter);
147b8737614SUday Bondhugula }
148b8737614SUday Bondhugula 
testBackwardSlicing(llvm::raw_ostream & outs)149b8737614SUday Bondhugula void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) {
15041574554SRiver Riddle   auto f = getOperation();
151b8737614SUday Bondhugula   outs << "\n" << f.getName();
152b8737614SUday Bondhugula 
153b8737614SUday Bondhugula   SmallVector<NestedMatch, 8> matches;
154b8737614SUday Bondhugula   patternTestSlicingOps().match(f, &matches);
155b8737614SUday Bondhugula   for (auto m : matches) {
156b8737614SUday Bondhugula     SetVector<Operation *> backwardSlice;
157b8737614SUday Bondhugula     getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
158b8737614SUday Bondhugula     outs << "\nmatched: " << *m.getMatchedOperation()
159b8737614SUday Bondhugula          << " backward static slice: ";
160b8737614SUday Bondhugula     for (auto *op : backwardSlice)
161b8737614SUday Bondhugula       outs << "\n" << *op;
162b8737614SUday Bondhugula   }
163b8737614SUday Bondhugula }
164b8737614SUday Bondhugula 
testForwardSlicing(llvm::raw_ostream & outs)165b8737614SUday Bondhugula void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) {
16641574554SRiver Riddle   auto f = getOperation();
167b8737614SUday Bondhugula   outs << "\n" << f.getName();
168b8737614SUday Bondhugula 
169b8737614SUday Bondhugula   SmallVector<NestedMatch, 8> matches;
170b8737614SUday Bondhugula   patternTestSlicingOps().match(f, &matches);
171b8737614SUday Bondhugula   for (auto m : matches) {
172b8737614SUday Bondhugula     SetVector<Operation *> forwardSlice;
173b8737614SUday Bondhugula     getForwardSlice(m.getMatchedOperation(), &forwardSlice);
174b8737614SUday Bondhugula     outs << "\nmatched: " << *m.getMatchedOperation()
175b8737614SUday Bondhugula          << " forward static slice: ";
176b8737614SUday Bondhugula     for (auto *op : forwardSlice)
177b8737614SUday Bondhugula       outs << "\n" << *op;
178b8737614SUday Bondhugula   }
179b8737614SUday Bondhugula }
180b8737614SUday Bondhugula 
testSlicing(llvm::raw_ostream & outs)181b8737614SUday Bondhugula void VectorizerTestPass::testSlicing(llvm::raw_ostream &outs) {
18241574554SRiver Riddle   auto f = getOperation();
183b8737614SUday Bondhugula   outs << "\n" << f.getName();
184b8737614SUday Bondhugula 
185b8737614SUday Bondhugula   SmallVector<NestedMatch, 8> matches;
186b8737614SUday Bondhugula   patternTestSlicingOps().match(f, &matches);
187b8737614SUday Bondhugula   for (auto m : matches) {
188b8737614SUday Bondhugula     SetVector<Operation *> staticSlice = getSlice(m.getMatchedOperation());
189b8737614SUday Bondhugula     outs << "\nmatched: " << *m.getMatchedOperation() << " static slice: ";
190b8737614SUday Bondhugula     for (auto *op : staticSlice)
191b8737614SUday Bondhugula       outs << "\n" << *op;
192b8737614SUday Bondhugula   }
193b8737614SUday Bondhugula }
194b8737614SUday Bondhugula 
customOpWithAffineMapAttribute(Operation & op)195b8737614SUday Bondhugula static bool customOpWithAffineMapAttribute(Operation &op) {
196b8737614SUday Bondhugula   return op.getName().getStringRef() ==
197b8737614SUday Bondhugula          VectorizerTestPass::kTestAffineMapOpName;
198b8737614SUday Bondhugula }
199b8737614SUday Bondhugula 
testComposeMaps(llvm::raw_ostream & outs)200b8737614SUday Bondhugula void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) {
20141574554SRiver Riddle   auto f = getOperation();
202b8737614SUday Bondhugula 
203b8737614SUday Bondhugula   using matcher::Op;
204b8737614SUday Bondhugula   auto pattern = Op(customOpWithAffineMapAttribute);
205b8737614SUday Bondhugula   SmallVector<NestedMatch, 8> matches;
206b8737614SUday Bondhugula   pattern.match(f, &matches);
207b8737614SUday Bondhugula   SmallVector<AffineMap, 4> maps;
208b8737614SUday Bondhugula   maps.reserve(matches.size());
209b8737614SUday Bondhugula   for (auto m : llvm::reverse(matches)) {
210b8737614SUday Bondhugula     auto *opInst = m.getMatchedOperation();
211b8737614SUday Bondhugula     auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
212b8737614SUday Bondhugula                    .cast<AffineMapAttr>()
213b8737614SUday Bondhugula                    .getValue();
214b8737614SUday Bondhugula     maps.push_back(map);
215b8737614SUday Bondhugula   }
216b8737614SUday Bondhugula   AffineMap res;
217b8737614SUday Bondhugula   for (auto m : maps) {
218b8737614SUday Bondhugula     res = res ? res.compose(m) : m;
219b8737614SUday Bondhugula   }
220b8737614SUday Bondhugula   simplifyAffineMap(res).print(outs << "\nComposed map: ");
221b8737614SUday Bondhugula }
222b8737614SUday Bondhugula 
22393936da9SDiego Caballero /// Test for 'vectorizeAffineLoopNest' utility.
testVecAffineLoopNest()22493936da9SDiego Caballero void VectorizerTestPass::testVecAffineLoopNest() {
22593936da9SDiego Caballero   std::vector<SmallVector<AffineForOp, 2>> loops;
22641574554SRiver Riddle   gatherLoops(getOperation(), loops);
227b8737614SUday Bondhugula 
22893936da9SDiego Caballero   // Expected only one loop nest.
22993936da9SDiego Caballero   if (loops.empty() || loops[0].size() != 1)
23093936da9SDiego Caballero     return;
23193936da9SDiego Caballero 
23293936da9SDiego Caballero   // We vectorize the outermost loop found with VF=4.
23393936da9SDiego Caballero   AffineForOp outermostLoop = loops[0][0];
23493936da9SDiego Caballero   VectorizationStrategy strategy;
23593936da9SDiego Caballero   strategy.vectorSizes.push_back(4 /*vectorization factor*/);
23693936da9SDiego Caballero   strategy.loopToVectorDim[outermostLoop] = 0;
23793936da9SDiego Caballero   std::vector<SmallVector<AffineForOp, 2>> loopsToVectorize;
23893936da9SDiego Caballero   loopsToVectorize.push_back({outermostLoop});
239e21adfa3SRiver Riddle   (void)vectorizeAffineLoopNest(loopsToVectorize, strategy);
24093936da9SDiego Caballero }
24193936da9SDiego Caballero 
runOnOperation()24241574554SRiver Riddle void VectorizerTestPass::runOnOperation() {
243b8737614SUday Bondhugula   // Only support single block functions at this point.
24458ceae95SRiver Riddle   func::FuncOp f = getOperation();
2452eaadfc4SRahul Joshi   if (!llvm::hasSingleElement(f))
246b8737614SUday Bondhugula     return;
247b8737614SUday Bondhugula 
248b8737614SUday Bondhugula   std::string str;
249b8737614SUday Bondhugula   llvm::raw_string_ostream outs(str);
250b8737614SUday Bondhugula 
25193936da9SDiego Caballero   { // Tests that expect a NestedPatternContext to be allocated externally.
25293936da9SDiego Caballero     NestedPatternContext mlContext;
25393936da9SDiego Caballero 
254b8737614SUday Bondhugula     if (!clTestVectorShapeRatio.empty())
255b8737614SUday Bondhugula       testVectorShapeRatio(outs);
256b8737614SUday Bondhugula 
257b8737614SUday Bondhugula     if (clTestForwardSlicingAnalysis)
258b8737614SUday Bondhugula       testForwardSlicing(outs);
259b8737614SUday Bondhugula 
260b8737614SUday Bondhugula     if (clTestBackwardSlicingAnalysis)
261b8737614SUday Bondhugula       testBackwardSlicing(outs);
262b8737614SUday Bondhugula 
263b8737614SUday Bondhugula     if (clTestSlicingAnalysis)
264b8737614SUday Bondhugula       testSlicing(outs);
265b8737614SUday Bondhugula 
266b8737614SUday Bondhugula     if (clTestComposeMaps)
267b8737614SUday Bondhugula       testComposeMaps(outs);
26893936da9SDiego Caballero   }
26993936da9SDiego Caballero 
27093936da9SDiego Caballero   if (clTestVecAffineLoopNest)
27193936da9SDiego Caballero     testVecAffineLoopNest();
272b8737614SUday Bondhugula 
273b8737614SUday Bondhugula   if (!outs.str().empty()) {
274b8737614SUday Bondhugula     emitRemark(UnknownLoc::get(&getContext()), outs.str());
275b8737614SUday Bondhugula   }
276b8737614SUday Bondhugula }
277b8737614SUday Bondhugula 
278b8737614SUday Bondhugula namespace mlir {
registerVectorizerTestPass()279b5e22e6dSMehdi Amini void registerVectorizerTestPass() { PassRegistration<VectorizerTestPass>(); }
280b8737614SUday Bondhugula } // namespace mlir
281