1 //===- TestTensorTransforms.cpp - Test Tensor transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Tensor transformations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20 using namespace mlir;
21
22 namespace {
23 struct TestTensorTransforms
24 : public PassWrapper<TestTensorTransforms, OperationPass<>> {
25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorTransforms)
26
27 TestTensorTransforms() = default;
TestTensorTransforms__anondc46a2a10111::TestTensorTransforms28 TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
29
getDependentDialects__anondc46a2a10111::TestTensorTransforms30 void getDependentDialects(DialectRegistry ®istry) const override {
31 registry.insert<arith::ArithmeticDialect, scf::SCFDialect>();
32 }
33
getArgument__anondc46a2a10111::TestTensorTransforms34 StringRef getArgument() const final {
35 return "test-tensor-transform-patterns";
36 }
getDescription__anondc46a2a10111::TestTensorTransforms37 StringRef getDescription() const final {
38 return "Test Tensor transformation patterns by applying them greedily.";
39 }
40
41 void runOnOperation() override;
42
43 Option<bool> testSplitPaddingPatterns{
44 *this, "test-split-padding-patterns",
45 llvm::cl::desc("Test patterns to split tensor.pad ops"),
46 llvm::cl::init(false)};
47
48 Option<bool> testFoldConstantExtractSlice{
49 *this, "test-fold-constant-extract-slice",
50 llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
51 llvm::cl::init(false)};
52 };
53 } // namespace
54
applySplitPaddingPatterns(Operation * rootOp)55 static void applySplitPaddingPatterns(Operation *rootOp) {
56 RewritePatternSet patterns(rootOp->getContext());
57 tensor::populateSplitPaddingPatterns(patterns);
58 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
59 }
60
applyFoldConstantExtractSlicePatterns(Operation * rootOp)61 static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
62 RewritePatternSet patterns(rootOp->getContext());
63 tensor::ControlConstantExtractSliceFusionFn controlFn =
64 [](tensor::ExtractSliceOp op) {
65 if (!op.getSource().hasOneUse())
66 return false;
67
68 auto resultType = op.getResult().getType().cast<ShapedType>();
69 constexpr int64_t kConstantFoldingMaxNumElements = 1024;
70 return resultType.getNumElements() <= kConstantFoldingMaxNumElements;
71 };
72
73 tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
74 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
75 }
76
runOnOperation()77 void TestTensorTransforms::runOnOperation() {
78 Operation *rootOp = getOperation();
79 if (testSplitPaddingPatterns)
80 applySplitPaddingPatterns(rootOp);
81 if (testFoldConstantExtractSlice)
82 applyFoldConstantExtractSlicePatterns(rootOp);
83 }
84
85 namespace mlir {
86 namespace test {
registerTestTensorTransforms()87 void registerTestTensorTransforms() {
88 PassRegistration<TestTensorTransforms>();
89 }
90 } // namespace test
91 } // namespace mlir
92