1 //===- TestTensorTransforms.cpp - Test Tensor transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Tensor transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 struct TestTensorTransforms
24     : public PassWrapper<TestTensorTransforms, OperationPass<>> {
25   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorTransforms)
26 
27   TestTensorTransforms() = default;
TestTensorTransforms__anondc46a2a10111::TestTensorTransforms28   TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
29 
getDependentDialects__anondc46a2a10111::TestTensorTransforms30   void getDependentDialects(DialectRegistry &registry) const override {
31     registry.insert<arith::ArithmeticDialect, scf::SCFDialect>();
32   }
33 
getArgument__anondc46a2a10111::TestTensorTransforms34   StringRef getArgument() const final {
35     return "test-tensor-transform-patterns";
36   }
getDescription__anondc46a2a10111::TestTensorTransforms37   StringRef getDescription() const final {
38     return "Test Tensor transformation patterns by applying them greedily.";
39   }
40 
41   void runOnOperation() override;
42 
43   Option<bool> testSplitPaddingPatterns{
44       *this, "test-split-padding-patterns",
45       llvm::cl::desc("Test patterns to split tensor.pad ops"),
46       llvm::cl::init(false)};
47 
48   Option<bool> testFoldConstantExtractSlice{
49       *this, "test-fold-constant-extract-slice",
50       llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
51       llvm::cl::init(false)};
52 };
53 } // namespace
54 
applySplitPaddingPatterns(Operation * rootOp)55 static void applySplitPaddingPatterns(Operation *rootOp) {
56   RewritePatternSet patterns(rootOp->getContext());
57   tensor::populateSplitPaddingPatterns(patterns);
58   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
59 }
60 
applyFoldConstantExtractSlicePatterns(Operation * rootOp)61 static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
62   RewritePatternSet patterns(rootOp->getContext());
63   tensor::ControlConstantExtractSliceFusionFn controlFn =
64       [](tensor::ExtractSliceOp op) {
65         if (!op.getSource().hasOneUse())
66           return false;
67 
68         auto resultType = op.getResult().getType().cast<ShapedType>();
69         constexpr int64_t kConstantFoldingMaxNumElements = 1024;
70         return resultType.getNumElements() <= kConstantFoldingMaxNumElements;
71       };
72 
73   tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
74   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
75 }
76 
runOnOperation()77 void TestTensorTransforms::runOnOperation() {
78   Operation *rootOp = getOperation();
79   if (testSplitPaddingPatterns)
80     applySplitPaddingPatterns(rootOp);
81   if (testFoldConstantExtractSlice)
82     applyFoldConstantExtractSlicePatterns(rootOp);
83 }
84 
85 namespace mlir {
86 namespace test {
registerTestTensorTransforms()87 void registerTestTensorTransforms() {
88   PassRegistration<TestTensorTransforms>();
89 }
90 } // namespace test
91 } // namespace mlir
92