1 //===- TestTraits.cpp - Test trait folding --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
12 
13 using namespace mlir;
14 
15 //===----------------------------------------------------------------------===//
16 // Trait Folder.
17 //===----------------------------------------------------------------------===//
18 
19 OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold(
20     ArrayRef<Attribute> operands) {
21   // This failure should cause the trait fold to run instead.
22   return {};
23 }
24 
25 OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
26     ArrayRef<Attribute> operands) {
27   auto argumentOp = getOperand();
28   // The success case should cause the trait fold to be supressed.
29   return argumentOp.getDefiningOp() ? argumentOp : OpFoldResult{};
30 }
31 
32 namespace {
33 struct TestTraitFolder : public PassWrapper<TestTraitFolder, FunctionPass> {
34   void runOnFunction() override {
35     applyPatternsAndFoldGreedily(getFunction(), OwningRewritePatternList());
36   }
37 };
38 } // end anonymous namespace
39 
40 namespace mlir {
41 void registerTestTraitsPass() {
42   PassRegistration<TestTraitFolder>("test-trait-folder", "Run trait folding");
43 }
44 } // namespace mlir
45