1 //===- TestTransformDialectExtension.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines an extension of the MLIR Transform dialect for testing 10 // purposes. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestTransformDialectExtension.h" 15 #include "mlir/Dialect/PDL/IR/PDL.h" 16 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 17 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/OpImplementation.h" 20 21 using namespace mlir; 22 23 namespace { 24 /// Simple transform op defined outside of the dialect. Just emits a remark when 25 /// applied. This op is defined in C++ to test that C++ definitions also work 26 /// for op injection into the Transform dialect. 27 class TestTransformOp 28 : public Op<TestTransformOp, transform::TransformOpInterface::Trait> { 29 public: 30 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) 31 32 using Op::Op; 33 34 static ArrayRef<StringRef> getAttributeNames() { return {}; } 35 36 static constexpr llvm::StringLiteral getOperationName() { 37 return llvm::StringLiteral("transform.test_transform_op"); 38 } 39 40 LogicalResult apply(transform::TransformResults &results, 41 transform::TransformState &state) { 42 InFlightDiagnostic remark = emitRemark() << "applying transformation"; 43 if (Attribute message = getMessage()) 44 remark << " " << message; 45 46 return success(); 47 } 48 49 Attribute getMessage() { return getOperation()->getAttr("message"); } 50 51 static ParseResult parse(OpAsmParser &parser, OperationState &state) { 52 StringAttr message; 53 OptionalParseResult result = parser.parseOptionalAttribute(message); 54 if (!result.hasValue()) 55 return success(); 56 57 if (result.getValue().succeeded()) 58 state.addAttribute("message", message); 59 return result.getValue(); 60 } 61 62 void print(OpAsmPrinter &printer) { 63 if (getMessage()) 64 printer << " " << getMessage(); 65 } 66 }; 67 68 /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait 69 /// in cases where it is attached to ops that do not comply with the trait 70 /// requirements. This op cannot be defined in ODS because ODS generates strict 71 /// verifiers that overalp with those in the trait and run earlier. 72 class TestTransformUnrestrictedOpNoInterface 73 : public Op<TestTransformUnrestrictedOpNoInterface, 74 transform::PossibleTopLevelTransformOpTrait, 75 transform::TransformOpInterface::Trait> { 76 public: 77 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 78 TestTransformUnrestrictedOpNoInterface) 79 80 using Op::Op; 81 82 static ArrayRef<StringRef> getAttributeNames() { return {}; } 83 84 static constexpr llvm::StringLiteral getOperationName() { 85 return llvm::StringLiteral( 86 "transform.test_transform_unrestricted_op_no_interface"); 87 } 88 89 LogicalResult apply(transform::TransformResults &results, 90 transform::TransformState &state) { 91 return success(); 92 } 93 }; 94 } // namespace 95 96 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply( 97 transform::TransformResults &results, transform::TransformState &state) { 98 if (getOperation()->getNumOperands() != 0) { 99 results.set(getResult().cast<OpResult>(), 100 getOperation()->getOperand(0).getDefiningOp()); 101 } else { 102 results.set(getResult().cast<OpResult>(), 103 reinterpret_cast<Operation *>(*getParameter())); 104 } 105 return success(); 106 } 107 108 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { 109 if (getParameter().hasValue() ^ (getNumOperands() != 1)) 110 return emitOpError() << "expects either a parameter or an operand"; 111 return success(); 112 } 113 114 LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( 115 transform::TransformResults &results, transform::TransformState &state) { 116 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 117 assert(payload.size() == 1 && "expected a single target op"); 118 auto value = reinterpret_cast<intptr_t>(payload[0]); 119 if (static_cast<uint64_t>(value) != getParameter()) { 120 return emitOpError() << "expected the operand to be associated with " 121 << getParameter() << " got " << value; 122 } 123 124 emitRemark() << "succeeded"; 125 return success(); 126 } 127 128 LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply( 129 transform::TransformResults &results, transform::TransformState &state) { 130 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 131 for (Operation *op : payload) 132 op->emitRemark() << getMessage(); 133 134 return success(); 135 } 136 137 namespace { 138 /// Test extension of the Transform dialect. Registers additional ops and 139 /// declares PDL as dependent dialect since the additional ops are using PDL 140 /// types for operands and results. 141 class TestTransformDialectExtension 142 : public transform::TransformDialectExtension< 143 TestTransformDialectExtension> { 144 public: 145 TestTransformDialectExtension() { 146 declareDependentDialect<pdl::PDLDialect>(); 147 registerTransformOps<TestTransformOp, 148 TestTransformUnrestrictedOpNoInterface, 149 #define GET_OP_LIST 150 #include "TestTransformDialectExtension.cpp.inc" 151 >(); 152 } 153 }; 154 } // namespace 155 156 #define GET_OP_CLASSES 157 #include "TestTransformDialectExtension.cpp.inc" 158 159 void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { 160 registry.addExtensions<TestTransformDialectExtension>(); 161 } 162