1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDL/IR/PDL.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/Interfaces/InferTypeOpInterface.h" 14 #include "llvm/ADT/DenseSet.h" 15 #include "llvm/ADT/TypeSwitch.h" 16 17 using namespace mlir; 18 using namespace mlir::pdl; 19 20 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc" 21 22 //===----------------------------------------------------------------------===// 23 // PDLDialect 24 //===----------------------------------------------------------------------===// 25 26 void PDLDialect::initialize() { 27 addOperations< 28 #define GET_OP_LIST 29 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" 30 >(); 31 registerTypes(); 32 } 33 34 //===----------------------------------------------------------------------===// 35 // PDL Operations 36 //===----------------------------------------------------------------------===// 37 38 /// Returns true if the given operation is used by a "binding" pdl operation. 39 static bool hasBindingUse(Operation *op) { 40 for (Operation *user : op->getUsers()) 41 // A result by itself is not binding, it must also be bound. 42 if (!isa<ResultOp, ResultsOp>(user) || hasBindingUse(user)) 43 return true; 44 return false; 45 } 46 47 /// Returns success if the given operation is not in the main matcher body or 48 /// is used by a "binding" operation. On failure, emits an error. 49 static LogicalResult verifyHasBindingUse(Operation *op) { 50 // If the parent is not a pattern, there is nothing to do. 51 if (!isa<PatternOp>(op->getParentOp())) 52 return success(); 53 if (hasBindingUse(op)) 54 return success(); 55 return op->emitOpError( 56 "expected a bindable user when defined in the matcher body of a " 57 "`pdl.pattern`"); 58 } 59 60 /// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) 61 /// connected to the given operation. 62 static void visit(Operation *op, DenseSet<Operation *> &visited) { 63 // If the parent is not a pattern, there is nothing to do. 64 if (!isa<PatternOp>(op->getParentOp()) || isa<RewriteOp>(op)) 65 return; 66 67 // Ignore if already visited. 68 if (visited.contains(op)) 69 return; 70 71 // Mark as visited. 72 visited.insert(op); 73 74 // Traverse the operands / parent. 75 TypeSwitch<Operation *>(op) 76 .Case<OperationOp>([&visited](auto operation) { 77 for (Value operand : operation.operands()) 78 visit(operand.getDefiningOp(), visited); 79 }) 80 .Case<ResultOp, ResultsOp>([&visited](auto result) { 81 visit(result.parent().getDefiningOp(), visited); 82 }); 83 84 // Traverse the users. 85 for (Operation *user : op->getUsers()) 86 visit(user, visited); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // pdl::ApplyNativeConstraintOp 91 //===----------------------------------------------------------------------===// 92 93 static LogicalResult verify(ApplyNativeConstraintOp op) { 94 if (op.getNumOperands() == 0) 95 return op.emitOpError("expected at least one argument"); 96 return success(); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // pdl::ApplyNativeRewriteOp 101 //===----------------------------------------------------------------------===// 102 103 static LogicalResult verify(ApplyNativeRewriteOp op) { 104 if (op.getNumOperands() == 0 && op.getNumResults() == 0) 105 return op.emitOpError("expected at least one argument or result"); 106 return success(); 107 } 108 109 //===----------------------------------------------------------------------===// 110 // pdl::AttributeOp 111 //===----------------------------------------------------------------------===// 112 113 static LogicalResult verify(AttributeOp op) { 114 Value attrType = op.type(); 115 Optional<Attribute> attrValue = op.value(); 116 117 if (!attrValue) { 118 if (isa<RewriteOp>(op->getParentOp())) 119 return op.emitOpError("expected constant value when specified within a " 120 "`pdl.rewrite`"); 121 return verifyHasBindingUse(op); 122 } 123 if (attrType) 124 return op.emitOpError("expected only one of [`type`, `value`] to be set"); 125 return success(); 126 } 127 128 //===----------------------------------------------------------------------===// 129 // pdl::OperandOp 130 //===----------------------------------------------------------------------===// 131 132 static LogicalResult verify(OperandOp op) { return verifyHasBindingUse(op); } 133 134 //===----------------------------------------------------------------------===// 135 // pdl::OperandsOp 136 //===----------------------------------------------------------------------===// 137 138 static LogicalResult verify(OperandsOp op) { return verifyHasBindingUse(op); } 139 140 //===----------------------------------------------------------------------===// 141 // pdl::OperationOp 142 //===----------------------------------------------------------------------===// 143 144 static ParseResult parseOperationOpAttributes( 145 OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands, 146 ArrayAttr &attrNamesAttr) { 147 Builder &builder = p.getBuilder(); 148 SmallVector<Attribute, 4> attrNames; 149 if (succeeded(p.parseOptionalLBrace())) { 150 do { 151 StringAttr nameAttr; 152 OpAsmParser::OperandType operand; 153 if (p.parseAttribute(nameAttr) || p.parseEqual() || 154 p.parseOperand(operand)) 155 return failure(); 156 attrNames.push_back(nameAttr); 157 attrOperands.push_back(operand); 158 } while (succeeded(p.parseOptionalComma())); 159 if (p.parseRBrace()) 160 return failure(); 161 } 162 attrNamesAttr = builder.getArrayAttr(attrNames); 163 return success(); 164 } 165 166 static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op, 167 OperandRange attrArgs, 168 ArrayAttr attrNames) { 169 if (attrNames.empty()) 170 return; 171 p << " {"; 172 interleaveComma(llvm::seq<int>(0, attrNames.size()), p, 173 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); 174 p << '}'; 175 } 176 177 /// Verifies that the result types of this operation, defined within a 178 /// `pdl.rewrite`, can be inferred. 179 static LogicalResult verifyResultTypesAreInferrable(OperationOp op, 180 OperandRange resultTypes) { 181 // Functor that returns if the given use can be used to infer a type. 182 Block *rewriterBlock = op->getBlock(); 183 auto canInferTypeFromUse = [&](OpOperand &use) { 184 // If the use is within a ReplaceOp and isn't the operation being replaced 185 // (i.e. is not the first operand of the replacement), we can infer a type. 186 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner()); 187 if (!replOpUser || use.getOperandNumber() == 0) 188 return false; 189 // Make sure the replaced operation was defined before this one. 190 Operation *replacedOp = replOpUser.operation().getDefiningOp(); 191 return replacedOp->getBlock() != rewriterBlock || 192 replacedOp->isBeforeInBlock(op); 193 }; 194 195 // Check to see if the uses of the operation itself can be used to infer 196 // types. 197 if (llvm::any_of(op.op().getUses(), canInferTypeFromUse)) 198 return success(); 199 200 // Otherwise, make sure each of the types can be inferred. 201 for (const auto &it : llvm::enumerate(resultTypes)) { 202 Operation *resultTypeOp = it.value().getDefiningOp(); 203 assert(resultTypeOp && "expected valid result type operation"); 204 205 // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be 206 // usable. 207 if (isa<ApplyNativeRewriteOp>(resultTypeOp)) 208 continue; 209 210 // If the type operation was defined in the matcher and constrains an 211 // operand or the result of an input operation, it can be used. 212 auto constrainsInput = [rewriterBlock](Operation *user) { 213 return user->getBlock() != rewriterBlock && 214 isa<OperandOp, OperandsOp, OperationOp>(user); 215 }; 216 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) { 217 if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInput)) 218 continue; 219 } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) { 220 if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInput)) 221 continue; 222 } 223 224 return op 225 .emitOpError("must have inferable or constrained result types when " 226 "nested within `pdl.rewrite`") 227 .attachNote() 228 .append("result type #", it.index(), " was not constrained"); 229 } 230 return success(); 231 } 232 233 static LogicalResult verify(OperationOp op) { 234 bool isWithinRewrite = isa<RewriteOp>(op->getParentOp()); 235 if (isWithinRewrite && !op.name()) 236 return op.emitOpError("must have an operation name when nested within " 237 "a `pdl.rewrite`"); 238 ArrayAttr attributeNames = op.attributeNames(); 239 auto attributeValues = op.attributes(); 240 if (attributeNames.size() != attributeValues.size()) { 241 return op.emitOpError() 242 << "expected the same number of attribute values and attribute " 243 "names, got " 244 << attributeNames.size() << " names and " << attributeValues.size() 245 << " values"; 246 } 247 248 // If the operation is within a rewrite body and doesn't have type inference, 249 // ensure that the result types can be resolved. 250 if (isWithinRewrite && !op.hasTypeInference()) { 251 if (failed(verifyResultTypesAreInferrable(op, op.types()))) 252 return failure(); 253 } 254 255 return verifyHasBindingUse(op); 256 } 257 258 bool OperationOp::hasTypeInference() { 259 Optional<StringRef> opName = name(); 260 if (!opName) 261 return false; 262 263 if (auto rInfo = RegisteredOperationName::lookup(*opName, getContext())) 264 return rInfo->hasInterface<InferTypeOpInterface>(); 265 return false; 266 } 267 268 //===----------------------------------------------------------------------===// 269 // pdl::PatternOp 270 //===----------------------------------------------------------------------===// 271 272 static LogicalResult verify(PatternOp pattern) { 273 Region &body = pattern.body(); 274 Operation *term = body.front().getTerminator(); 275 auto rewriteOp = dyn_cast<RewriteOp>(term); 276 if (!rewriteOp) { 277 return pattern.emitOpError("expected body to terminate with `pdl.rewrite`") 278 .attachNote(term->getLoc()) 279 .append("see terminator defined here"); 280 } 281 282 // Check that all values defined in the top-level pattern belong to the PDL 283 // dialect. 284 WalkResult result = body.walk([&](Operation *op) -> WalkResult { 285 if (!isa_and_nonnull<PDLDialect>(op->getDialect())) { 286 pattern 287 .emitOpError("expected only `pdl` operations within the pattern body") 288 .attachNote(op->getLoc()) 289 .append("see non-`pdl` operation defined here"); 290 return WalkResult::interrupt(); 291 } 292 return WalkResult::advance(); 293 }); 294 if (result.wasInterrupted()) 295 return failure(); 296 297 // Check that there is at least one operation. 298 if (body.front().getOps<OperationOp>().empty()) 299 return pattern.emitOpError( 300 "the pattern must contain at least one `pdl.operation`"); 301 302 // Determine if the operations within the pdl.pattern form a connected 303 // component. This is determined by starting the search from the first 304 // operand/result/operation and visiting their users / parents / operands. 305 // We limit our attention to operations that have a user in pdl.rewrite, 306 // those that do not will be detected via other means (expected bindable 307 // user). 308 bool first = true; 309 DenseSet<Operation *> visited; 310 for (Operation &op : body.front()) { 311 // The following are the operations forming the connected component. 312 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op)) 313 continue; 314 315 // Determine if the operation has a user in `pdl.rewrite`. 316 bool hasUserInRewrite = false; 317 for (Operation *user : op.getUsers()) { 318 Region *region = user->getParentRegion(); 319 if (isa<RewriteOp>(user) || 320 (region && isa<RewriteOp>(region->getParentOp()))) { 321 hasUserInRewrite = true; 322 break; 323 } 324 } 325 326 // If the operation does not have a user in `pdl.rewrite`, ignore it. 327 if (!hasUserInRewrite) 328 continue; 329 330 if (first) { 331 // For the first operation, invoke visit. 332 visit(&op, visited); 333 first = false; 334 } else if (!visited.count(&op)) { 335 // For the subsequent operations, check if already visited. 336 return pattern 337 .emitOpError("the operations must form a connected component") 338 .attachNote(op.getLoc()) 339 .append("see a disconnected value / operation here"); 340 } 341 } 342 343 return success(); 344 } 345 346 void PatternOp::build(OpBuilder &builder, OperationState &state, 347 Optional<uint16_t> benefit, Optional<StringRef> name) { 348 build(builder, state, builder.getI16IntegerAttr(benefit ? *benefit : 0), 349 name ? builder.getStringAttr(*name) : StringAttr()); 350 state.regions[0]->emplaceBlock(); 351 } 352 353 /// Returns the rewrite operation of this pattern. 354 RewriteOp PatternOp::getRewriter() { 355 return cast<RewriteOp>(body().front().getTerminator()); 356 } 357 358 //===----------------------------------------------------------------------===// 359 // pdl::ReplaceOp 360 //===----------------------------------------------------------------------===// 361 362 static LogicalResult verify(ReplaceOp op) { 363 if (op.replOperation() && !op.replValues().empty()) 364 return op.emitOpError() << "expected no replacement values to be provided" 365 " when the replacement operation is present"; 366 return success(); 367 } 368 369 //===----------------------------------------------------------------------===// 370 // pdl::ResultsOp 371 //===----------------------------------------------------------------------===// 372 373 static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index, 374 Type &resultType) { 375 if (!index) { 376 resultType = RangeType::get(p.getBuilder().getType<ValueType>()); 377 return success(); 378 } 379 if (p.parseArrow() || p.parseType(resultType)) 380 return failure(); 381 return success(); 382 } 383 384 static void printResultsValueType(OpAsmPrinter &p, ResultsOp op, 385 IntegerAttr index, Type resultType) { 386 if (index) 387 p << " -> " << resultType; 388 } 389 390 static LogicalResult verify(ResultsOp op) { 391 if (!op.index() && op.getType().isa<pdl::ValueType>()) { 392 return op.emitOpError() << "expected `pdl.range<value>` result type when " 393 "no index is specified, but got: " 394 << op.getType(); 395 } 396 return success(); 397 } 398 399 //===----------------------------------------------------------------------===// 400 // pdl::RewriteOp 401 //===----------------------------------------------------------------------===// 402 403 static LogicalResult verify(RewriteOp op) { 404 Region &rewriteRegion = op.body(); 405 406 // Handle the case where the rewrite is external. 407 if (op.name()) { 408 if (!rewriteRegion.empty()) { 409 return op.emitOpError() 410 << "expected rewrite region to be empty when rewrite is external"; 411 } 412 return success(); 413 } 414 415 // Otherwise, check that the rewrite region only contains a single block. 416 if (rewriteRegion.empty()) { 417 return op.emitOpError() << "expected rewrite region to be non-empty if " 418 "external name is not specified"; 419 } 420 421 // Check that no additional arguments were provided. 422 if (!op.externalArgs().empty()) { 423 return op.emitOpError() << "expected no external arguments when the " 424 "rewrite is specified inline"; 425 } 426 if (op.externalConstParams()) { 427 return op.emitOpError() << "expected no external constant parameters when " 428 "the rewrite is specified inline"; 429 } 430 431 return success(); 432 } 433 434 //===----------------------------------------------------------------------===// 435 // pdl::TypeOp 436 //===----------------------------------------------------------------------===// 437 438 static LogicalResult verify(TypeOp op) { 439 if (!op.typeAttr()) 440 return verifyHasBindingUse(op); 441 return success(); 442 } 443 444 //===----------------------------------------------------------------------===// 445 // pdl::TypesOp 446 //===----------------------------------------------------------------------===// 447 448 static LogicalResult verify(TypesOp op) { 449 if (!op.typesAttr()) 450 return verifyHasBindingUse(op); 451 return success(); 452 } 453 454 //===----------------------------------------------------------------------===// 455 // TableGen'd op method definitions 456 //===----------------------------------------------------------------------===// 457 458 #define GET_OP_CLASSES 459 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" 460