1 //===- TestTransformDialectExtension.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines an extension of the MLIR Transform dialect for testing 10 // purposes. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestTransformDialectExtension.h" 15 #include "TestTransformStateExtension.h" 16 #include "mlir/Dialect/PDL/IR/PDL.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 19 #include "mlir/IR/OpImplementation.h" 20 21 using namespace mlir; 22 23 namespace { 24 /// Simple transform op defined outside of the dialect. Just emits a remark when 25 /// applied. This op is defined in C++ to test that C++ definitions also work 26 /// for op injection into the Transform dialect. 27 class TestTransformOp 28 : public Op<TestTransformOp, transform::TransformOpInterface::Trait, 29 MemoryEffectOpInterface::Trait> { 30 public: 31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) 32 33 using Op::Op; 34 35 static ArrayRef<StringRef> getAttributeNames() { return {}; } 36 37 static constexpr llvm::StringLiteral getOperationName() { 38 return llvm::StringLiteral("transform.test_transform_op"); 39 } 40 41 DiagnosedSilenceableFailure apply(transform::TransformResults &results, 42 transform::TransformState &state) { 43 InFlightDiagnostic remark = emitRemark() << "applying transformation"; 44 if (Attribute message = getMessage()) 45 remark << " " << message; 46 47 return DiagnosedSilenceableFailure::success(); 48 } 49 50 Attribute getMessage() { return getOperation()->getAttr("message"); } 51 52 static ParseResult parse(OpAsmParser &parser, OperationState &state) { 53 StringAttr message; 54 OptionalParseResult result = parser.parseOptionalAttribute(message); 55 if (!result.hasValue()) 56 return success(); 57 58 if (result.getValue().succeeded()) 59 state.addAttribute("message", message); 60 return result.getValue(); 61 } 62 63 void print(OpAsmPrinter &printer) { 64 if (getMessage()) 65 printer << " " << getMessage(); 66 } 67 68 // No side effects. 69 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 70 }; 71 72 /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait 73 /// in cases where it is attached to ops that do not comply with the trait 74 /// requirements. This op cannot be defined in ODS because ODS generates strict 75 /// verifiers that overalp with those in the trait and run earlier. 76 class TestTransformUnrestrictedOpNoInterface 77 : public Op<TestTransformUnrestrictedOpNoInterface, 78 transform::PossibleTopLevelTransformOpTrait, 79 transform::TransformOpInterface::Trait, 80 MemoryEffectOpInterface::Trait> { 81 public: 82 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 83 TestTransformUnrestrictedOpNoInterface) 84 85 using Op::Op; 86 87 static ArrayRef<StringRef> getAttributeNames() { return {}; } 88 89 static constexpr llvm::StringLiteral getOperationName() { 90 return llvm::StringLiteral( 91 "transform.test_transform_unrestricted_op_no_interface"); 92 } 93 94 DiagnosedSilenceableFailure apply(transform::TransformResults &results, 95 transform::TransformState &state) { 96 return DiagnosedSilenceableFailure::success(); 97 } 98 99 // No side effects. 100 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 101 }; 102 } // namespace 103 104 DiagnosedSilenceableFailure 105 mlir::test::TestProduceParamOrForwardOperandOp::apply( 106 transform::TransformResults &results, transform::TransformState &state) { 107 if (getOperation()->getNumOperands() != 0) { 108 results.set(getResult().cast<OpResult>(), 109 getOperation()->getOperand(0).getDefiningOp()); 110 } else { 111 results.set(getResult().cast<OpResult>(), 112 reinterpret_cast<Operation *>(*getParameter())); 113 } 114 return DiagnosedSilenceableFailure::success(); 115 } 116 117 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { 118 if (getParameter().has_value() ^ (getNumOperands() != 1)) 119 return emitOpError() << "expects either a parameter or an operand"; 120 return success(); 121 } 122 123 DiagnosedSilenceableFailure 124 mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, 125 transform::TransformState &state) { 126 return DiagnosedSilenceableFailure::success(); 127 } 128 129 DiagnosedSilenceableFailure 130 mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( 131 transform::TransformResults &results, transform::TransformState &state) { 132 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 133 assert(payload.size() == 1 && "expected a single target op"); 134 auto value = reinterpret_cast<intptr_t>(payload[0]); 135 if (static_cast<uint64_t>(value) != getParameter()) { 136 return emitSilenceableError() 137 << "op expected the operand to be associated with " << getParameter() 138 << " got " << value; 139 } 140 141 emitRemark() << "succeeded"; 142 return DiagnosedSilenceableFailure::success(); 143 } 144 145 DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply( 146 transform::TransformResults &results, transform::TransformState &state) { 147 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 148 for (Operation *op : payload) 149 op->emitRemark() << getMessage(); 150 151 return DiagnosedSilenceableFailure::success(); 152 } 153 154 DiagnosedSilenceableFailure 155 mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, 156 transform::TransformState &state) { 157 state.addExtension<TestTransformStateExtension>(getMessageAttr()); 158 return DiagnosedSilenceableFailure::success(); 159 } 160 161 DiagnosedSilenceableFailure 162 mlir::test::TestCheckIfTestExtensionPresentOp::apply( 163 transform::TransformResults &results, transform::TransformState &state) { 164 auto *extension = state.getExtension<TestTransformStateExtension>(); 165 if (!extension) { 166 emitRemark() << "extension absent"; 167 return DiagnosedSilenceableFailure::success(); 168 } 169 170 InFlightDiagnostic diag = emitRemark() 171 << "extension present, " << extension->getMessage(); 172 for (Operation *payload : state.getPayloadOps(getOperand())) { 173 diag.attachNote(payload->getLoc()) << "associated payload op"; 174 assert(state.getHandleForPayloadOp(payload) == getOperand() && 175 "inconsistent mapping between transform IR handles and payload IR " 176 "operations"); 177 } 178 179 return DiagnosedSilenceableFailure::success(); 180 } 181 182 DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( 183 transform::TransformResults &results, transform::TransformState &state) { 184 auto *extension = state.getExtension<TestTransformStateExtension>(); 185 if (!extension) { 186 emitError() << "TestTransformStateExtension missing"; 187 return DiagnosedSilenceableFailure::definiteFailure(); 188 } 189 190 if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(), 191 getOperation()))) 192 return DiagnosedSilenceableFailure::definiteFailure(); 193 return DiagnosedSilenceableFailure::success(); 194 } 195 196 DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( 197 transform::TransformResults &results, transform::TransformState &state) { 198 state.removeExtension<TestTransformStateExtension>(); 199 return DiagnosedSilenceableFailure::success(); 200 } 201 202 DiagnosedSilenceableFailure 203 mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results, 204 transform::TransformState &state) { 205 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); 206 auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); 207 results.set(getResult().cast<OpResult>(), reversedOps); 208 return DiagnosedSilenceableFailure::success(); 209 } 210 211 DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( 212 transform::TransformResults &results, transform::TransformState &state) { 213 return DiagnosedSilenceableFailure::success(); 214 } 215 216 void mlir::test::TestTransformOpWithRegions::getEffects( 217 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 218 219 DiagnosedSilenceableFailure 220 mlir::test::TestBranchingTransformOpTerminator::apply( 221 transform::TransformResults &results, transform::TransformState &state) { 222 return DiagnosedSilenceableFailure::success(); 223 } 224 225 void mlir::test::TestBranchingTransformOpTerminator::getEffects( 226 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 227 228 DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( 229 transform::TransformResults &results, transform::TransformState &state) { 230 emitRemark() << getRemark(); 231 for (Operation *op : state.getPayloadOps(getTarget())) 232 op->erase(); 233 234 if (getFailAfterErase()) 235 return emitSilenceableError() << "silencable error"; 236 return DiagnosedSilenceableFailure::success(); 237 } 238 239 DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( 240 Operation *target, SmallVectorImpl<Operation *> &results, 241 transform::TransformState &state) { 242 OperationState opState(target->getLoc(), "foo"); 243 results.push_back(OpBuilder(target).create(opState)); 244 return DiagnosedSilenceableFailure::success(); 245 } 246 247 DiagnosedSilenceableFailure 248 mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( 249 Operation *target, SmallVectorImpl<Operation *> &results, 250 transform::TransformState &state) { 251 static int count = 0; 252 if (count++ == 0) { 253 OperationState opState(target->getLoc(), "foo"); 254 results.push_back(OpBuilder(target).create(opState)); 255 } 256 return DiagnosedSilenceableFailure::success(); 257 } 258 259 DiagnosedSilenceableFailure 260 mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( 261 Operation *target, SmallVectorImpl<Operation *> &results, 262 transform::TransformState &state) { 263 OperationState opState(target->getLoc(), "foo"); 264 results.push_back(OpBuilder(target).create(opState)); 265 results.push_back(OpBuilder(target).create(opState)); 266 return DiagnosedSilenceableFailure::success(); 267 } 268 269 DiagnosedSilenceableFailure 270 mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( 271 Operation *target, SmallVectorImpl<Operation *> &results, 272 transform::TransformState &state) { 273 OperationState opState(target->getLoc(), "foo"); 274 results.push_back(nullptr); 275 results.push_back(OpBuilder(target).create(opState)); 276 return DiagnosedSilenceableFailure::success(); 277 } 278 279 DiagnosedSilenceableFailure 280 mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( 281 Operation *target, SmallVectorImpl<Operation *> &results, 282 transform::TransformState &state) { 283 if (target->hasAttr("target_me")) 284 return DiagnosedSilenceableFailure::success(); 285 return emitDefaultSilenceableFailure(target); 286 } 287 288 DiagnosedSilenceableFailure 289 mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply( 290 transform::TransformResults &results, transform::TransformState &state) { 291 emitRemark() << state.getPayloadOps(getHandle()).size(); 292 return DiagnosedSilenceableFailure::success(); 293 } 294 295 void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects( 296 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 297 transform::onlyReadsHandle(getHandle(), effects); 298 } 299 300 namespace { 301 /// Test extension of the Transform dialect. Registers additional ops and 302 /// declares PDL as dependent dialect since the additional ops are using PDL 303 /// types for operands and results. 304 class TestTransformDialectExtension 305 : public transform::TransformDialectExtension< 306 TestTransformDialectExtension> { 307 public: 308 using Base::Base; 309 310 void init() { 311 declareDependentDialect<pdl::PDLDialect>(); 312 registerTransformOps<TestTransformOp, 313 TestTransformUnrestrictedOpNoInterface, 314 #define GET_OP_LIST 315 #include "TestTransformDialectExtension.cpp.inc" 316 >(); 317 } 318 }; 319 } // namespace 320 321 #define GET_OP_CLASSES 322 #include "TestTransformDialectExtension.cpp.inc" 323 324 void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { 325 registry.addExtensions<TestTransformDialectExtension>(); 326 } 327