1 //===- TestTransformDialectExtension.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines an extension of the MLIR Transform dialect for testing
10 // purposes.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "TestTransformDialectExtension.h"
15 #include "TestTransformStateExtension.h"
16 #include "mlir/Dialect/PDL/IR/PDL.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
19 #include "mlir/IR/OpImplementation.h"
20
21 using namespace mlir;
22
23 namespace {
24 /// Simple transform op defined outside of the dialect. Just emits a remark when
25 /// applied. This op is defined in C++ to test that C++ definitions also work
26 /// for op injection into the Transform dialect.
27 class TestTransformOp
28 : public Op<TestTransformOp, transform::TransformOpInterface::Trait,
29 MemoryEffectOpInterface::Trait> {
30 public:
31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
32
33 using Op::Op;
34
getAttributeNames()35 static ArrayRef<StringRef> getAttributeNames() { return {}; }
36
getOperationName()37 static constexpr llvm::StringLiteral getOperationName() {
38 return llvm::StringLiteral("transform.test_transform_op");
39 }
40
apply(transform::TransformResults & results,transform::TransformState & state)41 DiagnosedSilenceableFailure apply(transform::TransformResults &results,
42 transform::TransformState &state) {
43 InFlightDiagnostic remark = emitRemark() << "applying transformation";
44 if (Attribute message = getMessage())
45 remark << " " << message;
46
47 return DiagnosedSilenceableFailure::success();
48 }
49
getMessage()50 Attribute getMessage() { return getOperation()->getAttr("message"); }
51
parse(OpAsmParser & parser,OperationState & state)52 static ParseResult parse(OpAsmParser &parser, OperationState &state) {
53 StringAttr message;
54 OptionalParseResult result = parser.parseOptionalAttribute(message);
55 if (!result.hasValue())
56 return success();
57
58 if (result.getValue().succeeded())
59 state.addAttribute("message", message);
60 return result.getValue();
61 }
62
print(OpAsmPrinter & printer)63 void print(OpAsmPrinter &printer) {
64 if (getMessage())
65 printer << " " << getMessage();
66 }
67
68 // No side effects.
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)69 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
70 };
71
72 /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
73 /// in cases where it is attached to ops that do not comply with the trait
74 /// requirements. This op cannot be defined in ODS because ODS generates strict
75 /// verifiers that overalp with those in the trait and run earlier.
76 class TestTransformUnrestrictedOpNoInterface
77 : public Op<TestTransformUnrestrictedOpNoInterface,
78 transform::PossibleTopLevelTransformOpTrait,
79 transform::TransformOpInterface::Trait,
80 MemoryEffectOpInterface::Trait> {
81 public:
82 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
83 TestTransformUnrestrictedOpNoInterface)
84
85 using Op::Op;
86
getAttributeNames()87 static ArrayRef<StringRef> getAttributeNames() { return {}; }
88
getOperationName()89 static constexpr llvm::StringLiteral getOperationName() {
90 return llvm::StringLiteral(
91 "transform.test_transform_unrestricted_op_no_interface");
92 }
93
apply(transform::TransformResults & results,transform::TransformState & state)94 DiagnosedSilenceableFailure apply(transform::TransformResults &results,
95 transform::TransformState &state) {
96 return DiagnosedSilenceableFailure::success();
97 }
98
99 // No side effects.
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)100 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
101 };
102 } // namespace
103
104 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)105 mlir::test::TestProduceParamOrForwardOperandOp::apply(
106 transform::TransformResults &results, transform::TransformState &state) {
107 if (getOperation()->getNumOperands() != 0) {
108 results.set(getResult().cast<OpResult>(),
109 getOperation()->getOperand(0).getDefiningOp());
110 } else {
111 results.set(getResult().cast<OpResult>(),
112 reinterpret_cast<Operation *>(*getParameter()));
113 }
114 return DiagnosedSilenceableFailure::success();
115 }
116
verify()117 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
118 if (getParameter().has_value() ^ (getNumOperands() != 1))
119 return emitOpError() << "expects either a parameter or an operand";
120 return success();
121 }
122
123 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)124 mlir::test::TestConsumeOperand::apply(transform::TransformResults &results,
125 transform::TransformState &state) {
126 return DiagnosedSilenceableFailure::success();
127 }
128
129 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)130 mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
131 transform::TransformResults &results, transform::TransformState &state) {
132 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
133 assert(payload.size() == 1 && "expected a single target op");
134 auto value = reinterpret_cast<intptr_t>(payload[0]);
135 if (static_cast<uint64_t>(value) != getParameter()) {
136 return emitSilenceableError()
137 << "op expected the operand to be associated with " << getParameter()
138 << " got " << value;
139 }
140
141 emitRemark() << "succeeded";
142 return DiagnosedSilenceableFailure::success();
143 }
144
apply(transform::TransformResults & results,transform::TransformState & state)145 DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
146 transform::TransformResults &results, transform::TransformState &state) {
147 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
148 for (Operation *op : payload)
149 op->emitRemark() << getMessage();
150
151 return DiagnosedSilenceableFailure::success();
152 }
153
154 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)155 mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results,
156 transform::TransformState &state) {
157 state.addExtension<TestTransformStateExtension>(getMessageAttr());
158 return DiagnosedSilenceableFailure::success();
159 }
160
161 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)162 mlir::test::TestCheckIfTestExtensionPresentOp::apply(
163 transform::TransformResults &results, transform::TransformState &state) {
164 auto *extension = state.getExtension<TestTransformStateExtension>();
165 if (!extension) {
166 emitRemark() << "extension absent";
167 return DiagnosedSilenceableFailure::success();
168 }
169
170 InFlightDiagnostic diag = emitRemark()
171 << "extension present, " << extension->getMessage();
172 for (Operation *payload : state.getPayloadOps(getOperand())) {
173 diag.attachNote(payload->getLoc()) << "associated payload op";
174 assert(state.getHandleForPayloadOp(payload) == getOperand() &&
175 "inconsistent mapping between transform IR handles and payload IR "
176 "operations");
177 }
178
179 return DiagnosedSilenceableFailure::success();
180 }
181
apply(transform::TransformResults & results,transform::TransformState & state)182 DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
183 transform::TransformResults &results, transform::TransformState &state) {
184 auto *extension = state.getExtension<TestTransformStateExtension>();
185 if (!extension) {
186 emitError() << "TestTransformStateExtension missing";
187 return DiagnosedSilenceableFailure::definiteFailure();
188 }
189
190 if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(),
191 getOperation())))
192 return DiagnosedSilenceableFailure::definiteFailure();
193 return DiagnosedSilenceableFailure::success();
194 }
195
apply(transform::TransformResults & results,transform::TransformState & state)196 DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
197 transform::TransformResults &results, transform::TransformState &state) {
198 state.removeExtension<TestTransformStateExtension>();
199 return DiagnosedSilenceableFailure::success();
200 }
201
202 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)203 mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results,
204 transform::TransformState &state) {
205 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
206 auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
207 results.set(getResult().cast<OpResult>(), reversedOps);
208 return DiagnosedSilenceableFailure::success();
209 }
210
apply(transform::TransformResults & results,transform::TransformState & state)211 DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply(
212 transform::TransformResults &results, transform::TransformState &state) {
213 return DiagnosedSilenceableFailure::success();
214 }
215
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)216 void mlir::test::TestTransformOpWithRegions::getEffects(
217 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
218
219 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)220 mlir::test::TestBranchingTransformOpTerminator::apply(
221 transform::TransformResults &results, transform::TransformState &state) {
222 return DiagnosedSilenceableFailure::success();
223 }
224
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)225 void mlir::test::TestBranchingTransformOpTerminator::getEffects(
226 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
227
apply(transform::TransformResults & results,transform::TransformState & state)228 DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
229 transform::TransformResults &results, transform::TransformState &state) {
230 emitRemark() << getRemark();
231 for (Operation *op : state.getPayloadOps(getTarget()))
232 op->erase();
233
234 if (getFailAfterErase())
235 return emitSilenceableError() << "silencable error";
236 return DiagnosedSilenceableFailure::success();
237 }
238
applyToOne(Operation * target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)239 DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
240 Operation *target, SmallVectorImpl<Operation *> &results,
241 transform::TransformState &state) {
242 OperationState opState(target->getLoc(), "foo");
243 results.push_back(OpBuilder(target).create(opState));
244 return DiagnosedSilenceableFailure::success();
245 }
246
247 DiagnosedSilenceableFailure
applyToOne(Operation * target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)248 mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
249 Operation *target, SmallVectorImpl<Operation *> &results,
250 transform::TransformState &state) {
251 static int count = 0;
252 if (count++ == 0) {
253 OperationState opState(target->getLoc(), "foo");
254 results.push_back(OpBuilder(target).create(opState));
255 }
256 return DiagnosedSilenceableFailure::success();
257 }
258
259 DiagnosedSilenceableFailure
applyToOne(Operation * target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)260 mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
261 Operation *target, SmallVectorImpl<Operation *> &results,
262 transform::TransformState &state) {
263 OperationState opState(target->getLoc(), "foo");
264 results.push_back(OpBuilder(target).create(opState));
265 results.push_back(OpBuilder(target).create(opState));
266 return DiagnosedSilenceableFailure::success();
267 }
268
269 DiagnosedSilenceableFailure
applyToOne(Operation * target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)270 mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(
271 Operation *target, SmallVectorImpl<Operation *> &results,
272 transform::TransformState &state) {
273 OperationState opState(target->getLoc(), "foo");
274 results.push_back(nullptr);
275 results.push_back(OpBuilder(target).create(opState));
276 return DiagnosedSilenceableFailure::success();
277 }
278
279 DiagnosedSilenceableFailure
applyToOne(Operation * target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)280 mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
281 Operation *target, SmallVectorImpl<Operation *> &results,
282 transform::TransformState &state) {
283 if (target->hasAttr("target_me"))
284 return DiagnosedSilenceableFailure::success();
285 return emitDefaultSilenceableFailure(target);
286 }
287
288 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)289 mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
290 transform::TransformResults &results, transform::TransformState &state) {
291 emitRemark() << state.getPayloadOps(getHandle()).size();
292 return DiagnosedSilenceableFailure::success();
293 }
294
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)295 void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
296 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
297 transform::onlyReadsHandle(getHandle(), effects);
298 }
299
300 namespace {
301 /// Test extension of the Transform dialect. Registers additional ops and
302 /// declares PDL as dependent dialect since the additional ops are using PDL
303 /// types for operands and results.
304 class TestTransformDialectExtension
305 : public transform::TransformDialectExtension<
306 TestTransformDialectExtension> {
307 public:
308 using Base::Base;
309
init()310 void init() {
311 declareDependentDialect<pdl::PDLDialect>();
312 registerTransformOps<TestTransformOp,
313 TestTransformUnrestrictedOpNoInterface,
314 #define GET_OP_LIST
315 #include "TestTransformDialectExtension.cpp.inc"
316 >();
317 }
318 };
319 } // namespace
320
321 #define GET_OP_CLASSES
322 #include "TestTransformDialectExtension.cpp.inc"
323
registerTestTransformDialectExtension(DialectRegistry & registry)324 void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) {
325 registry.addExtensions<TestTransformDialectExtension>();
326 }
327