1 //===- TestTransformDialectExtension.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines an extension of the MLIR Transform dialect for testing 10 // purposes. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestTransformDialectExtension.h" 15 #include "TestTransformStateExtension.h" 16 #include "mlir/Dialect/PDL/IR/PDL.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 19 #include "mlir/IR/OpImplementation.h" 20 21 using namespace mlir; 22 23 namespace { 24 /// Simple transform op defined outside of the dialect. Just emits a remark when 25 /// applied. This op is defined in C++ to test that C++ definitions also work 26 /// for op injection into the Transform dialect. 27 class TestTransformOp 28 : public Op<TestTransformOp, transform::TransformOpInterface::Trait, 29 MemoryEffectOpInterface::Trait> { 30 public: 31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) 32 33 using Op::Op; 34 35 static ArrayRef<StringRef> getAttributeNames() { return {}; } 36 37 static constexpr llvm::StringLiteral getOperationName() { 38 return llvm::StringLiteral("transform.test_transform_op"); 39 } 40 41 LogicalResult apply(transform::TransformResults &results, 42 transform::TransformState &state) { 43 InFlightDiagnostic remark = emitRemark() << "applying transformation"; 44 if (Attribute message = getMessage()) 45 remark << " " << message; 46 47 return success(); 48 } 49 50 Attribute getMessage() { return getOperation()->getAttr("message"); } 51 52 static ParseResult parse(OpAsmParser &parser, OperationState &state) { 53 StringAttr message; 54 OptionalParseResult result = parser.parseOptionalAttribute(message); 55 if (!result.hasValue()) 56 return success(); 57 58 if (result.getValue().succeeded()) 59 state.addAttribute("message", message); 60 return result.getValue(); 61 } 62 63 void print(OpAsmPrinter &printer) { 64 if (getMessage()) 65 printer << " " << getMessage(); 66 } 67 68 // No side effects. 69 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 70 }; 71 72 /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait 73 /// in cases where it is attached to ops that do not comply with the trait 74 /// requirements. This op cannot be defined in ODS because ODS generates strict 75 /// verifiers that overalp with those in the trait and run earlier. 76 class TestTransformUnrestrictedOpNoInterface 77 : public Op<TestTransformUnrestrictedOpNoInterface, 78 transform::PossibleTopLevelTransformOpTrait, 79 transform::TransformOpInterface::Trait, 80 MemoryEffectOpInterface::Trait> { 81 public: 82 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 83 TestTransformUnrestrictedOpNoInterface) 84 85 using Op::Op; 86 87 static ArrayRef<StringRef> getAttributeNames() { return {}; } 88 89 static constexpr llvm::StringLiteral getOperationName() { 90 return llvm::StringLiteral( 91 "transform.test_transform_unrestricted_op_no_interface"); 92 } 93 94 LogicalResult apply(transform::TransformResults &results, 95 transform::TransformState &state) { 96 return success(); 97 } 98 99 // No side effects. 100 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 101 }; 102 } // namespace 103 104 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply( 105 transform::TransformResults &results, transform::TransformState &state) { 106 if (getOperation()->getNumOperands() != 0) { 107 results.set(getResult().cast<OpResult>(), 108 getOperation()->getOperand(0).getDefiningOp()); 109 } else { 110 results.set(getResult().cast<OpResult>(), 111 reinterpret_cast<Operation *>(*getParameter())); 112 } 113 return success(); 114 } 115 116 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { 117 if (getParameter().hasValue() ^ (getNumOperands() != 1)) 118 return emitOpError() << "expects either a parameter or an operand"; 119 return success(); 120 } 121 122 LogicalResult 123 mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, 124 transform::TransformState &state) { 125 return success(); 126 } 127 128 LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( 129 transform::TransformResults &results, transform::TransformState &state) { 130 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 131 assert(payload.size() == 1 && "expected a single target op"); 132 auto value = reinterpret_cast<intptr_t>(payload[0]); 133 if (static_cast<uint64_t>(value) != getParameter()) { 134 return emitOpError() << "expected the operand to be associated with " 135 << getParameter() << " got " << value; 136 } 137 138 emitRemark() << "succeeded"; 139 return success(); 140 } 141 142 LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply( 143 transform::TransformResults &results, transform::TransformState &state) { 144 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 145 for (Operation *op : payload) 146 op->emitRemark() << getMessage(); 147 148 return success(); 149 } 150 151 LogicalResult 152 mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, 153 transform::TransformState &state) { 154 state.addExtension<TestTransformStateExtension>(getMessageAttr()); 155 return success(); 156 } 157 158 LogicalResult mlir::test::TestCheckIfTestExtensionPresentOp::apply( 159 transform::TransformResults &results, transform::TransformState &state) { 160 auto *extension = state.getExtension<TestTransformStateExtension>(); 161 if (!extension) { 162 emitRemark() << "extension absent"; 163 return success(); 164 } 165 166 InFlightDiagnostic diag = emitRemark() 167 << "extension present, " << extension->getMessage(); 168 for (Operation *payload : state.getPayloadOps(getOperand())) { 169 diag.attachNote(payload->getLoc()) << "associated payload op"; 170 assert(state.getHandleForPayloadOp(payload) == getOperand() && 171 "inconsistent mapping between transform IR handles and payload IR " 172 "operations"); 173 } 174 175 return success(); 176 } 177 178 LogicalResult mlir::test::TestRemapOperandPayloadToSelfOp::apply( 179 transform::TransformResults &results, transform::TransformState &state) { 180 auto *extension = state.getExtension<TestTransformStateExtension>(); 181 if (!extension) 182 return emitError() << "TestTransformStateExtension missing"; 183 184 return extension->updateMapping(state.getPayloadOps(getOperand()).front(), 185 getOperation()); 186 } 187 188 LogicalResult mlir::test::TestRemoveTestExtensionOp::apply( 189 transform::TransformResults &results, transform::TransformState &state) { 190 state.removeExtension<TestTransformStateExtension>(); 191 return success(); 192 } 193 LogicalResult mlir::test::TestTransformOpWithRegions::apply( 194 transform::TransformResults &results, transform::TransformState &state) { 195 return success(); 196 } 197 198 void mlir::test::TestTransformOpWithRegions::getEffects( 199 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 200 201 LogicalResult mlir::test::TestBranchingTransformOpTerminator::apply( 202 transform::TransformResults &results, transform::TransformState &state) { 203 return success(); 204 } 205 206 void mlir::test::TestBranchingTransformOpTerminator::getEffects( 207 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 208 209 namespace { 210 /// Test extension of the Transform dialect. Registers additional ops and 211 /// declares PDL as dependent dialect since the additional ops are using PDL 212 /// types for operands and results. 213 class TestTransformDialectExtension 214 : public transform::TransformDialectExtension< 215 TestTransformDialectExtension> { 216 public: 217 TestTransformDialectExtension() { 218 declareDependentDialect<pdl::PDLDialect>(); 219 registerTransformOps<TestTransformOp, 220 TestTransformUnrestrictedOpNoInterface, 221 #define GET_OP_LIST 222 #include "TestTransformDialectExtension.cpp.inc" 223 >(); 224 } 225 }; 226 } // namespace 227 228 #define GET_OP_CLASSES 229 #include "TestTransformDialectExtension.cpp.inc" 230 231 void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { 232 registry.addExtensions<TestTransformDialectExtension>(); 233 } 234