1 //===- TestTransformDialectExtension.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines an extension of the MLIR Transform dialect for testing 10 // purposes. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestTransformDialectExtension.h" 15 #include "mlir/Dialect/PDL/IR/PDL.h" 16 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 17 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/OpImplementation.h" 20 21 using namespace mlir; 22 23 namespace { 24 /// Simple transform op defined outside of the dialect. Just emits a remark when 25 /// applied. 26 class TestTransformOp 27 : public Op<TestTransformOp, transform::TransformOpInterface::Trait> { 28 public: 29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) 30 31 using Op::Op; 32 33 static ArrayRef<StringRef> getAttributeNames() { return {}; } 34 35 static constexpr llvm::StringLiteral getOperationName() { 36 return llvm::StringLiteral("transform.test_transform_op"); 37 } 38 39 LogicalResult apply(transform::TransformResults &results, 40 transform::TransformState &state) { 41 InFlightDiagnostic remark = emitRemark() << "applying transformation"; 42 if (Attribute message = getMessage()) 43 remark << " " << message; 44 45 return success(); 46 } 47 48 Attribute getMessage() { return getOperation()->getAttr("message"); } 49 50 static ParseResult parse(OpAsmParser &parser, OperationState &state) { 51 StringAttr message; 52 OptionalParseResult result = parser.parseOptionalAttribute(message); 53 if (!result.hasValue()) 54 return success(); 55 56 if (result.getValue().succeeded()) 57 state.addAttribute("message", message); 58 return result.getValue(); 59 } 60 61 void print(OpAsmPrinter &printer) { 62 if (getMessage()) 63 printer << " " << getMessage(); 64 } 65 }; 66 } // namespace 67 68 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply( 69 transform::TransformResults &results, transform::TransformState &state) { 70 if (getOperation()->getNumOperands() != 0) { 71 results.set(getResult().cast<OpResult>(), 72 getOperation()->getOperand(0).getDefiningOp()); 73 } else { 74 results.set(getResult().cast<OpResult>(), 75 reinterpret_cast<Operation *>(*getParameter())); 76 } 77 return success(); 78 } 79 80 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { 81 if (getParameter().hasValue() ^ (getNumOperands() != 1)) 82 return emitOpError() << "expects either a parameter or an operand"; 83 return success(); 84 } 85 86 LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( 87 transform::TransformResults &results, transform::TransformState &state) { 88 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 89 assert(payload.size() == 1 && "expected a single target op"); 90 auto value = reinterpret_cast<intptr_t>(payload[0]); 91 if (static_cast<uint64_t>(value) != getParameter()) { 92 return emitOpError() << "expected the operand to be associated with " 93 << getParameter() << " got " << value; 94 } 95 96 emitRemark() << "succeeded"; 97 return success(); 98 } 99 100 namespace { 101 /// Test extension of the Transform dialect. Registers additional ops and 102 /// declares PDL as dependent dialect since the additional ops are using PDL 103 /// types for operands and results. 104 class TestTransformDialectExtension 105 : public transform::TransformDialectExtension< 106 TestTransformDialectExtension> { 107 public: 108 TestTransformDialectExtension() { 109 declareDependentDialect<pdl::PDLDialect>(); 110 registerTransformOps<TestTransformOp, 111 #define GET_OP_LIST 112 #include "TestTransformDialectExtension.cpp.inc" 113 >(); 114 } 115 }; 116 } // namespace 117 118 #define GET_OP_CLASSES 119 #include "TestTransformDialectExtension.cpp.inc" 120 121 void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { 122 registry.addExtensions<TestTransformDialectExtension>(); 123 } 124