1 //===- TestTransformDialectExtension.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines an extension of the MLIR Transform dialect for testing 10 // purposes. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestTransformDialectExtension.h" 15 #include "mlir/Dialect/PDL/IR/PDL.h" 16 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 17 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/OpImplementation.h" 20 21 using namespace mlir; 22 23 namespace { 24 /// Simple transform op defined outside of the dialect. Just emits a remark when 25 /// applied. This op is defined in C++ to test that C++ definitions also work 26 /// for op injection into the Transform dialect. 27 class TestTransformOp 28 : public Op<TestTransformOp, transform::TransformOpInterface::Trait, 29 MemoryEffectOpInterface::Trait> { 30 public: 31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) 32 33 using Op::Op; 34 35 static ArrayRef<StringRef> getAttributeNames() { return {}; } 36 37 static constexpr llvm::StringLiteral getOperationName() { 38 return llvm::StringLiteral("transform.test_transform_op"); 39 } 40 41 LogicalResult apply(transform::TransformResults &results, 42 transform::TransformState &state) { 43 InFlightDiagnostic remark = emitRemark() << "applying transformation"; 44 if (Attribute message = getMessage()) 45 remark << " " << message; 46 47 return success(); 48 } 49 50 Attribute getMessage() { return getOperation()->getAttr("message"); } 51 52 static ParseResult parse(OpAsmParser &parser, OperationState &state) { 53 StringAttr message; 54 OptionalParseResult result = parser.parseOptionalAttribute(message); 55 if (!result.hasValue()) 56 return success(); 57 58 if (result.getValue().succeeded()) 59 state.addAttribute("message", message); 60 return result.getValue(); 61 } 62 63 void print(OpAsmPrinter &printer) { 64 if (getMessage()) 65 printer << " " << getMessage(); 66 } 67 68 // No side effects. 69 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 70 }; 71 72 /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait 73 /// in cases where it is attached to ops that do not comply with the trait 74 /// requirements. This op cannot be defined in ODS because ODS generates strict 75 /// verifiers that overalp with those in the trait and run earlier. 76 class TestTransformUnrestrictedOpNoInterface 77 : public Op<TestTransformUnrestrictedOpNoInterface, 78 transform::PossibleTopLevelTransformOpTrait, 79 transform::TransformOpInterface::Trait, 80 MemoryEffectOpInterface::Trait> { 81 public: 82 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 83 TestTransformUnrestrictedOpNoInterface) 84 85 using Op::Op; 86 87 static ArrayRef<StringRef> getAttributeNames() { return {}; } 88 89 static constexpr llvm::StringLiteral getOperationName() { 90 return llvm::StringLiteral( 91 "transform.test_transform_unrestricted_op_no_interface"); 92 } 93 94 LogicalResult apply(transform::TransformResults &results, 95 transform::TransformState &state) { 96 return success(); 97 } 98 99 // No side effects. 100 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 101 }; 102 } // namespace 103 104 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply( 105 transform::TransformResults &results, transform::TransformState &state) { 106 if (getOperation()->getNumOperands() != 0) { 107 results.set(getResult().cast<OpResult>(), 108 getOperation()->getOperand(0).getDefiningOp()); 109 } else { 110 results.set(getResult().cast<OpResult>(), 111 reinterpret_cast<Operation *>(*getParameter())); 112 } 113 return success(); 114 } 115 116 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { 117 if (getParameter().hasValue() ^ (getNumOperands() != 1)) 118 return emitOpError() << "expects either a parameter or an operand"; 119 return success(); 120 } 121 122 LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( 123 transform::TransformResults &results, transform::TransformState &state) { 124 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 125 assert(payload.size() == 1 && "expected a single target op"); 126 auto value = reinterpret_cast<intptr_t>(payload[0]); 127 if (static_cast<uint64_t>(value) != getParameter()) { 128 return emitOpError() << "expected the operand to be associated with " 129 << getParameter() << " got " << value; 130 } 131 132 emitRemark() << "succeeded"; 133 return success(); 134 } 135 136 LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply( 137 transform::TransformResults &results, transform::TransformState &state) { 138 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 139 for (Operation *op : payload) 140 op->emitRemark() << getMessage(); 141 142 return success(); 143 } 144 145 namespace { 146 /// Test extension of the Transform dialect. Registers additional ops and 147 /// declares PDL as dependent dialect since the additional ops are using PDL 148 /// types for operands and results. 149 class TestTransformDialectExtension 150 : public transform::TransformDialectExtension< 151 TestTransformDialectExtension> { 152 public: 153 TestTransformDialectExtension() { 154 declareDependentDialect<pdl::PDLDialect>(); 155 registerTransformOps<TestTransformOp, 156 TestTransformUnrestrictedOpNoInterface, 157 #define GET_OP_LIST 158 #include "TestTransformDialectExtension.cpp.inc" 159 >(); 160 } 161 }; 162 } // namespace 163 164 #define GET_OP_CLASSES 165 #include "TestTransformDialectExtension.cpp.inc" 166 167 void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { 168 registry.addExtensions<TestTransformDialectExtension>(); 169 } 170