1 //===- TestTransformDialectExtension.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines an extension of the MLIR Transform dialect for testing
10 // purposes.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestTransformDialectExtension.h"
15 #include "TestTransformStateExtension.h"
16 #include "mlir/Dialect/PDL/IR/PDL.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
19 #include "mlir/IR/OpImplementation.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 /// Simple transform op defined outside of the dialect. Just emits a remark when
25 /// applied. This op is defined in C++ to test that C++ definitions also work
26 /// for op injection into the Transform dialect.
27 class TestTransformOp
28     : public Op<TestTransformOp, transform::TransformOpInterface::Trait,
29                 MemoryEffectOpInterface::Trait> {
30 public:
31   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
32 
33   using Op::Op;
34 
35   static ArrayRef<StringRef> getAttributeNames() { return {}; }
36 
37   static constexpr llvm::StringLiteral getOperationName() {
38     return llvm::StringLiteral("transform.test_transform_op");
39   }
40 
41   LogicalResult apply(transform::TransformResults &results,
42                       transform::TransformState &state) {
43     InFlightDiagnostic remark = emitRemark() << "applying transformation";
44     if (Attribute message = getMessage())
45       remark << " " << message;
46 
47     return success();
48   }
49 
50   Attribute getMessage() { return getOperation()->getAttr("message"); }
51 
52   static ParseResult parse(OpAsmParser &parser, OperationState &state) {
53     StringAttr message;
54     OptionalParseResult result = parser.parseOptionalAttribute(message);
55     if (!result.hasValue())
56       return success();
57 
58     if (result.getValue().succeeded())
59       state.addAttribute("message", message);
60     return result.getValue();
61   }
62 
63   void print(OpAsmPrinter &printer) {
64     if (getMessage())
65       printer << " " << getMessage();
66   }
67 
68   // No side effects.
69   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
70 };
71 
72 /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
73 /// in cases where it is attached to ops that do not comply with the trait
74 /// requirements. This op cannot be defined in ODS because ODS generates strict
75 /// verifiers that overalp with those in the trait and run earlier.
76 class TestTransformUnrestrictedOpNoInterface
77     : public Op<TestTransformUnrestrictedOpNoInterface,
78                 transform::PossibleTopLevelTransformOpTrait,
79                 transform::TransformOpInterface::Trait,
80                 MemoryEffectOpInterface::Trait> {
81 public:
82   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
83       TestTransformUnrestrictedOpNoInterface)
84 
85   using Op::Op;
86 
87   static ArrayRef<StringRef> getAttributeNames() { return {}; }
88 
89   static constexpr llvm::StringLiteral getOperationName() {
90     return llvm::StringLiteral(
91         "transform.test_transform_unrestricted_op_no_interface");
92   }
93 
94   LogicalResult apply(transform::TransformResults &results,
95                       transform::TransformState &state) {
96     return success();
97   }
98 
99   // No side effects.
100   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
101 };
102 } // namespace
103 
104 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
105     transform::TransformResults &results, transform::TransformState &state) {
106   if (getOperation()->getNumOperands() != 0) {
107     results.set(getResult().cast<OpResult>(),
108                 getOperation()->getOperand(0).getDefiningOp());
109   } else {
110     results.set(getResult().cast<OpResult>(),
111                 reinterpret_cast<Operation *>(*getParameter()));
112   }
113   return success();
114 }
115 
116 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
117   if (getParameter().hasValue() ^ (getNumOperands() != 1))
118     return emitOpError() << "expects either a parameter or an operand";
119   return success();
120 }
121 
122 LogicalResult
123 mlir::test::TestConsumeOperand::apply(transform::TransformResults &results,
124                                       transform::TransformState &state) {
125   return success();
126 }
127 
128 LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
129     transform::TransformResults &results, transform::TransformState &state) {
130   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
131   assert(payload.size() == 1 && "expected a single target op");
132   auto value = reinterpret_cast<intptr_t>(payload[0]);
133   if (static_cast<uint64_t>(value) != getParameter()) {
134     return emitOpError() << "expected the operand to be associated with "
135                          << getParameter() << " got " << value;
136   }
137 
138   emitRemark() << "succeeded";
139   return success();
140 }
141 
142 LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply(
143     transform::TransformResults &results, transform::TransformState &state) {
144   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
145   for (Operation *op : payload)
146     op->emitRemark() << getMessage();
147 
148   return success();
149 }
150 
151 LogicalResult
152 mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results,
153                                           transform::TransformState &state) {
154   state.addExtension<TestTransformStateExtension>(getMessageAttr());
155   return success();
156 }
157 
158 LogicalResult mlir::test::TestCheckIfTestExtensionPresentOp::apply(
159     transform::TransformResults &results, transform::TransformState &state) {
160   auto *extension = state.getExtension<TestTransformStateExtension>();
161   if (!extension) {
162     emitRemark() << "extension absent";
163     return success();
164   }
165 
166   InFlightDiagnostic diag = emitRemark()
167                             << "extension present, " << extension->getMessage();
168   for (Operation *payload : state.getPayloadOps(getOperand())) {
169     diag.attachNote(payload->getLoc()) << "associated payload op";
170     assert(state.getHandleForPayloadOp(payload) == getOperand() &&
171            "inconsistent mapping between transform IR handles and payload IR "
172            "operations");
173   }
174 
175   return success();
176 }
177 
178 LogicalResult mlir::test::TestRemapOperandPayloadToSelfOp::apply(
179     transform::TransformResults &results, transform::TransformState &state) {
180   auto *extension = state.getExtension<TestTransformStateExtension>();
181   if (!extension)
182     return emitError() << "TestTransformStateExtension missing";
183 
184   return extension->updateMapping(state.getPayloadOps(getOperand()).front(),
185                                   getOperation());
186 }
187 
188 LogicalResult mlir::test::TestRemoveTestExtensionOp::apply(
189     transform::TransformResults &results, transform::TransformState &state) {
190   state.removeExtension<TestTransformStateExtension>();
191   return success();
192 }
193 LogicalResult mlir::test::TestTransformOpWithRegions::apply(
194     transform::TransformResults &results, transform::TransformState &state) {
195   return success();
196 }
197 
198 void mlir::test::TestTransformOpWithRegions::getEffects(
199     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
200 
201 LogicalResult mlir::test::TestBranchingTransformOpTerminator::apply(
202     transform::TransformResults &results, transform::TransformState &state) {
203   return success();
204 }
205 
206 void mlir::test::TestBranchingTransformOpTerminator::getEffects(
207     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
208 
209 namespace {
210 /// Test extension of the Transform dialect. Registers additional ops and
211 /// declares PDL as dependent dialect since the additional ops are using PDL
212 /// types for operands and results.
213 class TestTransformDialectExtension
214     : public transform::TransformDialectExtension<
215           TestTransformDialectExtension> {
216 public:
217   TestTransformDialectExtension() {
218     declareDependentDialect<pdl::PDLDialect>();
219     registerTransformOps<TestTransformOp,
220                          TestTransformUnrestrictedOpNoInterface,
221 #define GET_OP_LIST
222 #include "TestTransformDialectExtension.cpp.inc"
223                          >();
224   }
225 };
226 } // namespace
227 
228 #define GET_OP_CLASSES
229 #include "TestTransformDialectExtension.cpp.inc"
230 
231 void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
232   registry.addExtensions<TestTransformDialectExtension>();
233 }
234