1 //===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <type_traits>
10 
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/Dialect/Vector/VectorTransforms.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 namespace {
25 
26 struct TestVectorToVectorConversion
27     : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
28   TestVectorToVectorConversion() = default;
29   TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
30   StringRef getArgument() const final {
31     return "test-vector-to-vector-conversion";
32   }
33   StringRef getDescription() const final {
34     return "Test conversion patterns between ops in the vector dialect";
35   }
36 
37   void getDependentDialects(DialectRegistry &registry) const override {
38     registry.insert<AffineDialect>();
39   }
40 
41   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
42                       llvm::cl::init(false)};
43 
44   void runOnFunction() override {
45     auto *ctx = &getContext();
46     RewritePatternSet patterns(ctx);
47     if (unroll) {
48       patterns.add<UnrollVectorPattern>(
49           ctx,
50           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
51               filter));
52     }
53     populateVectorToVectorCanonicalizationPatterns(patterns);
54     populateVectorToVectorTransformationPatterns(patterns);
55     populateBubbleVectorBitCastOpPatterns(patterns);
56     populateCastAwayVectorLeadingOneDimPatterns(patterns);
57     populateSplitVectorTransferPatterns(patterns);
58     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
59   }
60 
61 private:
62   // Return the target shape based on op type.
63   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
64     if (isa<AddFOp, SelectOp, CmpFOp>(op))
65       return SmallVector<int64_t, 4>(2, 2);
66     if (isa<vector::ContractionOp>(op))
67       return SmallVector<int64_t, 4>(3, 2);
68     return llvm::None;
69   }
70 
71   static LogicalResult filter(Operation *op) {
72     return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp>(op));
73   }
74 };
75 
76 struct TestVectorSlicesConversion
77     : public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
78   StringRef getArgument() const final {
79     return "test-vector-slices-conversion";
80   }
81   StringRef getDescription() const final {
82     return "Test conversion patterns that lower slices ops in the vector "
83            "dialect";
84   }
85   void runOnFunction() override {
86     RewritePatternSet patterns(&getContext());
87     populateVectorSlicesLoweringPatterns(patterns);
88     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
89   }
90 };
91 
92 struct TestVectorContractionConversion
93     : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
94   StringRef getArgument() const final {
95     return "test-vector-contraction-conversion";
96   }
97   StringRef getDescription() const final {
98     return "Test conversion patterns that lower contract ops in the vector "
99            "dialect";
100   }
101   TestVectorContractionConversion() = default;
102   TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
103   }
104 
105   Option<bool> lowerToFlatMatrix{
106       *this, "vector-lower-matrix-intrinsics",
107       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
108       llvm::cl::init(false)};
109   Option<bool> lowerToFlatTranspose{
110       *this, "vector-flat-transpose",
111       llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
112       llvm::cl::init(false)};
113   Option<bool> lowerToOuterProduct{
114       *this, "vector-outerproduct",
115       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
116       llvm::cl::init(false)};
117   Option<bool> lowerToFilterOuterProduct{
118       *this, "vector-filter-outerproduct",
119       llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
120                      "vectors of size 4."),
121       llvm::cl::init(false)};
122 
123   void runOnFunction() override {
124     RewritePatternSet patterns(&getContext());
125 
126     // Test on one pattern in isolation.
127     if (lowerToOuterProduct) {
128       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
129       VectorTransformsOptions options{lowering};
130       patterns.add<ContractionOpToOuterProductOpLowering>(options,
131                                                           &getContext());
132       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
133       return;
134     }
135 
136     // Test on one pattern in isolation.
137     if (lowerToFilterOuterProduct) {
138       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
139       VectorTransformsOptions options{lowering};
140       patterns.add<ContractionOpToOuterProductOpLowering>(
141           options, &getContext(), [](vector::ContractionOp op) {
142             // Only lowers vector.contract where the lhs as a type vector<MxNx?>
143             // where M is not 4.
144             if (op.getRhsType().getShape()[0] == 4)
145               return failure();
146             return success();
147           });
148       (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
149       return;
150     }
151 
152     // Test on all contract lowering patterns.
153     VectorContractLowering contractLowering = VectorContractLowering::Dot;
154     if (lowerToFlatMatrix)
155       contractLowering = VectorContractLowering::Matmul;
156     VectorTransposeLowering transposeLowering =
157         VectorTransposeLowering::EltWise;
158     if (lowerToFlatTranspose)
159       transposeLowering = VectorTransposeLowering::Flat;
160     VectorTransformsOptions options{contractLowering, transposeLowering};
161     populateVectorContractLoweringPatterns(patterns, options);
162     populateVectorTransposeLoweringPatterns(patterns, options);
163     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
164   }
165 };
166 
167 struct TestVectorUnrollingPatterns
168     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
169   StringRef getArgument() const final {
170     return "test-vector-unrolling-patterns";
171   }
172   StringRef getDescription() const final {
173     return "Test conversion patterns to unroll contract ops in the vector "
174            "dialect";
175   }
176   TestVectorUnrollingPatterns() = default;
177   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
178   void runOnFunction() override {
179     MLIRContext *ctx = &getContext();
180     RewritePatternSet patterns(ctx);
181     patterns.add<UnrollVectorPattern>(
182         ctx, UnrollVectorOptions()
183                  .setNativeShape(ArrayRef<int64_t>{2, 2})
184                  .setFilterConstraint([](Operation *op) {
185                    return success(isa<AddFOp, vector::FMAOp>(op));
186                  }));
187 
188     if (unrollBasedOnType) {
189       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
190           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
191         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
192         SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
193         if (auto floatType = contractOp.getLhsType()
194                                  .getElementType()
195                                  .dyn_cast<FloatType>()) {
196           if (floatType.getWidth() == 16) {
197             nativeShape[2] = 4;
198           }
199         }
200         return nativeShape;
201       };
202       patterns.add<UnrollVectorPattern>(
203           ctx, UnrollVectorOptions()
204                    .setNativeShapeFn(nativeShapeFn)
205                    .setFilterConstraint([](Operation *op) {
206                      return success(isa<ContractionOp>(op));
207                    }));
208     } else {
209       patterns.add<UnrollVectorPattern>(
210           ctx, UnrollVectorOptions()
211                    .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
212                    .setFilterConstraint([](Operation *op) {
213                      return success(isa<ContractionOp>(op));
214                    }));
215     }
216     populateVectorToVectorCanonicalizationPatterns(patterns);
217     populateVectorToVectorTransformationPatterns(patterns);
218     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
219   }
220 
221   Option<bool> unrollBasedOnType{
222       *this, "unroll-based-on-type",
223       llvm::cl::desc("Set the unroll factor based on type of the operation"),
224       llvm::cl::init(false)};
225 };
226 
227 struct TestVectorDistributePatterns
228     : public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
229   StringRef getArgument() const final {
230     return "test-vector-distribute-patterns";
231   }
232   StringRef getDescription() const final {
233     return "Test conversion patterns to distribute vector ops in the vector "
234            "dialect";
235   }
236   TestVectorDistributePatterns() = default;
237   TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
238   void getDependentDialects(DialectRegistry &registry) const override {
239     registry.insert<VectorDialect>();
240     registry.insert<AffineDialect>();
241   }
242   ListOption<int32_t> multiplicity{
243       *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
244       llvm::cl::desc("Set the multiplicity used for distributing vector")};
245 
246   void runOnFunction() override {
247     MLIRContext *ctx = &getContext();
248     RewritePatternSet patterns(ctx);
249     FuncOp func = getFunction();
250     func.walk([&](AddFOp op) {
251       OpBuilder builder(op);
252       if (auto vecType = op.getType().dyn_cast<VectorType>()) {
253         SmallVector<int64_t, 2> mul;
254         SmallVector<AffineExpr, 2> perm;
255         SmallVector<Value, 2> ids;
256         unsigned count = 0;
257         // Remove the multiplicity of 1 and calculate the affine map based on
258         // the multiplicity.
259         SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
260         for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
261           if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
262             mul.push_back(m[i]);
263             ids.push_back(func.getArgument(count++));
264             perm.push_back(getAffineDimExpr(i, ctx));
265           }
266         }
267         auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
268                                   perm, ctx);
269         Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
270             builder, op.getOperation(), ids, mul, map);
271         if (ops.hasValue()) {
272           SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
273           op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
274                                               extractOp);
275         }
276       }
277     });
278     patterns.add<PointwiseExtractPattern>(ctx);
279     populateVectorToVectorTransformationPatterns(patterns);
280     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
281   }
282 };
283 
284 struct TestVectorToLoopPatterns
285     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
286   StringRef getArgument() const final { return "test-vector-to-forloop"; }
287   StringRef getDescription() const final {
288     return "Test conversion patterns to break up a vector op into a for loop";
289   }
290   TestVectorToLoopPatterns() = default;
291   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
292   void getDependentDialects(DialectRegistry &registry) const override {
293     registry.insert<VectorDialect>();
294     registry.insert<AffineDialect>();
295   }
296   Option<int32_t> multiplicity{
297       *this, "distribution-multiplicity",
298       llvm::cl::desc("Set the multiplicity used for distributing vector"),
299       llvm::cl::init(32)};
300   void runOnFunction() override {
301     MLIRContext *ctx = &getContext();
302     RewritePatternSet patterns(ctx);
303     FuncOp func = getFunction();
304     func.walk([&](AddFOp op) {
305       // Check that the operation type can be broken down into a loop.
306       VectorType type = op.getType().dyn_cast<VectorType>();
307       if (!type || type.getRank() != 1 ||
308           type.getNumElements() % multiplicity != 0)
309         return mlir::WalkResult::advance();
310       auto filterAlloc = [](Operation *op) {
311         if (isa<ConstantOp, memref::AllocOp, CallOp>(op))
312           return false;
313         return true;
314       };
315       auto dependentOps = getSlice(op, filterAlloc);
316       // Create a loop and move instructions from the Op slice into the loop.
317       OpBuilder builder(op);
318       auto zero = builder.create<ConstantOp>(
319           op.getLoc(), builder.getIndexType(),
320           builder.getIntegerAttr(builder.getIndexType(), 0));
321       auto one = builder.create<ConstantOp>(
322           op.getLoc(), builder.getIndexType(),
323           builder.getIntegerAttr(builder.getIndexType(), 1));
324       auto numIter = builder.create<ConstantOp>(
325           op.getLoc(), builder.getIndexType(),
326           builder.getIntegerAttr(builder.getIndexType(), multiplicity));
327       auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
328       for (Operation *it : dependentOps) {
329         it->moveBefore(forOp.getBody()->getTerminator());
330       }
331       auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
332       // break up the original op and let the patterns propagate.
333       Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
334           builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
335           map);
336       if (ops.hasValue()) {
337         SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
338         op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
339       }
340       return mlir::WalkResult::interrupt();
341     });
342     patterns.add<PointwiseExtractPattern>(ctx);
343     populateVectorToVectorTransformationPatterns(patterns);
344     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
345   }
346 };
347 
348 struct TestVectorTransferUnrollingPatterns
349     : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
350   void getDependentDialects(DialectRegistry &registry) const override {
351     registry.insert<AffineDialect>();
352   }
353   StringRef getArgument() const final {
354     return "test-vector-transfer-unrolling-patterns";
355   }
356   StringRef getDescription() const final {
357     return "Test conversion patterns to unroll transfer ops in the vector "
358            "dialect";
359   }
360   void runOnFunction() override {
361     MLIRContext *ctx = &getContext();
362     RewritePatternSet patterns(ctx);
363     patterns.add<UnrollVectorPattern>(
364         ctx,
365         UnrollVectorOptions()
366             .setNativeShape(ArrayRef<int64_t>{2, 2})
367             .setFilterConstraint([](Operation *op) {
368               return success(
369                   isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
370             }));
371     populateVectorToVectorCanonicalizationPatterns(patterns);
372     populateVectorToVectorTransformationPatterns(patterns);
373     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
374   }
375 };
376 
377 struct TestVectorTransferFullPartialSplitPatterns
378     : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
379                          FunctionPass> {
380   StringRef getArgument() const final {
381     return "test-vector-transfer-full-partial-split";
382   }
383   StringRef getDescription() const final {
384     return "Test conversion patterns to split "
385            "transfer ops via scf.if + linalg ops";
386   }
387   TestVectorTransferFullPartialSplitPatterns() = default;
388   TestVectorTransferFullPartialSplitPatterns(
389       const TestVectorTransferFullPartialSplitPatterns &pass) {}
390 
391   void getDependentDialects(DialectRegistry &registry) const override {
392     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
393                     scf::SCFDialect>();
394   }
395 
396   Option<bool> useLinalgOps{
397       *this, "use-linalg-copy",
398       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
399                      "linalg.copy operations."),
400       llvm::cl::init(false)};
401   void runOnFunction() override {
402     MLIRContext *ctx = &getContext();
403     RewritePatternSet patterns(ctx);
404     VectorTransformsOptions options;
405     if (useLinalgOps)
406       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
407     else
408       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
409     patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
410     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
411   }
412 };
413 
414 struct TestVectorTransferOpt
415     : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
416   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
417   StringRef getDescription() const final {
418     return "Test optimization transformations for transfer ops";
419   }
420   void runOnFunction() override { transferOpflowOpt(getFunction()); }
421 };
422 
423 struct TestVectorTransferLoweringPatterns
424     : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
425   void getDependentDialects(DialectRegistry &registry) const override {
426     registry.insert<memref::MemRefDialect>();
427   }
428   StringRef getArgument() const final {
429     return "test-vector-transfer-lowering-patterns";
430   }
431   StringRef getDescription() const final {
432     return "Test conversion patterns to lower transfer ops to other vector ops";
433   }
434   void runOnFunction() override {
435     RewritePatternSet patterns(&getContext());
436     populateVectorTransferLoweringPatterns(patterns);
437     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
438   }
439 };
440 
441 struct TestVectorMultiReductionLoweringPatterns
442     : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
443                          FunctionPass> {
444   void getDependentDialects(DialectRegistry &registry) const override {
445     registry.insert<memref::MemRefDialect>();
446   }
447   StringRef getArgument() const final {
448     return "test-vector-multi-reduction-lowering-patterns";
449   }
450   StringRef getDescription() const final {
451     return "Test conversion patterns to lower vector.multi_reduction to other "
452            "vector ops";
453   }
454   void runOnFunction() override {
455     RewritePatternSet patterns(&getContext());
456     populateVectorMultiReductionLoweringPatterns(patterns);
457     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
458   }
459 };
460 
461 } // end anonymous namespace
462 
463 namespace mlir {
464 namespace test {
465 void registerTestVectorConversions() {
466   PassRegistration<TestVectorToVectorConversion>();
467 
468   PassRegistration<TestVectorSlicesConversion>();
469 
470   PassRegistration<TestVectorContractionConversion>();
471 
472   PassRegistration<TestVectorUnrollingPatterns>();
473 
474   PassRegistration<TestVectorTransferUnrollingPatterns>();
475 
476   PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
477 
478   PassRegistration<TestVectorDistributePatterns>();
479 
480   PassRegistration<TestVectorToLoopPatterns>();
481 
482   PassRegistration<TestVectorTransferOpt>();
483 
484   PassRegistration<TestVectorTransferLoweringPatterns>();
485 
486   PassRegistration<TestVectorMultiReductionLoweringPatterns>();
487 }
488 } // namespace test
489 } // namespace mlir
490