1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDL/IR/PDL.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "llvm/ADT/DenseSet.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 using namespace mlir::pdl;
19 
20 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
21 
22 //===----------------------------------------------------------------------===//
23 // PDLDialect
24 //===----------------------------------------------------------------------===//
25 
26 void PDLDialect::initialize() {
27   addOperations<
28 #define GET_OP_LIST
29 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
30       >();
31   registerTypes();
32 }
33 
34 //===----------------------------------------------------------------------===//
35 // PDL Operations
36 //===----------------------------------------------------------------------===//
37 
38 /// Returns true if the given operation is used by a "binding" pdl operation.
39 static bool hasBindingUse(Operation *op) {
40   for (Operation *user : op->getUsers())
41     // A result by itself is not binding, it must also be bound.
42     if (!isa<ResultOp, ResultsOp>(user) || hasBindingUse(user))
43       return true;
44   return false;
45 }
46 
47 /// Returns success if the given operation is not in the main matcher body or
48 /// is used by a "binding" operation. On failure, emits an error.
49 static LogicalResult verifyHasBindingUse(Operation *op) {
50   // If the parent is not a pattern, there is nothing to do.
51   if (!isa<PatternOp>(op->getParentOp()))
52     return success();
53   if (hasBindingUse(op))
54     return success();
55   return op->emitOpError(
56       "expected a bindable user when defined in the matcher body of a "
57       "`pdl.pattern`");
58 }
59 
60 /// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s)
61 /// connected to the given operation.
62 static void visit(Operation *op, DenseSet<Operation *> &visited) {
63   // If the parent is not a pattern, there is nothing to do.
64   if (!isa<PatternOp>(op->getParentOp()) || isa<RewriteOp>(op))
65     return;
66 
67   // Ignore if already visited.
68   if (visited.contains(op))
69     return;
70 
71   // Mark as visited.
72   visited.insert(op);
73 
74   // Traverse the operands / parent.
75   TypeSwitch<Operation *>(op)
76       .Case<OperationOp>([&visited](auto operation) {
77         for (Value operand : operation.operands())
78           visit(operand.getDefiningOp(), visited);
79       })
80       .Case<ResultOp, ResultsOp>([&visited](auto result) {
81         visit(result.parent().getDefiningOp(), visited);
82       });
83 
84   // Traverse the users.
85   for (Operation *user : op->getUsers())
86     visit(user, visited);
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // pdl::ApplyNativeConstraintOp
91 //===----------------------------------------------------------------------===//
92 
93 static LogicalResult verify(ApplyNativeConstraintOp op) {
94   if (op.getNumOperands() == 0)
95     return op.emitOpError("expected at least one argument");
96   return success();
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // pdl::ApplyNativeRewriteOp
101 //===----------------------------------------------------------------------===//
102 
103 static LogicalResult verify(ApplyNativeRewriteOp op) {
104   if (op.getNumOperands() == 0 && op.getNumResults() == 0)
105     return op.emitOpError("expected at least one argument or result");
106   return success();
107 }
108 
109 //===----------------------------------------------------------------------===//
110 // pdl::AttributeOp
111 //===----------------------------------------------------------------------===//
112 
113 static LogicalResult verify(AttributeOp op) {
114   Value attrType = op.type();
115   Optional<Attribute> attrValue = op.value();
116 
117   if (!attrValue) {
118     if (isa<RewriteOp>(op->getParentOp()))
119       return op.emitOpError("expected constant value when specified within a "
120                             "`pdl.rewrite`");
121     return verifyHasBindingUse(op);
122   }
123   if (attrType)
124     return op.emitOpError("expected only one of [`type`, `value`] to be set");
125   return success();
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // pdl::OperandOp
130 //===----------------------------------------------------------------------===//
131 
132 static LogicalResult verify(OperandOp op) { return verifyHasBindingUse(op); }
133 
134 //===----------------------------------------------------------------------===//
135 // pdl::OperandsOp
136 //===----------------------------------------------------------------------===//
137 
138 static LogicalResult verify(OperandsOp op) { return verifyHasBindingUse(op); }
139 
140 //===----------------------------------------------------------------------===//
141 // pdl::OperationOp
142 //===----------------------------------------------------------------------===//
143 
144 static ParseResult parseOperationOpAttributes(
145     OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands,
146     ArrayAttr &attrNamesAttr) {
147   Builder &builder = p.getBuilder();
148   SmallVector<Attribute, 4> attrNames;
149   if (succeeded(p.parseOptionalLBrace())) {
150     do {
151       StringAttr nameAttr;
152       OpAsmParser::OperandType operand;
153       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
154           p.parseOperand(operand))
155         return failure();
156       attrNames.push_back(nameAttr);
157       attrOperands.push_back(operand);
158     } while (succeeded(p.parseOptionalComma()));
159     if (p.parseRBrace())
160       return failure();
161   }
162   attrNamesAttr = builder.getArrayAttr(attrNames);
163   return success();
164 }
165 
166 static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
167                                        OperandRange attrArgs,
168                                        ArrayAttr attrNames) {
169   if (attrNames.empty())
170     return;
171   p << " {";
172   interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
173                   [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
174   p << '}';
175 }
176 
177 /// Verifies that the result types of this operation, defined within a
178 /// `pdl.rewrite`, can be inferred.
179 static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
180                                                     OperandRange resultTypes) {
181   // Functor that returns if the given use can be used to infer a type.
182   Block *rewriterBlock = op->getBlock();
183   auto canInferTypeFromUse = [&](OpOperand &use) {
184     // If the use is within a ReplaceOp and isn't the operation being replaced
185     // (i.e. is not the first operand of the replacement), we can infer a type.
186     ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
187     if (!replOpUser || use.getOperandNumber() == 0)
188       return false;
189     // Make sure the replaced operation was defined before this one.
190     Operation *replacedOp = replOpUser.operation().getDefiningOp();
191     return replacedOp->getBlock() != rewriterBlock ||
192            replacedOp->isBeforeInBlock(op);
193   };
194 
195   // Check to see if the uses of the operation itself can be used to infer
196   // types.
197   if (llvm::any_of(op.op().getUses(), canInferTypeFromUse))
198     return success();
199 
200   // Otherwise, make sure each of the types can be inferred.
201   for (const auto &it : llvm::enumerate(resultTypes)) {
202     Operation *resultTypeOp = it.value().getDefiningOp();
203     assert(resultTypeOp && "expected valid result type operation");
204 
205     // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be
206     // usable.
207     if (isa<ApplyNativeRewriteOp>(resultTypeOp))
208       continue;
209 
210     // If the type operation was defined in the matcher and constrains an
211     // operand or the result of an input operation, it can be used.
212     auto constrainsInput = [rewriterBlock](Operation *user) {
213       return user->getBlock() != rewriterBlock &&
214              isa<OperandOp, OperandsOp, OperationOp>(user);
215     };
216     if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
217       if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInput))
218         continue;
219     } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
220       if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInput))
221         continue;
222     }
223 
224     return op
225         .emitOpError("must have inferable or constrained result types when "
226                      "nested within `pdl.rewrite`")
227         .attachNote()
228         .append("result type #", it.index(), " was not constrained");
229   }
230   return success();
231 }
232 
233 static LogicalResult verify(OperationOp op) {
234   bool isWithinRewrite = isa<RewriteOp>(op->getParentOp());
235   if (isWithinRewrite && !op.name())
236     return op.emitOpError("must have an operation name when nested within "
237                           "a `pdl.rewrite`");
238   ArrayAttr attributeNames = op.attributeNames();
239   auto attributeValues = op.attributes();
240   if (attributeNames.size() != attributeValues.size()) {
241     return op.emitOpError()
242            << "expected the same number of attribute values and attribute "
243               "names, got "
244            << attributeNames.size() << " names and " << attributeValues.size()
245            << " values";
246   }
247 
248   // If the operation is within a rewrite body and doesn't have type inference,
249   // ensure that the result types can be resolved.
250   if (isWithinRewrite && !op.hasTypeInference()) {
251     if (failed(verifyResultTypesAreInferrable(op, op.types())))
252       return failure();
253   }
254 
255   return verifyHasBindingUse(op);
256 }
257 
258 bool OperationOp::hasTypeInference() {
259   Optional<StringRef> opName = name();
260   if (!opName)
261     return false;
262 
263   if (auto rInfo = RegisteredOperationName::lookup(*opName, getContext()))
264     return rInfo->hasInterface<InferTypeOpInterface>();
265   return false;
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // pdl::PatternOp
270 //===----------------------------------------------------------------------===//
271 
272 static LogicalResult verify(PatternOp pattern) {
273   Region &body = pattern.body();
274   Operation *term = body.front().getTerminator();
275   auto rewriteOp = dyn_cast<RewriteOp>(term);
276   if (!rewriteOp) {
277     return pattern.emitOpError("expected body to terminate with `pdl.rewrite`")
278         .attachNote(term->getLoc())
279         .append("see terminator defined here");
280   }
281 
282   // Check that all values defined in the top-level pattern belong to the PDL
283   // dialect.
284   WalkResult result = body.walk([&](Operation *op) -> WalkResult {
285     if (!isa_and_nonnull<PDLDialect>(op->getDialect())) {
286       pattern
287           .emitOpError("expected only `pdl` operations within the pattern body")
288           .attachNote(op->getLoc())
289           .append("see non-`pdl` operation defined here");
290       return WalkResult::interrupt();
291     }
292     return WalkResult::advance();
293   });
294   if (result.wasInterrupted())
295     return failure();
296 
297   // Check that there is at least one operation.
298   if (body.front().getOps<OperationOp>().empty())
299     return pattern.emitOpError(
300         "the pattern must contain at least one `pdl.operation`");
301 
302   // Determine if the operations within the pdl.pattern form a connected
303   // component. This is determined by starting the search from the first
304   // operand/result/operation and visiting their users / parents / operands.
305   // We limit our attention to operations that have a user in pdl.rewrite,
306   // those that do not will be detected via other means (expected bindable
307   // user).
308   bool first = true;
309   DenseSet<Operation *> visited;
310   for (Operation &op : body.front()) {
311     // The following are the operations forming the connected component.
312     if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op))
313       continue;
314 
315     // Determine if the operation has a user in `pdl.rewrite`.
316     bool hasUserInRewrite = false;
317     for (Operation *user : op.getUsers()) {
318       Region *region = user->getParentRegion();
319       if (isa<RewriteOp>(user) ||
320           (region && isa<RewriteOp>(region->getParentOp()))) {
321         hasUserInRewrite = true;
322         break;
323       }
324     }
325 
326     // If the operation does not have a user in `pdl.rewrite`, ignore it.
327     if (!hasUserInRewrite)
328       continue;
329 
330     if (first) {
331       // For the first operation, invoke visit.
332       visit(&op, visited);
333       first = false;
334     } else if (!visited.count(&op)) {
335       // For the subsequent operations, check if already visited.
336       return pattern
337           .emitOpError("the operations must form a connected component")
338           .attachNote(op.getLoc())
339           .append("see a disconnected value / operation here");
340     }
341   }
342 
343   return success();
344 }
345 
346 void PatternOp::build(OpBuilder &builder, OperationState &state,
347                       Optional<uint16_t> benefit, Optional<StringRef> name) {
348   build(builder, state, builder.getI16IntegerAttr(benefit ? *benefit : 0),
349         name ? builder.getStringAttr(*name) : StringAttr());
350   state.regions[0]->emplaceBlock();
351 }
352 
353 /// Returns the rewrite operation of this pattern.
354 RewriteOp PatternOp::getRewriter() {
355   return cast<RewriteOp>(body().front().getTerminator());
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // pdl::ReplaceOp
360 //===----------------------------------------------------------------------===//
361 
362 static LogicalResult verify(ReplaceOp op) {
363   if (op.replOperation() && !op.replValues().empty())
364     return op.emitOpError() << "expected no replacement values to be provided"
365                                " when the replacement operation is present";
366   return success();
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // pdl::ResultsOp
371 //===----------------------------------------------------------------------===//
372 
373 static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index,
374                                          Type &resultType) {
375   if (!index) {
376     resultType = RangeType::get(p.getBuilder().getType<ValueType>());
377     return success();
378   }
379   if (p.parseArrow() || p.parseType(resultType))
380     return failure();
381   return success();
382 }
383 
384 static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
385                                   IntegerAttr index, Type resultType) {
386   if (index)
387     p << " -> " << resultType;
388 }
389 
390 static LogicalResult verify(ResultsOp op) {
391   if (!op.index() && op.getType().isa<pdl::ValueType>()) {
392     return op.emitOpError() << "expected `pdl.range<value>` result type when "
393                                "no index is specified, but got: "
394                             << op.getType();
395   }
396   return success();
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // pdl::RewriteOp
401 //===----------------------------------------------------------------------===//
402 
403 static LogicalResult verify(RewriteOp op) {
404   Region &rewriteRegion = op.body();
405 
406   // Handle the case where the rewrite is external.
407   if (op.name()) {
408     if (!rewriteRegion.empty()) {
409       return op.emitOpError()
410              << "expected rewrite region to be empty when rewrite is external";
411     }
412     return success();
413   }
414 
415   // Otherwise, check that the rewrite region only contains a single block.
416   if (rewriteRegion.empty()) {
417     return op.emitOpError() << "expected rewrite region to be non-empty if "
418                                "external name is not specified";
419   }
420 
421   // Check that no additional arguments were provided.
422   if (!op.externalArgs().empty()) {
423     return op.emitOpError() << "expected no external arguments when the "
424                                "rewrite is specified inline";
425   }
426   if (op.externalConstParams()) {
427     return op.emitOpError() << "expected no external constant parameters when "
428                                "the rewrite is specified inline";
429   }
430 
431   return success();
432 }
433 
434 //===----------------------------------------------------------------------===//
435 // pdl::TypeOp
436 //===----------------------------------------------------------------------===//
437 
438 static LogicalResult verify(TypeOp op) {
439   if (!op.typeAttr())
440     return verifyHasBindingUse(op);
441   return success();
442 }
443 
444 //===----------------------------------------------------------------------===//
445 // pdl::TypesOp
446 //===----------------------------------------------------------------------===//
447 
448 static LogicalResult verify(TypesOp op) {
449   if (!op.typesAttr())
450     return verifyHasBindingUse(op);
451   return success();
452 }
453 
454 //===----------------------------------------------------------------------===//
455 // TableGen'd op method definitions
456 //===----------------------------------------------------------------------===//
457 
458 #define GET_OP_CLASSES
459 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
460