1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 using namespace mlir::vector;
32 
33 namespace {
34 
35 struct TestVectorToVectorLowering
36     : public PassWrapper<TestVectorToVectorLowering,
37                          OperationPass<func::FuncOp>> {
38   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
39 
40   TestVectorToVectorLowering() = default;
41   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
42       : PassWrapper(pass) {}
43   StringRef getArgument() const final {
44     return "test-vector-to-vector-lowering";
45   }
46   StringRef getDescription() const final {
47     return "Test lowering patterns between ops in the vector dialect";
48   }
49 
50   void getDependentDialects(DialectRegistry &registry) const override {
51     registry.insert<AffineDialect>();
52   }
53 
54   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
55                       llvm::cl::init(false)};
56 
57   void runOnOperation() override {
58     auto *ctx = &getContext();
59     RewritePatternSet patterns(ctx);
60     if (unroll) {
61       populateVectorUnrollPatterns(
62           patterns,
63           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
64               filter));
65     }
66     populateVectorToVectorCanonicalizationPatterns(patterns);
67     populateBubbleVectorBitCastOpPatterns(patterns);
68     populateCastAwayVectorLeadingOneDimPatterns(patterns);
69     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
70   }
71 
72 private:
73   // Return the target shape based on op type.
74   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
75     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
76       return SmallVector<int64_t, 4>(2, 2);
77     if (isa<vector::ContractionOp>(op))
78       return SmallVector<int64_t, 4>(3, 2);
79     // For transfer ops, just propagate the shape coming from
80     // InsertStridedSlices/ExtractStridedSlices.
81     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
82       VectorType dstVec;
83       for (Operation *users : readOp->getUsers()) {
84         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
85         if (!extract)
86           return llvm::None;
87         auto vecType = extract.getResult().getType().cast<VectorType>();
88         if (dstVec && dstVec != vecType)
89           return llvm::None;
90         dstVec = vecType;
91       }
92       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
93                                      dstVec.getShape().end());
94     }
95     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
96       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
97       if (!insert)
98         return llvm::None;
99       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
100       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
101     }
102     return llvm::None;
103   }
104 
105   static LogicalResult filter(Operation *op) {
106     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
107                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
108   }
109 };
110 
111 struct TestVectorContractionLowering
112     : public PassWrapper<TestVectorContractionLowering,
113                          OperationPass<func::FuncOp>> {
114   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
115 
116   StringRef getArgument() const final {
117     return "test-vector-contraction-lowering";
118   }
119   StringRef getDescription() const final {
120     return "Test lowering patterns that lower contract ops in the vector "
121            "dialect";
122   }
123   TestVectorContractionLowering() = default;
124   TestVectorContractionLowering(const TestVectorContractionLowering &pass)
125       : PassWrapper(pass) {}
126 
127   Option<bool> lowerToFlatMatrix{
128       *this, "vector-lower-matrix-intrinsics",
129       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
130       llvm::cl::init(false)};
131   Option<bool> lowerToOuterProduct{
132       *this, "vector-outerproduct",
133       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
134       llvm::cl::init(false)};
135   Option<bool> lowerToFilterOuterProduct{
136       *this, "vector-filter-outerproduct",
137       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
138                      "vectors of size 4."),
139       llvm::cl::init(false)};
140   Option<bool> lowerToParallelArith{
141       *this, "vector-parallel-arith",
142       llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
143       llvm::cl::init(false)};
144 
145   void runOnOperation() override {
146     RewritePatternSet patterns(&getContext());
147 
148     // Test on one pattern in isolation.
149     if (lowerToOuterProduct) {
150       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
151       VectorTransformsOptions options{lowering};
152       patterns.add<ContractionOpToOuterProductOpLowering>(options,
153                                                           &getContext());
154       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
155       return;
156     }
157 
158     // Test on one pattern in isolation.
159     if (lowerToFilterOuterProduct) {
160       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
161       VectorTransformsOptions options{lowering};
162       patterns.add<ContractionOpToOuterProductOpLowering>(
163           options, &getContext(), [](vector::ContractionOp op) {
164             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
165             // where M is not 4.
166             if (op.getRhsType().getShape()[0] == 4)
167               return failure();
168             return success();
169           });
170       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
171       return;
172     }
173 
174     if (lowerToParallelArith) {
175       vector::populateVectorContractLoweringPatterns(
176           patterns,
177           vector::VectorTransformsOptions().setVectorTransformsOptions(
178               vector::VectorContractLowering::ParallelArith));
179       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
180       return;
181     }
182 
183     // Test on all contract lowering patterns.
184     VectorContractLowering contractLowering = VectorContractLowering::Dot;
185     if (lowerToFlatMatrix)
186       contractLowering = VectorContractLowering::Matmul;
187     VectorMultiReductionLowering vectorMultiReductionLowering =
188         VectorMultiReductionLowering::InnerParallel;
189     VectorTransformsOptions options{contractLowering,
190                                     vectorMultiReductionLowering,
191                                     VectorTransposeLowering()};
192     populateVectorBroadcastLoweringPatterns(patterns);
193     populateVectorContractLoweringPatterns(patterns, options);
194     populateVectorMaskOpLoweringPatterns(patterns);
195     populateVectorShapeCastLoweringPatterns(patterns);
196     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
197   }
198 };
199 
200 struct TestVectorTransposeLowering
201     : public PassWrapper<TestVectorTransposeLowering,
202                          OperationPass<func::FuncOp>> {
203   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
204 
205   StringRef getArgument() const final {
206     return "test-vector-transpose-lowering";
207   }
208   StringRef getDescription() const final {
209     return "Test lowering patterns that lower contract ops in the vector "
210            "dialect";
211   }
212   TestVectorTransposeLowering() = default;
213   TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
214       : PassWrapper(pass) {}
215 
216   Option<bool> lowerToEltwise{
217       *this, "eltwise",
218       llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
219       llvm::cl::init(false)};
220   Option<bool> lowerToFlatTranspose{
221       *this, "flat",
222       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
223       llvm::cl::init(false)};
224   Option<bool> lowerToShuffleTranspose{
225       *this, "shuffle",
226       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
227       llvm::cl::init(false)};
228   Option<bool> lowerToAvx2{
229       *this, "avx2",
230       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
231       llvm::cl::init(false)};
232 
233   void getDependentDialects(DialectRegistry &registry) const override {
234     registry.insert<LLVM::LLVMDialect>();
235   }
236 
237   void runOnOperation() override {
238     RewritePatternSet patterns(&getContext());
239 
240     // Test on one pattern in isolation.
241     // Explicitly disable shape_cast lowering.
242     LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
243                                               .enableVectorTransposeLowering()
244                                               .enableShapeCastLowering(false);
245     if (lowerToEltwise) {
246       options = options.setVectorTransformsOptions(
247           VectorTransformsOptions().setVectorTransposeLowering(
248               VectorTransposeLowering::EltWise));
249     }
250     if (lowerToFlatTranspose) {
251       options = options.setVectorTransformsOptions(
252           VectorTransformsOptions().setVectorTransposeLowering(
253               VectorTransposeLowering::Flat));
254     }
255     if (lowerToShuffleTranspose) {
256       options = options.setVectorTransformsOptions(
257           VectorTransformsOptions().setVectorTransposeLowering(
258               VectorTransposeLowering::Shuffle));
259     }
260     if (lowerToAvx2) {
261       options = options.enableAVX2Lowering().setAVX2LoweringOptions(
262           x86vector::avx2::LoweringOptions().setTransposeOptions(
263               x86vector::avx2::TransposeLoweringOptions()
264                   .lower4x8xf32()
265                   .lower8x8xf32()));
266     }
267 
268     OpPassManager dynamicPM("func.func");
269     dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
270     if (failed(runPipeline(dynamicPM, getOperation())))
271       return signalPassFailure();
272   }
273 };
274 
275 struct TestVectorUnrollingPatterns
276     : public PassWrapper<TestVectorUnrollingPatterns,
277                          OperationPass<func::FuncOp>> {
278   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
279 
280   StringRef getArgument() const final {
281     return "test-vector-unrolling-patterns";
282   }
283   StringRef getDescription() const final {
284     return "Test lowering patterns to unroll contract ops in the vector "
285            "dialect";
286   }
287   TestVectorUnrollingPatterns() = default;
288   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
289       : PassWrapper(pass) {}
290   void runOnOperation() override {
291     MLIRContext *ctx = &getContext();
292     RewritePatternSet patterns(ctx);
293     populateVectorUnrollPatterns(
294         patterns, UnrollVectorOptions()
295                       .setNativeShape(ArrayRef<int64_t>{2, 2})
296                       .setFilterConstraint([](Operation *op) {
297                         return success(isa<arith::AddFOp, vector::FMAOp,
298                                            vector::MultiDimReductionOp>(op));
299                       }));
300     populateVectorUnrollPatterns(
301         patterns, UnrollVectorOptions()
302                       .setNativeShape(ArrayRef<int64_t>{2})
303                       .setFilterConstraint([](Operation *op) {
304                         return success(isa<vector::ReductionOp>(op));
305                       }));
306     populateVectorUnrollPatterns(
307         patterns, UnrollVectorOptions()
308                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
309                       .setFilterConstraint([](Operation *op) {
310                         return success(isa<vector::TransposeOp>(op));
311                       }));
312 
313     if (unrollBasedOnType) {
314       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
315           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
316         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
317         SmallVector<int64_t, 4> nativeShape(
318             contractOp.getIteratorTypes().size(), 4);
319         Type lhsType = contractOp.getLhsType().getElementType();
320         nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
321         return nativeShape;
322       };
323 
324       UnrollVectorOptions opts;
325       opts.setNativeShapeFn(nativeShapeFn)
326           .setFilterConstraint(
327               [](Operation *op) { return success(isa<ContractionOp>(op)); });
328 
329       if (!unrollOrder.empty()) {
330         opts.setUnrollTraversalOrderFn([this](Operation *op)
331                                            -> Optional<SmallVector<int64_t>> {
332           vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
333           if (contractOp.getIteratorTypes().size() == unrollOrder.size())
334             return SmallVector<int64_t>(unrollOrder.begin(), unrollOrder.end());
335           return None;
336         });
337       }
338       populateVectorUnrollPatterns(patterns, opts);
339     } else {
340       auto nativeShapeFn =
341           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
342         auto contractOp = dyn_cast<ContractionOp>(op);
343         if (!contractOp)
344           return None;
345         return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
346       };
347       populateVectorUnrollPatterns(patterns,
348                                    UnrollVectorOptions()
349                                        .setNativeShapeFn(nativeShapeFn)
350                                        .setFilterConstraint([](Operation *op) {
351                                          return success(isa<ContractionOp>(op));
352                                        }));
353     }
354     populateVectorToVectorCanonicalizationPatterns(patterns);
355     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
356   }
357 
358   ListOption<int64_t> unrollOrder{*this, "unroll-order",
359                                   llvm::cl::desc("set the unroll order"),
360                                   llvm::cl::ZeroOrMore};
361 
362   Option<bool> unrollBasedOnType{
363       *this, "unroll-based-on-type",
364       llvm::cl::desc("Set the unroll factor based on type of the operation"),
365       llvm::cl::init(false)};
366 };
367 
368 struct TestVectorDistributePatterns
369     : public PassWrapper<TestVectorDistributePatterns,
370                          OperationPass<func::FuncOp>> {
371   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns)
372 
373   StringRef getArgument() const final {
374     return "test-vector-distribute-patterns";
375   }
376   StringRef getDescription() const final {
377     return "Test lowering patterns to distribute vector ops in the vector "
378            "dialect";
379   }
380   TestVectorDistributePatterns() = default;
381   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass)
382       : PassWrapper(pass) {}
383   void getDependentDialects(DialectRegistry &registry) const override {
384     registry.insert<VectorDialect>();
385     registry.insert<AffineDialect>();
386   }
387   ListOption<int32_t> multiplicity{
388       *this, "distribution-multiplicity",
389       llvm::cl::desc("Set the multiplicity used for distributing vector")};
390 
391   void runOnOperation() override {
392     MLIRContext *ctx = &getContext();
393     RewritePatternSet patterns(ctx);
394     func::FuncOp func = getOperation();
395     func.walk([&](arith::AddFOp op) {
396       OpBuilder builder(op);
397       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
398         SmallVector<int64_t, 2> mul;
399         SmallVector<AffineExpr, 2> perm;
400         SmallVector<Value, 2> ids;
401         unsigned count = 0;
402         // Remove the multiplicity of 1 and calculate the affine map based on
403         // the multiplicity.
404         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
405         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
406           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
407             mul.push_back(m[i]);
408             ids.push_back(func.getArgument(count++));
409             perm.push_back(getAffineDimExpr(i, ctx));
410           }
411         }
412         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
413                                   perm, ctx);
414         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
415             builder, op.getOperation(), ids, mul, map);
416         if (ops.hasValue()) {
417           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
418           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
419                                               extractOp);
420         }
421       }
422     });
423     populatePropagateVectorDistributionPatterns(patterns);
424     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
425   }
426 };
427 
428 struct TestVectorToLoopPatterns
429     : public PassWrapper<TestVectorToLoopPatterns,
430                          OperationPass<func::FuncOp>> {
431   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns)
432 
433   StringRef getArgument() const final { return "test-vector-to-forloop"; }
434   StringRef getDescription() const final {
435     return "Test lowering patterns to break up a vector op into a for loop";
436   }
437   TestVectorToLoopPatterns() = default;
438   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass)
439       : PassWrapper(pass) {}
440   void getDependentDialects(DialectRegistry &registry) const override {
441     registry.insert<VectorDialect>();
442     registry.insert<AffineDialect>();
443   }
444   Option<int32_t> multiplicity{
445       *this, "distribution-multiplicity",
446       llvm::cl::desc("Set the multiplicity used for distributing vector"),
447       llvm::cl::init(32)};
448   void runOnOperation() override {
449     MLIRContext *ctx = &getContext();
450     RewritePatternSet patterns(ctx);
451     func::FuncOp func = getOperation();
452     func.walk([&](arith::AddFOp op) {
453       // Check that the operation type can be broken down into a loop.
454       VectorType type = op.getType().dyn_cast<VectorType>();
455       if (!type || type.getRank() != 1 ||
456           type.getNumElements() % multiplicity != 0)
457         return mlir::WalkResult::advance();
458       auto filterAlloc = [](Operation *op) {
459         return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op);
460       };
461       auto dependentOps = getSlice(op, filterAlloc);
462       // Create a loop and move instructions from the Op slice into the loop.
463       OpBuilder builder(op);
464       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
465       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
466       auto numIter =
467           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
468       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
469       for (Operation *it : dependentOps) {
470         it->moveBefore(forOp.getBody()->getTerminator());
471       }
472       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
473       // break up the original op and let the patterns propagate.
474       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
475           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
476           map);
477       if (ops.hasValue()) {
478         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
479         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
480       }
481       return mlir::WalkResult::interrupt();
482     });
483     populatePropagateVectorDistributionPatterns(patterns);
484     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
485   }
486 };
487 
488 struct TestVectorTransferUnrollingPatterns
489     : public PassWrapper<TestVectorTransferUnrollingPatterns,
490                          OperationPass<func::FuncOp>> {
491   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
492       TestVectorTransferUnrollingPatterns)
493 
494   TestVectorTransferUnrollingPatterns() = default;
495   TestVectorTransferUnrollingPatterns(
496       const TestVectorTransferUnrollingPatterns &pass)
497       : PassWrapper(pass) {}
498 
499   void getDependentDialects(DialectRegistry &registry) const override {
500     registry.insert<AffineDialect>();
501   }
502   StringRef getArgument() const final {
503     return "test-vector-transfer-unrolling-patterns";
504   }
505   StringRef getDescription() const final {
506     return "Test lowering patterns to unroll transfer ops in the vector "
507            "dialect";
508   }
509   void runOnOperation() override {
510     MLIRContext *ctx = &getContext();
511     RewritePatternSet patterns(ctx);
512     UnrollVectorOptions opts;
513     opts.setNativeShape(ArrayRef<int64_t>{2, 2})
514         .setFilterConstraint([](Operation *op) {
515           return success(
516               isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
517         });
518     if (reverseUnrollOrder.getValue()) {
519       opts.setUnrollTraversalOrderFn(
520           [](Operation *op) -> Optional<SmallVector<int64_t>> {
521             int64_t numLoops = 0;
522             if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
523               numLoops = readOp.getVectorType().getRank();
524             else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
525               numLoops = writeOp.getVectorType().getRank();
526             else
527               return None;
528             auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
529             return llvm::to_vector(order);
530           });
531     }
532     populateVectorUnrollPatterns(patterns, opts);
533     populateVectorToVectorCanonicalizationPatterns(patterns);
534     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
535   }
536 
537   Option<bool> reverseUnrollOrder{
538       *this, "reverse-unroll-order",
539       llvm::cl::desc(
540           "reverse the order of unrolling of vector transfer operations"),
541       llvm::cl::init(false)};
542 };
543 
544 struct TestVectorTransferFullPartialSplitPatterns
545     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
546                          OperationPass<func::FuncOp>> {
547   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
548       TestVectorTransferFullPartialSplitPatterns)
549 
550   StringRef getArgument() const final {
551     return "test-vector-transfer-full-partial-split";
552   }
553   StringRef getDescription() const final {
554     return "Test lowering patterns to split "
555            "transfer ops via scf.if + linalg ops";
556   }
557   TestVectorTransferFullPartialSplitPatterns() = default;
558   TestVectorTransferFullPartialSplitPatterns(
559       const TestVectorTransferFullPartialSplitPatterns &pass)
560       : PassWrapper(pass) {}
561 
562   void getDependentDialects(DialectRegistry &registry) const override {
563     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
564                     scf::SCFDialect>();
565   }
566 
567   Option<bool> useLinalgOps{
568       *this, "use-memref-copy",
569       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
570                      "memref.copy operations."),
571       llvm::cl::init(false)};
572   void runOnOperation() override {
573     MLIRContext *ctx = &getContext();
574     RewritePatternSet patterns(ctx);
575     VectorTransformsOptions options;
576     if (useLinalgOps)
577       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
578     else
579       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
580     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
581     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
582   }
583 };
584 
585 struct TestVectorTransferOpt
586     : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
587   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
588 
589   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
590   StringRef getDescription() const final {
591     return "Test optimization transformations for transfer ops";
592   }
593   void runOnOperation() override { transferOpflowOpt(getOperation()); }
594 };
595 
596 struct TestVectorTransferLoweringPatterns
597     : public PassWrapper<TestVectorTransferLoweringPatterns,
598                          OperationPass<func::FuncOp>> {
599   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
600       TestVectorTransferLoweringPatterns)
601 
602   void getDependentDialects(DialectRegistry &registry) const override {
603     registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
604   }
605   StringRef getArgument() const final {
606     return "test-vector-transfer-lowering-patterns";
607   }
608   StringRef getDescription() const final {
609     return "Test lowering patterns to lower transfer ops to other vector ops";
610   }
611   void runOnOperation() override {
612     RewritePatternSet patterns(&getContext());
613     populateVectorTransferLoweringPatterns(patterns);
614     populateVectorTransferPermutationMapLoweringPatterns(patterns);
615     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
616   }
617 };
618 
619 struct TestVectorMultiReductionLoweringPatterns
620     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
621                          OperationPass<func::FuncOp>> {
622   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
623       TestVectorMultiReductionLoweringPatterns)
624 
625   TestVectorMultiReductionLoweringPatterns() = default;
626   TestVectorMultiReductionLoweringPatterns(
627       const TestVectorMultiReductionLoweringPatterns &pass)
628       : PassWrapper(pass) {}
629   void getDependentDialects(DialectRegistry &registry) const override {
630     registry.insert<memref::MemRefDialect>();
631   }
632   StringRef getArgument() const final {
633     return "test-vector-multi-reduction-lowering-patterns";
634   }
635   StringRef getDescription() const final {
636     return "Test lowering patterns to lower vector.multi_reduction to other "
637            "vector ops";
638   }
639   Option<bool> useOuterReductions{
640       *this, "use-outer-reductions",
641       llvm::cl::desc("Move reductions to outer most dimensions"),
642       llvm::cl::init(false)};
643   void runOnOperation() override {
644     RewritePatternSet patterns(&getContext());
645     populateVectorMultiReductionLoweringPatterns(
646         patterns, useOuterReductions
647                       ? vector::VectorMultiReductionLowering::InnerParallel
648                       : vector::VectorMultiReductionLowering::InnerReduction);
649     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
650   }
651 };
652 
653 struct TestVectorTransferCollapseInnerMostContiguousDims
654     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
655                          OperationPass<func::FuncOp>> {
656   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
657       TestVectorTransferCollapseInnerMostContiguousDims)
658 
659   TestVectorTransferCollapseInnerMostContiguousDims() = default;
660   TestVectorTransferCollapseInnerMostContiguousDims(
661       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
662 
663   void getDependentDialects(DialectRegistry &registry) const override {
664     registry.insert<memref::MemRefDialect, AffineDialect>();
665   }
666 
667   StringRef getArgument() const final {
668     return "test-vector-transfer-collapse-inner-most-dims";
669   }
670 
671   StringRef getDescription() const final {
672     return "Test lowering patterns that reducedes the rank of the vector "
673            "transfer memory and vector operands.";
674   }
675 
676   void runOnOperation() override {
677     RewritePatternSet patterns(&getContext());
678     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
679     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
680   }
681 };
682 
683 struct TestVectorReduceToContractPatternsPatterns
684     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
685                          OperationPass<func::FuncOp>> {
686   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
687       TestVectorReduceToContractPatternsPatterns)
688 
689   StringRef getArgument() const final {
690     return "test-vector-reduction-to-contract-patterns";
691   }
692   StringRef getDescription() const final {
693     return "Test patterns to convert multireduce op to contract and combine "
694            "broadcast/transpose to contract";
695   }
696   void runOnOperation() override {
697     RewritePatternSet patterns(&getContext());
698     populateVectorReductionToContractPatterns(patterns);
699     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
700   }
701 };
702 
703 struct TestVectorTransferDropUnitDimsPatterns
704     : public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
705                          OperationPass<func::FuncOp>> {
706   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
707       TestVectorTransferDropUnitDimsPatterns)
708 
709   StringRef getArgument() const final {
710     return "test-vector-transfer-drop-unit-dims-patterns";
711   }
712   void getDependentDialects(DialectRegistry &registry) const override {
713     registry.insert<memref::MemRefDialect>();
714   }
715   void runOnOperation() override {
716     RewritePatternSet patterns(&getContext());
717     populateVectorTransferDropUnitDimsPatterns(patterns);
718     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
719   }
720 };
721 
722 struct TestFlattenVectorTransferPatterns
723     : public PassWrapper<TestFlattenVectorTransferPatterns,
724                          OperationPass<func::FuncOp>> {
725   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
726       TestFlattenVectorTransferPatterns)
727 
728   StringRef getArgument() const final {
729     return "test-vector-transfer-flatten-patterns";
730   }
731   StringRef getDescription() const final {
732     return "Test patterns to rewrite contiguous row-major N-dimensional "
733            "vector.transfer_{read,write} ops into 1D transfers";
734   }
735   void getDependentDialects(DialectRegistry &registry) const override {
736     registry.insert<memref::MemRefDialect>();
737   }
738   void runOnOperation() override {
739     RewritePatternSet patterns(&getContext());
740     populateFlattenVectorTransferPatterns(patterns);
741     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
742   }
743 };
744 
745 struct TestVectorScanLowering
746     : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
747   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
748 
749   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
750   StringRef getDescription() const final {
751     return "Test lowering patterns that lower the scan op in the vector "
752            "dialect";
753   }
754   void runOnOperation() override {
755     RewritePatternSet patterns(&getContext());
756     populateVectorScanLoweringPatterns(patterns);
757     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
758   }
759 };
760 
761 /// Allocate shared memory for a single warp to test lowering of
762 /// WarpExecuteOnLane0Op.
763 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
764                                         WarpExecuteOnLane0Op warpOp,
765                                         Type type) {
766   static constexpr int64_t kSharedMemorySpace = 3;
767   // Compute type of shared memory buffer.
768   MemRefType memrefType;
769   if (auto vectorType = type.dyn_cast<VectorType>()) {
770     memrefType =
771         MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
772                         kSharedMemorySpace);
773   } else {
774     memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
775   }
776 
777   // Get symbol table holding all shared memory globals.
778   ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
779   SymbolTable symbolTable(moduleOp);
780 
781   // Create a pretty name.
782   SmallString<64> buf;
783   llvm::raw_svector_ostream os(buf);
784   interleave(memrefType.getShape(), os, "x");
785   os << "x" << memrefType.getElementType();
786   std::string symbolName = (Twine("__shared_") + os.str()).str();
787 
788   auto ip = builder.saveInsertionPoint();
789   builder.setInsertionPoint(moduleOp);
790   auto global = builder.create<memref::GlobalOp>(
791       loc,
792       /*sym_name=*/symbolName,
793       /*sym_visibility=*/builder.getStringAttr("private"),
794       /*type=*/memrefType,
795       /*initial_value=*/Attribute(),
796       /*constant=*/false,
797       /*alignment=*/IntegerAttr());
798   symbolTable.insert(global);
799   // The symbol table inserts at the end of the module, but globals are a bit
800   // nicer if they are at the beginning.
801   global->moveBefore(&moduleOp.front());
802 
803   builder.restoreInsertionPoint(ip);
804   return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
805 }
806 
807 static Value warpReduction(Location loc, OpBuilder &builder, Value input,
808                            CombiningKind kind, uint32_t size) {
809   Value laneVal = input;
810   // Parallel reduction using butterfly shuffles.
811   for (uint64_t i = 1; i < size; i <<= 1) {
812     Value shuffled = builder
813                          .create<gpu::ShuffleOp>(loc, laneVal, i,
814                                                  /*width=*/size,
815                                                  /*mode=*/gpu::ShuffleMode::XOR)
816                          .result();
817     laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
818   }
819   return laneVal;
820 }
821 
822 struct TestVectorDistribution
823     : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
824   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
825 
826   void getDependentDialects(DialectRegistry &registry) const override {
827     registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
828                     AffineDialect>();
829   }
830 
831   StringRef getArgument() const final { return "test-vector-warp-distribute"; }
832   StringRef getDescription() const final {
833     return "Test vector warp distribute transformation and lowering patterns";
834   }
835   TestVectorDistribution() = default;
836   TestVectorDistribution(const TestVectorDistribution &pass)
837       : PassWrapper(pass) {}
838 
839   Option<bool> warpOpToSCF{
840       *this, "rewrite-warp-ops-to-scf-if",
841       llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
842       llvm::cl::init(false)};
843 
844   Option<bool> distributeTransferWriteOps{
845       *this, "distribute-transfer-write",
846       llvm::cl::desc("Test distribution of transfer write"),
847       llvm::cl::init(false)};
848 
849   Option<bool> hoistUniform{*this, "hoist-uniform",
850                             llvm::cl::desc("Test hoist uniform"),
851                             llvm::cl::init(false)};
852 
853   Option<bool> propagateDistribution{
854       *this, "propagate-distribution",
855       llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
856 
857   void runOnOperation() override {
858     RewritePatternSet patterns(&getContext());
859 
860     getOperation().walk([&](Operation *op) {
861       if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
862         if (hoistUniform) {
863           moveScalarUniformCode(warpOp);
864         }
865         WalkResult::interrupt();
866       }
867     });
868     MLIRContext *ctx = &getContext();
869     if (distributeTransferWriteOps) {
870       auto distributionFn = [](vector::TransferWriteOp writeOp) {
871         // Create a map (d0, d1) -> (d1) to distribute along the inner
872         // dimension. Once we support n-d distribution we can add more
873         // complex cases.
874         int64_t vecRank = writeOp.getVectorType().getRank();
875         OpBuilder builder(writeOp.getContext());
876         auto map =
877             AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
878         return map;
879       };
880       RewritePatternSet patterns(ctx);
881       populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
882       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
883     }
884     if (propagateDistribution) {
885       RewritePatternSet patterns(ctx);
886       vector::populatePropagateWarpVectorDistributionPatterns(patterns);
887       vector::populateDistributeReduction(patterns, warpReduction);
888       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
889     }
890     WarpExecuteOnLane0LoweringOptions options;
891     options.warpAllocationFn = allocateGlobalSharedMemory;
892     options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
893                                       WarpExecuteOnLane0Op warpOp) {
894       builder.create<gpu::BarrierOp>(loc);
895     };
896     // Test on one pattern in isolation.
897     if (warpOpToSCF) {
898       populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
899       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
900       return;
901     }
902   }
903 };
904 
905 } // namespace
906 
907 namespace mlir {
908 namespace test {
909 void registerTestVectorLowerings() {
910   PassRegistration<TestVectorToVectorLowering>();
911 
912   PassRegistration<TestVectorContractionLowering>();
913 
914   PassRegistration<TestVectorTransposeLowering>();
915 
916   PassRegistration<TestVectorUnrollingPatterns>();
917 
918   PassRegistration<TestVectorTransferUnrollingPatterns>();
919 
920   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
921 
922   PassRegistration<TestVectorDistributePatterns>();
923 
924   PassRegistration<TestVectorToLoopPatterns>();
925 
926   PassRegistration<TestVectorTransferOpt>();
927 
928   PassRegistration<TestVectorTransferLoweringPatterns>();
929 
930   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
931 
932   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
933 
934   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
935 
936   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
937 
938   PassRegistration<TestFlattenVectorTransferPatterns>();
939 
940   PassRegistration<TestVectorScanLowering>();
941 
942   PassRegistration<TestVectorDistribution>();
943 }
944 } // namespace test
945 } // namespace mlir
946