1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/Dialect/Vector/VectorTransforms.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 namespace {
25 
26 struct TestVectorToVectorConversion
27     : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
28   TestVectorToVectorConversion() = default;
29   TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
30   StringRef getArgument() const final {
31     return "test-vector-to-vector-conversion";
32   }
33   StringRef getDescription() const final {
34     return "Test conversion patterns between ops in the vector dialect";
35   }
36 
37   void getDependentDialects(DialectRegistry &registry) const override {
38     registry.insert<AffineDialect>();
39   }
40 
41   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
42                       llvm::cl::init(false)};
43 
44   void runOnFunction() override {
45     auto *ctx = &getContext();
46     RewritePatternSet patterns(ctx);
47     if (unroll) {
48       populateVectorUnrollPatterns(
49           patterns,
50           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
51               filter));
52     }
53     populateVectorToVectorCanonicalizationPatterns(patterns);
54     populateBubbleVectorBitCastOpPatterns(patterns);
55     populateCastAwayVectorLeadingOneDimPatterns(patterns);
56     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
57   }
58 
59 private:
60   // Return the target shape based on op type.
61   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
62     if (isa<arith::AddFOp, SelectOp, arith::CmpFOp>(op))
63       return SmallVector<int64_t, 4>(2, 2);
64     if (isa<vector::ContractionOp>(op))
65       return SmallVector<int64_t, 4>(3, 2);
66     // For transfer ops, just propagate the shape coming from
67     // InsertStridedSlices/ExtractStridedSlices.
68     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
69       VectorType dstVec;
70       for (Operation *users : readOp->getUsers()) {
71         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
72         if (!extract)
73           return llvm::None;
74         auto vecType = extract.getResult().getType().cast<VectorType>();
75         if (dstVec && dstVec != vecType)
76           return llvm::None;
77         dstVec = vecType;
78       }
79       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
80                                      dstVec.getShape().end());
81     }
82     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
83       auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>();
84       if (!insert)
85         return llvm::None;
86       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
87       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
88     }
89     return llvm::None;
90   }
91 
92   static LogicalResult filter(Operation *op) {
93     return success(isa<arith::AddFOp, SelectOp, arith::CmpFOp, ContractionOp,
94                        TransferReadOp, TransferWriteOp>(op));
95   }
96 };
97 
98 struct TestVectorContractionConversion
99     : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
100   StringRef getArgument() const final {
101     return "test-vector-contraction-conversion";
102   }
103   StringRef getDescription() const final {
104     return "Test conversion patterns that lower contract ops in the vector "
105            "dialect";
106   }
107   TestVectorContractionConversion() = default;
108   TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
109   }
110 
111   Option<bool> lowerToFlatMatrix{
112       *this, "vector-lower-matrix-intrinsics",
113       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
114       llvm::cl::init(false)};
115   Option<bool> lowerToFlatTranspose{
116       *this, "vector-flat-transpose",
117       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
118       llvm::cl::init(false)};
119   Option<bool> lowerToOuterProduct{
120       *this, "vector-outerproduct",
121       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
122       llvm::cl::init(false)};
123   Option<bool> lowerToFilterOuterProduct{
124       *this, "vector-filter-outerproduct",
125       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
126                      "vectors of size 4."),
127       llvm::cl::init(false)};
128 
129   void runOnFunction() override {
130     RewritePatternSet patterns(&getContext());
131 
132     // Test on one pattern in isolation.
133     if (lowerToOuterProduct) {
134       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
135       VectorTransformsOptions options{lowering};
136       patterns.add<ContractionOpToOuterProductOpLowering>(options,
137                                                           &getContext());
138       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
139       return;
140     }
141 
142     // Test on one pattern in isolation.
143     if (lowerToFilterOuterProduct) {
144       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
145       VectorTransformsOptions options{lowering};
146       patterns.add<ContractionOpToOuterProductOpLowering>(
147           options, &getContext(), [](vector::ContractionOp op) {
148             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
149             // where M is not 4.
150             if (op.getRhsType().getShape()[0] == 4)
151               return failure();
152             return success();
153           });
154       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
155       return;
156     }
157 
158     // Test on all contract lowering patterns.
159     VectorContractLowering contractLowering = VectorContractLowering::Dot;
160     if (lowerToFlatMatrix)
161       contractLowering = VectorContractLowering::Matmul;
162     VectorTransposeLowering transposeLowering =
163         VectorTransposeLowering::EltWise;
164     if (lowerToFlatTranspose)
165       transposeLowering = VectorTransposeLowering::Flat;
166     VectorTransformsOptions options{contractLowering, transposeLowering};
167     populateVectorBroadcastLoweringPatterns(patterns);
168     populateVectorContractLoweringPatterns(patterns, options);
169     populateVectorMaskOpLoweringPatterns(patterns);
170     populateVectorShapeCastLoweringPatterns(patterns);
171     populateVectorTransposeLoweringPatterns(patterns, options);
172     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
173   }
174 };
175 
176 struct TestVectorUnrollingPatterns
177     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
178   StringRef getArgument() const final {
179     return "test-vector-unrolling-patterns";
180   }
181   StringRef getDescription() const final {
182     return "Test conversion patterns to unroll contract ops in the vector "
183            "dialect";
184   }
185   TestVectorUnrollingPatterns() = default;
186   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
187   void runOnFunction() override {
188     MLIRContext *ctx = &getContext();
189     RewritePatternSet patterns(ctx);
190     populateVectorUnrollPatterns(
191         patterns, UnrollVectorOptions()
192                       .setNativeShape(ArrayRef<int64_t>{2, 2})
193                       .setFilterConstraint([](Operation *op) {
194                         return success(isa<arith::AddFOp, vector::FMAOp>(op));
195                       }));
196 
197     if (unrollBasedOnType) {
198       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
199           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
200         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
201         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
202         if (auto floatType = contractOp.getLhsType()
203                                  .getElementType()
204                                  .dyn_cast<FloatType>()) {
205           if (floatType.getWidth() == 16) {
206             nativeShape[2] = 4;
207           }
208         }
209         return nativeShape;
210       };
211       populateVectorUnrollPatterns(patterns,
212                                    UnrollVectorOptions()
213                                        .setNativeShapeFn(nativeShapeFn)
214                                        .setFilterConstraint([](Operation *op) {
215                                          return success(isa<ContractionOp>(op));
216                                        }));
217     } else {
218       populateVectorUnrollPatterns(
219           patterns, UnrollVectorOptions()
220                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
221                         .setFilterConstraint([](Operation *op) {
222                           return success(isa<ContractionOp>(op));
223                         }));
224     }
225     populateVectorToVectorCanonicalizationPatterns(patterns);
226     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
227   }
228 
229   Option<bool> unrollBasedOnType{
230       *this, "unroll-based-on-type",
231       llvm::cl::desc("Set the unroll factor based on type of the operation"),
232       llvm::cl::init(false)};
233 };
234 
235 struct TestVectorDistributePatterns
236     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
237   StringRef getArgument() const final {
238     return "test-vector-distribute-patterns";
239   }
240   StringRef getDescription() const final {
241     return "Test conversion patterns to distribute vector ops in the vector "
242            "dialect";
243   }
244   TestVectorDistributePatterns() = default;
245   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
246   void getDependentDialects(DialectRegistry &registry) const override {
247     registry.insert<VectorDialect>();
248     registry.insert<AffineDialect>();
249   }
250   ListOption<int32_t> multiplicity{
251       *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
252       llvm::cl::desc("Set the multiplicity used for distributing vector")};
253 
254   void runOnFunction() override {
255     MLIRContext *ctx = &getContext();
256     RewritePatternSet patterns(ctx);
257     FuncOp func = getFunction();
258     func.walk([&](arith::AddFOp op) {
259       OpBuilder builder(op);
260       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
261         SmallVector<int64_t, 2> mul;
262         SmallVector<AffineExpr, 2> perm;
263         SmallVector<Value, 2> ids;
264         unsigned count = 0;
265         // Remove the multiplicity of 1 and calculate the affine map based on
266         // the multiplicity.
267         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
268         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
269           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
270             mul.push_back(m[i]);
271             ids.push_back(func.getArgument(count++));
272             perm.push_back(getAffineDimExpr(i, ctx));
273           }
274         }
275         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
276                                   perm, ctx);
277         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
278             builder, op.getOperation(), ids, mul, map);
279         if (ops.hasValue()) {
280           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
281           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
282                                               extractOp);
283         }
284       }
285     });
286     populatePropagateVectorDistributionPatterns(patterns);
287     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
288   }
289 };
290 
291 struct TestVectorToLoopPatterns
292     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
293   StringRef getArgument() const final { return "test-vector-to-forloop"; }
294   StringRef getDescription() const final {
295     return "Test conversion patterns to break up a vector op into a for loop";
296   }
297   TestVectorToLoopPatterns() = default;
298   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
299   void getDependentDialects(DialectRegistry &registry) const override {
300     registry.insert<VectorDialect>();
301     registry.insert<AffineDialect>();
302   }
303   Option<int32_t> multiplicity{
304       *this, "distribution-multiplicity",
305       llvm::cl::desc("Set the multiplicity used for distributing vector"),
306       llvm::cl::init(32)};
307   void runOnFunction() override {
308     MLIRContext *ctx = &getContext();
309     RewritePatternSet patterns(ctx);
310     FuncOp func = getFunction();
311     func.walk([&](arith::AddFOp op) {
312       // Check that the operation type can be broken down into a loop.
313       VectorType type = op.getType().dyn_cast<VectorType>();
314       if (!type || type.getRank() != 1 ||
315           type.getNumElements() % multiplicity != 0)
316         return mlir::WalkResult::advance();
317       auto filterAlloc = [](Operation *op) {
318         if (isa<arith::ConstantOp, memref::AllocOp, CallOp>(op))
319           return false;
320         return true;
321       };
322       auto dependentOps = getSlice(op, filterAlloc);
323       // Create a loop and move instructions from the Op slice into the loop.
324       OpBuilder builder(op);
325       auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
326       auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
327       auto numIter =
328           builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
329       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
330       for (Operation *it : dependentOps) {
331         it->moveBefore(forOp.getBody()->getTerminator());
332       }
333       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
334       // break up the original op and let the patterns propagate.
335       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
336           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
337           map);
338       if (ops.hasValue()) {
339         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
340         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
341       }
342       return mlir::WalkResult::interrupt();
343     });
344     populatePropagateVectorDistributionPatterns(patterns);
345     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
346   }
347 };
348 
349 struct TestVectorTransferUnrollingPatterns
350     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
351   void getDependentDialects(DialectRegistry &registry) const override {
352     registry.insert<AffineDialect>();
353   }
354   StringRef getArgument() const final {
355     return "test-vector-transfer-unrolling-patterns";
356   }
357   StringRef getDescription() const final {
358     return "Test conversion patterns to unroll transfer ops in the vector "
359            "dialect";
360   }
361   void runOnFunction() override {
362     MLIRContext *ctx = &getContext();
363     RewritePatternSet patterns(ctx);
364     populateVectorUnrollPatterns(
365         patterns,
366         UnrollVectorOptions()
367             .setNativeShape(ArrayRef<int64_t>{2, 2})
368             .setFilterConstraint([](Operation *op) {
369               return success(
370                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
371             }));
372     populateVectorToVectorCanonicalizationPatterns(patterns);
373     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
374   }
375 };
376 
377 struct TestVectorTransferFullPartialSplitPatterns
378     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
379                          FunctionPass> {
380   StringRef getArgument() const final {
381     return "test-vector-transfer-full-partial-split";
382   }
383   StringRef getDescription() const final {
384     return "Test conversion patterns to split "
385            "transfer ops via scf.if + linalg ops";
386   }
387   TestVectorTransferFullPartialSplitPatterns() = default;
388   TestVectorTransferFullPartialSplitPatterns(
389       const TestVectorTransferFullPartialSplitPatterns &pass) {}
390 
391   void getDependentDialects(DialectRegistry &registry) const override {
392     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
393                     scf::SCFDialect>();
394   }
395 
396   Option<bool> useLinalgOps{
397       *this, "use-linalg-copy",
398       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
399                      "linalg.copy operations."),
400       llvm::cl::init(false)};
401   void runOnFunction() override {
402     MLIRContext *ctx = &getContext();
403     RewritePatternSet patterns(ctx);
404     VectorTransformsOptions options;
405     if (useLinalgOps)
406       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
407     else
408       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
409     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
410     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
411   }
412 };
413 
414 struct TestVectorTransferOpt
415     : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
416   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
417   StringRef getDescription() const final {
418     return "Test optimization transformations for transfer ops";
419   }
420   void runOnFunction() override { transferOpflowOpt(getFunction()); }
421 };
422 
423 struct TestVectorTransferLoweringPatterns
424     : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
425   void getDependentDialects(DialectRegistry &registry) const override {
426     registry.insert<memref::MemRefDialect>();
427   }
428   StringRef getArgument() const final {
429     return "test-vector-transfer-lowering-patterns";
430   }
431   StringRef getDescription() const final {
432     return "Test conversion patterns to lower transfer ops to other vector ops";
433   }
434   void runOnFunction() override {
435     RewritePatternSet patterns(&getContext());
436     populateVectorTransferLoweringPatterns(patterns);
437     populateVectorTransferPermutationMapLoweringPatterns(patterns);
438     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
439   }
440 };
441 
442 struct TestVectorMultiReductionLoweringPatterns
443     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
444                          FunctionPass> {
445   TestVectorMultiReductionLoweringPatterns() = default;
446   TestVectorMultiReductionLoweringPatterns(
447       const TestVectorMultiReductionLoweringPatterns &pass) {}
448   void getDependentDialects(DialectRegistry &registry) const override {
449     registry.insert<memref::MemRefDialect>();
450   }
451   StringRef getArgument() const final {
452     return "test-vector-multi-reduction-lowering-patterns";
453   }
454   StringRef getDescription() const final {
455     return "Test conversion patterns to lower vector.multi_reduction to other "
456            "vector ops";
457   }
458   Option<bool> useOuterReductions{
459       *this, "use-outer-reductions",
460       llvm::cl::desc("Move reductions to outer most dimensions"),
461       llvm::cl::init(false)};
462   void runOnFunction() override {
463     RewritePatternSet patterns(&getContext());
464     populateVectorMultiReductionLoweringPatterns(patterns, !useOuterReductions);
465     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
466   }
467 };
468 
469 } // end anonymous namespace
470 
471 namespace mlir {
472 namespace test {
473 void registerTestVectorConversions() {
474   PassRegistration<TestVectorToVectorConversion>();
475 
476   PassRegistration<TestVectorContractionConversion>();
477 
478   PassRegistration<TestVectorUnrollingPatterns>();
479 
480   PassRegistration<TestVectorTransferUnrollingPatterns>();
481 
482   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
483 
484   PassRegistration<TestVectorDistributePatterns>();
485 
486   PassRegistration<TestVectorToLoopPatterns>();
487 
488   PassRegistration<TestVectorTransferOpt>();
489 
490   PassRegistration<TestVectorTransferLoweringPatterns>();
491 
492   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
493 }
494 } // namespace test
495 } // namespace mlir
496