1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/PDL/IR/PDL.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "llvm/ADT/DenseSet.h"
15 #include "llvm/ADT/TypeSwitch.h"
16
17 using namespace mlir;
18 using namespace mlir::pdl;
19
20 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
21
22 //===----------------------------------------------------------------------===//
23 // PDLDialect
24 //===----------------------------------------------------------------------===//
25
initialize()26 void PDLDialect::initialize() {
27 addOperations<
28 #define GET_OP_LIST
29 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
30 >();
31 registerTypes();
32 }
33
34 //===----------------------------------------------------------------------===//
35 // PDL Operations
36 //===----------------------------------------------------------------------===//
37
38 /// Returns true if the given operation is used by a "binding" pdl operation.
hasBindingUse(Operation * op)39 static bool hasBindingUse(Operation *op) {
40 for (Operation *user : op->getUsers())
41 // A result by itself is not binding, it must also be bound.
42 if (!isa<ResultOp, ResultsOp>(user) || hasBindingUse(user))
43 return true;
44 return false;
45 }
46
47 /// Returns success if the given operation is not in the main matcher body or
48 /// is used by a "binding" operation. On failure, emits an error.
verifyHasBindingUse(Operation * op)49 static LogicalResult verifyHasBindingUse(Operation *op) {
50 // If the parent is not a pattern, there is nothing to do.
51 if (!isa<PatternOp>(op->getParentOp()))
52 return success();
53 if (hasBindingUse(op))
54 return success();
55 return op->emitOpError(
56 "expected a bindable user when defined in the matcher body of a "
57 "`pdl.pattern`");
58 }
59
60 /// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s)
61 /// connected to the given operation.
visit(Operation * op,DenseSet<Operation * > & visited)62 static void visit(Operation *op, DenseSet<Operation *> &visited) {
63 // If the parent is not a pattern, there is nothing to do.
64 if (!isa<PatternOp>(op->getParentOp()) || isa<RewriteOp>(op))
65 return;
66
67 // Ignore if already visited.
68 if (visited.contains(op))
69 return;
70
71 // Mark as visited.
72 visited.insert(op);
73
74 // Traverse the operands / parent.
75 TypeSwitch<Operation *>(op)
76 .Case<OperationOp>([&visited](auto operation) {
77 for (Value operand : operation.operands())
78 visit(operand.getDefiningOp(), visited);
79 })
80 .Case<ResultOp, ResultsOp>([&visited](auto result) {
81 visit(result.parent().getDefiningOp(), visited);
82 });
83
84 // Traverse the users.
85 for (Operation *user : op->getUsers())
86 visit(user, visited);
87 }
88
89 //===----------------------------------------------------------------------===//
90 // pdl::ApplyNativeConstraintOp
91 //===----------------------------------------------------------------------===//
92
verify()93 LogicalResult ApplyNativeConstraintOp::verify() {
94 if (getNumOperands() == 0)
95 return emitOpError("expected at least one argument");
96 return success();
97 }
98
99 //===----------------------------------------------------------------------===//
100 // pdl::ApplyNativeRewriteOp
101 //===----------------------------------------------------------------------===//
102
verify()103 LogicalResult ApplyNativeRewriteOp::verify() {
104 if (getNumOperands() == 0 && getNumResults() == 0)
105 return emitOpError("expected at least one argument or result");
106 return success();
107 }
108
109 //===----------------------------------------------------------------------===//
110 // pdl::AttributeOp
111 //===----------------------------------------------------------------------===//
112
verify()113 LogicalResult AttributeOp::verify() {
114 Value attrType = type();
115 Optional<Attribute> attrValue = value();
116
117 if (!attrValue) {
118 if (isa<RewriteOp>((*this)->getParentOp()))
119 return emitOpError(
120 "expected constant value when specified within a `pdl.rewrite`");
121 return verifyHasBindingUse(*this);
122 }
123 if (attrType)
124 return emitOpError("expected only one of [`type`, `value`] to be set");
125 return success();
126 }
127
128 //===----------------------------------------------------------------------===//
129 // pdl::OperandOp
130 //===----------------------------------------------------------------------===//
131
verify()132 LogicalResult OperandOp::verify() { return verifyHasBindingUse(*this); }
133
134 //===----------------------------------------------------------------------===//
135 // pdl::OperandsOp
136 //===----------------------------------------------------------------------===//
137
verify()138 LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); }
139
140 //===----------------------------------------------------------------------===//
141 // pdl::OperationOp
142 //===----------------------------------------------------------------------===//
143
parseOperationOpAttributes(OpAsmParser & p,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & attrOperands,ArrayAttr & attrNamesAttr)144 static ParseResult parseOperationOpAttributes(
145 OpAsmParser &p,
146 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
147 ArrayAttr &attrNamesAttr) {
148 Builder &builder = p.getBuilder();
149 SmallVector<Attribute, 4> attrNames;
150 if (succeeded(p.parseOptionalLBrace())) {
151 auto parseOperands = [&]() {
152 StringAttr nameAttr;
153 OpAsmParser::UnresolvedOperand operand;
154 if (p.parseAttribute(nameAttr) || p.parseEqual() ||
155 p.parseOperand(operand))
156 return failure();
157 attrNames.push_back(nameAttr);
158 attrOperands.push_back(operand);
159 return success();
160 };
161 if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
162 return failure();
163 }
164 attrNamesAttr = builder.getArrayAttr(attrNames);
165 return success();
166 }
167
printOperationOpAttributes(OpAsmPrinter & p,OperationOp op,OperandRange attrArgs,ArrayAttr attrNames)168 static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
169 OperandRange attrArgs,
170 ArrayAttr attrNames) {
171 if (attrNames.empty())
172 return;
173 p << " {";
174 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
175 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
176 p << '}';
177 }
178
179 /// Verifies that the result types of this operation, defined within a
180 /// `pdl.rewrite`, can be inferred.
verifyResultTypesAreInferrable(OperationOp op,OperandRange resultTypes)181 static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
182 OperandRange resultTypes) {
183 // Functor that returns if the given use can be used to infer a type.
184 Block *rewriterBlock = op->getBlock();
185 auto canInferTypeFromUse = [&](OpOperand &use) {
186 // If the use is within a ReplaceOp and isn't the operation being replaced
187 // (i.e. is not the first operand of the replacement), we can infer a type.
188 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
189 if (!replOpUser || use.getOperandNumber() == 0)
190 return false;
191 // Make sure the replaced operation was defined before this one.
192 Operation *replacedOp = replOpUser.operation().getDefiningOp();
193 return replacedOp->getBlock() != rewriterBlock ||
194 replacedOp->isBeforeInBlock(op);
195 };
196
197 // Check to see if the uses of the operation itself can be used to infer
198 // types.
199 if (llvm::any_of(op.op().getUses(), canInferTypeFromUse))
200 return success();
201
202 // Handle the case where the operation has no explicit result types.
203 if (resultTypes.empty()) {
204 // If we don't know the concrete operation, don't attempt any verification.
205 // We can't make assumptions if we don't know the concrete operation.
206 Optional<StringRef> rawOpName = op.name();
207 if (!rawOpName)
208 return success();
209 Optional<RegisteredOperationName> opName =
210 RegisteredOperationName::lookup(*rawOpName, op.getContext());
211 if (!opName)
212 return success();
213
214 // If no explicit result types were provided, check to see if the operation
215 // expected at least one result. This doesn't cover all cases, but this
216 // should cover many cases in which the user intended to infer the results
217 // of an operation, but it isn't actually possible.
218 bool expectedAtLeastOneResult =
219 !opName->hasTrait<OpTrait::ZeroResults>() &&
220 !opName->hasTrait<OpTrait::VariadicResults>();
221 if (expectedAtLeastOneResult) {
222 return op
223 .emitOpError("must have inferable or constrained result types when "
224 "nested within `pdl.rewrite`")
225 .attachNote()
226 .append("operation is created in a non-inferrable context, but '",
227 *opName, "' does not implement InferTypeOpInterface");
228 }
229 return success();
230 }
231
232 // Otherwise, make sure each of the types can be inferred.
233 for (const auto &it : llvm::enumerate(resultTypes)) {
234 Operation *resultTypeOp = it.value().getDefiningOp();
235 assert(resultTypeOp && "expected valid result type operation");
236
237 // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be
238 // usable.
239 if (isa<ApplyNativeRewriteOp>(resultTypeOp))
240 continue;
241
242 // If the type operation was defined in the matcher and constrains an
243 // operand or the result of an input operation, it can be used.
244 auto constrainsInput = [rewriterBlock](Operation *user) {
245 return user->getBlock() != rewriterBlock &&
246 isa<OperandOp, OperandsOp, OperationOp>(user);
247 };
248 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
249 if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInput))
250 continue;
251 } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
252 if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInput))
253 continue;
254 }
255
256 return op
257 .emitOpError("must have inferable or constrained result types when "
258 "nested within `pdl.rewrite`")
259 .attachNote()
260 .append("result type #", it.index(), " was not constrained");
261 }
262 return success();
263 }
264
verify()265 LogicalResult OperationOp::verify() {
266 bool isWithinRewrite = isa<RewriteOp>((*this)->getParentOp());
267 if (isWithinRewrite && !name())
268 return emitOpError("must have an operation name when nested within "
269 "a `pdl.rewrite`");
270 ArrayAttr attributeNames = attributeNamesAttr();
271 auto attributeValues = attributes();
272 if (attributeNames.size() != attributeValues.size()) {
273 return emitOpError()
274 << "expected the same number of attribute values and attribute "
275 "names, got "
276 << attributeNames.size() << " names and " << attributeValues.size()
277 << " values";
278 }
279
280 // If the operation is within a rewrite body and doesn't have type inference,
281 // ensure that the result types can be resolved.
282 if (isWithinRewrite && !mightHaveTypeInference()) {
283 if (failed(verifyResultTypesAreInferrable(*this, types())))
284 return failure();
285 }
286
287 return verifyHasBindingUse(*this);
288 }
289
hasTypeInference()290 bool OperationOp::hasTypeInference() {
291 if (Optional<StringRef> rawOpName = name()) {
292 OperationName opName(*rawOpName, getContext());
293 return opName.hasInterface<InferTypeOpInterface>();
294 }
295 return false;
296 }
297
mightHaveTypeInference()298 bool OperationOp::mightHaveTypeInference() {
299 if (Optional<StringRef> rawOpName = name()) {
300 OperationName opName(*rawOpName, getContext());
301 return opName.mightHaveInterface<InferTypeOpInterface>();
302 }
303 return false;
304 }
305
306 //===----------------------------------------------------------------------===//
307 // pdl::PatternOp
308 //===----------------------------------------------------------------------===//
309
verifyRegions()310 LogicalResult PatternOp::verifyRegions() {
311 Region &body = getBodyRegion();
312 Operation *term = body.front().getTerminator();
313 auto rewriteOp = dyn_cast<RewriteOp>(term);
314 if (!rewriteOp) {
315 return emitOpError("expected body to terminate with `pdl.rewrite`")
316 .attachNote(term->getLoc())
317 .append("see terminator defined here");
318 }
319
320 // Check that all values defined in the top-level pattern belong to the PDL
321 // dialect.
322 WalkResult result = body.walk([&](Operation *op) -> WalkResult {
323 if (!isa_and_nonnull<PDLDialect>(op->getDialect())) {
324 emitOpError("expected only `pdl` operations within the pattern body")
325 .attachNote(op->getLoc())
326 .append("see non-`pdl` operation defined here");
327 return WalkResult::interrupt();
328 }
329 return WalkResult::advance();
330 });
331 if (result.wasInterrupted())
332 return failure();
333
334 // Check that there is at least one operation.
335 if (body.front().getOps<OperationOp>().empty())
336 return emitOpError("the pattern must contain at least one `pdl.operation`");
337
338 // Determine if the operations within the pdl.pattern form a connected
339 // component. This is determined by starting the search from the first
340 // operand/result/operation and visiting their users / parents / operands.
341 // We limit our attention to operations that have a user in pdl.rewrite,
342 // those that do not will be detected via other means (expected bindable
343 // user).
344 bool first = true;
345 DenseSet<Operation *> visited;
346 for (Operation &op : body.front()) {
347 // The following are the operations forming the connected component.
348 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op))
349 continue;
350
351 // Determine if the operation has a user in `pdl.rewrite`.
352 bool hasUserInRewrite = false;
353 for (Operation *user : op.getUsers()) {
354 Region *region = user->getParentRegion();
355 if (isa<RewriteOp>(user) ||
356 (region && isa<RewriteOp>(region->getParentOp()))) {
357 hasUserInRewrite = true;
358 break;
359 }
360 }
361
362 // If the operation does not have a user in `pdl.rewrite`, ignore it.
363 if (!hasUserInRewrite)
364 continue;
365
366 if (first) {
367 // For the first operation, invoke visit.
368 visit(&op, visited);
369 first = false;
370 } else if (!visited.count(&op)) {
371 // For the subsequent operations, check if already visited.
372 return emitOpError("the operations must form a connected component")
373 .attachNote(op.getLoc())
374 .append("see a disconnected value / operation here");
375 }
376 }
377
378 return success();
379 }
380
build(OpBuilder & builder,OperationState & state,Optional<uint16_t> benefit,Optional<StringRef> name)381 void PatternOp::build(OpBuilder &builder, OperationState &state,
382 Optional<uint16_t> benefit, Optional<StringRef> name) {
383 build(builder, state, builder.getI16IntegerAttr(benefit ? *benefit : 0),
384 name ? builder.getStringAttr(*name) : StringAttr());
385 state.regions[0]->emplaceBlock();
386 }
387
388 /// Returns the rewrite operation of this pattern.
getRewriter()389 RewriteOp PatternOp::getRewriter() {
390 return cast<RewriteOp>(body().front().getTerminator());
391 }
392
393 /// The default dialect is `pdl`.
getDefaultDialect()394 StringRef PatternOp::getDefaultDialect() {
395 return PDLDialect::getDialectNamespace();
396 }
397
398 //===----------------------------------------------------------------------===//
399 // pdl::ReplaceOp
400 //===----------------------------------------------------------------------===//
401
verify()402 LogicalResult ReplaceOp::verify() {
403 if (replOperation() && !replValues().empty())
404 return emitOpError() << "expected no replacement values to be provided"
405 " when the replacement operation is present";
406 return success();
407 }
408
409 //===----------------------------------------------------------------------===//
410 // pdl::ResultsOp
411 //===----------------------------------------------------------------------===//
412
parseResultsValueType(OpAsmParser & p,IntegerAttr index,Type & resultType)413 static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index,
414 Type &resultType) {
415 if (!index) {
416 resultType = RangeType::get(p.getBuilder().getType<ValueType>());
417 return success();
418 }
419 if (p.parseArrow() || p.parseType(resultType))
420 return failure();
421 return success();
422 }
423
printResultsValueType(OpAsmPrinter & p,ResultsOp op,IntegerAttr index,Type resultType)424 static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
425 IntegerAttr index, Type resultType) {
426 if (index)
427 p << " -> " << resultType;
428 }
429
verify()430 LogicalResult ResultsOp::verify() {
431 if (!index() && getType().isa<pdl::ValueType>()) {
432 return emitOpError() << "expected `pdl.range<value>` result type when "
433 "no index is specified, but got: "
434 << getType();
435 }
436 return success();
437 }
438
439 //===----------------------------------------------------------------------===//
440 // pdl::RewriteOp
441 //===----------------------------------------------------------------------===//
442
verifyRegions()443 LogicalResult RewriteOp::verifyRegions() {
444 Region &rewriteRegion = body();
445
446 // Handle the case where the rewrite is external.
447 if (name()) {
448 if (!rewriteRegion.empty()) {
449 return emitOpError()
450 << "expected rewrite region to be empty when rewrite is external";
451 }
452 return success();
453 }
454
455 // Otherwise, check that the rewrite region only contains a single block.
456 if (rewriteRegion.empty()) {
457 return emitOpError() << "expected rewrite region to be non-empty if "
458 "external name is not specified";
459 }
460
461 // Check that no additional arguments were provided.
462 if (!externalArgs().empty()) {
463 return emitOpError() << "expected no external arguments when the "
464 "rewrite is specified inline";
465 }
466
467 return success();
468 }
469
470 /// The default dialect is `pdl`.
getDefaultDialect()471 StringRef RewriteOp::getDefaultDialect() {
472 return PDLDialect::getDialectNamespace();
473 }
474
475 //===----------------------------------------------------------------------===//
476 // pdl::TypeOp
477 //===----------------------------------------------------------------------===//
478
verify()479 LogicalResult TypeOp::verify() {
480 if (!typeAttr())
481 return verifyHasBindingUse(*this);
482 return success();
483 }
484
485 //===----------------------------------------------------------------------===//
486 // pdl::TypesOp
487 //===----------------------------------------------------------------------===//
488
verify()489 LogicalResult TypesOp::verify() {
490 if (!typesAttr())
491 return verifyHasBindingUse(*this);
492 return success();
493 }
494
495 //===----------------------------------------------------------------------===//
496 // TableGen'd op method definitions
497 //===----------------------------------------------------------------------===//
498
499 #define GET_OP_CLASSES
500 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
501