1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDL/IR/PDL.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "llvm/ADT/StringSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::pdl;
18 
19 //===----------------------------------------------------------------------===//
20 // PDLDialect
21 //===----------------------------------------------------------------------===//
22 
23 void PDLDialect::initialize() {
24   addOperations<
25 #define GET_OP_LIST
26 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
27       >();
28   addTypes<
29 #define GET_TYPEDEF_LIST
30 #include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
31       >();
32 }
33 
34 /// Returns true if the given operation is used by a "binding" pdl operation
35 /// within the main matcher body of a `pdl.pattern`.
36 static LogicalResult
37 verifyHasBindingUseInMatcher(Operation *op,
38                              StringRef bindableContextStr = "`pdl.operation`") {
39   // If the pattern is not a pattern, there is nothing to do.
40   if (!isa<PatternOp>(op->getParentOp()))
41     return success();
42   Block *matcherBlock = op->getBlock();
43   for (Operation *user : op->getUsers()) {
44     if (user->getBlock() != matcherBlock)
45       continue;
46     if (isa<AttributeOp, OperandOp, OperationOp, RewriteOp>(user))
47       return success();
48   }
49   return op->emitOpError()
50          << "expected a bindable (i.e. " << bindableContextStr
51          << ") user when defined in the matcher body of a `pdl.pattern`";
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // pdl::ApplyConstraintOp
56 //===----------------------------------------------------------------------===//
57 
58 static LogicalResult verify(ApplyConstraintOp op) {
59   if (op.getNumOperands() == 0)
60     return op.emitOpError("expected at least one argument");
61   return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // pdl::AttributeOp
66 //===----------------------------------------------------------------------===//
67 
68 static LogicalResult verify(AttributeOp op) {
69   Value attrType = op.type();
70   Optional<Attribute> attrValue = op.value();
71 
72   if (!attrValue && isa<RewriteOp>(op->getParentOp()))
73     return op.emitOpError("expected constant value when specified within a "
74                           "`pdl.rewrite`");
75   if (attrValue && attrType)
76     return op.emitOpError("expected only one of [`type`, `value`] to be set");
77   return verifyHasBindingUseInMatcher(op);
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // pdl::OperandOp
82 //===----------------------------------------------------------------------===//
83 
84 static LogicalResult verify(OperandOp op) {
85   return verifyHasBindingUseInMatcher(op);
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // pdl::OperationOp
90 //===----------------------------------------------------------------------===//
91 
92 static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) {
93   Builder &builder = p.getBuilder();
94 
95   // Parse the optional operation name.
96   bool startsWithOperands = succeeded(p.parseOptionalLParen());
97   bool startsWithAttributes =
98       !startsWithOperands && succeeded(p.parseOptionalLBrace());
99   bool startsWithOpName = false;
100   if (!startsWithAttributes && !startsWithOperands) {
101     StringAttr opName;
102     OptionalParseResult opNameResult =
103         p.parseOptionalAttribute(opName, "name", state.attributes);
104     startsWithOpName = opNameResult.hasValue();
105     if (startsWithOpName && failed(*opNameResult))
106       return failure();
107   }
108 
109   // Parse the operands.
110   SmallVector<OpAsmParser::OperandType, 4> operands;
111   if (startsWithOperands ||
112       (!startsWithAttributes && succeeded(p.parseOptionalLParen()))) {
113     if (p.parseOperandList(operands) || p.parseRParen() ||
114         p.resolveOperands(operands, builder.getType<ValueType>(),
115                           state.operands))
116       return failure();
117   }
118 
119   // Parse the attributes.
120   SmallVector<Attribute, 4> attrNames;
121   if (startsWithAttributes || succeeded(p.parseOptionalLBrace())) {
122     SmallVector<OpAsmParser::OperandType, 4> attrOps;
123     do {
124       StringAttr nameAttr;
125       OpAsmParser::OperandType operand;
126       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
127           p.parseOperand(operand))
128         return failure();
129       attrNames.push_back(nameAttr);
130       attrOps.push_back(operand);
131     } while (succeeded(p.parseOptionalComma()));
132 
133     if (p.parseRBrace() ||
134         p.resolveOperands(attrOps, builder.getType<AttributeType>(),
135                           state.operands))
136       return failure();
137   }
138   state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
139   state.addTypes(builder.getType<OperationType>());
140 
141   // Parse the result types.
142   SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
143   if (succeeded(p.parseOptionalArrow())) {
144     if (p.parseOperandList(opResultTypes) ||
145         p.resolveOperands(opResultTypes, builder.getType<TypeType>(),
146                           state.operands))
147       return failure();
148     state.types.append(opResultTypes.size(), builder.getType<ValueType>());
149   }
150 
151   if (p.parseOptionalAttrDict(state.attributes))
152     return failure();
153 
154   int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
155                                    static_cast<int32_t>(attrNames.size()),
156                                    static_cast<int32_t>(opResultTypes.size())};
157   state.addAttribute("operand_segment_sizes",
158                      builder.getI32VectorAttr(operandSegmentSizes));
159   return success();
160 }
161 
162 static void print(OpAsmPrinter &p, OperationOp op) {
163   p << "pdl.operation ";
164   if (Optional<StringRef> name = op.name())
165     p << '"' << *name << '"';
166 
167   auto operandValues = op.operands();
168   if (!operandValues.empty())
169     p << '(' << operandValues << ')';
170 
171   // Emit the optional attributes.
172   ArrayAttr attrNames = op.attributeNames();
173   if (!attrNames.empty()) {
174     Operation::operand_range attrArgs = op.attributes();
175     p << " {";
176     interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
177                     [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
178     p << '}';
179   }
180 
181   // Print the result type constraints of the operation.
182   if (!op.results().empty())
183     p << " -> " << op.types();
184   p.printOptionalAttrDict(op->getAttrs(),
185                           {"attributeNames", "name", "operand_segment_sizes"});
186 }
187 
188 /// Verifies that the result types of this operation, defined within a
189 /// `pdl.rewrite`, can be inferred.
190 static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
191                                                     ResultRange opResults,
192                                                     OperandRange resultTypes) {
193   // Functor that returns if the given use can be used to infer a type.
194   Block *rewriterBlock = op->getBlock();
195   auto canInferTypeFromUse = [&](OpOperand &use) {
196     // If the use is within a ReplaceOp and isn't the operation being replaced
197     // (i.e. is not the first operand of the replacement), we can infer a type.
198     ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
199     if (!replOpUser || use.getOperandNumber() == 0)
200       return false;
201     // Make sure the replaced operation was defined before this one.
202     Operation *replacedOp = replOpUser.operation().getDefiningOp();
203     return replacedOp->getBlock() != rewriterBlock ||
204            replacedOp->isBeforeInBlock(op);
205   };
206 
207   // Check to see if the uses of the operation itself can be used to infer
208   // types.
209   if (llvm::any_of(op.op().getUses(), canInferTypeFromUse))
210     return success();
211 
212   // Otherwise, make sure each of the types can be inferred.
213   for (int i : llvm::seq<int>(0, opResults.size())) {
214     Operation *resultTypeOp = resultTypes[i].getDefiningOp();
215     assert(resultTypeOp && "expected valid result type operation");
216 
217     // If the op was defined by a `create_native`, it is guaranteed to be
218     // usable.
219     if (isa<CreateNativeOp>(resultTypeOp))
220       continue;
221 
222     // If the type is already constrained, there is nothing to do.
223     TypeOp typeOp = cast<TypeOp>(resultTypeOp);
224     if (typeOp.type())
225       continue;
226 
227     // If the type operation was defined in the matcher and constrains the
228     // result of an input operation, it can be used.
229     auto constrainsInputOp = [rewriterBlock](Operation *user) {
230       return user->getBlock() != rewriterBlock && isa<OperationOp>(user);
231     };
232     if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp))
233       continue;
234 
235     // Otherwise, check to see if any uses of the result can infer the type.
236     if (llvm::any_of(opResults[i].getUses(), canInferTypeFromUse))
237       continue;
238     return op
239         .emitOpError("must have inferable or constrained result types when "
240                      "nested within `pdl.rewrite`")
241         .attachNote()
242         .append("result type #", i, " was not constrained");
243   }
244   return success();
245 }
246 
247 static LogicalResult verify(OperationOp op) {
248   bool isWithinRewrite = isa<RewriteOp>(op->getParentOp());
249   if (isWithinRewrite && !op.name())
250     return op.emitOpError("must have an operation name when nested within "
251                           "a `pdl.rewrite`");
252   ArrayAttr attributeNames = op.attributeNames();
253   auto attributeValues = op.attributes();
254   if (attributeNames.size() != attributeValues.size()) {
255     return op.emitOpError()
256            << "expected the same number of attribute values and attribute "
257               "names, got "
258            << attributeNames.size() << " names and " << attributeValues.size()
259            << " values";
260   }
261 
262   OperandRange resultTypes = op.types();
263   auto opResults = op.results();
264   if (resultTypes.size() != opResults.size()) {
265     return op.emitOpError() << "expected the same number of result values and "
266                                "result type constraints, got "
267                             << opResults.size() << " results and "
268                             << resultTypes.size() << " constraints";
269   }
270 
271   // If the operation is within a rewrite body and doesn't have type inference,
272   // ensure that the result types can be resolved.
273   if (isWithinRewrite && !op.hasTypeInference()) {
274     if (failed(verifyResultTypesAreInferrable(op, opResults, resultTypes)))
275       return failure();
276   }
277 
278   return verifyHasBindingUseInMatcher(op, "`pdl.operation` or `pdl.rewrite`");
279 }
280 
281 bool OperationOp::hasTypeInference() {
282   Optional<StringRef> opName = name();
283   if (!opName)
284     return false;
285 
286   OperationName name(*opName, getContext());
287   if (const AbstractOperation *op = name.getAbstractOperation())
288     return op->getInterface<InferTypeOpInterface>();
289   return false;
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // pdl::PatternOp
294 //===----------------------------------------------------------------------===//
295 
296 static LogicalResult verify(PatternOp pattern) {
297   Region &body = pattern.body();
298   auto *term = body.front().getTerminator();
299   if (!isa<RewriteOp>(term)) {
300     return pattern.emitOpError("expected body to terminate with `pdl.rewrite`")
301         .attachNote(term->getLoc())
302         .append("see terminator defined here");
303   }
304 
305   // Check that all values defined in the top-level pattern are referenced at
306   // least once in the source tree.
307   WalkResult result = body.walk([&](Operation *op) -> WalkResult {
308     if (!isa_and_nonnull<PDLDialect>(op->getDialect())) {
309       pattern
310           .emitOpError("expected only `pdl` operations within the pattern body")
311           .attachNote(op->getLoc())
312           .append("see non-`pdl` operation defined here");
313       return WalkResult::interrupt();
314     }
315     return WalkResult::advance();
316   });
317   return failure(result.wasInterrupted());
318 }
319 
320 void PatternOp::build(OpBuilder &builder, OperationState &state,
321                       Optional<StringRef> rootKind, Optional<uint16_t> benefit,
322                       Optional<StringRef> name) {
323   build(builder, state,
324         rootKind ? builder.getStringAttr(*rootKind) : StringAttr(),
325         builder.getI16IntegerAttr(benefit ? *benefit : 0),
326         name ? builder.getStringAttr(*name) : StringAttr());
327   builder.createBlock(state.addRegion());
328 }
329 
330 /// Returns the rewrite operation of this pattern.
331 RewriteOp PatternOp::getRewriter() {
332   return cast<RewriteOp>(body().front().getTerminator());
333 }
334 
335 /// Return the root operation kind that this pattern matches, or None if
336 /// there isn't a specific root.
337 Optional<StringRef> PatternOp::getRootKind() {
338   OperationOp rootOp = cast<OperationOp>(getRewriter().root().getDefiningOp());
339   return rootOp.name();
340 }
341 
342 //===----------------------------------------------------------------------===//
343 // pdl::ReplaceOp
344 //===----------------------------------------------------------------------===//
345 
346 static LogicalResult verify(ReplaceOp op) {
347   auto sourceOp = cast<OperationOp>(op.operation().getDefiningOp());
348   auto sourceOpResults = sourceOp.results();
349   auto replValues = op.replValues();
350 
351   if (Value replOpVal = op.replOperation()) {
352     auto replOp = cast<OperationOp>(replOpVal.getDefiningOp());
353     auto replOpResults = replOp.results();
354     if (sourceOpResults.size() != replOpResults.size()) {
355       return op.emitOpError()
356              << "expected source operation to have the same number of results "
357                 "as the replacement operation, replacement operation provided "
358              << replOpResults.size() << " but expected "
359              << sourceOpResults.size();
360     }
361 
362     if (!replValues.empty()) {
363       return op.emitOpError() << "expected no replacement values to be provided"
364                                  " when the replacement operation is present";
365     }
366 
367     return success();
368   }
369 
370   if (sourceOpResults.size() != replValues.size()) {
371     return op.emitOpError()
372            << "expected source operation to have the same number of results "
373               "as the provided replacement values, found "
374            << replValues.size() << " replacement values but expected "
375            << sourceOpResults.size();
376   }
377 
378   return success();
379 }
380 
381 //===----------------------------------------------------------------------===//
382 // pdl::RewriteOp
383 //===----------------------------------------------------------------------===//
384 
385 static LogicalResult verify(RewriteOp op) {
386   Region &rewriteRegion = op.body();
387 
388   // Handle the case where the rewrite is external.
389   if (op.name()) {
390     if (!rewriteRegion.empty()) {
391       return op.emitOpError()
392              << "expected rewrite region to be empty when rewrite is external";
393     }
394     return success();
395   }
396 
397   // Otherwise, check that the rewrite region only contains a single block.
398   if (rewriteRegion.empty()) {
399     return op.emitOpError() << "expected rewrite region to be non-empty if "
400                                "external name is not specified";
401   }
402 
403   // Check that no additional arguments were provided.
404   if (!op.externalArgs().empty()) {
405     return op.emitOpError() << "expected no external arguments when the "
406                                "rewrite is specified inline";
407   }
408   if (op.externalConstParams()) {
409     return op.emitOpError() << "expected no external constant parameters when "
410                                "the rewrite is specified inline";
411   }
412 
413   return success();
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // pdl::TypeOp
418 //===----------------------------------------------------------------------===//
419 
420 static LogicalResult verify(TypeOp op) {
421   return verifyHasBindingUseInMatcher(
422       op, "`pdl.attribute`, `pdl.operand`, or `pdl.operation`");
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // TableGen'd op method definitions
427 //===----------------------------------------------------------------------===//
428 
429 #define GET_OP_CLASSES
430 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
431