1 //===- TestTransformDialectExtension.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines an extension of the MLIR Transform dialect for testing
10 // purposes.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestTransformDialectExtension.h"
15 #include "mlir/Dialect/PDL/IR/PDL.h"
16 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
17 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/OpImplementation.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 /// Simple transform op defined outside of the dialect. Just emits a remark when
25 /// applied. This op is defined in C++ to test that C++ definitions also work
26 /// for op injection into the Transform dialect.
27 class TestTransformOp
28     : public Op<TestTransformOp, transform::TransformOpInterface::Trait> {
29 public:
30   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
31 
32   using Op::Op;
33 
34   static ArrayRef<StringRef> getAttributeNames() { return {}; }
35 
36   static constexpr llvm::StringLiteral getOperationName() {
37     return llvm::StringLiteral("transform.test_transform_op");
38   }
39 
40   LogicalResult apply(transform::TransformResults &results,
41                       transform::TransformState &state) {
42     InFlightDiagnostic remark = emitRemark() << "applying transformation";
43     if (Attribute message = getMessage())
44       remark << " " << message;
45 
46     return success();
47   }
48 
49   Attribute getMessage() { return getOperation()->getAttr("message"); }
50 
51   static ParseResult parse(OpAsmParser &parser, OperationState &state) {
52     StringAttr message;
53     OptionalParseResult result = parser.parseOptionalAttribute(message);
54     if (!result.hasValue())
55       return success();
56 
57     if (result.getValue().succeeded())
58       state.addAttribute("message", message);
59     return result.getValue();
60   }
61 
62   void print(OpAsmPrinter &printer) {
63     if (getMessage())
64       printer << " " << getMessage();
65   }
66 };
67 
68 /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
69 /// in cases where it is attached to ops that do not comply with the trait
70 /// requirements. This op cannot be defined in ODS because ODS generates strict
71 /// verifiers that overalp with those in the trait and run earlier.
72 class TestTransformUnrestrictedOpNoInterface
73     : public Op<TestTransformUnrestrictedOpNoInterface,
74                 transform::PossibleTopLevelTransformOpTrait,
75                 transform::TransformOpInterface::Trait> {
76 public:
77   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
78       TestTransformUnrestrictedOpNoInterface)
79 
80   using Op::Op;
81 
82   static ArrayRef<StringRef> getAttributeNames() { return {}; }
83 
84   static constexpr llvm::StringLiteral getOperationName() {
85     return llvm::StringLiteral(
86         "transform.test_transform_unrestricted_op_no_interface");
87   }
88 
89   LogicalResult apply(transform::TransformResults &results,
90                       transform::TransformState &state) {
91     return success();
92   }
93 };
94 } // namespace
95 
96 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
97     transform::TransformResults &results, transform::TransformState &state) {
98   if (getOperation()->getNumOperands() != 0) {
99     results.set(getResult().cast<OpResult>(),
100                 getOperation()->getOperand(0).getDefiningOp());
101   } else {
102     results.set(getResult().cast<OpResult>(),
103                 reinterpret_cast<Operation *>(*getParameter()));
104   }
105   return success();
106 }
107 
108 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
109   if (getParameter().hasValue() ^ (getNumOperands() != 1))
110     return emitOpError() << "expects either a parameter or an operand";
111   return success();
112 }
113 
114 LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
115     transform::TransformResults &results, transform::TransformState &state) {
116   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
117   assert(payload.size() == 1 && "expected a single target op");
118   auto value = reinterpret_cast<intptr_t>(payload[0]);
119   if (static_cast<uint64_t>(value) != getParameter()) {
120     return emitOpError() << "expected the operand to be associated with "
121                          << getParameter() << " got " << value;
122   }
123 
124   emitRemark() << "succeeded";
125   return success();
126 }
127 
128 LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply(
129     transform::TransformResults &results, transform::TransformState &state) {
130   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
131   for (Operation *op : payload)
132     op->emitRemark() << getMessage();
133 
134   return success();
135 }
136 
137 namespace {
138 /// Test extension of the Transform dialect. Registers additional ops and
139 /// declares PDL as dependent dialect since the additional ops are using PDL
140 /// types for operands and results.
141 class TestTransformDialectExtension
142     : public transform::TransformDialectExtension<
143           TestTransformDialectExtension> {
144 public:
145   TestTransformDialectExtension() {
146     declareDependentDialect<pdl::PDLDialect>();
147     registerTransformOps<TestTransformOp,
148                          TestTransformUnrestrictedOpNoInterface,
149 #define GET_OP_LIST
150 #include "TestTransformDialectExtension.cpp.inc"
151                          >();
152   }
153 };
154 } // namespace
155 
156 #define GET_OP_CLASSES
157 #include "TestTransformDialectExtension.cpp.inc"
158 
159 void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
160   registry.addExtensions<TestTransformDialectExtension>();
161 }
162