1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 
24 namespace {
25 
26 struct TestVectorToVectorConversion
27     : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
28   TestVectorToVectorConversion() = default;
29   TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
30   StringRef getArgument() const final {
31     return "test-vector-to-vector-conversion";
32   }
33   StringRef getDescription() const final {
34     return "Test conversion patterns between ops in the vector dialect";
35   }
36 
37   void getDependentDialects(DialectRegistry &registry) const override {
38     registry.insert<AffineDialect>();
39   }
40 
41   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
42                       llvm::cl::init(false)};
43 
44   void runOnFunction() override {
45     auto *ctx = &getContext();
46     RewritePatternSet patterns(ctx);
47     if (unroll) {
48       populateVectorUnrollPatterns(
49           patterns,
50           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
51               filter));
52     }
53     populateVectorToVectorCanonicalizationPatterns(patterns);
54     populateBubbleVectorBitCastOpPatterns(patterns);
55     populateCastAwayVectorLeadingOneDimPatterns(patterns);
56     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
57   }
58 
59 private:
60   // Return the target shape based on op type.
61   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
62     if (isa<arith::AddFOp, SelectOp, arith::CmpFOp>(op))
63       return SmallVector<int64_t, 4>(2, 2);
64     if (isa<vector::ContractionOp>(op))
65       return SmallVector<int64_t, 4>(3, 2);
66     // For transfer ops, just propagate the shape coming from
67     // InsertStridedSlices/ExtractStridedSlices.
68     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
69       VectorType dstVec;
70       for (Operation *users : readOp->getUsers()) {
71         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
72         if (!extract)
73           return llvm::None;
74         auto vecType = extract.getResult().getType().cast<VectorType>();
75         if (dstVec && dstVec != vecType)
76           return llvm::None;
77         dstVec = vecType;
78       }
79       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
80                                      dstVec.getShape().end());
81     }
82     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
83       auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>();
84       if (!insert)
85         return llvm::None;
86       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
87       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
88     }
89     return llvm::None;
90   }
91 
92   static LogicalResult filter(Operation *op) {
93     return success(isa<arith::AddFOp, SelectOp, arith::CmpFOp, ContractionOp,
94                        TransferReadOp, TransferWriteOp>(op));
95   }
96 };
97 
98 struct TestVectorContractionConversion
99     : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
100   StringRef getArgument() const final {
101     return "test-vector-contraction-conversion";
102   }
103   StringRef getDescription() const final {
104     return "Test conversion patterns that lower contract ops in the vector "
105            "dialect";
106   }
107   TestVectorContractionConversion() = default;
108   TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
109   }
110 
111   Option<bool> lowerToFlatMatrix{
112       *this, "vector-lower-matrix-intrinsics",
113       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
114       llvm::cl::init(false)};
115   Option<bool> lowerToFlatTranspose{
116       *this, "vector-flat-transpose",
117       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
118       llvm::cl::init(false)};
119   Option<bool> lowerToShuffleTranspose{
120       *this, "vector-shuffle-transpose",
121       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
122       llvm::cl::init(false)};
123   Option<bool> lowerToOuterProduct{
124       *this, "vector-outerproduct",
125       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
126       llvm::cl::init(false)};
127   Option<bool> lowerToFilterOuterProduct{
128       *this, "vector-filter-outerproduct",
129       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
130                      "vectors of size 4."),
131       llvm::cl::init(false)};
132 
133   void runOnFunction() override {
134     RewritePatternSet patterns(&getContext());
135 
136     // Test on one pattern in isolation.
137     if (lowerToOuterProduct) {
138       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
139       VectorTransformsOptions options{lowering};
140       patterns.add<ContractionOpToOuterProductOpLowering>(options,
141                                                           &getContext());
142       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
143       return;
144     }
145 
146     // Test on one pattern in isolation.
147     if (lowerToFilterOuterProduct) {
148       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
149       VectorTransformsOptions options{lowering};
150       patterns.add<ContractionOpToOuterProductOpLowering>(
151           options, &getContext(), [](vector::ContractionOp op) {
152             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
153             // where M is not 4.
154             if (op.getRhsType().getShape()[0] == 4)
155               return failure();
156             return success();
157           });
158       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
159       return;
160     }
161 
162     // Test on all contract lowering patterns.
163     VectorContractLowering contractLowering = VectorContractLowering::Dot;
164     if (lowerToFlatMatrix)
165       contractLowering = VectorContractLowering::Matmul;
166     VectorMultiReductionLowering vectorMultiReductionLowering =
167         VectorMultiReductionLowering::InnerParallel;
168     VectorTransposeLowering transposeLowering =
169         VectorTransposeLowering::EltWise;
170     if (lowerToFlatTranspose)
171       transposeLowering = VectorTransposeLowering::Flat;
172     if (lowerToShuffleTranspose)
173       transposeLowering = VectorTransposeLowering::Shuffle;
174     VectorTransformsOptions options{
175         contractLowering, vectorMultiReductionLowering, transposeLowering};
176     populateVectorBroadcastLoweringPatterns(patterns);
177     populateVectorContractLoweringPatterns(patterns, options);
178     populateVectorMaskOpLoweringPatterns(patterns);
179     if (!lowerToShuffleTranspose)
180       populateVectorShapeCastLoweringPatterns(patterns);
181     populateVectorTransposeLoweringPatterns(patterns, options);
182     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
183   }
184 };
185 
186 struct TestVectorUnrollingPatterns
187     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
188   StringRef getArgument() const final {
189     return "test-vector-unrolling-patterns";
190   }
191   StringRef getDescription() const final {
192     return "Test conversion patterns to unroll contract ops in the vector "
193            "dialect";
194   }
195   TestVectorUnrollingPatterns() = default;
196   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
197   void runOnFunction() override {
198     MLIRContext *ctx = &getContext();
199     RewritePatternSet patterns(ctx);
200     populateVectorUnrollPatterns(
201         patterns, UnrollVectorOptions()
202                       .setNativeShape(ArrayRef<int64_t>{2, 2})
203                       .setFilterConstraint([](Operation *op) {
204                         return success(isa<arith::AddFOp, vector::FMAOp>(op));
205                       }));
206 
207     if (unrollBasedOnType) {
208       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
209           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
210         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
211         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
212         if (auto floatType = contractOp.getLhsType()
213                                  .getElementType()
214                                  .dyn_cast<FloatType>()) {
215           if (floatType.getWidth() == 16) {
216             nativeShape[2] = 4;
217           }
218         }
219         return nativeShape;
220       };
221       populateVectorUnrollPatterns(patterns,
222                                    UnrollVectorOptions()
223                                        .setNativeShapeFn(nativeShapeFn)
224                                        .setFilterConstraint([](Operation *op) {
225                                          return success(isa<ContractionOp>(op));
226                                        }));
227     } else {
228       populateVectorUnrollPatterns(
229           patterns, UnrollVectorOptions()
230                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
231                         .setFilterConstraint([](Operation *op) {
232                           return success(isa<ContractionOp>(op));
233                         }));
234     }
235     populateVectorToVectorCanonicalizationPatterns(patterns);
236     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
237   }
238 
239   Option<bool> unrollBasedOnType{
240       *this, "unroll-based-on-type",
241       llvm::cl::desc("Set the unroll factor based on type of the operation"),
242       llvm::cl::init(false)};
243 };
244 
245 struct TestVectorDistributePatterns
246     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
247   StringRef getArgument() const final {
248     return "test-vector-distribute-patterns";
249   }
250   StringRef getDescription() const final {
251     return "Test conversion patterns to distribute vector ops in the vector "
252            "dialect";
253   }
254   TestVectorDistributePatterns() = default;
255   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
256   void getDependentDialects(DialectRegistry &registry) const override {
257     registry.insert<VectorDialect>();
258     registry.insert<AffineDialect>();
259   }
260   ListOption<int32_t> multiplicity{
261       *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
262       llvm::cl::desc("Set the multiplicity used for distributing vector")};
263 
264   void runOnFunction() override {
265     MLIRContext *ctx = &getContext();
266     RewritePatternSet patterns(ctx);
267     FuncOp func = getFunction();
268     func.walk([&](arith::AddFOp op) {
269       OpBuilder builder(op);
270       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
271         SmallVector<int64_t, 2> mul;
272         SmallVector<AffineExpr, 2> perm;
273         SmallVector<Value, 2> ids;
274         unsigned count = 0;
275         // Remove the multiplicity of 1 and calculate the affine map based on
276         // the multiplicity.
277         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
278         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
279           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
280             mul.push_back(m[i]);
281             ids.push_back(func.getArgument(count++));
282             perm.push_back(getAffineDimExpr(i, ctx));
283           }
284         }
285         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
286                                   perm, ctx);
287         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
288             builder, op.getOperation(), ids, mul, map);
289         if (ops.hasValue()) {
290           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
291           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
292                                               extractOp);
293         }
294       }
295     });
296     populatePropagateVectorDistributionPatterns(patterns);
297     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
298   }
299 };
300 
301 struct TestVectorToLoopPatterns
302     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
303   StringRef getArgument() const final { return "test-vector-to-forloop"; }
304   StringRef getDescription() const final {
305     return "Test conversion patterns to break up a vector op into a for loop";
306   }
307   TestVectorToLoopPatterns() = default;
308   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
309   void getDependentDialects(DialectRegistry &registry) const override {
310     registry.insert<VectorDialect>();
311     registry.insert<AffineDialect>();
312   }
313   Option<int32_t> multiplicity{
314       *this, "distribution-multiplicity",
315       llvm::cl::desc("Set the multiplicity used for distributing vector"),
316       llvm::cl::init(32)};
317   void runOnFunction() override {
318     MLIRContext *ctx = &getContext();
319     RewritePatternSet patterns(ctx);
320     FuncOp func = getFunction();
321     func.walk([&](arith::AddFOp op) {
322       // Check that the operation type can be broken down into a loop.
323       VectorType type = op.getType().dyn_cast<VectorType>();
324       if (!type || type.getRank() != 1 ||
325           type.getNumElements() % multiplicity != 0)
326         return mlir::WalkResult::advance();
327       auto filterAlloc = [](Operation *op) {
328         if (isa<arith::ConstantOp, memref::AllocOp, CallOp>(op))
329           return false;
330         return true;
331       };
332       auto dependentOps = getSlice(op, filterAlloc);
333       // Create a loop and move instructions from the Op slice into the loop.
334       OpBuilder builder(op);
335       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
336       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
337       auto numIter =
338           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
339       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
340       for (Operation *it : dependentOps) {
341         it->moveBefore(forOp.getBody()->getTerminator());
342       }
343       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
344       // break up the original op and let the patterns propagate.
345       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
346           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
347           map);
348       if (ops.hasValue()) {
349         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
350         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
351       }
352       return mlir::WalkResult::interrupt();
353     });
354     populatePropagateVectorDistributionPatterns(patterns);
355     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
356   }
357 };
358 
359 struct TestVectorTransferUnrollingPatterns
360     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
361   void getDependentDialects(DialectRegistry &registry) const override {
362     registry.insert<AffineDialect>();
363   }
364   StringRef getArgument() const final {
365     return "test-vector-transfer-unrolling-patterns";
366   }
367   StringRef getDescription() const final {
368     return "Test conversion patterns to unroll transfer ops in the vector "
369            "dialect";
370   }
371   void runOnFunction() override {
372     MLIRContext *ctx = &getContext();
373     RewritePatternSet patterns(ctx);
374     populateVectorUnrollPatterns(
375         patterns,
376         UnrollVectorOptions()
377             .setNativeShape(ArrayRef<int64_t>{2, 2})
378             .setFilterConstraint([](Operation *op) {
379               return success(
380                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
381             }));
382     populateVectorToVectorCanonicalizationPatterns(patterns);
383     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
384   }
385 };
386 
387 struct TestVectorTransferFullPartialSplitPatterns
388     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
389                          FunctionPass> {
390   StringRef getArgument() const final {
391     return "test-vector-transfer-full-partial-split";
392   }
393   StringRef getDescription() const final {
394     return "Test conversion patterns to split "
395            "transfer ops via scf.if + linalg ops";
396   }
397   TestVectorTransferFullPartialSplitPatterns() = default;
398   TestVectorTransferFullPartialSplitPatterns(
399       const TestVectorTransferFullPartialSplitPatterns &pass) {}
400 
401   void getDependentDialects(DialectRegistry &registry) const override {
402     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
403                     scf::SCFDialect>();
404   }
405 
406   Option<bool> useLinalgOps{
407       *this, "use-linalg-copy",
408       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
409                      "linalg.copy operations."),
410       llvm::cl::init(false)};
411   void runOnFunction() override {
412     MLIRContext *ctx = &getContext();
413     RewritePatternSet patterns(ctx);
414     VectorTransformsOptions options;
415     if (useLinalgOps)
416       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
417     else
418       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
419     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
420     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
421   }
422 };
423 
424 struct TestVectorTransferOpt
425     : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
426   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
427   StringRef getDescription() const final {
428     return "Test optimization transformations for transfer ops";
429   }
430   void runOnFunction() override { transferOpflowOpt(getFunction()); }
431 };
432 
433 struct TestVectorTransferLoweringPatterns
434     : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
435   void getDependentDialects(DialectRegistry &registry) const override {
436     registry.insert<memref::MemRefDialect>();
437   }
438   StringRef getArgument() const final {
439     return "test-vector-transfer-lowering-patterns";
440   }
441   StringRef getDescription() const final {
442     return "Test conversion patterns to lower transfer ops to other vector ops";
443   }
444   void runOnFunction() override {
445     RewritePatternSet patterns(&getContext());
446     populateVectorTransferLoweringPatterns(patterns);
447     populateVectorTransferPermutationMapLoweringPatterns(patterns);
448     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
449   }
450 };
451 
452 struct TestVectorMultiReductionLoweringPatterns
453     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
454                          FunctionPass> {
455   TestVectorMultiReductionLoweringPatterns() = default;
456   TestVectorMultiReductionLoweringPatterns(
457       const TestVectorMultiReductionLoweringPatterns &pass) {}
458   void getDependentDialects(DialectRegistry &registry) const override {
459     registry.insert<memref::MemRefDialect>();
460   }
461   StringRef getArgument() const final {
462     return "test-vector-multi-reduction-lowering-patterns";
463   }
464   StringRef getDescription() const final {
465     return "Test conversion patterns to lower vector.multi_reduction to other "
466            "vector ops";
467   }
468   Option<bool> useOuterReductions{
469       *this, "use-outer-reductions",
470       llvm::cl::desc("Move reductions to outer most dimensions"),
471       llvm::cl::init(false)};
472   void runOnFunction() override {
473     RewritePatternSet patterns(&getContext());
474     populateVectorMultiReductionLoweringPatterns(
475         patterns, useOuterReductions
476                       ? vector::VectorMultiReductionLowering::InnerParallel
477                       : vector::VectorMultiReductionLowering::InnerReduction);
478     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
479   }
480 };
481 
482 struct TestVectorTransferCollapseInnerMostContiguousDims
483     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
484                          FunctionPass> {
485   TestVectorTransferCollapseInnerMostContiguousDims() = default;
486   TestVectorTransferCollapseInnerMostContiguousDims(
487       const TestVectorTransferCollapseInnerMostContiguousDims &pass) {}
488 
489   void getDependentDialects(DialectRegistry &registry) const override {
490     registry.insert<memref::MemRefDialect, AffineDialect>();
491   }
492 
493   StringRef getArgument() const final {
494     return "test-vector-transfer-collapse-inner-most-dims";
495   }
496 
497   StringRef getDescription() const final {
498     return "Test conversion patterns that reducedes the rank of the vector "
499            "transfer memory and vector operands.";
500   }
501 
502   void runOnFunction() override {
503     RewritePatternSet patterns(&getContext());
504     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
505     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
506   }
507 };
508 
509 struct TestVectorReduceToContractPatternsPatterns
510     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
511                          FunctionPass> {
512   StringRef getArgument() const final {
513     return "test-vector-reduction-to-contract-patterns";
514   }
515   StringRef getDescription() const final {
516     return "Test patterns to convert multireduce op to contract and combine "
517            "broadcast/transpose to contract";
518   }
519   void runOnFunction() override {
520     RewritePatternSet patterns(&getContext());
521     populateVectorReductionToContractPatterns(patterns);
522     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
523   }
524 };
525 
526 } // end anonymous namespace
527 
528 namespace mlir {
529 namespace test {
530 void registerTestVectorConversions() {
531   PassRegistration<TestVectorToVectorConversion>();
532 
533   PassRegistration<TestVectorContractionConversion>();
534 
535   PassRegistration<TestVectorUnrollingPatterns>();
536 
537   PassRegistration<TestVectorTransferUnrollingPatterns>();
538 
539   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
540 
541   PassRegistration<TestVectorDistributePatterns>();
542 
543   PassRegistration<TestVectorToLoopPatterns>();
544 
545   PassRegistration<TestVectorTransferOpt>();
546 
547   PassRegistration<TestVectorTransferLoweringPatterns>();
548 
549   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
550 
551   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
552 
553   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
554 }
555 } // namespace test
556 } // namespace mlir
557