1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/Dialect/Vector/VectorTransforms.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 namespace {
25 
26 struct TestVectorToVectorConversion
27     : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
28   TestVectorToVectorConversion() = default;
29   TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
30   StringRef getArgument() const final {
31     return "test-vector-to-vector-conversion";
32   }
33   StringRef getDescription() const final {
34     return "Test conversion patterns between ops in the vector dialect";
35   }
36 
37   void getDependentDialects(DialectRegistry &registry) const override {
38     registry.insert<AffineDialect>();
39   }
40 
41   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
42                       llvm::cl::init(false)};
43 
44   void runOnFunction() override {
45     auto *ctx = &getContext();
46     RewritePatternSet patterns(ctx);
47     if (unroll) {
48       populateVectorUnrollPatterns(
49           patterns,
50           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
51               filter));
52     }
53     populateVectorToVectorCanonicalizationPatterns(patterns);
54     populateBubbleVectorBitCastOpPatterns(patterns);
55     populateCastAwayVectorLeadingOneDimPatterns(patterns);
56     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
57   }
58 
59 private:
60   // Return the target shape based on op type.
61   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
62     if (isa<AddFOp, SelectOp, CmpFOp>(op))
63       return SmallVector<int64_t, 4>(2, 2);
64     if (isa<vector::ContractionOp>(op))
65       return SmallVector<int64_t, 4>(3, 2);
66     // For transfer ops, just propagate the shape coming from
67     // InsertStridedSlices/ExtractStridedSlices.
68     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
69       VectorType dstVec;
70       for (Operation *users : readOp->getUsers()) {
71         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
72         if (!extract)
73           return llvm::None;
74         auto vecType = extract.getResult().getType().cast<VectorType>();
75         if (dstVec && dstVec != vecType)
76           return llvm::None;
77         dstVec = vecType;
78       }
79       return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
80                                      dstVec.getShape().end());
81     }
82     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
83       auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>();
84       if (!insert)
85         return llvm::None;
86       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
87       return SmallVector<int64_t, 4>(shape.begin(), shape.end());
88     }
89     return llvm::None;
90   }
91 
92   static LogicalResult filter(Operation *op) {
93     return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp, TransferReadOp,
94                        TransferWriteOp>(op));
95   }
96 };
97 
98 struct TestVectorContractionConversion
99     : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
100   StringRef getArgument() const final {
101     return "test-vector-contraction-conversion";
102   }
103   StringRef getDescription() const final {
104     return "Test conversion patterns that lower contract ops in the vector "
105            "dialect";
106   }
107   TestVectorContractionConversion() = default;
108   TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
109   }
110 
111   Option<bool> lowerToFlatMatrix{
112       *this, "vector-lower-matrix-intrinsics",
113       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
114       llvm::cl::init(false)};
115   Option<bool> lowerToFlatTranspose{
116       *this, "vector-flat-transpose",
117       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
118       llvm::cl::init(false)};
119   Option<bool> lowerToOuterProduct{
120       *this, "vector-outerproduct",
121       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
122       llvm::cl::init(false)};
123   Option<bool> lowerToFilterOuterProduct{
124       *this, "vector-filter-outerproduct",
125       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
126                      "vectors of size 4."),
127       llvm::cl::init(false)};
128 
129   void runOnFunction() override {
130     RewritePatternSet patterns(&getContext());
131 
132     // Test on one pattern in isolation.
133     if (lowerToOuterProduct) {
134       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
135       VectorTransformsOptions options{lowering};
136       patterns.add<ContractionOpToOuterProductOpLowering>(options,
137                                                           &getContext());
138       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
139       return;
140     }
141 
142     // Test on one pattern in isolation.
143     if (lowerToFilterOuterProduct) {
144       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
145       VectorTransformsOptions options{lowering};
146       patterns.add<ContractionOpToOuterProductOpLowering>(
147           options, &getContext(), [](vector::ContractionOp op) {
148             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
149             // where M is not 4.
150             if (op.getRhsType().getShape()[0] == 4)
151               return failure();
152             return success();
153           });
154       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
155       return;
156     }
157 
158     // Test on all contract lowering patterns.
159     VectorContractLowering contractLowering = VectorContractLowering::Dot;
160     if (lowerToFlatMatrix)
161       contractLowering = VectorContractLowering::Matmul;
162     VectorTransposeLowering transposeLowering =
163         VectorTransposeLowering::EltWise;
164     if (lowerToFlatTranspose)
165       transposeLowering = VectorTransposeLowering::Flat;
166     VectorTransformsOptions options{contractLowering, transposeLowering};
167     populateVectorBroadcastLoweringPatterns(patterns);
168     populateVectorContractLoweringPatterns(patterns, options);
169     populateVectorMaskOpLoweringPatterns(patterns);
170     populateVectorShapeCastLoweringPatterns(patterns);
171     populateVectorTransposeLoweringPatterns(patterns, options);
172     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
173   }
174 };
175 
176 struct TestVectorUnrollingPatterns
177     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
178   StringRef getArgument() const final {
179     return "test-vector-unrolling-patterns";
180   }
181   StringRef getDescription() const final {
182     return "Test conversion patterns to unroll contract ops in the vector "
183            "dialect";
184   }
185   TestVectorUnrollingPatterns() = default;
186   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
187   void runOnFunction() override {
188     MLIRContext *ctx = &getContext();
189     RewritePatternSet patterns(ctx);
190     populateVectorUnrollPatterns(
191         patterns, UnrollVectorOptions()
192                       .setNativeShape(ArrayRef<int64_t>{2, 2})
193                       .setFilterConstraint([](Operation *op) {
194                         return success(isa<AddFOp, vector::FMAOp>(op));
195                       }));
196 
197     if (unrollBasedOnType) {
198       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
199           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
200         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
201         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
202         if (auto floatType = contractOp.getLhsType()
203                                  .getElementType()
204                                  .dyn_cast<FloatType>()) {
205           if (floatType.getWidth() == 16) {
206             nativeShape[2] = 4;
207           }
208         }
209         return nativeShape;
210       };
211       populateVectorUnrollPatterns(patterns,
212                                    UnrollVectorOptions()
213                                        .setNativeShapeFn(nativeShapeFn)
214                                        .setFilterConstraint([](Operation *op) {
215                                          return success(isa<ContractionOp>(op));
216                                        }));
217     } else {
218       populateVectorUnrollPatterns(
219           patterns, UnrollVectorOptions()
220                         .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
221                         .setFilterConstraint([](Operation *op) {
222                           return success(isa<ContractionOp>(op));
223                         }));
224     }
225     populateVectorToVectorCanonicalizationPatterns(patterns);
226     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
227   }
228 
229   Option<bool> unrollBasedOnType{
230       *this, "unroll-based-on-type",
231       llvm::cl::desc("Set the unroll factor based on type of the operation"),
232       llvm::cl::init(false)};
233 };
234 
235 struct TestVectorDistributePatterns
236     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
237   StringRef getArgument() const final {
238     return "test-vector-distribute-patterns";
239   }
240   StringRef getDescription() const final {
241     return "Test conversion patterns to distribute vector ops in the vector "
242            "dialect";
243   }
244   TestVectorDistributePatterns() = default;
245   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
246   void getDependentDialects(DialectRegistry &registry) const override {
247     registry.insert<VectorDialect>();
248     registry.insert<AffineDialect>();
249   }
250   ListOption<int32_t> multiplicity{
251       *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
252       llvm::cl::desc("Set the multiplicity used for distributing vector")};
253 
254   void runOnFunction() override {
255     MLIRContext *ctx = &getContext();
256     RewritePatternSet patterns(ctx);
257     FuncOp func = getFunction();
258     func.walk([&](AddFOp op) {
259       OpBuilder builder(op);
260       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
261         SmallVector<int64_t, 2> mul;
262         SmallVector<AffineExpr, 2> perm;
263         SmallVector<Value, 2> ids;
264         unsigned count = 0;
265         // Remove the multiplicity of 1 and calculate the affine map based on
266         // the multiplicity.
267         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
268         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
269           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
270             mul.push_back(m[i]);
271             ids.push_back(func.getArgument(count++));
272             perm.push_back(getAffineDimExpr(i, ctx));
273           }
274         }
275         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
276                                   perm, ctx);
277         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
278             builder, op.getOperation(), ids, mul, map);
279         if (ops.hasValue()) {
280           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
281           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
282                                               extractOp);
283         }
284       }
285     });
286     populatePropagateVectorDistributionPatterns(patterns);
287     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
288   }
289 };
290 
291 struct TestVectorToLoopPatterns
292     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
293   StringRef getArgument() const final { return "test-vector-to-forloop"; }
294   StringRef getDescription() const final {
295     return "Test conversion patterns to break up a vector op into a for loop";
296   }
297   TestVectorToLoopPatterns() = default;
298   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
299   void getDependentDialects(DialectRegistry &registry) const override {
300     registry.insert<VectorDialect>();
301     registry.insert<AffineDialect>();
302   }
303   Option<int32_t> multiplicity{
304       *this, "distribution-multiplicity",
305       llvm::cl::desc("Set the multiplicity used for distributing vector"),
306       llvm::cl::init(32)};
307   void runOnFunction() override {
308     MLIRContext *ctx = &getContext();
309     RewritePatternSet patterns(ctx);
310     FuncOp func = getFunction();
311     func.walk([&](AddFOp op) {
312       // Check that the operation type can be broken down into a loop.
313       VectorType type = op.getType().dyn_cast<VectorType>();
314       if (!type || type.getRank() != 1 ||
315           type.getNumElements() % multiplicity != 0)
316         return mlir::WalkResult::advance();
317       auto filterAlloc = [](Operation *op) {
318         if (isa<ConstantOp, memref::AllocOp, CallOp>(op))
319           return false;
320         return true;
321       };
322       auto dependentOps = getSlice(op, filterAlloc);
323       // Create a loop and move instructions from the Op slice into the loop.
324       OpBuilder builder(op);
325       auto zero = builder.create<ConstantOp>(
326           op.getLoc(), builder.getIndexType(),
327           builder.getIntegerAttr(builder.getIndexType(), 0));
328       auto one = builder.create<ConstantOp>(
329           op.getLoc(), builder.getIndexType(),
330           builder.getIntegerAttr(builder.getIndexType(), 1));
331       auto numIter = builder.create<ConstantOp>(
332           op.getLoc(), builder.getIndexType(),
333           builder.getIntegerAttr(builder.getIndexType(), multiplicity));
334       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
335       for (Operation *it : dependentOps) {
336         it->moveBefore(forOp.getBody()->getTerminator());
337       }
338       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
339       // break up the original op and let the patterns propagate.
340       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
341           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
342           map);
343       if (ops.hasValue()) {
344         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
345         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
346       }
347       return mlir::WalkResult::interrupt();
348     });
349     populatePropagateVectorDistributionPatterns(patterns);
350     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
351   }
352 };
353 
354 struct TestVectorTransferUnrollingPatterns
355     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
356   void getDependentDialects(DialectRegistry &registry) const override {
357     registry.insert<AffineDialect>();
358   }
359   StringRef getArgument() const final {
360     return "test-vector-transfer-unrolling-patterns";
361   }
362   StringRef getDescription() const final {
363     return "Test conversion patterns to unroll transfer ops in the vector "
364            "dialect";
365   }
366   void runOnFunction() override {
367     MLIRContext *ctx = &getContext();
368     RewritePatternSet patterns(ctx);
369     populateVectorUnrollPatterns(
370         patterns,
371         UnrollVectorOptions()
372             .setNativeShape(ArrayRef<int64_t>{2, 2})
373             .setFilterConstraint([](Operation *op) {
374               return success(
375                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
376             }));
377     populateVectorToVectorCanonicalizationPatterns(patterns);
378     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
379   }
380 };
381 
382 struct TestVectorTransferFullPartialSplitPatterns
383     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
384                          FunctionPass> {
385   StringRef getArgument() const final {
386     return "test-vector-transfer-full-partial-split";
387   }
388   StringRef getDescription() const final {
389     return "Test conversion patterns to split "
390            "transfer ops via scf.if + linalg ops";
391   }
392   TestVectorTransferFullPartialSplitPatterns() = default;
393   TestVectorTransferFullPartialSplitPatterns(
394       const TestVectorTransferFullPartialSplitPatterns &pass) {}
395 
396   void getDependentDialects(DialectRegistry &registry) const override {
397     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
398                     scf::SCFDialect>();
399   }
400 
401   Option<bool> useLinalgOps{
402       *this, "use-linalg-copy",
403       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
404                      "linalg.copy operations."),
405       llvm::cl::init(false)};
406   void runOnFunction() override {
407     MLIRContext *ctx = &getContext();
408     RewritePatternSet patterns(ctx);
409     VectorTransformsOptions options;
410     if (useLinalgOps)
411       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
412     else
413       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
414     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
415     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
416   }
417 };
418 
419 struct TestVectorTransferOpt
420     : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
421   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
422   StringRef getDescription() const final {
423     return "Test optimization transformations for transfer ops";
424   }
425   void runOnFunction() override { transferOpflowOpt(getFunction()); }
426 };
427 
428 struct TestVectorTransferLoweringPatterns
429     : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
430   void getDependentDialects(DialectRegistry &registry) const override {
431     registry.insert<memref::MemRefDialect>();
432   }
433   StringRef getArgument() const final {
434     return "test-vector-transfer-lowering-patterns";
435   }
436   StringRef getDescription() const final {
437     return "Test conversion patterns to lower transfer ops to other vector ops";
438   }
439   void runOnFunction() override {
440     RewritePatternSet patterns(&getContext());
441     populateVectorTransferLoweringPatterns(patterns);
442     populateVectorTransferPermutationMapLoweringPatterns(patterns);
443     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
444   }
445 };
446 
447 struct TestVectorMultiReductionLoweringPatterns
448     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
449                          FunctionPass> {
450   TestVectorMultiReductionLoweringPatterns() = default;
451   TestVectorMultiReductionLoweringPatterns(
452       const TestVectorMultiReductionLoweringPatterns &pass) {}
453   void getDependentDialects(DialectRegistry &registry) const override {
454     registry.insert<memref::MemRefDialect>();
455   }
456   StringRef getArgument() const final {
457     return "test-vector-multi-reduction-lowering-patterns";
458   }
459   StringRef getDescription() const final {
460     return "Test conversion patterns to lower vector.multi_reduction to other "
461            "vector ops";
462   }
463   Option<bool> useOuterReductions{
464       *this, "use-outer-reductions",
465       llvm::cl::desc("Move reductions to outer most dimensions"),
466       llvm::cl::init(false)};
467   void runOnFunction() override {
468     RewritePatternSet patterns(&getContext());
469     populateVectorMultiReductionLoweringPatterns(patterns, !useOuterReductions);
470     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
471   }
472 };
473 
474 } // end anonymous namespace
475 
476 namespace mlir {
477 namespace test {
478 void registerTestVectorConversions() {
479   PassRegistration<TestVectorToVectorConversion>();
480 
481   PassRegistration<TestVectorContractionConversion>();
482 
483   PassRegistration<TestVectorUnrollingPatterns>();
484 
485   PassRegistration<TestVectorTransferUnrollingPatterns>();
486 
487   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
488 
489   PassRegistration<TestVectorDistributePatterns>();
490 
491   PassRegistration<TestVectorToLoopPatterns>();
492 
493   PassRegistration<TestVectorTransferOpt>();
494 
495   PassRegistration<TestVectorTransferLoweringPatterns>();
496 
497   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
498 }
499 } // namespace test
500 } // namespace mlir
501