1 //===- TestTransformDialectExtension.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines an extension of the MLIR Transform dialect for testing
10 // purposes.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestTransformDialectExtension.h"
15 #include "mlir/Dialect/PDL/IR/PDL.h"
16 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
17 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/OpImplementation.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 /// Simple transform op defined outside of the dialect. Just emits a remark when
25 /// applied.
26 class TestTransformOp
27     : public Op<TestTransformOp, transform::TransformOpInterface::Trait> {
28 public:
29   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
30 
31   using Op::Op;
32 
33   static ArrayRef<StringRef> getAttributeNames() { return {}; }
34 
35   static constexpr llvm::StringLiteral getOperationName() {
36     return llvm::StringLiteral("transform.test_transform_op");
37   }
38 
39   LogicalResult apply(transform::TransformResults &results,
40                       transform::TransformState &state) {
41     InFlightDiagnostic remark = emitRemark() << "applying transformation";
42     if (Attribute message = getMessage())
43       remark << " " << message;
44 
45     return success();
46   }
47 
48   Attribute getMessage() { return getOperation()->getAttr("message"); }
49 
50   static ParseResult parse(OpAsmParser &parser, OperationState &state) {
51     StringAttr message;
52     OptionalParseResult result = parser.parseOptionalAttribute(message);
53     if (!result.hasValue())
54       return success();
55 
56     if (result.getValue().succeeded())
57       state.addAttribute("message", message);
58     return result.getValue();
59   }
60 
61   void print(OpAsmPrinter &printer) {
62     if (getMessage())
63       printer << " " << getMessage();
64   }
65 };
66 } // namespace
67 
68 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
69     transform::TransformResults &results, transform::TransformState &state) {
70   if (getOperation()->getNumOperands() != 0) {
71     results.set(getResult().cast<OpResult>(),
72                 getOperation()->getOperand(0).getDefiningOp());
73   } else {
74     results.set(getResult().cast<OpResult>(),
75                 reinterpret_cast<Operation *>(*getParameter()));
76   }
77   return success();
78 }
79 
80 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
81   if (getParameter().hasValue() ^ (getNumOperands() != 1))
82     return emitOpError() << "expects either a parameter or an operand";
83   return success();
84 }
85 
86 LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
87     transform::TransformResults &results, transform::TransformState &state) {
88   ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
89   assert(payload.size() == 1 && "expected a single target op");
90   auto value = reinterpret_cast<intptr_t>(payload[0]);
91   if (static_cast<uint64_t>(value) != getParameter()) {
92     return emitOpError() << "expected the operand to be associated with "
93                          << getParameter() << " got " << value;
94   }
95 
96   emitRemark() << "succeeded";
97   return success();
98 }
99 
100 namespace {
101 /// Test extension of the Transform dialect. Registers additional ops and
102 /// declares PDL as dependent dialect since the additional ops are using PDL
103 /// types for operands and results.
104 class TestTransformDialectExtension
105     : public transform::TransformDialectExtension<
106           TestTransformDialectExtension> {
107 public:
108   TestTransformDialectExtension() {
109     declareDependentDialect<pdl::PDLDialect>();
110     registerTransformOps<TestTransformOp,
111 #define GET_OP_LIST
112 #include "TestTransformDialectExtension.cpp.inc"
113                          >();
114   }
115 };
116 } // namespace
117 
118 #define GET_OP_CLASSES
119 #include "TestTransformDialectExtension.cpp.inc"
120 
121 void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
122   registry.addExtensions<TestTransformDialectExtension>();
123 }
124