1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDL/IR/PDL.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/Interfaces/InferTypeOpInterface.h" 14 #include "llvm/ADT/DenseSet.h" 15 #include "llvm/ADT/TypeSwitch.h" 16 17 using namespace mlir; 18 using namespace mlir::pdl; 19 20 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc" 21 22 //===----------------------------------------------------------------------===// 23 // PDLDialect 24 //===----------------------------------------------------------------------===// 25 26 void PDLDialect::initialize() { 27 addOperations< 28 #define GET_OP_LIST 29 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" 30 >(); 31 registerTypes(); 32 } 33 34 //===----------------------------------------------------------------------===// 35 // PDL Operations 36 //===----------------------------------------------------------------------===// 37 38 /// Returns true if the given operation is used by a "binding" pdl operation. 39 static bool hasBindingUse(Operation *op) { 40 for (Operation *user : op->getUsers()) 41 // A result by itself is not binding, it must also be bound. 42 if (!isa<ResultOp, ResultsOp>(user) || hasBindingUse(user)) 43 return true; 44 return false; 45 } 46 47 /// Returns success if the given operation is not in the main matcher body or 48 /// is used by a "binding" operation. On failure, emits an error. 49 static LogicalResult verifyHasBindingUse(Operation *op) { 50 // If the parent is not a pattern, there is nothing to do. 51 if (!isa<PatternOp>(op->getParentOp())) 52 return success(); 53 if (hasBindingUse(op)) 54 return success(); 55 return op->emitOpError( 56 "expected a bindable user when defined in the matcher body of a " 57 "`pdl.pattern`"); 58 } 59 60 /// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) 61 /// connected to the given operation. 62 static void visit(Operation *op, DenseSet<Operation *> &visited) { 63 // If the parent is not a pattern, there is nothing to do. 64 if (!isa<PatternOp>(op->getParentOp()) || isa<RewriteOp>(op)) 65 return; 66 67 // Ignore if already visited. 68 if (visited.contains(op)) 69 return; 70 71 // Mark as visited. 72 visited.insert(op); 73 74 // Traverse the operands / parent. 75 TypeSwitch<Operation *>(op) 76 .Case<OperationOp>([&visited](auto operation) { 77 for (Value operand : operation.operands()) 78 visit(operand.getDefiningOp(), visited); 79 }) 80 .Case<ResultOp, ResultsOp>([&visited](auto result) { 81 visit(result.parent().getDefiningOp(), visited); 82 }); 83 84 // Traverse the users. 85 for (Operation *user : op->getUsers()) 86 visit(user, visited); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // pdl::ApplyNativeConstraintOp 91 //===----------------------------------------------------------------------===// 92 93 LogicalResult ApplyNativeConstraintOp::verify() { 94 if (getNumOperands() == 0) 95 return emitOpError("expected at least one argument"); 96 return success(); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // pdl::ApplyNativeRewriteOp 101 //===----------------------------------------------------------------------===// 102 103 LogicalResult ApplyNativeRewriteOp::verify() { 104 if (getNumOperands() == 0 && getNumResults() == 0) 105 return emitOpError("expected at least one argument or result"); 106 return success(); 107 } 108 109 //===----------------------------------------------------------------------===// 110 // pdl::AttributeOp 111 //===----------------------------------------------------------------------===// 112 113 LogicalResult AttributeOp::verify() { 114 Value attrType = type(); 115 Optional<Attribute> attrValue = value(); 116 117 if (!attrValue) { 118 if (isa<RewriteOp>((*this)->getParentOp())) 119 return emitOpError( 120 "expected constant value when specified within a `pdl.rewrite`"); 121 return verifyHasBindingUse(*this); 122 } 123 if (attrType) 124 return emitOpError("expected only one of [`type`, `value`] to be set"); 125 return success(); 126 } 127 128 //===----------------------------------------------------------------------===// 129 // pdl::OperandOp 130 //===----------------------------------------------------------------------===// 131 132 LogicalResult OperandOp::verify() { return verifyHasBindingUse(*this); } 133 134 //===----------------------------------------------------------------------===// 135 // pdl::OperandsOp 136 //===----------------------------------------------------------------------===// 137 138 LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); } 139 140 //===----------------------------------------------------------------------===// 141 // pdl::OperationOp 142 //===----------------------------------------------------------------------===// 143 144 static ParseResult parseOperationOpAttributes( 145 OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands, 146 ArrayAttr &attrNamesAttr) { 147 Builder &builder = p.getBuilder(); 148 SmallVector<Attribute, 4> attrNames; 149 if (succeeded(p.parseOptionalLBrace())) { 150 do { 151 StringAttr nameAttr; 152 OpAsmParser::OperandType operand; 153 if (p.parseAttribute(nameAttr) || p.parseEqual() || 154 p.parseOperand(operand)) 155 return failure(); 156 attrNames.push_back(nameAttr); 157 attrOperands.push_back(operand); 158 } while (succeeded(p.parseOptionalComma())); 159 if (p.parseRBrace()) 160 return failure(); 161 } 162 attrNamesAttr = builder.getArrayAttr(attrNames); 163 return success(); 164 } 165 166 static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op, 167 OperandRange attrArgs, 168 ArrayAttr attrNames) { 169 if (attrNames.empty()) 170 return; 171 p << " {"; 172 interleaveComma(llvm::seq<int>(0, attrNames.size()), p, 173 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); 174 p << '}'; 175 } 176 177 /// Verifies that the result types of this operation, defined within a 178 /// `pdl.rewrite`, can be inferred. 179 static LogicalResult verifyResultTypesAreInferrable(OperationOp op, 180 OperandRange resultTypes) { 181 // Functor that returns if the given use can be used to infer a type. 182 Block *rewriterBlock = op->getBlock(); 183 auto canInferTypeFromUse = [&](OpOperand &use) { 184 // If the use is within a ReplaceOp and isn't the operation being replaced 185 // (i.e. is not the first operand of the replacement), we can infer a type. 186 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner()); 187 if (!replOpUser || use.getOperandNumber() == 0) 188 return false; 189 // Make sure the replaced operation was defined before this one. 190 Operation *replacedOp = replOpUser.operation().getDefiningOp(); 191 return replacedOp->getBlock() != rewriterBlock || 192 replacedOp->isBeforeInBlock(op); 193 }; 194 195 // Check to see if the uses of the operation itself can be used to infer 196 // types. 197 if (llvm::any_of(op.op().getUses(), canInferTypeFromUse)) 198 return success(); 199 200 // Otherwise, make sure each of the types can be inferred. 201 for (const auto &it : llvm::enumerate(resultTypes)) { 202 Operation *resultTypeOp = it.value().getDefiningOp(); 203 assert(resultTypeOp && "expected valid result type operation"); 204 205 // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be 206 // usable. 207 if (isa<ApplyNativeRewriteOp>(resultTypeOp)) 208 continue; 209 210 // If the type operation was defined in the matcher and constrains an 211 // operand or the result of an input operation, it can be used. 212 auto constrainsInput = [rewriterBlock](Operation *user) { 213 return user->getBlock() != rewriterBlock && 214 isa<OperandOp, OperandsOp, OperationOp>(user); 215 }; 216 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) { 217 if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInput)) 218 continue; 219 } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) { 220 if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInput)) 221 continue; 222 } 223 224 return op 225 .emitOpError("must have inferable or constrained result types when " 226 "nested within `pdl.rewrite`") 227 .attachNote() 228 .append("result type #", it.index(), " was not constrained"); 229 } 230 return success(); 231 } 232 233 LogicalResult OperationOp::verify() { 234 bool isWithinRewrite = isa<RewriteOp>((*this)->getParentOp()); 235 if (isWithinRewrite && !name()) 236 return emitOpError("must have an operation name when nested within " 237 "a `pdl.rewrite`"); 238 ArrayAttr attributeNames = attributeNamesAttr(); 239 auto attributeValues = attributes(); 240 if (attributeNames.size() != attributeValues.size()) { 241 return emitOpError() 242 << "expected the same number of attribute values and attribute " 243 "names, got " 244 << attributeNames.size() << " names and " << attributeValues.size() 245 << " values"; 246 } 247 248 // If the operation is within a rewrite body and doesn't have type inference, 249 // ensure that the result types can be resolved. 250 if (isWithinRewrite && !hasTypeInference()) { 251 if (failed(verifyResultTypesAreInferrable(*this, types()))) 252 return failure(); 253 } 254 255 return verifyHasBindingUse(*this); 256 } 257 258 bool OperationOp::hasTypeInference() { 259 Optional<StringRef> opName = name(); 260 if (!opName) 261 return false; 262 263 if (auto rInfo = RegisteredOperationName::lookup(*opName, getContext())) 264 return rInfo->hasInterface<InferTypeOpInterface>(); 265 return false; 266 } 267 268 //===----------------------------------------------------------------------===// 269 // pdl::PatternOp 270 //===----------------------------------------------------------------------===// 271 272 LogicalResult PatternOp::verifyRegions() { 273 Region &body = getBodyRegion(); 274 Operation *term = body.front().getTerminator(); 275 auto rewriteOp = dyn_cast<RewriteOp>(term); 276 if (!rewriteOp) { 277 return emitOpError("expected body to terminate with `pdl.rewrite`") 278 .attachNote(term->getLoc()) 279 .append("see terminator defined here"); 280 } 281 282 // Check that all values defined in the top-level pattern belong to the PDL 283 // dialect. 284 WalkResult result = body.walk([&](Operation *op) -> WalkResult { 285 if (!isa_and_nonnull<PDLDialect>(op->getDialect())) { 286 emitOpError("expected only `pdl` operations within the pattern body") 287 .attachNote(op->getLoc()) 288 .append("see non-`pdl` operation defined here"); 289 return WalkResult::interrupt(); 290 } 291 return WalkResult::advance(); 292 }); 293 if (result.wasInterrupted()) 294 return failure(); 295 296 // Check that there is at least one operation. 297 if (body.front().getOps<OperationOp>().empty()) 298 return emitOpError("the pattern must contain at least one `pdl.operation`"); 299 300 // Determine if the operations within the pdl.pattern form a connected 301 // component. This is determined by starting the search from the first 302 // operand/result/operation and visiting their users / parents / operands. 303 // We limit our attention to operations that have a user in pdl.rewrite, 304 // those that do not will be detected via other means (expected bindable 305 // user). 306 bool first = true; 307 DenseSet<Operation *> visited; 308 for (Operation &op : body.front()) { 309 // The following are the operations forming the connected component. 310 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op)) 311 continue; 312 313 // Determine if the operation has a user in `pdl.rewrite`. 314 bool hasUserInRewrite = false; 315 for (Operation *user : op.getUsers()) { 316 Region *region = user->getParentRegion(); 317 if (isa<RewriteOp>(user) || 318 (region && isa<RewriteOp>(region->getParentOp()))) { 319 hasUserInRewrite = true; 320 break; 321 } 322 } 323 324 // If the operation does not have a user in `pdl.rewrite`, ignore it. 325 if (!hasUserInRewrite) 326 continue; 327 328 if (first) { 329 // For the first operation, invoke visit. 330 visit(&op, visited); 331 first = false; 332 } else if (!visited.count(&op)) { 333 // For the subsequent operations, check if already visited. 334 return emitOpError("the operations must form a connected component") 335 .attachNote(op.getLoc()) 336 .append("see a disconnected value / operation here"); 337 } 338 } 339 340 return success(); 341 } 342 343 void PatternOp::build(OpBuilder &builder, OperationState &state, 344 Optional<uint16_t> benefit, Optional<StringRef> name) { 345 build(builder, state, builder.getI16IntegerAttr(benefit ? *benefit : 0), 346 name ? builder.getStringAttr(*name) : StringAttr()); 347 state.regions[0]->emplaceBlock(); 348 } 349 350 /// Returns the rewrite operation of this pattern. 351 RewriteOp PatternOp::getRewriter() { 352 return cast<RewriteOp>(body().front().getTerminator()); 353 } 354 355 /// The default dialect is `pdl`. 356 StringRef PatternOp::getDefaultDialect() { 357 return PDLDialect::getDialectNamespace(); 358 } 359 360 //===----------------------------------------------------------------------===// 361 // pdl::ReplaceOp 362 //===----------------------------------------------------------------------===// 363 364 LogicalResult ReplaceOp::verify() { 365 if (replOperation() && !replValues().empty()) 366 return emitOpError() << "expected no replacement values to be provided" 367 " when the replacement operation is present"; 368 return success(); 369 } 370 371 //===----------------------------------------------------------------------===// 372 // pdl::ResultsOp 373 //===----------------------------------------------------------------------===// 374 375 static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index, 376 Type &resultType) { 377 if (!index) { 378 resultType = RangeType::get(p.getBuilder().getType<ValueType>()); 379 return success(); 380 } 381 if (p.parseArrow() || p.parseType(resultType)) 382 return failure(); 383 return success(); 384 } 385 386 static void printResultsValueType(OpAsmPrinter &p, ResultsOp op, 387 IntegerAttr index, Type resultType) { 388 if (index) 389 p << " -> " << resultType; 390 } 391 392 LogicalResult ResultsOp::verify() { 393 if (!index() && getType().isa<pdl::ValueType>()) { 394 return emitOpError() << "expected `pdl.range<value>` result type when " 395 "no index is specified, but got: " 396 << getType(); 397 } 398 return success(); 399 } 400 401 //===----------------------------------------------------------------------===// 402 // pdl::RewriteOp 403 //===----------------------------------------------------------------------===// 404 405 LogicalResult RewriteOp::verifyRegions() { 406 Region &rewriteRegion = body(); 407 408 // Handle the case where the rewrite is external. 409 if (name()) { 410 if (!rewriteRegion.empty()) { 411 return emitOpError() 412 << "expected rewrite region to be empty when rewrite is external"; 413 } 414 return success(); 415 } 416 417 // Otherwise, check that the rewrite region only contains a single block. 418 if (rewriteRegion.empty()) { 419 return emitOpError() << "expected rewrite region to be non-empty if " 420 "external name is not specified"; 421 } 422 423 // Check that no additional arguments were provided. 424 if (!externalArgs().empty()) { 425 return emitOpError() << "expected no external arguments when the " 426 "rewrite is specified inline"; 427 } 428 if (externalConstParams()) { 429 return emitOpError() << "expected no external constant parameters when " 430 "the rewrite is specified inline"; 431 } 432 433 return success(); 434 } 435 436 /// The default dialect is `pdl`. 437 StringRef RewriteOp::getDefaultDialect() { 438 return PDLDialect::getDialectNamespace(); 439 } 440 441 //===----------------------------------------------------------------------===// 442 // pdl::TypeOp 443 //===----------------------------------------------------------------------===// 444 445 LogicalResult TypeOp::verify() { 446 if (!typeAttr()) 447 return verifyHasBindingUse(*this); 448 return success(); 449 } 450 451 //===----------------------------------------------------------------------===// 452 // pdl::TypesOp 453 //===----------------------------------------------------------------------===// 454 455 LogicalResult TypesOp::verify() { 456 if (!typesAttr()) 457 return verifyHasBindingUse(*this); 458 return success(); 459 } 460 461 //===----------------------------------------------------------------------===// 462 // TableGen'd op method definitions 463 //===----------------------------------------------------------------------===// 464 465 #define GET_OP_CLASSES 466 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" 467