1 //===- TestTransformDialectExtension.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines an extension of the MLIR Transform dialect for testing 10 // purposes. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestTransformDialectExtension.h" 15 #include "TestTransformStateExtension.h" 16 #include "mlir/Dialect/PDL/IR/PDL.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 19 #include "mlir/IR/OpImplementation.h" 20 21 using namespace mlir; 22 23 namespace { 24 /// Simple transform op defined outside of the dialect. Just emits a remark when 25 /// applied. This op is defined in C++ to test that C++ definitions also work 26 /// for op injection into the Transform dialect. 27 class TestTransformOp 28 : public Op<TestTransformOp, transform::TransformOpInterface::Trait, 29 MemoryEffectOpInterface::Trait> { 30 public: 31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) 32 33 using Op::Op; 34 35 static ArrayRef<StringRef> getAttributeNames() { return {}; } 36 37 static constexpr llvm::StringLiteral getOperationName() { 38 return llvm::StringLiteral("transform.test_transform_op"); 39 } 40 41 LogicalResult apply(transform::TransformResults &results, 42 transform::TransformState &state) { 43 InFlightDiagnostic remark = emitRemark() << "applying transformation"; 44 if (Attribute message = getMessage()) 45 remark << " " << message; 46 47 return success(); 48 } 49 50 Attribute getMessage() { return getOperation()->getAttr("message"); } 51 52 static ParseResult parse(OpAsmParser &parser, OperationState &state) { 53 StringAttr message; 54 OptionalParseResult result = parser.parseOptionalAttribute(message); 55 if (!result.hasValue()) 56 return success(); 57 58 if (result.getValue().succeeded()) 59 state.addAttribute("message", message); 60 return result.getValue(); 61 } 62 63 void print(OpAsmPrinter &printer) { 64 if (getMessage()) 65 printer << " " << getMessage(); 66 } 67 68 // No side effects. 69 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 70 }; 71 72 /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait 73 /// in cases where it is attached to ops that do not comply with the trait 74 /// requirements. This op cannot be defined in ODS because ODS generates strict 75 /// verifiers that overalp with those in the trait and run earlier. 76 class TestTransformUnrestrictedOpNoInterface 77 : public Op<TestTransformUnrestrictedOpNoInterface, 78 transform::PossibleTopLevelTransformOpTrait, 79 transform::TransformOpInterface::Trait, 80 MemoryEffectOpInterface::Trait> { 81 public: 82 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 83 TestTransformUnrestrictedOpNoInterface) 84 85 using Op::Op; 86 87 static ArrayRef<StringRef> getAttributeNames() { return {}; } 88 89 static constexpr llvm::StringLiteral getOperationName() { 90 return llvm::StringLiteral( 91 "transform.test_transform_unrestricted_op_no_interface"); 92 } 93 94 LogicalResult apply(transform::TransformResults &results, 95 transform::TransformState &state) { 96 return success(); 97 } 98 99 // No side effects. 100 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 101 }; 102 } // namespace 103 104 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply( 105 transform::TransformResults &results, transform::TransformState &state) { 106 if (getOperation()->getNumOperands() != 0) { 107 results.set(getResult().cast<OpResult>(), 108 getOperation()->getOperand(0).getDefiningOp()); 109 } else { 110 results.set(getResult().cast<OpResult>(), 111 reinterpret_cast<Operation *>(*getParameter())); 112 } 113 return success(); 114 } 115 116 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { 117 if (getParameter().hasValue() ^ (getNumOperands() != 1)) 118 return emitOpError() << "expects either a parameter or an operand"; 119 return success(); 120 } 121 122 LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( 123 transform::TransformResults &results, transform::TransformState &state) { 124 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 125 assert(payload.size() == 1 && "expected a single target op"); 126 auto value = reinterpret_cast<intptr_t>(payload[0]); 127 if (static_cast<uint64_t>(value) != getParameter()) { 128 return emitOpError() << "expected the operand to be associated with " 129 << getParameter() << " got " << value; 130 } 131 132 emitRemark() << "succeeded"; 133 return success(); 134 } 135 136 LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply( 137 transform::TransformResults &results, transform::TransformState &state) { 138 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 139 for (Operation *op : payload) 140 op->emitRemark() << getMessage(); 141 142 return success(); 143 } 144 145 LogicalResult 146 mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, 147 transform::TransformState &state) { 148 state.addExtension<TestTransformStateExtension>(getMessageAttr()); 149 return success(); 150 } 151 152 LogicalResult mlir::test::TestCheckIfTestExtensionPresentOp::apply( 153 transform::TransformResults &results, transform::TransformState &state) { 154 auto *extension = state.getExtension<TestTransformStateExtension>(); 155 if (!extension) { 156 emitRemark() << "extension absent"; 157 return success(); 158 } 159 160 InFlightDiagnostic diag = emitRemark() 161 << "extension present, " << extension->getMessage(); 162 for (Operation *payload : state.getPayloadOps(getOperand())) { 163 diag.attachNote(payload->getLoc()) << "associated payload op"; 164 assert(state.getHandleForPayloadOp(payload) == getOperand() && 165 "inconsistent mapping between transform IR handles and payload IR " 166 "operations"); 167 } 168 169 return success(); 170 } 171 172 LogicalResult mlir::test::TestRemapOperandPayloadToSelfOp::apply( 173 transform::TransformResults &results, transform::TransformState &state) { 174 auto *extension = state.getExtension<TestTransformStateExtension>(); 175 if (!extension) 176 return emitError() << "TestTransformStateExtension missing"; 177 178 return extension->updateMapping(state.getPayloadOps(getOperand()).front(), 179 getOperation()); 180 } 181 182 LogicalResult mlir::test::TestRemoveTestExtensionOp::apply( 183 transform::TransformResults &results, transform::TransformState &state) { 184 state.removeExtension<TestTransformStateExtension>(); 185 return success(); 186 } 187 LogicalResult mlir::test::TestTransformOpWithRegions::apply( 188 transform::TransformResults &results, transform::TransformState &state) { 189 return success(); 190 } 191 192 void mlir::test::TestTransformOpWithRegions::getEffects( 193 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 194 195 LogicalResult mlir::test::TestBranchingTransformOpTerminator::apply( 196 transform::TransformResults &results, transform::TransformState &state) { 197 return success(); 198 } 199 200 void mlir::test::TestBranchingTransformOpTerminator::getEffects( 201 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 202 203 namespace { 204 /// Test extension of the Transform dialect. Registers additional ops and 205 /// declares PDL as dependent dialect since the additional ops are using PDL 206 /// types for operands and results. 207 class TestTransformDialectExtension 208 : public transform::TransformDialectExtension< 209 TestTransformDialectExtension> { 210 public: 211 TestTransformDialectExtension() { 212 declareDependentDialect<pdl::PDLDialect>(); 213 registerTransformOps<TestTransformOp, 214 TestTransformUnrestrictedOpNoInterface, 215 #define GET_OP_LIST 216 #include "TestTransformDialectExtension.cpp.inc" 217 >(); 218 } 219 }; 220 } // namespace 221 222 #define GET_OP_CLASSES 223 #include "TestTransformDialectExtension.cpp.inc" 224 225 void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { 226 registry.addExtensions<TestTransformDialectExtension>(); 227 } 228