1 //===- TestTransformDialectExtension.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines an extension of the MLIR Transform dialect for testing 10 // purposes. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestTransformDialectExtension.h" 15 #include "mlir/Dialect/PDL/IR/PDL.h" 16 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 17 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/OpImplementation.h" 20 21 using namespace mlir; 22 23 namespace { 24 /// Simple transform op defined outside of the dialect. Just emits a remark when 25 /// applied. 26 class TestTransformOp 27 : public Op<TestTransformOp, transform::TransformOpInterface::Trait> { 28 public: 29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) 30 31 using Op::Op; 32 33 static ArrayRef<StringRef> getAttributeNames() { return {}; } 34 35 static constexpr llvm::StringLiteral getOperationName() { 36 return llvm::StringLiteral("transform.test_transform_op"); 37 } 38 39 LogicalResult apply(transform::TransformResults &results, 40 transform::TransformState &state) { 41 emitRemark() << "applying transformation"; 42 return success(); 43 } 44 45 static ParseResult parse(OpAsmParser &parser, OperationState &state) { 46 return success(); 47 } 48 49 void print(OpAsmPrinter &printer) {} 50 }; 51 } // namespace 52 53 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply( 54 transform::TransformResults &results, transform::TransformState &state) { 55 if (getOperation()->getNumOperands() != 0) { 56 results.set(getResult().cast<OpResult>(), getOperand(0).getDefiningOp()); 57 } else { 58 results.set(getResult().cast<OpResult>(), 59 reinterpret_cast<Operation *>(*parameter())); 60 } 61 return success(); 62 } 63 64 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { 65 if (parameter().hasValue() ^ (getNumOperands() != 1)) 66 return emitOpError() << "expects either a parameter or an operand"; 67 return success(); 68 } 69 70 LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( 71 transform::TransformResults &results, transform::TransformState &state) { 72 ArrayRef<Operation *> payload = state.getPayloadOps(getOperand()); 73 assert(payload.size() == 1 && "expected a single target op"); 74 auto value = reinterpret_cast<intptr_t>(payload[0]); 75 if (static_cast<uint64_t>(value) != parameter()) { 76 return emitOpError() << "expected the operand to be associated with " 77 << parameter() << " got " << value; 78 } 79 80 emitRemark() << "succeeded"; 81 return success(); 82 } 83 84 namespace { 85 /// Test extension of the Transform dialect. Registers additional ops and 86 /// declares PDL as dependent dialect since the additional ops are using PDL 87 /// types for operands and results. 88 class TestTransformDialectExtension 89 : public transform::TransformDialectExtension< 90 TestTransformDialectExtension> { 91 public: 92 TestTransformDialectExtension() { 93 declareDependentDialect<pdl::PDLDialect>(); 94 registerTransformOps<TestTransformOp, 95 #define GET_OP_LIST 96 #include "TestTransformDialectExtension.cpp.inc" 97 >(); 98 } 99 }; 100 } // namespace 101 102 #define GET_OP_CLASSES 103 #include "TestTransformDialectExtension.cpp.inc" 104 105 void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { 106 registry.addExtensions<TestTransformDialectExtension>(); 107 } 108