1 //===- TestTransformDialectExtension.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines an extension of the MLIR Transform dialect for testing
10 // purposes.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestTransformDialectExtension.h"
15 #include "mlir/Dialect/PDL/IR/PDL.h"
16 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
17 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/OpImplementation.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 /// Simple transform op defined outside of the dialect. Just emits a remark when
25 /// applied.
26 class TestTransformOp
27     : public Op<TestTransformOp, transform::TransformOpInterface::Trait> {
28 public:
29   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
30 
31   using Op::Op;
32 
33   static ArrayRef<StringRef> getAttributeNames() { return {}; }
34 
35   static constexpr llvm::StringLiteral getOperationName() {
36     return llvm::StringLiteral("transform.test_transform_op");
37   }
38 
39   LogicalResult apply(transform::TransformResults &results,
40                       transform::TransformState &state) {
41     emitRemark() << "applying transformation";
42     return success();
43   }
44 
45   static ParseResult parse(OpAsmParser &parser, OperationState &state) {
46     return success();
47   }
48 
49   void print(OpAsmPrinter &printer) {}
50 };
51 } // namespace
52 
53 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
54     transform::TransformResults &results, transform::TransformState &state) {
55   if (getOperation()->getNumOperands() != 0) {
56     results.set(getResult().cast<OpResult>(), getOperand(0).getDefiningOp());
57   } else {
58     results.set(getResult().cast<OpResult>(),
59                 reinterpret_cast<Operation *>(*parameter()));
60   }
61   return success();
62 }
63 
64 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
65   if (parameter().hasValue() ^ (getNumOperands() != 1))
66     return emitOpError() << "expects either a parameter or an operand";
67   return success();
68 }
69 
70 LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
71     transform::TransformResults &results, transform::TransformState &state) {
72   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
73   assert(payload.size() == 1 && "expected a single target op");
74   auto value = reinterpret_cast<intptr_t>(payload[0]);
75   if (static_cast<uint64_t>(value) != parameter()) {
76     return emitOpError() << "expected the operand to be associated with "
77                          << parameter() << " got " << value;
78   }
79 
80   emitRemark() << "succeeded";
81   return success();
82 }
83 
84 namespace {
85 /// Test extension of the Transform dialect. Registers additional ops and
86 /// declares PDL as dependent dialect since the additional ops are using PDL
87 /// types for operands and results.
88 class TestTransformDialectExtension
89     : public transform::TransformDialectExtension<
90           TestTransformDialectExtension> {
91 public:
92   TestTransformDialectExtension() {
93     declareDependentDialect<pdl::PDLDialect>();
94     registerTransformOps<TestTransformOp,
95 #define GET_OP_LIST
96 #include "TestTransformDialectExtension.cpp.inc"
97                          >();
98   }
99 };
100 } // namespace
101 
102 #define GET_OP_CLASSES
103 #include "TestTransformDialectExtension.cpp.inc"
104 
105 void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
106   registry.addExtensions<TestTransformDialectExtension>();
107 }
108