1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 using namespace mlir::vector;
32 
33 namespace {
34 
35 struct TestVectorToVectorLowering
36     : public PassWrapper<TestVectorToVectorLowering,
37                          OperationPass<func::FuncOp>> {
38   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
39 
40   TestVectorToVectorLowering() = default;
TestVectorToVectorLowering__anonb56dea510111::TestVectorToVectorLowering41   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
42       : PassWrapper(pass) {}
getArgument__anonb56dea510111::TestVectorToVectorLowering43   StringRef getArgument() const final {
44     return "test-vector-to-vector-lowering";
45   }
getDescription__anonb56dea510111::TestVectorToVectorLowering46   StringRef getDescription() const final {
47     return "Test lowering patterns between ops in the vector dialect";
48   }
49 
getDependentDialects__anonb56dea510111::TestVectorToVectorLowering50   void getDependentDialects(DialectRegistry &registry) const override {
51     registry.insert<AffineDialect>();
52   }
53 
54   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
55                       llvm::cl::init(false)};
56 
runOnOperation__anonb56dea510111::TestVectorToVectorLowering57   void runOnOperation() override {
58     auto *ctx = &getContext();
59     RewritePatternSet patterns(ctx);
60     if (unroll) {
61       populateVectorUnrollPatterns(
62           patterns,
63           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
64               filter));
65     }
66     populateVectorToVectorCanonicalizationPatterns(patterns);
67     populateBubbleVectorBitCastOpPatterns(patterns);
68     populateCastAwayVectorLeadingOneDimPatterns(patterns);
69     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
70   }
71 
72 private:
73   // Return the target shape based on op type.
getShape__anonb56dea510111::TestVectorToVectorLowering74   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
75     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
76       return SmallVector<int64_t, 4>(2, 2);
77     if (isa<vector::ContractionOp>(op))
78       return SmallVector<int64_t, 4>(3, 2);
79     // For transfer ops, just propagate the shape coming from
80     // InsertStridedSlices/ExtractStridedSlices.
81     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
82       VectorType dstVec;
83       for (Operation *users : readOp->getUsers()) {
84         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
85         if (!extract)
86           return llvm::None;
87         auto vecType = extract.getResult().getType().cast<VectorType>();
88         if (dstVec && dstVec != vecType)
89           return llvm::None;
90         dstVec = vecType;
91       }
92       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
93                                      dstVec.getShape().end());
94     }
95     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
96       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
97       if (!insert)
98         return llvm::None;
99       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
100       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
101     }
102     return llvm::None;
103   }
104 
filter__anonb56dea510111::TestVectorToVectorLowering105   static LogicalResult filter(Operation *op) {
106     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
107                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
108   }
109 };
110 
111 struct TestVectorContractionLowering
112     : public PassWrapper<TestVectorContractionLowering,
113                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorContractionLowering114   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
115 
116   StringRef getArgument() const final {
117     return "test-vector-contraction-lowering";
118   }
getDescription__anonb56dea510111::TestVectorContractionLowering119   StringRef getDescription() const final {
120     return "Test lowering patterns that lower contract ops in the vector "
121            "dialect";
122   }
123   TestVectorContractionLowering() = default;
TestVectorContractionLowering__anonb56dea510111::TestVectorContractionLowering124   TestVectorContractionLowering(const TestVectorContractionLowering &pass)
125       : PassWrapper(pass) {}
126 
127   Option<bool> lowerToFlatMatrix{
128       *this, "vector-lower-matrix-intrinsics",
129       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
130       llvm::cl::init(false)};
131   Option<bool> lowerToOuterProduct{
132       *this, "vector-outerproduct",
133       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
134       llvm::cl::init(false)};
135   Option<bool> lowerToFilterOuterProduct{
136       *this, "vector-filter-outerproduct",
137       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
138                      "vectors of size 4."),
139       llvm::cl::init(false)};
140   Option<bool> lowerToParallelArith{
141       *this, "vector-parallel-arith",
142       llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
143       llvm::cl::init(false)};
144 
runOnOperation__anonb56dea510111::TestVectorContractionLowering145   void runOnOperation() override {
146     RewritePatternSet patterns(&getContext());
147 
148     // Test on one pattern in isolation.
149     if (lowerToOuterProduct) {
150       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
151       VectorTransformsOptions options{lowering};
152       patterns.add<ContractionOpToOuterProductOpLowering>(options,
153                                                           &getContext());
154       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
155       return;
156     }
157 
158     // Test on one pattern in isolation.
159     if (lowerToFilterOuterProduct) {
160       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
161       VectorTransformsOptions options{lowering};
162       patterns.add<ContractionOpToOuterProductOpLowering>(
163           options, &getContext(), [](vector::ContractionOp op) {
164             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
165             // where M is not 4.
166             if (op.getRhsType().getShape()[0] == 4)
167               return failure();
168             return success();
169           });
170       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
171       return;
172     }
173 
174     if (lowerToParallelArith) {
175       vector::populateVectorContractLoweringPatterns(
176           patterns,
177           vector::VectorTransformsOptions().setVectorTransformsOptions(
178               vector::VectorContractLowering::ParallelArith));
179       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
180       return;
181     }
182 
183     // Test on all contract lowering patterns.
184     VectorContractLowering contractLowering = VectorContractLowering::Dot;
185     if (lowerToFlatMatrix)
186       contractLowering = VectorContractLowering::Matmul;
187     VectorMultiReductionLowering vectorMultiReductionLowering =
188         VectorMultiReductionLowering::InnerParallel;
189     VectorTransformsOptions options{contractLowering,
190                                     vectorMultiReductionLowering,
191                                     VectorTransposeLowering()};
192     populateVectorBroadcastLoweringPatterns(patterns);
193     populateVectorContractLoweringPatterns(patterns, options);
194     populateVectorMaskOpLoweringPatterns(patterns);
195     populateVectorShapeCastLoweringPatterns(patterns);
196     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
197   }
198 };
199 
200 struct TestVectorTransposeLowering
201     : public PassWrapper<TestVectorTransposeLowering,
202                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransposeLowering203   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
204 
205   StringRef getArgument() const final {
206     return "test-vector-transpose-lowering";
207   }
getDescription__anonb56dea510111::TestVectorTransposeLowering208   StringRef getDescription() const final {
209     return "Test lowering patterns that lower contract ops in the vector "
210            "dialect";
211   }
212   TestVectorTransposeLowering() = default;
TestVectorTransposeLowering__anonb56dea510111::TestVectorTransposeLowering213   TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
214       : PassWrapper(pass) {}
215 
216   Option<bool> lowerToEltwise{
217       *this, "eltwise",
218       llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
219       llvm::cl::init(false)};
220   Option<bool> lowerToFlatTranspose{
221       *this, "flat",
222       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
223       llvm::cl::init(false)};
224   Option<bool> lowerToShuffleTranspose{
225       *this, "shuffle",
226       llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
227       llvm::cl::init(false)};
228   Option<bool> lowerToAvx2{
229       *this, "avx2",
230       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
231       llvm::cl::init(false)};
232 
getDependentDialects__anonb56dea510111::TestVectorTransposeLowering233   void getDependentDialects(DialectRegistry &registry) const override {
234     registry.insert<LLVM::LLVMDialect>();
235   }
236 
runOnOperation__anonb56dea510111::TestVectorTransposeLowering237   void runOnOperation() override {
238     RewritePatternSet patterns(&getContext());
239 
240     // Test on one pattern in isolation.
241     // Explicitly disable shape_cast lowering.
242     LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
243                                               .enableVectorTransposeLowering()
244                                               .enableShapeCastLowering(false);
245     if (lowerToEltwise) {
246       options = options.setVectorTransformsOptions(
247           VectorTransformsOptions().setVectorTransposeLowering(
248               VectorTransposeLowering::EltWise));
249     }
250     if (lowerToFlatTranspose) {
251       options = options.setVectorTransformsOptions(
252           VectorTransformsOptions().setVectorTransposeLowering(
253               VectorTransposeLowering::Flat));
254     }
255     if (lowerToShuffleTranspose) {
256       options = options.setVectorTransformsOptions(
257           VectorTransformsOptions().setVectorTransposeLowering(
258               VectorTransposeLowering::Shuffle));
259     }
260     if (lowerToAvx2) {
261       options = options.enableAVX2Lowering().setAVX2LoweringOptions(
262           x86vector::avx2::LoweringOptions().setTransposeOptions(
263               x86vector::avx2::TransposeLoweringOptions()
264                   .lower4x8xf32()
265                   .lower8x8xf32()));
266     }
267 
268     OpPassManager dynamicPM("func.func");
269     dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
270     if (failed(runPipeline(dynamicPM, getOperation())))
271       return signalPassFailure();
272   }
273 };
274 
275 struct TestVectorUnrollingPatterns
276     : public PassWrapper<TestVectorUnrollingPatterns,
277                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorUnrollingPatterns278   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
279 
280   StringRef getArgument() const final {
281     return "test-vector-unrolling-patterns";
282   }
getDescription__anonb56dea510111::TestVectorUnrollingPatterns283   StringRef getDescription() const final {
284     return "Test lowering patterns to unroll contract ops in the vector "
285            "dialect";
286   }
287   TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns__anonb56dea510111::TestVectorUnrollingPatterns288   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
289       : PassWrapper(pass) {}
runOnOperation__anonb56dea510111::TestVectorUnrollingPatterns290   void runOnOperation() override {
291     MLIRContext *ctx = &getContext();
292     RewritePatternSet patterns(ctx);
293     populateVectorUnrollPatterns(
294         patterns, UnrollVectorOptions()
295                       .setNativeShape(ArrayRef<int64_t>{2, 2})
296                       .setFilterConstraint([](Operation *op) {
297                         return success(isa<arith::AddFOp, vector::FMAOp,
298                                            vector::MultiDimReductionOp>(op));
299                       }));
300     populateVectorUnrollPatterns(
301         patterns, UnrollVectorOptions()
302                       .setNativeShape(ArrayRef<int64_t>{2})
303                       .setFilterConstraint([](Operation *op) {
304                         return success(isa<vector::ReductionOp>(op));
305                       }));
306     populateVectorUnrollPatterns(
307         patterns, UnrollVectorOptions()
308                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
309                       .setFilterConstraint([](Operation *op) {
310                         return success(isa<vector::TransposeOp>(op));
311                       }));
312 
313     if (unrollBasedOnType) {
314       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
315           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
316         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
317         SmallVector<int64_t, 4> nativeShape(
318             contractOp.getIteratorTypes().size(), 4);
319         Type lhsType = contractOp.getLhsType().getElementType();
320         nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
321         return nativeShape;
322       };
323 
324       UnrollVectorOptions opts;
325       opts.setNativeShapeFn(nativeShapeFn)
326           .setFilterConstraint(
327               [](Operation *op) { return success(isa<ContractionOp>(op)); });
328 
329       if (!unrollOrder.empty()) {
330         opts.setUnrollTraversalOrderFn([this](Operation *op)
331                                            -> Optional<SmallVector<int64_t>> {
332           vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
333           if (contractOp.getIteratorTypes().size() == unrollOrder.size())
334             return SmallVector<int64_t>(unrollOrder.begin(), unrollOrder.end());
335           return None;
336         });
337       }
338       populateVectorUnrollPatterns(patterns, opts);
339     } else {
340       auto nativeShapeFn =
341           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
342         auto contractOp = dyn_cast<ContractionOp>(op);
343         if (!contractOp)
344           return None;
345         return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
346       };
347       populateVectorUnrollPatterns(patterns,
348                                    UnrollVectorOptions()
349                                        .setNativeShapeFn(nativeShapeFn)
350                                        .setFilterConstraint([](Operation *op) {
351                                          return success(isa<ContractionOp>(op));
352                                        }));
353     }
354     populateVectorToVectorCanonicalizationPatterns(patterns);
355     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
356   }
357 
358   ListOption<int64_t> unrollOrder{*this, "unroll-order",
359                                   llvm::cl::desc("set the unroll order")};
360 
361   Option<bool> unrollBasedOnType{
362       *this, "unroll-based-on-type",
363       llvm::cl::desc("Set the unroll factor based on type of the operation"),
364       llvm::cl::init(false)};
365 };
366 
367 struct TestVectorDistributePatterns
368     : public PassWrapper<TestVectorDistributePatterns,
369                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorDistributePatterns370   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns)
371 
372   StringRef getArgument() const final {
373     return "test-vector-distribute-patterns";
374   }
getDescription__anonb56dea510111::TestVectorDistributePatterns375   StringRef getDescription() const final {
376     return "Test lowering patterns to distribute vector ops in the vector "
377            "dialect";
378   }
379   TestVectorDistributePatterns() = default;
TestVectorDistributePatterns__anonb56dea510111::TestVectorDistributePatterns380   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass)
381       : PassWrapper(pass) {}
getDependentDialects__anonb56dea510111::TestVectorDistributePatterns382   void getDependentDialects(DialectRegistry &registry) const override {
383     registry.insert<VectorDialect>();
384     registry.insert<AffineDialect>();
385   }
386   ListOption<int32_t> multiplicity{
387       *this, "distribution-multiplicity",
388       llvm::cl::desc("Set the multiplicity used for distributing vector")};
389 
runOnOperation__anonb56dea510111::TestVectorDistributePatterns390   void runOnOperation() override {
391     MLIRContext *ctx = &getContext();
392     RewritePatternSet patterns(ctx);
393     func::FuncOp func = getOperation();
394     func.walk([&](arith::AddFOp op) {
395       OpBuilder builder(op);
396       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
397         SmallVector<int64_t, 2> mul;
398         SmallVector<AffineExpr, 2> perm;
399         SmallVector<Value, 2> ids;
400         unsigned count = 0;
401         // Remove the multiplicity of 1 and calculate the affine map based on
402         // the multiplicity.
403         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
404         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
405           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
406             mul.push_back(m[i]);
407             ids.push_back(func.getArgument(count++));
408             perm.push_back(getAffineDimExpr(i, ctx));
409           }
410         }
411         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
412                                   perm, ctx);
413         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
414             builder, op.getOperation(), ids, mul, map);
415         if (ops) {
416           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
417           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
418                                               extractOp);
419         }
420       }
421     });
422     populatePropagateVectorDistributionPatterns(patterns);
423     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
424   }
425 };
426 
427 struct TestVectorToLoopPatterns
428     : public PassWrapper<TestVectorToLoopPatterns,
429                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorToLoopPatterns430   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns)
431 
432   StringRef getArgument() const final { return "test-vector-to-forloop"; }
getDescription__anonb56dea510111::TestVectorToLoopPatterns433   StringRef getDescription() const final {
434     return "Test lowering patterns to break up a vector op into a for loop";
435   }
436   TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns__anonb56dea510111::TestVectorToLoopPatterns437   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass)
438       : PassWrapper(pass) {}
getDependentDialects__anonb56dea510111::TestVectorToLoopPatterns439   void getDependentDialects(DialectRegistry &registry) const override {
440     registry.insert<VectorDialect>();
441     registry.insert<AffineDialect>();
442   }
443   Option<int32_t> multiplicity{
444       *this, "distribution-multiplicity",
445       llvm::cl::desc("Set the multiplicity used for distributing vector"),
446       llvm::cl::init(32)};
runOnOperation__anonb56dea510111::TestVectorToLoopPatterns447   void runOnOperation() override {
448     MLIRContext *ctx = &getContext();
449     RewritePatternSet patterns(ctx);
450     func::FuncOp func = getOperation();
451     func.walk([&](arith::AddFOp op) {
452       // Check that the operation type can be broken down into a loop.
453       VectorType type = op.getType().dyn_cast<VectorType>();
454       if (!type || type.getRank() != 1 ||
455           type.getNumElements() % multiplicity != 0)
456         return mlir::WalkResult::advance();
457       auto filterAlloc = [](Operation *op) {
458         return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op);
459       };
460       auto dependentOps = getSlice(op, filterAlloc);
461       // Create a loop and move instructions from the Op slice into the loop.
462       OpBuilder builder(op);
463       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
464       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
465       auto numIter =
466           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
467       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
468       for (Operation *it : dependentOps) {
469         it->moveBefore(forOp.getBody()->getTerminator());
470       }
471       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
472       // break up the original op and let the patterns propagate.
473       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
474           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
475           map);
476       if (ops) {
477         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
478         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
479       }
480       return mlir::WalkResult::interrupt();
481     });
482     populatePropagateVectorDistributionPatterns(patterns);
483     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
484   }
485 };
486 
487 struct TestVectorTransferUnrollingPatterns
488     : public PassWrapper<TestVectorTransferUnrollingPatterns,
489                          OperationPass<func::FuncOp>> {
490   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
491       TestVectorTransferUnrollingPatterns)
492 
493   TestVectorTransferUnrollingPatterns() = default;
TestVectorTransferUnrollingPatterns__anonb56dea510111::TestVectorTransferUnrollingPatterns494   TestVectorTransferUnrollingPatterns(
495       const TestVectorTransferUnrollingPatterns &pass)
496       : PassWrapper(pass) {}
497 
getDependentDialects__anonb56dea510111::TestVectorTransferUnrollingPatterns498   void getDependentDialects(DialectRegistry &registry) const override {
499     registry.insert<AffineDialect>();
500   }
getArgument__anonb56dea510111::TestVectorTransferUnrollingPatterns501   StringRef getArgument() const final {
502     return "test-vector-transfer-unrolling-patterns";
503   }
getDescription__anonb56dea510111::TestVectorTransferUnrollingPatterns504   StringRef getDescription() const final {
505     return "Test lowering patterns to unroll transfer ops in the vector "
506            "dialect";
507   }
runOnOperation__anonb56dea510111::TestVectorTransferUnrollingPatterns508   void runOnOperation() override {
509     MLIRContext *ctx = &getContext();
510     RewritePatternSet patterns(ctx);
511     UnrollVectorOptions opts;
512     opts.setNativeShape(ArrayRef<int64_t>{2, 2})
513         .setFilterConstraint([](Operation *op) {
514           return success(
515               isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
516         });
517     if (reverseUnrollOrder.getValue()) {
518       opts.setUnrollTraversalOrderFn(
519           [](Operation *op) -> Optional<SmallVector<int64_t>> {
520             int64_t numLoops = 0;
521             if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
522               numLoops = readOp.getVectorType().getRank();
523             else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
524               numLoops = writeOp.getVectorType().getRank();
525             else
526               return None;
527             auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
528             return llvm::to_vector(order);
529           });
530     }
531     populateVectorUnrollPatterns(patterns, opts);
532     populateVectorToVectorCanonicalizationPatterns(patterns);
533     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
534   }
535 
536   Option<bool> reverseUnrollOrder{
537       *this, "reverse-unroll-order",
538       llvm::cl::desc(
539           "reverse the order of unrolling of vector transfer operations"),
540       llvm::cl::init(false)};
541 };
542 
543 struct TestVectorTransferFullPartialSplitPatterns
544     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
545                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns546   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
547       TestVectorTransferFullPartialSplitPatterns)
548 
549   StringRef getArgument() const final {
550     return "test-vector-transfer-full-partial-split";
551   }
getDescription__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns552   StringRef getDescription() const final {
553     return "Test lowering patterns to split "
554            "transfer ops via scf.if + linalg ops";
555   }
556   TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns557   TestVectorTransferFullPartialSplitPatterns(
558       const TestVectorTransferFullPartialSplitPatterns &pass)
559       : PassWrapper(pass) {}
560 
getDependentDialects__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns561   void getDependentDialects(DialectRegistry &registry) const override {
562     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
563                     scf::SCFDialect>();
564   }
565 
566   Option<bool> useLinalgOps{
567       *this, "use-memref-copy",
568       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
569                      "memref.copy operations."),
570       llvm::cl::init(false)};
runOnOperation__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns571   void runOnOperation() override {
572     MLIRContext *ctx = &getContext();
573     RewritePatternSet patterns(ctx);
574     VectorTransformsOptions options;
575     if (useLinalgOps)
576       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
577     else
578       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
579     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
580     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
581   }
582 };
583 
584 struct TestVectorTransferOpt
585     : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferOpt586   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
587 
588   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
getDescription__anonb56dea510111::TestVectorTransferOpt589   StringRef getDescription() const final {
590     return "Test optimization transformations for transfer ops";
591   }
runOnOperation__anonb56dea510111::TestVectorTransferOpt592   void runOnOperation() override { transferOpflowOpt(getOperation()); }
593 };
594 
595 struct TestVectorTransferLoweringPatterns
596     : public PassWrapper<TestVectorTransferLoweringPatterns,
597                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferLoweringPatterns598   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
599       TestVectorTransferLoweringPatterns)
600 
601   void getDependentDialects(DialectRegistry &registry) const override {
602     registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
603   }
getArgument__anonb56dea510111::TestVectorTransferLoweringPatterns604   StringRef getArgument() const final {
605     return "test-vector-transfer-lowering-patterns";
606   }
getDescription__anonb56dea510111::TestVectorTransferLoweringPatterns607   StringRef getDescription() const final {
608     return "Test lowering patterns to lower transfer ops to other vector ops";
609   }
runOnOperation__anonb56dea510111::TestVectorTransferLoweringPatterns610   void runOnOperation() override {
611     RewritePatternSet patterns(&getContext());
612     populateVectorTransferLoweringPatterns(patterns);
613     populateVectorTransferPermutationMapLoweringPatterns(patterns);
614     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
615   }
616 };
617 
618 struct TestVectorMultiReductionLoweringPatterns
619     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
620                          OperationPass<func::FuncOp>> {
621   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
622       TestVectorMultiReductionLoweringPatterns)
623 
624   TestVectorMultiReductionLoweringPatterns() = default;
TestVectorMultiReductionLoweringPatterns__anonb56dea510111::TestVectorMultiReductionLoweringPatterns625   TestVectorMultiReductionLoweringPatterns(
626       const TestVectorMultiReductionLoweringPatterns &pass)
627       : PassWrapper(pass) {}
getDependentDialects__anonb56dea510111::TestVectorMultiReductionLoweringPatterns628   void getDependentDialects(DialectRegistry &registry) const override {
629     registry.insert<memref::MemRefDialect>();
630   }
getArgument__anonb56dea510111::TestVectorMultiReductionLoweringPatterns631   StringRef getArgument() const final {
632     return "test-vector-multi-reduction-lowering-patterns";
633   }
getDescription__anonb56dea510111::TestVectorMultiReductionLoweringPatterns634   StringRef getDescription() const final {
635     return "Test lowering patterns to lower vector.multi_reduction to other "
636            "vector ops";
637   }
638   Option<bool> useOuterReductions{
639       *this, "use-outer-reductions",
640       llvm::cl::desc("Move reductions to outer most dimensions"),
641       llvm::cl::init(false)};
runOnOperation__anonb56dea510111::TestVectorMultiReductionLoweringPatterns642   void runOnOperation() override {
643     RewritePatternSet patterns(&getContext());
644     populateVectorMultiReductionLoweringPatterns(
645         patterns, useOuterReductions
646                       ? vector::VectorMultiReductionLowering::InnerParallel
647                       : vector::VectorMultiReductionLowering::InnerReduction);
648     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
649   }
650 };
651 
652 struct TestVectorTransferCollapseInnerMostContiguousDims
653     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
654                          OperationPass<func::FuncOp>> {
655   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
656       TestVectorTransferCollapseInnerMostContiguousDims)
657 
658   TestVectorTransferCollapseInnerMostContiguousDims() = default;
659   TestVectorTransferCollapseInnerMostContiguousDims(
660       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
661 
getDependentDialects__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims662   void getDependentDialects(DialectRegistry &registry) const override {
663     registry.insert<memref::MemRefDialect, AffineDialect>();
664   }
665 
getArgument__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims666   StringRef getArgument() const final {
667     return "test-vector-transfer-collapse-inner-most-dims";
668   }
669 
getDescription__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims670   StringRef getDescription() const final {
671     return "Test lowering patterns that reducedes the rank of the vector "
672            "transfer memory and vector operands.";
673   }
674 
runOnOperation__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims675   void runOnOperation() override {
676     RewritePatternSet patterns(&getContext());
677     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
678     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
679   }
680 };
681 
682 struct TestVectorReduceToContractPatternsPatterns
683     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
684                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorReduceToContractPatternsPatterns685   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
686       TestVectorReduceToContractPatternsPatterns)
687 
688   StringRef getArgument() const final {
689     return "test-vector-reduction-to-contract-patterns";
690   }
getDescription__anonb56dea510111::TestVectorReduceToContractPatternsPatterns691   StringRef getDescription() const final {
692     return "Test patterns to convert multireduce op to contract and combine "
693            "broadcast/transpose to contract";
694   }
runOnOperation__anonb56dea510111::TestVectorReduceToContractPatternsPatterns695   void runOnOperation() override {
696     RewritePatternSet patterns(&getContext());
697     populateVectorReductionToContractPatterns(patterns);
698     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
699   }
700 };
701 
702 struct TestVectorTransferDropUnitDimsPatterns
703     : public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
704                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferDropUnitDimsPatterns705   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
706       TestVectorTransferDropUnitDimsPatterns)
707 
708   StringRef getArgument() const final {
709     return "test-vector-transfer-drop-unit-dims-patterns";
710   }
getDependentDialects__anonb56dea510111::TestVectorTransferDropUnitDimsPatterns711   void getDependentDialects(DialectRegistry &registry) const override {
712     registry.insert<memref::MemRefDialect>();
713   }
runOnOperation__anonb56dea510111::TestVectorTransferDropUnitDimsPatterns714   void runOnOperation() override {
715     RewritePatternSet patterns(&getContext());
716     populateVectorTransferDropUnitDimsPatterns(patterns);
717     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
718   }
719 };
720 
721 struct TestFlattenVectorTransferPatterns
722     : public PassWrapper<TestFlattenVectorTransferPatterns,
723                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestFlattenVectorTransferPatterns724   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
725       TestFlattenVectorTransferPatterns)
726 
727   StringRef getArgument() const final {
728     return "test-vector-transfer-flatten-patterns";
729   }
getDescription__anonb56dea510111::TestFlattenVectorTransferPatterns730   StringRef getDescription() const final {
731     return "Test patterns to rewrite contiguous row-major N-dimensional "
732            "vector.transfer_{read,write} ops into 1D transfers";
733   }
getDependentDialects__anonb56dea510111::TestFlattenVectorTransferPatterns734   void getDependentDialects(DialectRegistry &registry) const override {
735     registry.insert<memref::MemRefDialect>();
736   }
runOnOperation__anonb56dea510111::TestFlattenVectorTransferPatterns737   void runOnOperation() override {
738     RewritePatternSet patterns(&getContext());
739     populateFlattenVectorTransferPatterns(patterns);
740     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
741   }
742 };
743 
744 struct TestVectorScanLowering
745     : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorScanLowering746   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
747 
748   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
getDescription__anonb56dea510111::TestVectorScanLowering749   StringRef getDescription() const final {
750     return "Test lowering patterns that lower the scan op in the vector "
751            "dialect";
752   }
runOnOperation__anonb56dea510111::TestVectorScanLowering753   void runOnOperation() override {
754     RewritePatternSet patterns(&getContext());
755     populateVectorScanLoweringPatterns(patterns);
756     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
757   }
758 };
759 
760 /// Allocate shared memory for a single warp to test lowering of
761 /// WarpExecuteOnLane0Op.
allocateGlobalSharedMemory(Location loc,OpBuilder & builder,WarpExecuteOnLane0Op warpOp,Type type)762 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
763                                         WarpExecuteOnLane0Op warpOp,
764                                         Type type) {
765   static constexpr int64_t kSharedMemorySpace = 3;
766   // Compute type of shared memory buffer.
767   MemRefType memrefType;
768   if (auto vectorType = type.dyn_cast<VectorType>()) {
769     memrefType =
770         MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
771                         kSharedMemorySpace);
772   } else {
773     memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
774   }
775 
776   // Get symbol table holding all shared memory globals.
777   ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
778   SymbolTable symbolTable(moduleOp);
779 
780   // Create a pretty name.
781   SmallString<64> buf;
782   llvm::raw_svector_ostream os(buf);
783   interleave(memrefType.getShape(), os, "x");
784   os << "x" << memrefType.getElementType();
785   std::string symbolName = (Twine("__shared_") + os.str()).str();
786 
787   auto ip = builder.saveInsertionPoint();
788   builder.setInsertionPoint(moduleOp);
789   auto global = builder.create<memref::GlobalOp>(
790       loc,
791       /*sym_name=*/symbolName,
792       /*sym_visibility=*/builder.getStringAttr("private"),
793       /*type=*/memrefType,
794       /*initial_value=*/Attribute(),
795       /*constant=*/false,
796       /*alignment=*/IntegerAttr());
797   symbolTable.insert(global);
798   // The symbol table inserts at the end of the module, but globals are a bit
799   // nicer if they are at the beginning.
800   global->moveBefore(&moduleOp.front());
801 
802   builder.restoreInsertionPoint(ip);
803   return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
804 }
805 
warpReduction(Location loc,OpBuilder & builder,Value input,CombiningKind kind,uint32_t size)806 static Value warpReduction(Location loc, OpBuilder &builder, Value input,
807                            CombiningKind kind, uint32_t size) {
808   Value laneVal = input;
809   // Parallel reduction using butterfly shuffles.
810   for (uint64_t i = 1; i < size; i <<= 1) {
811     Value shuffled = builder
812                          .create<gpu::ShuffleOp>(loc, laneVal, i,
813                                                  /*width=*/size,
814                                                  /*mode=*/gpu::ShuffleMode::XOR)
815                          .result();
816     laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
817   }
818   return laneVal;
819 }
820 
821 struct TestVectorDistribution
822     : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorDistribution823   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
824 
825   void getDependentDialects(DialectRegistry &registry) const override {
826     registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
827                     AffineDialect>();
828   }
829 
getArgument__anonb56dea510111::TestVectorDistribution830   StringRef getArgument() const final { return "test-vector-warp-distribute"; }
getDescription__anonb56dea510111::TestVectorDistribution831   StringRef getDescription() const final {
832     return "Test vector warp distribute transformation and lowering patterns";
833   }
834   TestVectorDistribution() = default;
TestVectorDistribution__anonb56dea510111::TestVectorDistribution835   TestVectorDistribution(const TestVectorDistribution &pass)
836       : PassWrapper(pass) {}
837 
838   Option<bool> warpOpToSCF{
839       *this, "rewrite-warp-ops-to-scf-if",
840       llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
841       llvm::cl::init(false)};
842 
843   Option<bool> distributeTransferWriteOps{
844       *this, "distribute-transfer-write",
845       llvm::cl::desc("Test distribution of transfer write"),
846       llvm::cl::init(false)};
847 
848   Option<bool> hoistUniform{*this, "hoist-uniform",
849                             llvm::cl::desc("Test hoist uniform"),
850                             llvm::cl::init(false)};
851 
852   Option<bool> propagateDistribution{
853       *this, "propagate-distribution",
854       llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
855 
runOnOperation__anonb56dea510111::TestVectorDistribution856   void runOnOperation() override {
857     RewritePatternSet patterns(&getContext());
858 
859     getOperation().walk([&](Operation *op) {
860       if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
861         if (hoistUniform) {
862           moveScalarUniformCode(warpOp);
863         }
864         WalkResult::interrupt();
865       }
866     });
867     MLIRContext *ctx = &getContext();
868     if (distributeTransferWriteOps) {
869       auto distributionFn = [](vector::TransferWriteOp writeOp) {
870         // Create a map (d0, d1) -> (d1) to distribute along the inner
871         // dimension. Once we support n-d distribution we can add more
872         // complex cases.
873         int64_t vecRank = writeOp.getVectorType().getRank();
874         OpBuilder builder(writeOp.getContext());
875         auto map =
876             AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
877         return map;
878       };
879       RewritePatternSet patterns(ctx);
880       populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
881       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
882     }
883     if (propagateDistribution) {
884       RewritePatternSet patterns(ctx);
885       vector::populatePropagateWarpVectorDistributionPatterns(patterns);
886       vector::populateDistributeReduction(patterns, warpReduction);
887       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
888     }
889     WarpExecuteOnLane0LoweringOptions options;
890     options.warpAllocationFn = allocateGlobalSharedMemory;
891     options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
892                                       WarpExecuteOnLane0Op warpOp) {
893       builder.create<gpu::BarrierOp>(loc);
894     };
895     // Test on one pattern in isolation.
896     if (warpOpToSCF) {
897       populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
898       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
899       return;
900     }
901   }
902 };
903 
904 } // namespace
905 
906 namespace mlir {
907 namespace test {
registerTestVectorLowerings()908 void registerTestVectorLowerings() {
909   PassRegistration<TestVectorToVectorLowering>();
910 
911   PassRegistration<TestVectorContractionLowering>();
912 
913   PassRegistration<TestVectorTransposeLowering>();
914 
915   PassRegistration<TestVectorUnrollingPatterns>();
916 
917   PassRegistration<TestVectorTransferUnrollingPatterns>();
918 
919   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
920 
921   PassRegistration<TestVectorDistributePatterns>();
922 
923   PassRegistration<TestVectorToLoopPatterns>();
924 
925   PassRegistration<TestVectorTransferOpt>();
926 
927   PassRegistration<TestVectorTransferLoweringPatterns>();
928 
929   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
930 
931   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
932 
933   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
934 
935   PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
936 
937   PassRegistration<TestFlattenVectorTransferPatterns>();
938 
939   PassRegistration<TestVectorScanLowering>();
940 
941   PassRegistration<TestVectorDistribution>();
942 }
943 } // namespace test
944 } // namespace mlir
945