1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 
24 namespace {
25 
26 struct TestVectorToVectorConversion
27     : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
28   TestVectorToVectorConversion() = default;
29   TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
30   StringRef getArgument() const final {
31     return "test-vector-to-vector-conversion";
32   }
33   StringRef getDescription() const final {
34     return "Test conversion patterns between ops in the vector dialect";
35   }
36 
37   void getDependentDialects(DialectRegistry &registry) const override {
38     registry.insert<AffineDialect>();
39   }
40 
41   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
42                       llvm::cl::init(false)};
43 
44   void runOnFunction() override {
45     auto *ctx = &getContext();
46     RewritePatternSet patterns(ctx);
47     if (unroll) {
48       populateVectorUnrollPatterns(
49           patterns,
50           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
51               filter));
52     }
53     populateVectorToVectorCanonicalizationPatterns(patterns);
54     populateBubbleVectorBitCastOpPatterns(patterns);
55     populateCastAwayVectorLeadingOneDimPatterns(patterns);
56     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
57   }
58 
59 private:
60   // Return the target shape based on op type.
61   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
62     if (isa<arith::AddFOp, SelectOp, arith::CmpFOp>(op))
63       return SmallVector<int64_t, 4>(2, 2);
64     if (isa<vector::ContractionOp>(op))
65       return SmallVector<int64_t, 4>(3, 2);
66     // For transfer ops, just propagate the shape coming from
67     // InsertStridedSlices/ExtractStridedSlices.
68     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
69       VectorType dstVec;
70       for (Operation *users : readOp->getUsers()) {
71         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
72         if (!extract)
73           return llvm::None;
74         auto vecType = extract.getResult().getType().cast<VectorType>();
75         if (dstVec && dstVec != vecType)
76           return llvm::None;
77         dstVec = vecType;
78       }
79       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
80                                      dstVec.getShape().end());
81     }
82     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
83       auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>();
84       if (!insert)
85         return llvm::None;
86       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
87       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
88     }
89     return llvm::None;
90   }
91 
92   static LogicalResult filter(Operation *op) {
93     return success(isa<arith::AddFOp, SelectOp, arith::CmpFOp, ContractionOp,
94                        TransferReadOp, TransferWriteOp>(op));
95   }
96 };
97 
98 struct TestVectorContractionConversion
99     : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
100   StringRef getArgument() const final {
101     return "test-vector-contraction-conversion";
102   }
103   StringRef getDescription() const final {
104     return "Test conversion patterns that lower contract ops in the vector "
105            "dialect";
106   }
107   TestVectorContractionConversion() = default;
108   TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
109   }
110 
111   Option<bool> lowerToFlatMatrix{
112       *this, "vector-lower-matrix-intrinsics",
113       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
114       llvm::cl::init(false)};
115   Option<bool> lowerToFlatTranspose{
116       *this, "vector-flat-transpose",
117       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
118       llvm::cl::init(false)};
119   Option<bool> lowerToOuterProduct{
120       *this, "vector-outerproduct",
121       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
122       llvm::cl::init(false)};
123   Option<bool> lowerToFilterOuterProduct{
124       *this, "vector-filter-outerproduct",
125       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
126                      "vectors of size 4."),
127       llvm::cl::init(false)};
128 
129   void runOnFunction() override {
130     RewritePatternSet patterns(&getContext());
131 
132     // Test on one pattern in isolation.
133     if (lowerToOuterProduct) {
134       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
135       VectorTransformsOptions options{lowering};
136       patterns.add<ContractionOpToOuterProductOpLowering>(options,
137                                                           &getContext());
138       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
139       return;
140     }
141 
142     // Test on one pattern in isolation.
143     if (lowerToFilterOuterProduct) {
144       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
145       VectorTransformsOptions options{lowering};
146       patterns.add<ContractionOpToOuterProductOpLowering>(
147           options, &getContext(), [](vector::ContractionOp op) {
148             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
149             // where M is not 4.
150             if (op.getRhsType().getShape()[0] == 4)
151               return failure();
152             return success();
153           });
154       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
155       return;
156     }
157 
158     // Test on all contract lowering patterns.
159     VectorContractLowering contractLowering = VectorContractLowering::Dot;
160     if (lowerToFlatMatrix)
161       contractLowering = VectorContractLowering::Matmul;
162     VectorMultiReductionLowering vectorMultiReductionLowering =
163         VectorMultiReductionLowering::InnerParallel;
164     VectorTransposeLowering transposeLowering =
165         VectorTransposeLowering::EltWise;
166     if (lowerToFlatTranspose)
167       transposeLowering = VectorTransposeLowering::Flat;
168     VectorTransformsOptions options{
169         contractLowering, vectorMultiReductionLowering, transposeLowering};
170     populateVectorBroadcastLoweringPatterns(patterns);
171     populateVectorContractLoweringPatterns(patterns, options);
172     populateVectorMaskOpLoweringPatterns(patterns);
173     populateVectorShapeCastLoweringPatterns(patterns);
174     populateVectorTransposeLoweringPatterns(patterns, options);
175     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
176   }
177 };
178 
179 struct TestVectorUnrollingPatterns
180     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
181   StringRef getArgument() const final {
182     return "test-vector-unrolling-patterns";
183   }
184   StringRef getDescription() const final {
185     return "Test conversion patterns to unroll contract ops in the vector "
186            "dialect";
187   }
188   TestVectorUnrollingPatterns() = default;
189   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
190   void runOnFunction() override {
191     MLIRContext *ctx = &getContext();
192     RewritePatternSet patterns(ctx);
193     populateVectorUnrollPatterns(
194         patterns, UnrollVectorOptions()
195                       .setNativeShape(ArrayRef<int64_t>{2, 2})
196                       .setFilterConstraint([](Operation *op) {
197                         return success(isa<arith::AddFOp, vector::FMAOp>(op));
198                       }));
199 
200     if (unrollBasedOnType) {
201       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
202           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
203         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
204         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
205         if (auto floatType = contractOp.getLhsType()
206                                  .getElementType()
207                                  .dyn_cast<FloatType>()) {
208           if (floatType.getWidth() == 16) {
209             nativeShape[2] = 4;
210           }
211         }
212         return nativeShape;
213       };
214       populateVectorUnrollPatterns(patterns,
215                                    UnrollVectorOptions()
216                                        .setNativeShapeFn(nativeShapeFn)
217                                        .setFilterConstraint([](Operation *op) {
218                                          return success(isa<ContractionOp>(op));
219                                        }));
220     } else {
221       populateVectorUnrollPatterns(
222           patterns, UnrollVectorOptions()
223                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
224                         .setFilterConstraint([](Operation *op) {
225                           return success(isa<ContractionOp>(op));
226                         }));
227     }
228     populateVectorToVectorCanonicalizationPatterns(patterns);
229     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
230   }
231 
232   Option<bool> unrollBasedOnType{
233       *this, "unroll-based-on-type",
234       llvm::cl::desc("Set the unroll factor based on type of the operation"),
235       llvm::cl::init(false)};
236 };
237 
238 struct TestVectorDistributePatterns
239     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
240   StringRef getArgument() const final {
241     return "test-vector-distribute-patterns";
242   }
243   StringRef getDescription() const final {
244     return "Test conversion patterns to distribute vector ops in the vector "
245            "dialect";
246   }
247   TestVectorDistributePatterns() = default;
248   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
249   void getDependentDialects(DialectRegistry &registry) const override {
250     registry.insert<VectorDialect>();
251     registry.insert<AffineDialect>();
252   }
253   ListOption<int32_t> multiplicity{
254       *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
255       llvm::cl::desc("Set the multiplicity used for distributing vector")};
256 
257   void runOnFunction() override {
258     MLIRContext *ctx = &getContext();
259     RewritePatternSet patterns(ctx);
260     FuncOp func = getFunction();
261     func.walk([&](arith::AddFOp op) {
262       OpBuilder builder(op);
263       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
264         SmallVector<int64_t, 2> mul;
265         SmallVector<AffineExpr, 2> perm;
266         SmallVector<Value, 2> ids;
267         unsigned count = 0;
268         // Remove the multiplicity of 1 and calculate the affine map based on
269         // the multiplicity.
270         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
271         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
272           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
273             mul.push_back(m[i]);
274             ids.push_back(func.getArgument(count++));
275             perm.push_back(getAffineDimExpr(i, ctx));
276           }
277         }
278         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
279                                   perm, ctx);
280         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
281             builder, op.getOperation(), ids, mul, map);
282         if (ops.hasValue()) {
283           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
284           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
285                                               extractOp);
286         }
287       }
288     });
289     populatePropagateVectorDistributionPatterns(patterns);
290     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
291   }
292 };
293 
294 struct TestVectorToLoopPatterns
295     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
296   StringRef getArgument() const final { return "test-vector-to-forloop"; }
297   StringRef getDescription() const final {
298     return "Test conversion patterns to break up a vector op into a for loop";
299   }
300   TestVectorToLoopPatterns() = default;
301   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
302   void getDependentDialects(DialectRegistry &registry) const override {
303     registry.insert<VectorDialect>();
304     registry.insert<AffineDialect>();
305   }
306   Option<int32_t> multiplicity{
307       *this, "distribution-multiplicity",
308       llvm::cl::desc("Set the multiplicity used for distributing vector"),
309       llvm::cl::init(32)};
310   void runOnFunction() override {
311     MLIRContext *ctx = &getContext();
312     RewritePatternSet patterns(ctx);
313     FuncOp func = getFunction();
314     func.walk([&](arith::AddFOp op) {
315       // Check that the operation type can be broken down into a loop.
316       VectorType type = op.getType().dyn_cast<VectorType>();
317       if (!type || type.getRank() != 1 ||
318           type.getNumElements() % multiplicity != 0)
319         return mlir::WalkResult::advance();
320       auto filterAlloc = [](Operation *op) {
321         if (isa<arith::ConstantOp, memref::AllocOp, CallOp>(op))
322           return false;
323         return true;
324       };
325       auto dependentOps = getSlice(op, filterAlloc);
326       // Create a loop and move instructions from the Op slice into the loop.
327       OpBuilder builder(op);
328       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
329       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
330       auto numIter =
331           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
332       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
333       for (Operation *it : dependentOps) {
334         it->moveBefore(forOp.getBody()->getTerminator());
335       }
336       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
337       // break up the original op and let the patterns propagate.
338       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
339           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
340           map);
341       if (ops.hasValue()) {
342         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
343         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
344       }
345       return mlir::WalkResult::interrupt();
346     });
347     populatePropagateVectorDistributionPatterns(patterns);
348     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
349   }
350 };
351 
352 struct TestVectorTransferUnrollingPatterns
353     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
354   void getDependentDialects(DialectRegistry &registry) const override {
355     registry.insert<AffineDialect>();
356   }
357   StringRef getArgument() const final {
358     return "test-vector-transfer-unrolling-patterns";
359   }
360   StringRef getDescription() const final {
361     return "Test conversion patterns to unroll transfer ops in the vector "
362            "dialect";
363   }
364   void runOnFunction() override {
365     MLIRContext *ctx = &getContext();
366     RewritePatternSet patterns(ctx);
367     populateVectorUnrollPatterns(
368         patterns,
369         UnrollVectorOptions()
370             .setNativeShape(ArrayRef<int64_t>{2, 2})
371             .setFilterConstraint([](Operation *op) {
372               return success(
373                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
374             }));
375     populateVectorToVectorCanonicalizationPatterns(patterns);
376     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
377   }
378 };
379 
380 struct TestVectorTransferFullPartialSplitPatterns
381     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
382                          FunctionPass> {
383   StringRef getArgument() const final {
384     return "test-vector-transfer-full-partial-split";
385   }
386   StringRef getDescription() const final {
387     return "Test conversion patterns to split "
388            "transfer ops via scf.if + linalg ops";
389   }
390   TestVectorTransferFullPartialSplitPatterns() = default;
391   TestVectorTransferFullPartialSplitPatterns(
392       const TestVectorTransferFullPartialSplitPatterns &pass) {}
393 
394   void getDependentDialects(DialectRegistry &registry) const override {
395     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
396                     scf::SCFDialect>();
397   }
398 
399   Option<bool> useLinalgOps{
400       *this, "use-linalg-copy",
401       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
402                      "linalg.copy operations."),
403       llvm::cl::init(false)};
404   void runOnFunction() override {
405     MLIRContext *ctx = &getContext();
406     RewritePatternSet patterns(ctx);
407     VectorTransformsOptions options;
408     if (useLinalgOps)
409       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
410     else
411       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
412     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
413     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
414   }
415 };
416 
417 struct TestVectorTransferOpt
418     : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
419   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
420   StringRef getDescription() const final {
421     return "Test optimization transformations for transfer ops";
422   }
423   void runOnFunction() override { transferOpflowOpt(getFunction()); }
424 };
425 
426 struct TestVectorTransferLoweringPatterns
427     : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
428   void getDependentDialects(DialectRegistry &registry) const override {
429     registry.insert<memref::MemRefDialect>();
430   }
431   StringRef getArgument() const final {
432     return "test-vector-transfer-lowering-patterns";
433   }
434   StringRef getDescription() const final {
435     return "Test conversion patterns to lower transfer ops to other vector ops";
436   }
437   void runOnFunction() override {
438     RewritePatternSet patterns(&getContext());
439     populateVectorTransferLoweringPatterns(patterns);
440     populateVectorTransferPermutationMapLoweringPatterns(patterns);
441     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
442   }
443 };
444 
445 struct TestVectorMultiReductionLoweringPatterns
446     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
447                          FunctionPass> {
448   TestVectorMultiReductionLoweringPatterns() = default;
449   TestVectorMultiReductionLoweringPatterns(
450       const TestVectorMultiReductionLoweringPatterns &pass) {}
451   void getDependentDialects(DialectRegistry &registry) const override {
452     registry.insert<memref::MemRefDialect>();
453   }
454   StringRef getArgument() const final {
455     return "test-vector-multi-reduction-lowering-patterns";
456   }
457   StringRef getDescription() const final {
458     return "Test conversion patterns to lower vector.multi_reduction to other "
459            "vector ops";
460   }
461   Option<bool> useOuterReductions{
462       *this, "use-outer-reductions",
463       llvm::cl::desc("Move reductions to outer most dimensions"),
464       llvm::cl::init(false)};
465   void runOnFunction() override {
466     RewritePatternSet patterns(&getContext());
467     populateVectorMultiReductionLoweringPatterns(
468         patterns, useOuterReductions
469                       ? vector::VectorMultiReductionLowering::InnerParallel
470                       : vector::VectorMultiReductionLowering::InnerReduction);
471     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
472   }
473 };
474 
475 struct TestVectorTransferCollapseInnerMostContiguousDims
476     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
477                          FunctionPass> {
478   TestVectorTransferCollapseInnerMostContiguousDims() = default;
479   TestVectorTransferCollapseInnerMostContiguousDims(
480       const TestVectorTransferCollapseInnerMostContiguousDims &pass) {}
481 
482   void getDependentDialects(DialectRegistry &registry) const override {
483     registry.insert<memref::MemRefDialect, AffineDialect>();
484   }
485 
486   StringRef getArgument() const final {
487     return "test-vector-transfer-collapse-inner-most-dims";
488   }
489 
490   StringRef getDescription() const final {
491     return "Test conversion patterns that reducedes the rank of the vector "
492            "transfer memory and vector operands.";
493   }
494 
495   void runOnFunction() override {
496     RewritePatternSet patterns(&getContext());
497     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
498     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
499   }
500 };
501 
502 struct TestVectorReduceToContractPatternsPatterns
503     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
504                          FunctionPass> {
505   StringRef getArgument() const final {
506     return "test-vector-reduction-to-contract-patterns";
507   }
508   StringRef getDescription() const final {
509     return "Test patterns to convert multireduce op to contract and combine "
510            "broadcast/transpose to contract";
511   }
512   void runOnFunction() override {
513     RewritePatternSet patterns(&getContext());
514     populateVectorReductionToContractPatterns(patterns);
515     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
516   }
517 };
518 
519 } // end anonymous namespace
520 
521 namespace mlir {
522 namespace test {
523 void registerTestVectorConversions() {
524   PassRegistration<TestVectorToVectorConversion>();
525 
526   PassRegistration<TestVectorContractionConversion>();
527 
528   PassRegistration<TestVectorUnrollingPatterns>();
529 
530   PassRegistration<TestVectorTransferUnrollingPatterns>();
531 
532   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
533 
534   PassRegistration<TestVectorDistributePatterns>();
535 
536   PassRegistration<TestVectorToLoopPatterns>();
537 
538   PassRegistration<TestVectorTransferOpt>();
539 
540   PassRegistration<TestVectorTransferLoweringPatterns>();
541 
542   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
543 
544   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
545 
546   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
547 }
548 } // namespace test
549 } // namespace mlir
550