1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDL/IR/PDL.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/Interfaces/InferTypeOpInterface.h" 14 #include "llvm/ADT/DenseSet.h" 15 #include "llvm/ADT/TypeSwitch.h" 16 17 using namespace mlir; 18 using namespace mlir::pdl; 19 20 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc" 21 22 //===----------------------------------------------------------------------===// 23 // PDLDialect 24 //===----------------------------------------------------------------------===// 25 26 void PDLDialect::initialize() { 27 addOperations< 28 #define GET_OP_LIST 29 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" 30 >(); 31 registerTypes(); 32 } 33 34 //===----------------------------------------------------------------------===// 35 // PDL Operations 36 //===----------------------------------------------------------------------===// 37 38 /// Returns true if the given operation is used by a "binding" pdl operation. 39 static bool hasBindingUse(Operation *op) { 40 for (Operation *user : op->getUsers()) 41 // A result by itself is not binding, it must also be bound. 42 if (!isa<ResultOp, ResultsOp>(user) || hasBindingUse(user)) 43 return true; 44 return false; 45 } 46 47 /// Returns success if the given operation is not in the main matcher body or 48 /// is used by a "binding" operation. On failure, emits an error. 49 static LogicalResult verifyHasBindingUse(Operation *op) { 50 // If the parent is not a pattern, there is nothing to do. 51 if (!isa<PatternOp>(op->getParentOp())) 52 return success(); 53 if (hasBindingUse(op)) 54 return success(); 55 return op->emitOpError( 56 "expected a bindable user when defined in the matcher body of a " 57 "`pdl.pattern`"); 58 } 59 60 /// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) 61 /// connected to the given operation. 62 static void visit(Operation *op, DenseSet<Operation *> &visited) { 63 // If the parent is not a pattern, there is nothing to do. 64 if (!isa<PatternOp>(op->getParentOp()) || isa<RewriteOp>(op)) 65 return; 66 67 // Ignore if already visited. 68 if (visited.contains(op)) 69 return; 70 71 // Mark as visited. 72 visited.insert(op); 73 74 // Traverse the operands / parent. 75 TypeSwitch<Operation *>(op) 76 .Case<OperationOp>([&visited](auto operation) { 77 for (Value operand : operation.operands()) 78 visit(operand.getDefiningOp(), visited); 79 }) 80 .Case<ResultOp, ResultsOp>([&visited](auto result) { 81 visit(result.parent().getDefiningOp(), visited); 82 }); 83 84 // Traverse the users. 85 for (Operation *user : op->getUsers()) 86 visit(user, visited); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // pdl::ApplyNativeConstraintOp 91 //===----------------------------------------------------------------------===// 92 93 LogicalResult ApplyNativeConstraintOp::verify() { 94 if (getNumOperands() == 0) 95 return emitOpError("expected at least one argument"); 96 return success(); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // pdl::ApplyNativeRewriteOp 101 //===----------------------------------------------------------------------===// 102 103 LogicalResult ApplyNativeRewriteOp::verify() { 104 if (getNumOperands() == 0 && getNumResults() == 0) 105 return emitOpError("expected at least one argument or result"); 106 return success(); 107 } 108 109 //===----------------------------------------------------------------------===// 110 // pdl::AttributeOp 111 //===----------------------------------------------------------------------===// 112 113 LogicalResult AttributeOp::verify() { 114 Value attrType = type(); 115 Optional<Attribute> attrValue = value(); 116 117 if (!attrValue) { 118 if (isa<RewriteOp>((*this)->getParentOp())) 119 return emitOpError( 120 "expected constant value when specified within a `pdl.rewrite`"); 121 return verifyHasBindingUse(*this); 122 } 123 if (attrType) 124 return emitOpError("expected only one of [`type`, `value`] to be set"); 125 return success(); 126 } 127 128 //===----------------------------------------------------------------------===// 129 // pdl::OperandOp 130 //===----------------------------------------------------------------------===// 131 132 LogicalResult OperandOp::verify() { return verifyHasBindingUse(*this); } 133 134 //===----------------------------------------------------------------------===// 135 // pdl::OperandsOp 136 //===----------------------------------------------------------------------===// 137 138 LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); } 139 140 //===----------------------------------------------------------------------===// 141 // pdl::OperationOp 142 //===----------------------------------------------------------------------===// 143 144 static ParseResult parseOperationOpAttributes( 145 OpAsmParser &p, 146 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands, 147 ArrayAttr &attrNamesAttr) { 148 Builder &builder = p.getBuilder(); 149 SmallVector<Attribute, 4> attrNames; 150 if (succeeded(p.parseOptionalLBrace())) { 151 do { 152 StringAttr nameAttr; 153 OpAsmParser::UnresolvedOperand operand; 154 if (p.parseAttribute(nameAttr) || p.parseEqual() || 155 p.parseOperand(operand)) 156 return failure(); 157 attrNames.push_back(nameAttr); 158 attrOperands.push_back(operand); 159 } while (succeeded(p.parseOptionalComma())); 160 if (p.parseRBrace()) 161 return failure(); 162 } 163 attrNamesAttr = builder.getArrayAttr(attrNames); 164 return success(); 165 } 166 167 static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op, 168 OperandRange attrArgs, 169 ArrayAttr attrNames) { 170 if (attrNames.empty()) 171 return; 172 p << " {"; 173 interleaveComma(llvm::seq<int>(0, attrNames.size()), p, 174 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); 175 p << '}'; 176 } 177 178 /// Verifies that the result types of this operation, defined within a 179 /// `pdl.rewrite`, can be inferred. 180 static LogicalResult verifyResultTypesAreInferrable(OperationOp op, 181 OperandRange resultTypes) { 182 // Functor that returns if the given use can be used to infer a type. 183 Block *rewriterBlock = op->getBlock(); 184 auto canInferTypeFromUse = [&](OpOperand &use) { 185 // If the use is within a ReplaceOp and isn't the operation being replaced 186 // (i.e. is not the first operand of the replacement), we can infer a type. 187 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner()); 188 if (!replOpUser || use.getOperandNumber() == 0) 189 return false; 190 // Make sure the replaced operation was defined before this one. 191 Operation *replacedOp = replOpUser.operation().getDefiningOp(); 192 return replacedOp->getBlock() != rewriterBlock || 193 replacedOp->isBeforeInBlock(op); 194 }; 195 196 // Check to see if the uses of the operation itself can be used to infer 197 // types. 198 if (llvm::any_of(op.op().getUses(), canInferTypeFromUse)) 199 return success(); 200 201 // Otherwise, make sure each of the types can be inferred. 202 for (const auto &it : llvm::enumerate(resultTypes)) { 203 Operation *resultTypeOp = it.value().getDefiningOp(); 204 assert(resultTypeOp && "expected valid result type operation"); 205 206 // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be 207 // usable. 208 if (isa<ApplyNativeRewriteOp>(resultTypeOp)) 209 continue; 210 211 // If the type operation was defined in the matcher and constrains an 212 // operand or the result of an input operation, it can be used. 213 auto constrainsInput = [rewriterBlock](Operation *user) { 214 return user->getBlock() != rewriterBlock && 215 isa<OperandOp, OperandsOp, OperationOp>(user); 216 }; 217 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) { 218 if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInput)) 219 continue; 220 } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) { 221 if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInput)) 222 continue; 223 } 224 225 return op 226 .emitOpError("must have inferable or constrained result types when " 227 "nested within `pdl.rewrite`") 228 .attachNote() 229 .append("result type #", it.index(), " was not constrained"); 230 } 231 return success(); 232 } 233 234 LogicalResult OperationOp::verify() { 235 bool isWithinRewrite = isa<RewriteOp>((*this)->getParentOp()); 236 if (isWithinRewrite && !name()) 237 return emitOpError("must have an operation name when nested within " 238 "a `pdl.rewrite`"); 239 ArrayAttr attributeNames = attributeNamesAttr(); 240 auto attributeValues = attributes(); 241 if (attributeNames.size() != attributeValues.size()) { 242 return emitOpError() 243 << "expected the same number of attribute values and attribute " 244 "names, got " 245 << attributeNames.size() << " names and " << attributeValues.size() 246 << " values"; 247 } 248 249 // If the operation is within a rewrite body and doesn't have type inference, 250 // ensure that the result types can be resolved. 251 if (isWithinRewrite && !hasTypeInference()) { 252 if (failed(verifyResultTypesAreInferrable(*this, types()))) 253 return failure(); 254 } 255 256 return verifyHasBindingUse(*this); 257 } 258 259 bool OperationOp::hasTypeInference() { 260 Optional<StringRef> opName = name(); 261 if (!opName) 262 return false; 263 264 if (auto rInfo = RegisteredOperationName::lookup(*opName, getContext())) 265 return rInfo->hasInterface<InferTypeOpInterface>(); 266 return false; 267 } 268 269 //===----------------------------------------------------------------------===// 270 // pdl::PatternOp 271 //===----------------------------------------------------------------------===// 272 273 LogicalResult PatternOp::verifyRegions() { 274 Region &body = getBodyRegion(); 275 Operation *term = body.front().getTerminator(); 276 auto rewriteOp = dyn_cast<RewriteOp>(term); 277 if (!rewriteOp) { 278 return emitOpError("expected body to terminate with `pdl.rewrite`") 279 .attachNote(term->getLoc()) 280 .append("see terminator defined here"); 281 } 282 283 // Check that all values defined in the top-level pattern belong to the PDL 284 // dialect. 285 WalkResult result = body.walk([&](Operation *op) -> WalkResult { 286 if (!isa_and_nonnull<PDLDialect>(op->getDialect())) { 287 emitOpError("expected only `pdl` operations within the pattern body") 288 .attachNote(op->getLoc()) 289 .append("see non-`pdl` operation defined here"); 290 return WalkResult::interrupt(); 291 } 292 return WalkResult::advance(); 293 }); 294 if (result.wasInterrupted()) 295 return failure(); 296 297 // Check that there is at least one operation. 298 if (body.front().getOps<OperationOp>().empty()) 299 return emitOpError("the pattern must contain at least one `pdl.operation`"); 300 301 // Determine if the operations within the pdl.pattern form a connected 302 // component. This is determined by starting the search from the first 303 // operand/result/operation and visiting their users / parents / operands. 304 // We limit our attention to operations that have a user in pdl.rewrite, 305 // those that do not will be detected via other means (expected bindable 306 // user). 307 bool first = true; 308 DenseSet<Operation *> visited; 309 for (Operation &op : body.front()) { 310 // The following are the operations forming the connected component. 311 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op)) 312 continue; 313 314 // Determine if the operation has a user in `pdl.rewrite`. 315 bool hasUserInRewrite = false; 316 for (Operation *user : op.getUsers()) { 317 Region *region = user->getParentRegion(); 318 if (isa<RewriteOp>(user) || 319 (region && isa<RewriteOp>(region->getParentOp()))) { 320 hasUserInRewrite = true; 321 break; 322 } 323 } 324 325 // If the operation does not have a user in `pdl.rewrite`, ignore it. 326 if (!hasUserInRewrite) 327 continue; 328 329 if (first) { 330 // For the first operation, invoke visit. 331 visit(&op, visited); 332 first = false; 333 } else if (!visited.count(&op)) { 334 // For the subsequent operations, check if already visited. 335 return emitOpError("the operations must form a connected component") 336 .attachNote(op.getLoc()) 337 .append("see a disconnected value / operation here"); 338 } 339 } 340 341 return success(); 342 } 343 344 void PatternOp::build(OpBuilder &builder, OperationState &state, 345 Optional<uint16_t> benefit, Optional<StringRef> name) { 346 build(builder, state, builder.getI16IntegerAttr(benefit ? *benefit : 0), 347 name ? builder.getStringAttr(*name) : StringAttr()); 348 state.regions[0]->emplaceBlock(); 349 } 350 351 /// Returns the rewrite operation of this pattern. 352 RewriteOp PatternOp::getRewriter() { 353 return cast<RewriteOp>(body().front().getTerminator()); 354 } 355 356 /// The default dialect is `pdl`. 357 StringRef PatternOp::getDefaultDialect() { 358 return PDLDialect::getDialectNamespace(); 359 } 360 361 //===----------------------------------------------------------------------===// 362 // pdl::ReplaceOp 363 //===----------------------------------------------------------------------===// 364 365 LogicalResult ReplaceOp::verify() { 366 if (replOperation() && !replValues().empty()) 367 return emitOpError() << "expected no replacement values to be provided" 368 " when the replacement operation is present"; 369 return success(); 370 } 371 372 //===----------------------------------------------------------------------===// 373 // pdl::ResultsOp 374 //===----------------------------------------------------------------------===// 375 376 static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index, 377 Type &resultType) { 378 if (!index) { 379 resultType = RangeType::get(p.getBuilder().getType<ValueType>()); 380 return success(); 381 } 382 if (p.parseArrow() || p.parseType(resultType)) 383 return failure(); 384 return success(); 385 } 386 387 static void printResultsValueType(OpAsmPrinter &p, ResultsOp op, 388 IntegerAttr index, Type resultType) { 389 if (index) 390 p << " -> " << resultType; 391 } 392 393 LogicalResult ResultsOp::verify() { 394 if (!index() && getType().isa<pdl::ValueType>()) { 395 return emitOpError() << "expected `pdl.range<value>` result type when " 396 "no index is specified, but got: " 397 << getType(); 398 } 399 return success(); 400 } 401 402 //===----------------------------------------------------------------------===// 403 // pdl::RewriteOp 404 //===----------------------------------------------------------------------===// 405 406 LogicalResult RewriteOp::verifyRegions() { 407 Region &rewriteRegion = body(); 408 409 // Handle the case where the rewrite is external. 410 if (name()) { 411 if (!rewriteRegion.empty()) { 412 return emitOpError() 413 << "expected rewrite region to be empty when rewrite is external"; 414 } 415 return success(); 416 } 417 418 // Otherwise, check that the rewrite region only contains a single block. 419 if (rewriteRegion.empty()) { 420 return emitOpError() << "expected rewrite region to be non-empty if " 421 "external name is not specified"; 422 } 423 424 // Check that no additional arguments were provided. 425 if (!externalArgs().empty()) { 426 return emitOpError() << "expected no external arguments when the " 427 "rewrite is specified inline"; 428 } 429 430 return success(); 431 } 432 433 /// The default dialect is `pdl`. 434 StringRef RewriteOp::getDefaultDialect() { 435 return PDLDialect::getDialectNamespace(); 436 } 437 438 //===----------------------------------------------------------------------===// 439 // pdl::TypeOp 440 //===----------------------------------------------------------------------===// 441 442 LogicalResult TypeOp::verify() { 443 if (!typeAttr()) 444 return verifyHasBindingUse(*this); 445 return success(); 446 } 447 448 //===----------------------------------------------------------------------===// 449 // pdl::TypesOp 450 //===----------------------------------------------------------------------===// 451 452 LogicalResult TypesOp::verify() { 453 if (!typesAttr()) 454 return verifyHasBindingUse(*this); 455 return success(); 456 } 457 458 //===----------------------------------------------------------------------===// 459 // TableGen'd op method definitions 460 //===----------------------------------------------------------------------===// 461 462 #define GET_OP_CLASSES 463 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" 464