1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/Dialect/Vector/VectorTransforms.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 namespace {
25 
26 struct TestVectorToVectorConversion
27     : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
28   TestVectorToVectorConversion() = default;
29   TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
30 
31   void getDependentDialects(DialectRegistry &registry) const override {
32     registry.insert<AffineDialect>();
33   }
34 
35   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
36                       llvm::cl::init(false)};
37 
38   void runOnFunction() override {
39     auto *ctx = &getContext();
40     RewritePatternSet patterns(ctx);
41     if (unroll) {
42       patterns.add<UnrollVectorPattern>(
43           ctx,
44           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
45               filter));
46     }
47     populateVectorToVectorCanonicalizationPatterns(patterns);
48     populateVectorToVectorTransformationPatterns(patterns);
49     populateBubbleVectorBitCastOpPatterns(patterns);
50     populateCastAwayVectorLeadingOneDimPatterns(patterns);
51     populateSplitVectorTransferPatterns(patterns);
52     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
53   }
54 
55 private:
56   // Return the target shape based on op type.
57   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
58     if (isa<AddFOp, SelectOp, CmpFOp>(op))
59       return SmallVector<int64_t, 4>(2, 2);
60     if (isa<vector::ContractionOp>(op))
61       return SmallVector<int64_t, 4>(3, 2);
62     return llvm::None;
63   }
64 
65   static LogicalResult filter(Operation *op) {
66     return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp>(op));
67   }
68 };
69 
70 struct TestVectorSlicesConversion
71     : public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
72   void runOnFunction() override {
73     RewritePatternSet patterns(&getContext());
74     populateVectorSlicesLoweringPatterns(patterns);
75     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
76   }
77 };
78 
79 struct TestVectorContractionConversion
80     : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
81   TestVectorContractionConversion() = default;
82   TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
83   }
84 
85   Option<bool> lowerToFlatMatrix{
86       *this, "vector-lower-matrix-intrinsics",
87       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
88       llvm::cl::init(false)};
89   Option<bool> lowerToFlatTranspose{
90       *this, "vector-flat-transpose",
91       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
92       llvm::cl::init(false)};
93   Option<bool> lowerToOuterProduct{
94       *this, "vector-outerproduct",
95       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
96       llvm::cl::init(false)};
97   Option<bool> lowerToFilterOuterProduct{
98       *this, "vector-filter-outerproduct",
99       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
100                      "vectors of size 4."),
101       llvm::cl::init(false)};
102 
103   void runOnFunction() override {
104     RewritePatternSet patterns(&getContext());
105 
106     // Test on one pattern in isolation.
107     if (lowerToOuterProduct) {
108       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
109       VectorTransformsOptions options{lowering};
110       patterns.add<ContractionOpToOuterProductOpLowering>(options,
111                                                           &getContext());
112       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
113       return;
114     }
115 
116     // Test on one pattern in isolation.
117     if (lowerToFilterOuterProduct) {
118       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
119       VectorTransformsOptions options{lowering};
120       patterns.add<ContractionOpToOuterProductOpLowering>(
121           options, &getContext(), [](vector::ContractionOp op) {
122             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
123             // where M is not 4.
124             if (op.getRhsType().getShape()[0] == 4)
125               return failure();
126             return success();
127           });
128       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
129       return;
130     }
131 
132     // Test on all contract lowering patterns.
133     VectorContractLowering contractLowering = VectorContractLowering::Dot;
134     if (lowerToFlatMatrix)
135       contractLowering = VectorContractLowering::Matmul;
136     VectorTransposeLowering transposeLowering =
137         VectorTransposeLowering::EltWise;
138     if (lowerToFlatTranspose)
139       transposeLowering = VectorTransposeLowering::Flat;
140     VectorTransformsOptions options{contractLowering, transposeLowering};
141     populateVectorContractLoweringPatterns(patterns, options);
142     populateVectorTransposeLoweringPatterns(patterns, options);
143     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
144   }
145 };
146 
147 struct TestVectorUnrollingPatterns
148     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
149   TestVectorUnrollingPatterns() = default;
150   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
151   void runOnFunction() override {
152     MLIRContext *ctx = &getContext();
153     RewritePatternSet patterns(ctx);
154     patterns.add<UnrollVectorPattern>(
155         ctx, UnrollVectorOptions()
156                  .setNativeShape(ArrayRef<int64_t>{2, 2})
157                  .setFilterConstraint([](Operation *op) {
158                    return success(isa<AddFOp, vector::FMAOp>(op));
159                  }));
160 
161     if (unrollBasedOnType) {
162       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
163           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
164         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
165         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
166         if (auto floatType = contractOp.getLhsType()
167                                  .getElementType()
168                                  .dyn_cast<FloatType>()) {
169           if (floatType.getWidth() == 16) {
170             nativeShape[2] = 4;
171           }
172         }
173         return nativeShape;
174       };
175       patterns.add<UnrollVectorPattern>(
176           ctx, UnrollVectorOptions()
177                    .setNativeShapeFn(nativeShapeFn)
178                    .setFilterConstraint([](Operation *op) {
179                      return success(isa<ContractionOp>(op));
180                    }));
181     } else {
182       patterns.add<UnrollVectorPattern>(
183           ctx, UnrollVectorOptions()
184                    .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
185                    .setFilterConstraint([](Operation *op) {
186                      return success(isa<ContractionOp>(op));
187                    }));
188     }
189     populateVectorToVectorCanonicalizationPatterns(patterns);
190     populateVectorToVectorTransformationPatterns(patterns);
191     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
192   }
193 
194   Option<bool> unrollBasedOnType{
195       *this, "unroll-based-on-type",
196       llvm::cl::desc("Set the unroll factor based on type of the operation"),
197       llvm::cl::init(false)};
198 };
199 
200 struct TestVectorDistributePatterns
201     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
202   TestVectorDistributePatterns() = default;
203   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
204   void getDependentDialects(DialectRegistry &registry) const override {
205     registry.insert<VectorDialect>();
206     registry.insert<AffineDialect>();
207   }
208   ListOption<int32_t> multiplicity{
209       *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
210       llvm::cl::desc("Set the multiplicity used for distributing vector")};
211 
212   void runOnFunction() override {
213     MLIRContext *ctx = &getContext();
214     RewritePatternSet patterns(ctx);
215     FuncOp func = getFunction();
216     func.walk([&](AddFOp op) {
217       OpBuilder builder(op);
218       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
219         SmallVector<int64_t, 2> mul;
220         SmallVector<AffineExpr, 2> perm;
221         SmallVector<Value, 2> ids;
222         unsigned count = 0;
223         // Remove the multiplicity of 1 and calculate the affine map based on
224         // the multiplicity.
225         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
226         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
227           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
228             mul.push_back(m[i]);
229             ids.push_back(func.getArgument(count++));
230             perm.push_back(getAffineDimExpr(i, ctx));
231           }
232         }
233         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
234                                   perm, ctx);
235         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
236             builder, op.getOperation(), ids, mul, map);
237         if (ops.hasValue()) {
238           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
239           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
240                                               extractOp);
241         }
242       }
243     });
244     patterns.add<PointwiseExtractPattern>(ctx);
245     populateVectorToVectorTransformationPatterns(patterns);
246     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
247   }
248 };
249 
250 struct TestVectorToLoopPatterns
251     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
252   TestVectorToLoopPatterns() = default;
253   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
254   void getDependentDialects(DialectRegistry &registry) const override {
255     registry.insert<VectorDialect>();
256     registry.insert<AffineDialect>();
257   }
258   Option<int32_t> multiplicity{
259       *this, "distribution-multiplicity",
260       llvm::cl::desc("Set the multiplicity used for distributing vector"),
261       llvm::cl::init(32)};
262   void runOnFunction() override {
263     MLIRContext *ctx = &getContext();
264     RewritePatternSet patterns(ctx);
265     FuncOp func = getFunction();
266     func.walk([&](AddFOp op) {
267       // Check that the operation type can be broken down into a loop.
268       VectorType type = op.getType().dyn_cast<VectorType>();
269       if (!type || type.getRank() != 1 ||
270           type.getNumElements() % multiplicity != 0)
271         return mlir::WalkResult::advance();
272       auto filterAlloc = [](Operation *op) {
273         if (isa<ConstantOp, memref::AllocOp, CallOp>(op))
274           return false;
275         return true;
276       };
277       auto dependentOps = getSlice(op, filterAlloc);
278       // Create a loop and move instructions from the Op slice into the loop.
279       OpBuilder builder(op);
280       auto zero = builder.create<ConstantOp>(
281           op.getLoc(), builder.getIndexType(),
282           builder.getIntegerAttr(builder.getIndexType(), 0));
283       auto one = builder.create<ConstantOp>(
284           op.getLoc(), builder.getIndexType(),
285           builder.getIntegerAttr(builder.getIndexType(), 1));
286       auto numIter = builder.create<ConstantOp>(
287           op.getLoc(), builder.getIndexType(),
288           builder.getIntegerAttr(builder.getIndexType(), multiplicity));
289       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
290       for (Operation *it : dependentOps) {
291         it->moveBefore(forOp.getBody()->getTerminator());
292       }
293       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
294       // break up the original op and let the patterns propagate.
295       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
296           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
297           map);
298       if (ops.hasValue()) {
299         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
300         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
301       }
302       return mlir::WalkResult::interrupt();
303     });
304     patterns.add<PointwiseExtractPattern>(ctx);
305     populateVectorToVectorTransformationPatterns(patterns);
306     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
307   }
308 };
309 
310 struct TestVectorTransferUnrollingPatterns
311     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
312   void getDependentDialects(DialectRegistry &registry) const override {
313     registry.insert<AffineDialect>();
314   }
315   void runOnFunction() override {
316     MLIRContext *ctx = &getContext();
317     RewritePatternSet patterns(ctx);
318     patterns.add<UnrollVectorPattern>(
319         ctx,
320         UnrollVectorOptions()
321             .setNativeShape(ArrayRef<int64_t>{2, 2})
322             .setFilterConstraint([](Operation *op) {
323               return success(
324                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
325             }));
326     populateVectorToVectorCanonicalizationPatterns(patterns);
327     populateVectorToVectorTransformationPatterns(patterns);
328     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
329   }
330 };
331 
332 struct TestVectorTransferFullPartialSplitPatterns
333     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
334                          FunctionPass> {
335   TestVectorTransferFullPartialSplitPatterns() = default;
336   TestVectorTransferFullPartialSplitPatterns(
337       const TestVectorTransferFullPartialSplitPatterns &pass) {}
338 
339   void getDependentDialects(DialectRegistry &registry) const override {
340     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
341                     scf::SCFDialect>();
342   }
343 
344   Option<bool> useLinalgOps{
345       *this, "use-linalg-copy",
346       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
347                      "linalg.copy operations."),
348       llvm::cl::init(false)};
349   void runOnFunction() override {
350     MLIRContext *ctx = &getContext();
351     RewritePatternSet patterns(ctx);
352     VectorTransformsOptions options;
353     if (useLinalgOps)
354       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
355     else
356       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
357     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
358     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
359   }
360 };
361 
362 struct TestVectorTransferOpt
363     : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
364   void runOnFunction() override { transferOpflowOpt(getFunction()); }
365 };
366 
367 struct TestVectorTransferLoweringPatterns
368     : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
369   void getDependentDialects(DialectRegistry &registry) const override {
370     registry.insert<memref::MemRefDialect>();
371   }
372   void runOnFunction() override {
373     RewritePatternSet patterns(&getContext());
374     populateVectorTransferLoweringPatterns(patterns);
375     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
376   }
377 };
378 
379 struct TestVectorMultiReductionLoweringPatterns
380     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
381                          FunctionPass> {
382   void getDependentDialects(DialectRegistry &registry) const override {
383     registry.insert<memref::MemRefDialect>();
384   }
385   void runOnFunction() override {
386     RewritePatternSet patterns(&getContext());
387     populateVectorMultiReductionLoweringPatterns(patterns);
388     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
389   }
390 };
391 
392 } // end anonymous namespace
393 
394 namespace mlir {
395 namespace test {
396 void registerTestVectorConversions() {
397   PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
398       "test-vector-to-vector-conversion",
399       "Test conversion patterns between ops in the vector dialect");
400 
401   PassRegistration<TestVectorSlicesConversion> slicesPass(
402       "test-vector-slices-conversion",
403       "Test conversion patterns that lower slices ops in the vector dialect");
404 
405   PassRegistration<TestVectorContractionConversion> contractionPass(
406       "test-vector-contraction-conversion",
407       "Test conversion patterns that lower contract ops in the vector dialect");
408 
409   PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
410       "test-vector-unrolling-patterns",
411       "Test conversion patterns to unroll contract ops in the vector dialect");
412 
413   PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
414       "test-vector-transfer-unrolling-patterns",
415       "Test conversion patterns to unroll transfer ops in the vector dialect");
416 
417   PassRegistration<TestVectorTransferFullPartialSplitPatterns>
418       vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
419                                      "Test conversion patterns to split "
420                                      "transfer ops via scf.if + linalg ops");
421 
422   PassRegistration<TestVectorDistributePatterns> distributePass(
423       "test-vector-distribute-patterns",
424       "Test conversion patterns to distribute vector ops in the vector "
425       "dialect");
426 
427   PassRegistration<TestVectorToLoopPatterns> vectorToForLoop(
428       "test-vector-to-forloop",
429       "Test conversion patterns to break up a vector op into a for loop");
430 
431   PassRegistration<TestVectorTransferOpt> transferOpOpt(
432       "test-vector-transferop-opt",
433       "Test optimization transformations for transfer ops");
434 
435   PassRegistration<TestVectorTransferLoweringPatterns> transferOpLoweringPass(
436       "test-vector-transfer-lowering-patterns",
437       "Test conversion patterns to lower transfer ops to other vector ops");
438 
439   PassRegistration<TestVectorMultiReductionLoweringPatterns>
440       multiDimReductionOpLoweringPass(
441           "test-vector-multi-reduction-lowering-patterns",
442           "Test conversion patterns to lower vector.multi_reduction to other "
443           "vector ops");
444 }
445 } // namespace test
446 } // namespace mlir
447