1 //===- TestTransformDialectExtension.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines an extension of the MLIR Transform dialect for testing 10 // purposes. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestTransformDialectExtension.h" 15 #include "TestTransformStateExtension.h" 16 #include "mlir/Dialect/PDL/IR/PDL.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 19 #include "mlir/IR/OpImplementation.h" 20 21 using namespace mlir; 22 23 namespace { 24 /// Simple transform op defined outside of the dialect. Just emits a remark when 25 /// applied. This op is defined in C++ to test that C++ definitions also work 26 /// for op injection into the Transform dialect. 27 class TestTransformOp 28 : public Op<TestTransformOp, transform::TransformOpInterface::Trait, 29 MemoryEffectOpInterface::Trait> { 30 public: 31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) 32 33 using Op::Op; 34 35 static ArrayRef<StringRef> getAttributeNames() { return {}; } 36 37 static constexpr llvm::StringLiteral getOperationName() { 38 return llvm::StringLiteral("transform.test_transform_op"); 39 } 40 41 DiagnosedSilencableFailure apply(transform::TransformResults &results, 42 transform::TransformState &state) { 43 InFlightDiagnostic remark = emitRemark() << "applying transformation"; 44 if (Attribute message = getMessage()) 45 remark << " " << message; 46 47 return DiagnosedSilencableFailure::success(); 48 } 49 50 Attribute getMessage() { return getOperation()->getAttr("message"); } 51 52 static ParseResult parse(OpAsmParser &parser, OperationState &state) { 53 StringAttr message; 54 OptionalParseResult result = parser.parseOptionalAttribute(message); 55 if (!result.hasValue()) 56 return success(); 57 58 if (result.getValue().succeeded()) 59 state.addAttribute("message", message); 60 return result.getValue(); 61 } 62 63 void print(OpAsmPrinter &printer) { 64 if (getMessage()) 65 printer << " " << getMessage(); 66 } 67 68 // No side effects. 69 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 70 }; 71 72 /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait 73 /// in cases where it is attached to ops that do not comply with the trait 74 /// requirements. This op cannot be defined in ODS because ODS generates strict 75 /// verifiers that overalp with those in the trait and run earlier. 76 class TestTransformUnrestrictedOpNoInterface 77 : public Op<TestTransformUnrestrictedOpNoInterface, 78 transform::PossibleTopLevelTransformOpTrait, 79 transform::TransformOpInterface::Trait, 80 MemoryEffectOpInterface::Trait> { 81 public: 82 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 83 TestTransformUnrestrictedOpNoInterface) 84 85 using Op::Op; 86 87 static ArrayRef<StringRef> getAttributeNames() { return {}; } 88 89 static constexpr llvm::StringLiteral getOperationName() { 90 return llvm::StringLiteral( 91 "transform.test_transform_unrestricted_op_no_interface"); 92 } 93 94 DiagnosedSilencableFailure apply(transform::TransformResults &results, 95 transform::TransformState &state) { 96 return DiagnosedSilencableFailure::success(); 97 } 98 99 // No side effects. 100 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 101 }; 102 } // namespace 103 104 DiagnosedSilencableFailure 105 mlir::test::TestProduceParamOrForwardOperandOp::apply( 106 transform::TransformResults &results, transform::TransformState &state) { 107 if (getOperation()->getNumOperands() != 0) { 108 results.set(getResult().cast<OpResult>(), 109 getOperation()->getOperand(0).getDefiningOp()); 110 } else { 111 results.set(getResult().cast<OpResult>(), 112 reinterpret_cast<Operation *>(*getParameter())); 113 } 114 return DiagnosedSilencableFailure::success(); 115 } 116 117 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { 118 if (getParameter().hasValue() ^ (getNumOperands() != 1)) 119 return emitOpError() << "expects either a parameter or an operand"; 120 return success(); 121 } 122 123 DiagnosedSilencableFailure 124 mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, 125 transform::TransformState &state) { 126 return DiagnosedSilencableFailure::success(); 127 } 128 129 DiagnosedSilencableFailure 130 mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( 131 transform::TransformResults &results, transform::TransformState &state) { 132 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 133 assert(payload.size() == 1 && "expected a single target op"); 134 auto value = reinterpret_cast<intptr_t>(payload[0]); 135 if (static_cast<uint64_t>(value) != getParameter()) { 136 return emitSilencableError() 137 << "op expected the operand to be associated with " << getParameter() 138 << " got " << value; 139 } 140 141 emitRemark() << "succeeded"; 142 return DiagnosedSilencableFailure::success(); 143 } 144 145 DiagnosedSilencableFailure mlir::test::TestPrintRemarkAtOperandOp::apply( 146 transform::TransformResults &results, transform::TransformState &state) { 147 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 148 for (Operation *op : payload) 149 op->emitRemark() << getMessage(); 150 151 return DiagnosedSilencableFailure::success(); 152 } 153 154 DiagnosedSilencableFailure 155 mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, 156 transform::TransformState &state) { 157 state.addExtension<TestTransformStateExtension>(getMessageAttr()); 158 return DiagnosedSilencableFailure::success(); 159 } 160 161 DiagnosedSilencableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply( 162 transform::TransformResults &results, transform::TransformState &state) { 163 auto *extension = state.getExtension<TestTransformStateExtension>(); 164 if (!extension) { 165 emitRemark() << "extension absent"; 166 return DiagnosedSilencableFailure::success(); 167 } 168 169 InFlightDiagnostic diag = emitRemark() 170 << "extension present, " << extension->getMessage(); 171 for (Operation *payload : state.getPayloadOps(getOperand())) { 172 diag.attachNote(payload->getLoc()) << "associated payload op"; 173 assert(state.getHandleForPayloadOp(payload) == getOperand() && 174 "inconsistent mapping between transform IR handles and payload IR " 175 "operations"); 176 } 177 178 return DiagnosedSilencableFailure::success(); 179 } 180 181 DiagnosedSilencableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( 182 transform::TransformResults &results, transform::TransformState &state) { 183 auto *extension = state.getExtension<TestTransformStateExtension>(); 184 if (!extension) { 185 emitError() << "TestTransformStateExtension missing"; 186 return DiagnosedSilencableFailure::definiteFailure(); 187 } 188 189 if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(), 190 getOperation()))) 191 return DiagnosedSilencableFailure::definiteFailure(); 192 return DiagnosedSilencableFailure::success(); 193 } 194 195 DiagnosedSilencableFailure mlir::test::TestRemoveTestExtensionOp::apply( 196 transform::TransformResults &results, transform::TransformState &state) { 197 state.removeExtension<TestTransformStateExtension>(); 198 return DiagnosedSilencableFailure::success(); 199 } 200 DiagnosedSilencableFailure mlir::test::TestTransformOpWithRegions::apply( 201 transform::TransformResults &results, transform::TransformState &state) { 202 return DiagnosedSilencableFailure::success(); 203 } 204 205 void mlir::test::TestTransformOpWithRegions::getEffects( 206 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 207 208 DiagnosedSilencableFailure 209 mlir::test::TestBranchingTransformOpTerminator::apply( 210 transform::TransformResults &results, transform::TransformState &state) { 211 return DiagnosedSilencableFailure::success(); 212 } 213 214 void mlir::test::TestBranchingTransformOpTerminator::getEffects( 215 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 216 217 DiagnosedSilencableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( 218 transform::TransformResults &results, transform::TransformState &state) { 219 emitRemark() << getRemark(); 220 for (Operation *op : state.getPayloadOps(getTarget())) 221 op->erase(); 222 223 if (getFailAfterErase()) 224 return emitSilencableError() << "silencable error"; 225 return DiagnosedSilencableFailure::success(); 226 } 227 228 namespace { 229 /// Test extension of the Transform dialect. Registers additional ops and 230 /// declares PDL as dependent dialect since the additional ops are using PDL 231 /// types for operands and results. 232 class TestTransformDialectExtension 233 : public transform::TransformDialectExtension< 234 TestTransformDialectExtension> { 235 public: 236 TestTransformDialectExtension() { 237 declareDependentDialect<pdl::PDLDialect>(); 238 registerTransformOps<TestTransformOp, 239 TestTransformUnrestrictedOpNoInterface, 240 #define GET_OP_LIST 241 #include "TestTransformDialectExtension.cpp.inc" 242 >(); 243 } 244 }; 245 } // namespace 246 247 #define GET_OP_CLASSES 248 #include "TestTransformDialectExtension.cpp.inc" 249 250 void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { 251 registry.addExtensions<TestTransformDialectExtension>(); 252 } 253