1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDL/IR/PDL.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/Interfaces/InferTypeOpInterface.h" 14 #include "llvm/ADT/StringSwitch.h" 15 16 using namespace mlir; 17 using namespace mlir::pdl; 18 19 //===----------------------------------------------------------------------===// 20 // PDLDialect 21 //===----------------------------------------------------------------------===// 22 23 void PDLDialect::initialize() { 24 addOperations< 25 #define GET_OP_LIST 26 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" 27 >(); 28 addTypes< 29 #define GET_TYPEDEF_LIST 30 #include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc" 31 >(); 32 } 33 34 /// Returns true if the given operation is used by a "binding" pdl operation 35 /// within the main matcher body of a `pdl.pattern`. 36 static LogicalResult 37 verifyHasBindingUseInMatcher(Operation *op, 38 StringRef bindableContextStr = "`pdl.operation`") { 39 // If the pattern is not a pattern, there is nothing to do. 40 if (!isa<PatternOp>(op->getParentOp())) 41 return success(); 42 Block *matcherBlock = op->getBlock(); 43 for (Operation *user : op->getUsers()) { 44 if (user->getBlock() != matcherBlock) 45 continue; 46 if (isa<AttributeOp, OperandOp, OperationOp, RewriteOp>(user)) 47 return success(); 48 } 49 return op->emitOpError() 50 << "expected a bindable (i.e. " << bindableContextStr 51 << ") user when defined in the matcher body of a `pdl.pattern`"; 52 } 53 54 //===----------------------------------------------------------------------===// 55 // pdl::ApplyConstraintOp 56 //===----------------------------------------------------------------------===// 57 58 static LogicalResult verify(ApplyConstraintOp op) { 59 if (op.getNumOperands() == 0) 60 return op.emitOpError("expected at least one argument"); 61 return success(); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // pdl::AttributeOp 66 //===----------------------------------------------------------------------===// 67 68 static LogicalResult verify(AttributeOp op) { 69 Value attrType = op.type(); 70 Optional<Attribute> attrValue = op.value(); 71 72 if (!attrValue && isa<RewriteOp>(op->getParentOp())) 73 return op.emitOpError("expected constant value when specified within a " 74 "`pdl.rewrite`"); 75 if (attrValue && attrType) 76 return op.emitOpError("expected only one of [`type`, `value`] to be set"); 77 return verifyHasBindingUseInMatcher(op); 78 } 79 80 //===----------------------------------------------------------------------===// 81 // pdl::OperandOp 82 //===----------------------------------------------------------------------===// 83 84 static LogicalResult verify(OperandOp op) { 85 return verifyHasBindingUseInMatcher(op); 86 } 87 88 //===----------------------------------------------------------------------===// 89 // pdl::OperationOp 90 //===----------------------------------------------------------------------===// 91 92 static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) { 93 Builder &builder = p.getBuilder(); 94 95 // Parse the optional operation name. 96 bool startsWithOperands = succeeded(p.parseOptionalLParen()); 97 bool startsWithAttributes = 98 !startsWithOperands && succeeded(p.parseOptionalLBrace()); 99 bool startsWithOpName = false; 100 if (!startsWithAttributes && !startsWithOperands) { 101 StringAttr opName; 102 OptionalParseResult opNameResult = 103 p.parseOptionalAttribute(opName, "name", state.attributes); 104 startsWithOpName = opNameResult.hasValue(); 105 if (startsWithOpName && failed(*opNameResult)) 106 return failure(); 107 } 108 109 // Parse the operands. 110 SmallVector<OpAsmParser::OperandType, 4> operands; 111 if (startsWithOperands || 112 (!startsWithAttributes && succeeded(p.parseOptionalLParen()))) { 113 if (p.parseOperandList(operands) || p.parseRParen() || 114 p.resolveOperands(operands, builder.getType<ValueType>(), 115 state.operands)) 116 return failure(); 117 } 118 119 // Parse the attributes. 120 SmallVector<Attribute, 4> attrNames; 121 if (startsWithAttributes || succeeded(p.parseOptionalLBrace())) { 122 SmallVector<OpAsmParser::OperandType, 4> attrOps; 123 do { 124 StringAttr nameAttr; 125 OpAsmParser::OperandType operand; 126 if (p.parseAttribute(nameAttr) || p.parseEqual() || 127 p.parseOperand(operand)) 128 return failure(); 129 attrNames.push_back(nameAttr); 130 attrOps.push_back(operand); 131 } while (succeeded(p.parseOptionalComma())); 132 133 if (p.parseRBrace() || 134 p.resolveOperands(attrOps, builder.getType<AttributeType>(), 135 state.operands)) 136 return failure(); 137 } 138 state.addAttribute("attributeNames", builder.getArrayAttr(attrNames)); 139 state.addTypes(builder.getType<OperationType>()); 140 141 // Parse the result types. 142 SmallVector<OpAsmParser::OperandType, 4> opResultTypes; 143 if (succeeded(p.parseOptionalArrow())) { 144 if (p.parseOperandList(opResultTypes) || 145 p.resolveOperands(opResultTypes, builder.getType<TypeType>(), 146 state.operands)) 147 return failure(); 148 state.types.append(opResultTypes.size(), builder.getType<ValueType>()); 149 } 150 151 if (p.parseOptionalAttrDict(state.attributes)) 152 return failure(); 153 154 int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()), 155 static_cast<int32_t>(attrNames.size()), 156 static_cast<int32_t>(opResultTypes.size())}; 157 state.addAttribute("operand_segment_sizes", 158 builder.getI32VectorAttr(operandSegmentSizes)); 159 return success(); 160 } 161 162 static void print(OpAsmPrinter &p, OperationOp op) { 163 p << "pdl.operation "; 164 if (Optional<StringRef> name = op.name()) 165 p << '"' << *name << '"'; 166 167 auto operandValues = op.operands(); 168 if (!operandValues.empty()) 169 p << '(' << operandValues << ')'; 170 171 // Emit the optional attributes. 172 ArrayAttr attrNames = op.attributeNames(); 173 if (!attrNames.empty()) { 174 Operation::operand_range attrArgs = op.attributes(); 175 p << " {"; 176 interleaveComma(llvm::seq<int>(0, attrNames.size()), p, 177 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); 178 p << '}'; 179 } 180 181 // Print the result type constraints of the operation. 182 if (!op.results().empty()) 183 p << " -> " << op.types(); 184 p.printOptionalAttrDict(op->getAttrs(), 185 {"attributeNames", "name", "operand_segment_sizes"}); 186 } 187 188 /// Verifies that the result types of this operation, defined within a 189 /// `pdl.rewrite`, can be inferred. 190 static LogicalResult verifyResultTypesAreInferrable(OperationOp op, 191 ResultRange opResults, 192 OperandRange resultTypes) { 193 // Functor that returns if the given use can be used to infer a type. 194 Block *rewriterBlock = op->getBlock(); 195 auto canInferTypeFromUse = [&](OpOperand &use) { 196 // If the use is within a ReplaceOp and isn't the operation being replaced 197 // (i.e. is not the first operand of the replacement), we can infer a type. 198 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner()); 199 if (!replOpUser || use.getOperandNumber() == 0) 200 return false; 201 // Make sure the replaced operation was defined before this one. 202 Operation *replacedOp = replOpUser.operation().getDefiningOp(); 203 return replacedOp->getBlock() != rewriterBlock || 204 replacedOp->isBeforeInBlock(op); 205 }; 206 207 // Check to see if the uses of the operation itself can be used to infer 208 // types. 209 if (llvm::any_of(op.op().getUses(), canInferTypeFromUse)) 210 return success(); 211 212 // Otherwise, make sure each of the types can be inferred. 213 for (int i : llvm::seq<int>(0, opResults.size())) { 214 Operation *resultTypeOp = resultTypes[i].getDefiningOp(); 215 assert(resultTypeOp && "expected valid result type operation"); 216 217 // If the op was defined by a `create_native`, it is guaranteed to be 218 // usable. 219 if (isa<CreateNativeOp>(resultTypeOp)) 220 continue; 221 222 // If the type is already constrained, there is nothing to do. 223 TypeOp typeOp = cast<TypeOp>(resultTypeOp); 224 if (typeOp.type()) 225 continue; 226 227 // If the type operation was defined in the matcher and constrains the 228 // result of an input operation, it can be used. 229 auto constrainsInputOp = [rewriterBlock](Operation *user) { 230 return user->getBlock() != rewriterBlock && isa<OperationOp>(user); 231 }; 232 if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp)) 233 continue; 234 235 // Otherwise, check to see if any uses of the result can infer the type. 236 if (llvm::any_of(opResults[i].getUses(), canInferTypeFromUse)) 237 continue; 238 return op 239 .emitOpError("must have inferable or constrained result types when " 240 "nested within `pdl.rewrite`") 241 .attachNote() 242 .append("result type #", i, " was not constrained"); 243 } 244 return success(); 245 } 246 247 static LogicalResult verify(OperationOp op) { 248 bool isWithinRewrite = isa<RewriteOp>(op->getParentOp()); 249 if (isWithinRewrite && !op.name()) 250 return op.emitOpError("must have an operation name when nested within " 251 "a `pdl.rewrite`"); 252 ArrayAttr attributeNames = op.attributeNames(); 253 auto attributeValues = op.attributes(); 254 if (attributeNames.size() != attributeValues.size()) { 255 return op.emitOpError() 256 << "expected the same number of attribute values and attribute " 257 "names, got " 258 << attributeNames.size() << " names and " << attributeValues.size() 259 << " values"; 260 } 261 262 OperandRange resultTypes = op.types(); 263 auto opResults = op.results(); 264 if (resultTypes.size() != opResults.size()) { 265 return op.emitOpError() << "expected the same number of result values and " 266 "result type constraints, got " 267 << opResults.size() << " results and " 268 << resultTypes.size() << " constraints"; 269 } 270 271 // If the operation is within a rewrite body and doesn't have type inference, 272 // ensure that the result types can be resolved. 273 if (isWithinRewrite && !op.hasTypeInference()) { 274 if (failed(verifyResultTypesAreInferrable(op, opResults, resultTypes))) 275 return failure(); 276 } 277 278 return verifyHasBindingUseInMatcher(op, "`pdl.operation` or `pdl.rewrite`"); 279 } 280 281 bool OperationOp::hasTypeInference() { 282 Optional<StringRef> opName = name(); 283 if (!opName) 284 return false; 285 286 OperationName name(*opName, getContext()); 287 if (const AbstractOperation *op = name.getAbstractOperation()) 288 return op->getInterface<InferTypeOpInterface>(); 289 return false; 290 } 291 292 //===----------------------------------------------------------------------===// 293 // pdl::PatternOp 294 //===----------------------------------------------------------------------===// 295 296 static LogicalResult verify(PatternOp pattern) { 297 Region &body = pattern.body(); 298 auto *term = body.front().getTerminator(); 299 if (!isa<RewriteOp>(term)) { 300 return pattern.emitOpError("expected body to terminate with `pdl.rewrite`") 301 .attachNote(term->getLoc()) 302 .append("see terminator defined here"); 303 } 304 305 // Check that all values defined in the top-level pattern are referenced at 306 // least once in the source tree. 307 WalkResult result = body.walk([&](Operation *op) -> WalkResult { 308 if (!isa_and_nonnull<PDLDialect>(op->getDialect())) { 309 pattern 310 .emitOpError("expected only `pdl` operations within the pattern body") 311 .attachNote(op->getLoc()) 312 .append("see non-`pdl` operation defined here"); 313 return WalkResult::interrupt(); 314 } 315 return WalkResult::advance(); 316 }); 317 return failure(result.wasInterrupted()); 318 } 319 320 void PatternOp::build(OpBuilder &builder, OperationState &state, 321 Optional<StringRef> rootKind, Optional<uint16_t> benefit, 322 Optional<StringRef> name) { 323 build(builder, state, 324 rootKind ? builder.getStringAttr(*rootKind) : StringAttr(), 325 builder.getI16IntegerAttr(benefit ? *benefit : 0), 326 name ? builder.getStringAttr(*name) : StringAttr()); 327 builder.createBlock(state.addRegion()); 328 } 329 330 /// Returns the rewrite operation of this pattern. 331 RewriteOp PatternOp::getRewriter() { 332 return cast<RewriteOp>(body().front().getTerminator()); 333 } 334 335 /// Return the root operation kind that this pattern matches, or None if 336 /// there isn't a specific root. 337 Optional<StringRef> PatternOp::getRootKind() { 338 OperationOp rootOp = cast<OperationOp>(getRewriter().root().getDefiningOp()); 339 return rootOp.name(); 340 } 341 342 //===----------------------------------------------------------------------===// 343 // pdl::ReplaceOp 344 //===----------------------------------------------------------------------===// 345 346 static LogicalResult verify(ReplaceOp op) { 347 auto sourceOp = cast<OperationOp>(op.operation().getDefiningOp()); 348 auto sourceOpResults = sourceOp.results(); 349 auto replValues = op.replValues(); 350 351 if (Value replOpVal = op.replOperation()) { 352 auto replOp = cast<OperationOp>(replOpVal.getDefiningOp()); 353 auto replOpResults = replOp.results(); 354 if (sourceOpResults.size() != replOpResults.size()) { 355 return op.emitOpError() 356 << "expected source operation to have the same number of results " 357 "as the replacement operation, replacement operation provided " 358 << replOpResults.size() << " but expected " 359 << sourceOpResults.size(); 360 } 361 362 if (!replValues.empty()) { 363 return op.emitOpError() << "expected no replacement values to be provided" 364 " when the replacement operation is present"; 365 } 366 367 return success(); 368 } 369 370 if (sourceOpResults.size() != replValues.size()) { 371 return op.emitOpError() 372 << "expected source operation to have the same number of results " 373 "as the provided replacement values, found " 374 << replValues.size() << " replacement values but expected " 375 << sourceOpResults.size(); 376 } 377 378 return success(); 379 } 380 381 //===----------------------------------------------------------------------===// 382 // pdl::RewriteOp 383 //===----------------------------------------------------------------------===// 384 385 static LogicalResult verify(RewriteOp op) { 386 Region &rewriteRegion = op.body(); 387 388 // Handle the case where the rewrite is external. 389 if (op.name()) { 390 if (!rewriteRegion.empty()) { 391 return op.emitOpError() 392 << "expected rewrite region to be empty when rewrite is external"; 393 } 394 return success(); 395 } 396 397 // Otherwise, check that the rewrite region only contains a single block. 398 if (rewriteRegion.empty()) { 399 return op.emitOpError() << "expected rewrite region to be non-empty if " 400 "external name is not specified"; 401 } 402 403 // Check that no additional arguments were provided. 404 if (!op.externalArgs().empty()) { 405 return op.emitOpError() << "expected no external arguments when the " 406 "rewrite is specified inline"; 407 } 408 if (op.externalConstParams()) { 409 return op.emitOpError() << "expected no external constant parameters when " 410 "the rewrite is specified inline"; 411 } 412 413 return success(); 414 } 415 416 //===----------------------------------------------------------------------===// 417 // pdl::TypeOp 418 //===----------------------------------------------------------------------===// 419 420 static LogicalResult verify(TypeOp op) { 421 return verifyHasBindingUseInMatcher( 422 op, "`pdl.attribute`, `pdl.operand`, or `pdl.operation`"); 423 } 424 425 //===----------------------------------------------------------------------===// 426 // TableGen'd op method definitions 427 //===----------------------------------------------------------------------===// 428 429 #define GET_OP_CLASSES 430 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" 431