1 //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <type_traits>
10
11 #include "mlir/Analysis/SliceAnalysis.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28
29 using namespace mlir;
30 using namespace mlir::linalg;
31 using namespace mlir::vector;
32
33 namespace {
34
35 struct TestVectorToVectorLowering
36 : public PassWrapper<TestVectorToVectorLowering,
37 OperationPass<func::FuncOp>> {
38 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
39
40 TestVectorToVectorLowering() = default;
TestVectorToVectorLowering__anonb56dea510111::TestVectorToVectorLowering41 TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
42 : PassWrapper(pass) {}
getArgument__anonb56dea510111::TestVectorToVectorLowering43 StringRef getArgument() const final {
44 return "test-vector-to-vector-lowering";
45 }
getDescription__anonb56dea510111::TestVectorToVectorLowering46 StringRef getDescription() const final {
47 return "Test lowering patterns between ops in the vector dialect";
48 }
49
getDependentDialects__anonb56dea510111::TestVectorToVectorLowering50 void getDependentDialects(DialectRegistry ®istry) const override {
51 registry.insert<AffineDialect>();
52 }
53
54 Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
55 llvm::cl::init(false)};
56
runOnOperation__anonb56dea510111::TestVectorToVectorLowering57 void runOnOperation() override {
58 auto *ctx = &getContext();
59 RewritePatternSet patterns(ctx);
60 if (unroll) {
61 populateVectorUnrollPatterns(
62 patterns,
63 UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
64 filter));
65 }
66 populateVectorToVectorCanonicalizationPatterns(patterns);
67 populateBubbleVectorBitCastOpPatterns(patterns);
68 populateCastAwayVectorLeadingOneDimPatterns(patterns);
69 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
70 }
71
72 private:
73 // Return the target shape based on op type.
getShape__anonb56dea510111::TestVectorToVectorLowering74 static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
75 if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
76 return SmallVector<int64_t, 4>(2, 2);
77 if (isa<vector::ContractionOp>(op))
78 return SmallVector<int64_t, 4>(3, 2);
79 // For transfer ops, just propagate the shape coming from
80 // InsertStridedSlices/ExtractStridedSlices.
81 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
82 VectorType dstVec;
83 for (Operation *users : readOp->getUsers()) {
84 auto extract = dyn_cast<ExtractStridedSliceOp>(users);
85 if (!extract)
86 return llvm::None;
87 auto vecType = extract.getResult().getType().cast<VectorType>();
88 if (dstVec && dstVec != vecType)
89 return llvm::None;
90 dstVec = vecType;
91 }
92 return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
93 dstVec.getShape().end());
94 }
95 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
96 auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
97 if (!insert)
98 return llvm::None;
99 ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
100 return SmallVector<int64_t, 4>(shape.begin(), shape.end());
101 }
102 return llvm::None;
103 }
104
filter__anonb56dea510111::TestVectorToVectorLowering105 static LogicalResult filter(Operation *op) {
106 return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
107 ContractionOp, TransferReadOp, TransferWriteOp>(op));
108 }
109 };
110
111 struct TestVectorContractionLowering
112 : public PassWrapper<TestVectorContractionLowering,
113 OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorContractionLowering114 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
115
116 StringRef getArgument() const final {
117 return "test-vector-contraction-lowering";
118 }
getDescription__anonb56dea510111::TestVectorContractionLowering119 StringRef getDescription() const final {
120 return "Test lowering patterns that lower contract ops in the vector "
121 "dialect";
122 }
123 TestVectorContractionLowering() = default;
TestVectorContractionLowering__anonb56dea510111::TestVectorContractionLowering124 TestVectorContractionLowering(const TestVectorContractionLowering &pass)
125 : PassWrapper(pass) {}
126
127 Option<bool> lowerToFlatMatrix{
128 *this, "vector-lower-matrix-intrinsics",
129 llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
130 llvm::cl::init(false)};
131 Option<bool> lowerToOuterProduct{
132 *this, "vector-outerproduct",
133 llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
134 llvm::cl::init(false)};
135 Option<bool> lowerToFilterOuterProduct{
136 *this, "vector-filter-outerproduct",
137 llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
138 "vectors of size 4."),
139 llvm::cl::init(false)};
140 Option<bool> lowerToParallelArith{
141 *this, "vector-parallel-arith",
142 llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
143 llvm::cl::init(false)};
144
runOnOperation__anonb56dea510111::TestVectorContractionLowering145 void runOnOperation() override {
146 RewritePatternSet patterns(&getContext());
147
148 // Test on one pattern in isolation.
149 if (lowerToOuterProduct) {
150 VectorContractLowering lowering = VectorContractLowering::OuterProduct;
151 VectorTransformsOptions options{lowering};
152 patterns.add<ContractionOpToOuterProductOpLowering>(options,
153 &getContext());
154 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
155 return;
156 }
157
158 // Test on one pattern in isolation.
159 if (lowerToFilterOuterProduct) {
160 VectorContractLowering lowering = VectorContractLowering::OuterProduct;
161 VectorTransformsOptions options{lowering};
162 patterns.add<ContractionOpToOuterProductOpLowering>(
163 options, &getContext(), [](vector::ContractionOp op) {
164 // Only lowers vector.contract where the lhs as a type vector<MxNx?>
165 // where M is not 4.
166 if (op.getRhsType().getShape()[0] == 4)
167 return failure();
168 return success();
169 });
170 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
171 return;
172 }
173
174 if (lowerToParallelArith) {
175 vector::populateVectorContractLoweringPatterns(
176 patterns,
177 vector::VectorTransformsOptions().setVectorTransformsOptions(
178 vector::VectorContractLowering::ParallelArith));
179 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
180 return;
181 }
182
183 // Test on all contract lowering patterns.
184 VectorContractLowering contractLowering = VectorContractLowering::Dot;
185 if (lowerToFlatMatrix)
186 contractLowering = VectorContractLowering::Matmul;
187 VectorMultiReductionLowering vectorMultiReductionLowering =
188 VectorMultiReductionLowering::InnerParallel;
189 VectorTransformsOptions options{contractLowering,
190 vectorMultiReductionLowering,
191 VectorTransposeLowering()};
192 populateVectorBroadcastLoweringPatterns(patterns);
193 populateVectorContractLoweringPatterns(patterns, options);
194 populateVectorMaskOpLoweringPatterns(patterns);
195 populateVectorShapeCastLoweringPatterns(patterns);
196 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
197 }
198 };
199
200 struct TestVectorTransposeLowering
201 : public PassWrapper<TestVectorTransposeLowering,
202 OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransposeLowering203 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
204
205 StringRef getArgument() const final {
206 return "test-vector-transpose-lowering";
207 }
getDescription__anonb56dea510111::TestVectorTransposeLowering208 StringRef getDescription() const final {
209 return "Test lowering patterns that lower contract ops in the vector "
210 "dialect";
211 }
212 TestVectorTransposeLowering() = default;
TestVectorTransposeLowering__anonb56dea510111::TestVectorTransposeLowering213 TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
214 : PassWrapper(pass) {}
215
216 Option<bool> lowerToEltwise{
217 *this, "eltwise",
218 llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
219 llvm::cl::init(false)};
220 Option<bool> lowerToFlatTranspose{
221 *this, "flat",
222 llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
223 llvm::cl::init(false)};
224 Option<bool> lowerToShuffleTranspose{
225 *this, "shuffle",
226 llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
227 llvm::cl::init(false)};
228 Option<bool> lowerToAvx2{
229 *this, "avx2",
230 llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
231 llvm::cl::init(false)};
232
getDependentDialects__anonb56dea510111::TestVectorTransposeLowering233 void getDependentDialects(DialectRegistry ®istry) const override {
234 registry.insert<LLVM::LLVMDialect>();
235 }
236
runOnOperation__anonb56dea510111::TestVectorTransposeLowering237 void runOnOperation() override {
238 RewritePatternSet patterns(&getContext());
239
240 // Test on one pattern in isolation.
241 // Explicitly disable shape_cast lowering.
242 LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
243 .enableVectorTransposeLowering()
244 .enableShapeCastLowering(false);
245 if (lowerToEltwise) {
246 options = options.setVectorTransformsOptions(
247 VectorTransformsOptions().setVectorTransposeLowering(
248 VectorTransposeLowering::EltWise));
249 }
250 if (lowerToFlatTranspose) {
251 options = options.setVectorTransformsOptions(
252 VectorTransformsOptions().setVectorTransposeLowering(
253 VectorTransposeLowering::Flat));
254 }
255 if (lowerToShuffleTranspose) {
256 options = options.setVectorTransformsOptions(
257 VectorTransformsOptions().setVectorTransposeLowering(
258 VectorTransposeLowering::Shuffle));
259 }
260 if (lowerToAvx2) {
261 options = options.enableAVX2Lowering().setAVX2LoweringOptions(
262 x86vector::avx2::LoweringOptions().setTransposeOptions(
263 x86vector::avx2::TransposeLoweringOptions()
264 .lower4x8xf32()
265 .lower8x8xf32()));
266 }
267
268 OpPassManager dynamicPM("func.func");
269 dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
270 if (failed(runPipeline(dynamicPM, getOperation())))
271 return signalPassFailure();
272 }
273 };
274
275 struct TestVectorUnrollingPatterns
276 : public PassWrapper<TestVectorUnrollingPatterns,
277 OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorUnrollingPatterns278 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
279
280 StringRef getArgument() const final {
281 return "test-vector-unrolling-patterns";
282 }
getDescription__anonb56dea510111::TestVectorUnrollingPatterns283 StringRef getDescription() const final {
284 return "Test lowering patterns to unroll contract ops in the vector "
285 "dialect";
286 }
287 TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns__anonb56dea510111::TestVectorUnrollingPatterns288 TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
289 : PassWrapper(pass) {}
runOnOperation__anonb56dea510111::TestVectorUnrollingPatterns290 void runOnOperation() override {
291 MLIRContext *ctx = &getContext();
292 RewritePatternSet patterns(ctx);
293 populateVectorUnrollPatterns(
294 patterns, UnrollVectorOptions()
295 .setNativeShape(ArrayRef<int64_t>{2, 2})
296 .setFilterConstraint([](Operation *op) {
297 return success(isa<arith::AddFOp, vector::FMAOp,
298 vector::MultiDimReductionOp>(op));
299 }));
300 populateVectorUnrollPatterns(
301 patterns, UnrollVectorOptions()
302 .setNativeShape(ArrayRef<int64_t>{2})
303 .setFilterConstraint([](Operation *op) {
304 return success(isa<vector::ReductionOp>(op));
305 }));
306 populateVectorUnrollPatterns(
307 patterns, UnrollVectorOptions()
308 .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
309 .setFilterConstraint([](Operation *op) {
310 return success(isa<vector::TransposeOp>(op));
311 }));
312
313 if (unrollBasedOnType) {
314 UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
315 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
316 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
317 SmallVector<int64_t, 4> nativeShape(
318 contractOp.getIteratorTypes().size(), 4);
319 Type lhsType = contractOp.getLhsType().getElementType();
320 nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
321 return nativeShape;
322 };
323
324 UnrollVectorOptions opts;
325 opts.setNativeShapeFn(nativeShapeFn)
326 .setFilterConstraint(
327 [](Operation *op) { return success(isa<ContractionOp>(op)); });
328
329 if (!unrollOrder.empty()) {
330 opts.setUnrollTraversalOrderFn([this](Operation *op)
331 -> Optional<SmallVector<int64_t>> {
332 vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
333 if (contractOp.getIteratorTypes().size() == unrollOrder.size())
334 return SmallVector<int64_t>(unrollOrder.begin(), unrollOrder.end());
335 return None;
336 });
337 }
338 populateVectorUnrollPatterns(patterns, opts);
339 } else {
340 auto nativeShapeFn =
341 [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
342 auto contractOp = dyn_cast<ContractionOp>(op);
343 if (!contractOp)
344 return None;
345 return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
346 };
347 populateVectorUnrollPatterns(patterns,
348 UnrollVectorOptions()
349 .setNativeShapeFn(nativeShapeFn)
350 .setFilterConstraint([](Operation *op) {
351 return success(isa<ContractionOp>(op));
352 }));
353 }
354 populateVectorToVectorCanonicalizationPatterns(patterns);
355 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
356 }
357
358 ListOption<int64_t> unrollOrder{*this, "unroll-order",
359 llvm::cl::desc("set the unroll order")};
360
361 Option<bool> unrollBasedOnType{
362 *this, "unroll-based-on-type",
363 llvm::cl::desc("Set the unroll factor based on type of the operation"),
364 llvm::cl::init(false)};
365 };
366
367 struct TestVectorDistributePatterns
368 : public PassWrapper<TestVectorDistributePatterns,
369 OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorDistributePatterns370 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistributePatterns)
371
372 StringRef getArgument() const final {
373 return "test-vector-distribute-patterns";
374 }
getDescription__anonb56dea510111::TestVectorDistributePatterns375 StringRef getDescription() const final {
376 return "Test lowering patterns to distribute vector ops in the vector "
377 "dialect";
378 }
379 TestVectorDistributePatterns() = default;
TestVectorDistributePatterns__anonb56dea510111::TestVectorDistributePatterns380 TestVectorDistributePatterns(const TestVectorDistributePatterns &pass)
381 : PassWrapper(pass) {}
getDependentDialects__anonb56dea510111::TestVectorDistributePatterns382 void getDependentDialects(DialectRegistry ®istry) const override {
383 registry.insert<VectorDialect>();
384 registry.insert<AffineDialect>();
385 }
386 ListOption<int32_t> multiplicity{
387 *this, "distribution-multiplicity",
388 llvm::cl::desc("Set the multiplicity used for distributing vector")};
389
runOnOperation__anonb56dea510111::TestVectorDistributePatterns390 void runOnOperation() override {
391 MLIRContext *ctx = &getContext();
392 RewritePatternSet patterns(ctx);
393 func::FuncOp func = getOperation();
394 func.walk([&](arith::AddFOp op) {
395 OpBuilder builder(op);
396 if (auto vecType = op.getType().dyn_cast<VectorType>()) {
397 SmallVector<int64_t, 2> mul;
398 SmallVector<AffineExpr, 2> perm;
399 SmallVector<Value, 2> ids;
400 unsigned count = 0;
401 // Remove the multiplicity of 1 and calculate the affine map based on
402 // the multiplicity.
403 SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
404 for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
405 if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
406 mul.push_back(m[i]);
407 ids.push_back(func.getArgument(count++));
408 perm.push_back(getAffineDimExpr(i, ctx));
409 }
410 }
411 auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
412 perm, ctx);
413 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
414 builder, op.getOperation(), ids, mul, map);
415 if (ops) {
416 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
417 op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
418 extractOp);
419 }
420 }
421 });
422 populatePropagateVectorDistributionPatterns(patterns);
423 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
424 }
425 };
426
427 struct TestVectorToLoopPatterns
428 : public PassWrapper<TestVectorToLoopPatterns,
429 OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorToLoopPatterns430 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToLoopPatterns)
431
432 StringRef getArgument() const final { return "test-vector-to-forloop"; }
getDescription__anonb56dea510111::TestVectorToLoopPatterns433 StringRef getDescription() const final {
434 return "Test lowering patterns to break up a vector op into a for loop";
435 }
436 TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns__anonb56dea510111::TestVectorToLoopPatterns437 TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass)
438 : PassWrapper(pass) {}
getDependentDialects__anonb56dea510111::TestVectorToLoopPatterns439 void getDependentDialects(DialectRegistry ®istry) const override {
440 registry.insert<VectorDialect>();
441 registry.insert<AffineDialect>();
442 }
443 Option<int32_t> multiplicity{
444 *this, "distribution-multiplicity",
445 llvm::cl::desc("Set the multiplicity used for distributing vector"),
446 llvm::cl::init(32)};
runOnOperation__anonb56dea510111::TestVectorToLoopPatterns447 void runOnOperation() override {
448 MLIRContext *ctx = &getContext();
449 RewritePatternSet patterns(ctx);
450 func::FuncOp func = getOperation();
451 func.walk([&](arith::AddFOp op) {
452 // Check that the operation type can be broken down into a loop.
453 VectorType type = op.getType().dyn_cast<VectorType>();
454 if (!type || type.getRank() != 1 ||
455 type.getNumElements() % multiplicity != 0)
456 return mlir::WalkResult::advance();
457 auto filterAlloc = [](Operation *op) {
458 return !isa<arith::ConstantOp, memref::AllocOp, func::CallOp>(op);
459 };
460 auto dependentOps = getSlice(op, filterAlloc);
461 // Create a loop and move instructions from the Op slice into the loop.
462 OpBuilder builder(op);
463 auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
464 auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
465 auto numIter =
466 builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
467 auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
468 for (Operation *it : dependentOps) {
469 it->moveBefore(forOp.getBody()->getTerminator());
470 }
471 auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
472 // break up the original op and let the patterns propagate.
473 Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
474 builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
475 map);
476 if (ops) {
477 SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
478 op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
479 }
480 return mlir::WalkResult::interrupt();
481 });
482 populatePropagateVectorDistributionPatterns(patterns);
483 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
484 }
485 };
486
487 struct TestVectorTransferUnrollingPatterns
488 : public PassWrapper<TestVectorTransferUnrollingPatterns,
489 OperationPass<func::FuncOp>> {
490 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
491 TestVectorTransferUnrollingPatterns)
492
493 TestVectorTransferUnrollingPatterns() = default;
TestVectorTransferUnrollingPatterns__anonb56dea510111::TestVectorTransferUnrollingPatterns494 TestVectorTransferUnrollingPatterns(
495 const TestVectorTransferUnrollingPatterns &pass)
496 : PassWrapper(pass) {}
497
getDependentDialects__anonb56dea510111::TestVectorTransferUnrollingPatterns498 void getDependentDialects(DialectRegistry ®istry) const override {
499 registry.insert<AffineDialect>();
500 }
getArgument__anonb56dea510111::TestVectorTransferUnrollingPatterns501 StringRef getArgument() const final {
502 return "test-vector-transfer-unrolling-patterns";
503 }
getDescription__anonb56dea510111::TestVectorTransferUnrollingPatterns504 StringRef getDescription() const final {
505 return "Test lowering patterns to unroll transfer ops in the vector "
506 "dialect";
507 }
runOnOperation__anonb56dea510111::TestVectorTransferUnrollingPatterns508 void runOnOperation() override {
509 MLIRContext *ctx = &getContext();
510 RewritePatternSet patterns(ctx);
511 UnrollVectorOptions opts;
512 opts.setNativeShape(ArrayRef<int64_t>{2, 2})
513 .setFilterConstraint([](Operation *op) {
514 return success(
515 isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
516 });
517 if (reverseUnrollOrder.getValue()) {
518 opts.setUnrollTraversalOrderFn(
519 [](Operation *op) -> Optional<SmallVector<int64_t>> {
520 int64_t numLoops = 0;
521 if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
522 numLoops = readOp.getVectorType().getRank();
523 else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
524 numLoops = writeOp.getVectorType().getRank();
525 else
526 return None;
527 auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
528 return llvm::to_vector(order);
529 });
530 }
531 populateVectorUnrollPatterns(patterns, opts);
532 populateVectorToVectorCanonicalizationPatterns(patterns);
533 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
534 }
535
536 Option<bool> reverseUnrollOrder{
537 *this, "reverse-unroll-order",
538 llvm::cl::desc(
539 "reverse the order of unrolling of vector transfer operations"),
540 llvm::cl::init(false)};
541 };
542
543 struct TestVectorTransferFullPartialSplitPatterns
544 : public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
545 OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns546 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
547 TestVectorTransferFullPartialSplitPatterns)
548
549 StringRef getArgument() const final {
550 return "test-vector-transfer-full-partial-split";
551 }
getDescription__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns552 StringRef getDescription() const final {
553 return "Test lowering patterns to split "
554 "transfer ops via scf.if + linalg ops";
555 }
556 TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns557 TestVectorTransferFullPartialSplitPatterns(
558 const TestVectorTransferFullPartialSplitPatterns &pass)
559 : PassWrapper(pass) {}
560
getDependentDialects__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns561 void getDependentDialects(DialectRegistry ®istry) const override {
562 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
563 scf::SCFDialect>();
564 }
565
566 Option<bool> useLinalgOps{
567 *this, "use-memref-copy",
568 llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
569 "memref.copy operations."),
570 llvm::cl::init(false)};
runOnOperation__anonb56dea510111::TestVectorTransferFullPartialSplitPatterns571 void runOnOperation() override {
572 MLIRContext *ctx = &getContext();
573 RewritePatternSet patterns(ctx);
574 VectorTransformsOptions options;
575 if (useLinalgOps)
576 options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
577 else
578 options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
579 patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
580 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
581 }
582 };
583
584 struct TestVectorTransferOpt
585 : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferOpt586 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
587
588 StringRef getArgument() const final { return "test-vector-transferop-opt"; }
getDescription__anonb56dea510111::TestVectorTransferOpt589 StringRef getDescription() const final {
590 return "Test optimization transformations for transfer ops";
591 }
runOnOperation__anonb56dea510111::TestVectorTransferOpt592 void runOnOperation() override { transferOpflowOpt(getOperation()); }
593 };
594
595 struct TestVectorTransferLoweringPatterns
596 : public PassWrapper<TestVectorTransferLoweringPatterns,
597 OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferLoweringPatterns598 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
599 TestVectorTransferLoweringPatterns)
600
601 void getDependentDialects(DialectRegistry ®istry) const override {
602 registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
603 }
getArgument__anonb56dea510111::TestVectorTransferLoweringPatterns604 StringRef getArgument() const final {
605 return "test-vector-transfer-lowering-patterns";
606 }
getDescription__anonb56dea510111::TestVectorTransferLoweringPatterns607 StringRef getDescription() const final {
608 return "Test lowering patterns to lower transfer ops to other vector ops";
609 }
runOnOperation__anonb56dea510111::TestVectorTransferLoweringPatterns610 void runOnOperation() override {
611 RewritePatternSet patterns(&getContext());
612 populateVectorTransferLoweringPatterns(patterns);
613 populateVectorTransferPermutationMapLoweringPatterns(patterns);
614 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
615 }
616 };
617
618 struct TestVectorMultiReductionLoweringPatterns
619 : public PassWrapper<TestVectorMultiReductionLoweringPatterns,
620 OperationPass<func::FuncOp>> {
621 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
622 TestVectorMultiReductionLoweringPatterns)
623
624 TestVectorMultiReductionLoweringPatterns() = default;
TestVectorMultiReductionLoweringPatterns__anonb56dea510111::TestVectorMultiReductionLoweringPatterns625 TestVectorMultiReductionLoweringPatterns(
626 const TestVectorMultiReductionLoweringPatterns &pass)
627 : PassWrapper(pass) {}
getDependentDialects__anonb56dea510111::TestVectorMultiReductionLoweringPatterns628 void getDependentDialects(DialectRegistry ®istry) const override {
629 registry.insert<memref::MemRefDialect>();
630 }
getArgument__anonb56dea510111::TestVectorMultiReductionLoweringPatterns631 StringRef getArgument() const final {
632 return "test-vector-multi-reduction-lowering-patterns";
633 }
getDescription__anonb56dea510111::TestVectorMultiReductionLoweringPatterns634 StringRef getDescription() const final {
635 return "Test lowering patterns to lower vector.multi_reduction to other "
636 "vector ops";
637 }
638 Option<bool> useOuterReductions{
639 *this, "use-outer-reductions",
640 llvm::cl::desc("Move reductions to outer most dimensions"),
641 llvm::cl::init(false)};
runOnOperation__anonb56dea510111::TestVectorMultiReductionLoweringPatterns642 void runOnOperation() override {
643 RewritePatternSet patterns(&getContext());
644 populateVectorMultiReductionLoweringPatterns(
645 patterns, useOuterReductions
646 ? vector::VectorMultiReductionLowering::InnerParallel
647 : vector::VectorMultiReductionLowering::InnerReduction);
648 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
649 }
650 };
651
652 struct TestVectorTransferCollapseInnerMostContiguousDims
653 : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
654 OperationPass<func::FuncOp>> {
655 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
656 TestVectorTransferCollapseInnerMostContiguousDims)
657
658 TestVectorTransferCollapseInnerMostContiguousDims() = default;
659 TestVectorTransferCollapseInnerMostContiguousDims(
660 const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
661
getDependentDialects__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims662 void getDependentDialects(DialectRegistry ®istry) const override {
663 registry.insert<memref::MemRefDialect, AffineDialect>();
664 }
665
getArgument__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims666 StringRef getArgument() const final {
667 return "test-vector-transfer-collapse-inner-most-dims";
668 }
669
getDescription__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims670 StringRef getDescription() const final {
671 return "Test lowering patterns that reducedes the rank of the vector "
672 "transfer memory and vector operands.";
673 }
674
runOnOperation__anonb56dea510111::TestVectorTransferCollapseInnerMostContiguousDims675 void runOnOperation() override {
676 RewritePatternSet patterns(&getContext());
677 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
678 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
679 }
680 };
681
682 struct TestVectorReduceToContractPatternsPatterns
683 : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
684 OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorReduceToContractPatternsPatterns685 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
686 TestVectorReduceToContractPatternsPatterns)
687
688 StringRef getArgument() const final {
689 return "test-vector-reduction-to-contract-patterns";
690 }
getDescription__anonb56dea510111::TestVectorReduceToContractPatternsPatterns691 StringRef getDescription() const final {
692 return "Test patterns to convert multireduce op to contract and combine "
693 "broadcast/transpose to contract";
694 }
runOnOperation__anonb56dea510111::TestVectorReduceToContractPatternsPatterns695 void runOnOperation() override {
696 RewritePatternSet patterns(&getContext());
697 populateVectorReductionToContractPatterns(patterns);
698 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
699 }
700 };
701
702 struct TestVectorTransferDropUnitDimsPatterns
703 : public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
704 OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorTransferDropUnitDimsPatterns705 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
706 TestVectorTransferDropUnitDimsPatterns)
707
708 StringRef getArgument() const final {
709 return "test-vector-transfer-drop-unit-dims-patterns";
710 }
getDependentDialects__anonb56dea510111::TestVectorTransferDropUnitDimsPatterns711 void getDependentDialects(DialectRegistry ®istry) const override {
712 registry.insert<memref::MemRefDialect>();
713 }
runOnOperation__anonb56dea510111::TestVectorTransferDropUnitDimsPatterns714 void runOnOperation() override {
715 RewritePatternSet patterns(&getContext());
716 populateVectorTransferDropUnitDimsPatterns(patterns);
717 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
718 }
719 };
720
721 struct TestFlattenVectorTransferPatterns
722 : public PassWrapper<TestFlattenVectorTransferPatterns,
723 OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestFlattenVectorTransferPatterns724 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
725 TestFlattenVectorTransferPatterns)
726
727 StringRef getArgument() const final {
728 return "test-vector-transfer-flatten-patterns";
729 }
getDescription__anonb56dea510111::TestFlattenVectorTransferPatterns730 StringRef getDescription() const final {
731 return "Test patterns to rewrite contiguous row-major N-dimensional "
732 "vector.transfer_{read,write} ops into 1D transfers";
733 }
getDependentDialects__anonb56dea510111::TestFlattenVectorTransferPatterns734 void getDependentDialects(DialectRegistry ®istry) const override {
735 registry.insert<memref::MemRefDialect>();
736 }
runOnOperation__anonb56dea510111::TestFlattenVectorTransferPatterns737 void runOnOperation() override {
738 RewritePatternSet patterns(&getContext());
739 populateFlattenVectorTransferPatterns(patterns);
740 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
741 }
742 };
743
744 struct TestVectorScanLowering
745 : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorScanLowering746 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
747
748 StringRef getArgument() const final { return "test-vector-scan-lowering"; }
getDescription__anonb56dea510111::TestVectorScanLowering749 StringRef getDescription() const final {
750 return "Test lowering patterns that lower the scan op in the vector "
751 "dialect";
752 }
runOnOperation__anonb56dea510111::TestVectorScanLowering753 void runOnOperation() override {
754 RewritePatternSet patterns(&getContext());
755 populateVectorScanLoweringPatterns(patterns);
756 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
757 }
758 };
759
760 /// Allocate shared memory for a single warp to test lowering of
761 /// WarpExecuteOnLane0Op.
allocateGlobalSharedMemory(Location loc,OpBuilder & builder,WarpExecuteOnLane0Op warpOp,Type type)762 static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
763 WarpExecuteOnLane0Op warpOp,
764 Type type) {
765 static constexpr int64_t kSharedMemorySpace = 3;
766 // Compute type of shared memory buffer.
767 MemRefType memrefType;
768 if (auto vectorType = type.dyn_cast<VectorType>()) {
769 memrefType =
770 MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
771 kSharedMemorySpace);
772 } else {
773 memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
774 }
775
776 // Get symbol table holding all shared memory globals.
777 ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
778 SymbolTable symbolTable(moduleOp);
779
780 // Create a pretty name.
781 SmallString<64> buf;
782 llvm::raw_svector_ostream os(buf);
783 interleave(memrefType.getShape(), os, "x");
784 os << "x" << memrefType.getElementType();
785 std::string symbolName = (Twine("__shared_") + os.str()).str();
786
787 auto ip = builder.saveInsertionPoint();
788 builder.setInsertionPoint(moduleOp);
789 auto global = builder.create<memref::GlobalOp>(
790 loc,
791 /*sym_name=*/symbolName,
792 /*sym_visibility=*/builder.getStringAttr("private"),
793 /*type=*/memrefType,
794 /*initial_value=*/Attribute(),
795 /*constant=*/false,
796 /*alignment=*/IntegerAttr());
797 symbolTable.insert(global);
798 // The symbol table inserts at the end of the module, but globals are a bit
799 // nicer if they are at the beginning.
800 global->moveBefore(&moduleOp.front());
801
802 builder.restoreInsertionPoint(ip);
803 return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
804 }
805
warpReduction(Location loc,OpBuilder & builder,Value input,CombiningKind kind,uint32_t size)806 static Value warpReduction(Location loc, OpBuilder &builder, Value input,
807 CombiningKind kind, uint32_t size) {
808 Value laneVal = input;
809 // Parallel reduction using butterfly shuffles.
810 for (uint64_t i = 1; i < size; i <<= 1) {
811 Value shuffled = builder
812 .create<gpu::ShuffleOp>(loc, laneVal, i,
813 /*width=*/size,
814 /*mode=*/gpu::ShuffleMode::XOR)
815 .result();
816 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
817 }
818 return laneVal;
819 }
820
821 struct TestVectorDistribution
822 : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonb56dea510111::TestVectorDistribution823 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
824
825 void getDependentDialects(DialectRegistry ®istry) const override {
826 registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
827 AffineDialect>();
828 }
829
getArgument__anonb56dea510111::TestVectorDistribution830 StringRef getArgument() const final { return "test-vector-warp-distribute"; }
getDescription__anonb56dea510111::TestVectorDistribution831 StringRef getDescription() const final {
832 return "Test vector warp distribute transformation and lowering patterns";
833 }
834 TestVectorDistribution() = default;
TestVectorDistribution__anonb56dea510111::TestVectorDistribution835 TestVectorDistribution(const TestVectorDistribution &pass)
836 : PassWrapper(pass) {}
837
838 Option<bool> warpOpToSCF{
839 *this, "rewrite-warp-ops-to-scf-if",
840 llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
841 llvm::cl::init(false)};
842
843 Option<bool> distributeTransferWriteOps{
844 *this, "distribute-transfer-write",
845 llvm::cl::desc("Test distribution of transfer write"),
846 llvm::cl::init(false)};
847
848 Option<bool> hoistUniform{*this, "hoist-uniform",
849 llvm::cl::desc("Test hoist uniform"),
850 llvm::cl::init(false)};
851
852 Option<bool> propagateDistribution{
853 *this, "propagate-distribution",
854 llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
855
runOnOperation__anonb56dea510111::TestVectorDistribution856 void runOnOperation() override {
857 RewritePatternSet patterns(&getContext());
858
859 getOperation().walk([&](Operation *op) {
860 if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
861 if (hoistUniform) {
862 moveScalarUniformCode(warpOp);
863 }
864 WalkResult::interrupt();
865 }
866 });
867 MLIRContext *ctx = &getContext();
868 if (distributeTransferWriteOps) {
869 auto distributionFn = [](vector::TransferWriteOp writeOp) {
870 // Create a map (d0, d1) -> (d1) to distribute along the inner
871 // dimension. Once we support n-d distribution we can add more
872 // complex cases.
873 int64_t vecRank = writeOp.getVectorType().getRank();
874 OpBuilder builder(writeOp.getContext());
875 auto map =
876 AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
877 return map;
878 };
879 RewritePatternSet patterns(ctx);
880 populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
881 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
882 }
883 if (propagateDistribution) {
884 RewritePatternSet patterns(ctx);
885 vector::populatePropagateWarpVectorDistributionPatterns(patterns);
886 vector::populateDistributeReduction(patterns, warpReduction);
887 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
888 }
889 WarpExecuteOnLane0LoweringOptions options;
890 options.warpAllocationFn = allocateGlobalSharedMemory;
891 options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
892 WarpExecuteOnLane0Op warpOp) {
893 builder.create<gpu::BarrierOp>(loc);
894 };
895 // Test on one pattern in isolation.
896 if (warpOpToSCF) {
897 populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
898 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
899 return;
900 }
901 }
902 };
903
904 } // namespace
905
906 namespace mlir {
907 namespace test {
registerTestVectorLowerings()908 void registerTestVectorLowerings() {
909 PassRegistration<TestVectorToVectorLowering>();
910
911 PassRegistration<TestVectorContractionLowering>();
912
913 PassRegistration<TestVectorTransposeLowering>();
914
915 PassRegistration<TestVectorUnrollingPatterns>();
916
917 PassRegistration<TestVectorTransferUnrollingPatterns>();
918
919 PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
920
921 PassRegistration<TestVectorDistributePatterns>();
922
923 PassRegistration<TestVectorToLoopPatterns>();
924
925 PassRegistration<TestVectorTransferOpt>();
926
927 PassRegistration<TestVectorTransferLoweringPatterns>();
928
929 PassRegistration<TestVectorMultiReductionLoweringPatterns>();
930
931 PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
932
933 PassRegistration<TestVectorReduceToContractPatternsPatterns>();
934
935 PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
936
937 PassRegistration<TestFlattenVectorTransferPatterns>();
938
939 PassRegistration<TestVectorScanLowering>();
940
941 PassRegistration<TestVectorDistribution>();
942 }
943 } // namespace test
944 } // namespace mlir
945