1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 24 #define DEBUG_TYPE "pdl-bytecode" 25 26 using namespace mlir; 27 using namespace mlir::detail; 28 29 //===----------------------------------------------------------------------===// 30 // PDLByteCodePattern 31 //===----------------------------------------------------------------------===// 32 33 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 34 ByteCodeAddr rewriterAddr) { 35 SmallVector<StringRef, 8> generatedOps; 36 if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 37 generatedOps = 38 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 39 40 PatternBenefit benefit = matchOp.benefit(); 41 MLIRContext *ctx = matchOp.getContext(); 42 43 // Check to see if this is pattern matches a specific operation type. 44 if (Optional<StringRef> rootKind = matchOp.rootKind()) 45 return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit, 46 ctx); 47 return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx, 48 MatchAnyOpTypeTag()); 49 } 50 51 //===----------------------------------------------------------------------===// 52 // PDLByteCodeMutableState 53 //===----------------------------------------------------------------------===// 54 55 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 56 /// to the position of the pattern within the range returned by 57 /// `PDLByteCode::getPatterns`. 58 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 59 PatternBenefit benefit) { 60 currentPatternBenefits[patternIndex] = benefit; 61 } 62 63 //===----------------------------------------------------------------------===// 64 // Bytecode OpCodes 65 //===----------------------------------------------------------------------===// 66 67 namespace { 68 enum OpCode : ByteCodeField { 69 /// Apply an externally registered constraint. 70 ApplyConstraint, 71 /// Apply an externally registered rewrite. 72 ApplyRewrite, 73 /// Check if two generic values are equal. 74 AreEqual, 75 /// Unconditional branch. 76 Branch, 77 /// Compare the operand count of an operation with a constant. 78 CheckOperandCount, 79 /// Compare the name of an operation with a constant. 80 CheckOperationName, 81 /// Compare the result count of an operation with a constant. 82 CheckResultCount, 83 /// Create an operation. 84 CreateOperation, 85 /// Erase an operation. 86 EraseOp, 87 /// Terminate a matcher or rewrite sequence. 88 Finalize, 89 /// Get a specific attribute of an operation. 90 GetAttribute, 91 /// Get the type of an attribute. 92 GetAttributeType, 93 /// Get the defining operation of a value. 94 GetDefiningOp, 95 /// Get a specific operand of an operation. 96 GetOperand0, 97 GetOperand1, 98 GetOperand2, 99 GetOperand3, 100 GetOperandN, 101 /// Get a specific result of an operation. 102 GetResult0, 103 GetResult1, 104 GetResult2, 105 GetResult3, 106 GetResultN, 107 /// Get the type of a value. 108 GetValueType, 109 /// Check if a generic value is not null. 110 IsNotNull, 111 /// Record a successful pattern match. 112 RecordMatch, 113 /// Replace an operation. 114 ReplaceOp, 115 /// Compare an attribute with a set of constants. 116 SwitchAttribute, 117 /// Compare the operand count of an operation with a set of constants. 118 SwitchOperandCount, 119 /// Compare the name of an operation with a set of constants. 120 SwitchOperationName, 121 /// Compare the result count of an operation with a set of constants. 122 SwitchResultCount, 123 /// Compare a type with a set of constants. 124 SwitchType, 125 }; 126 127 enum class PDLValueKind { Attribute, Operation, Type, Value }; 128 } // end anonymous namespace 129 130 //===----------------------------------------------------------------------===// 131 // ByteCode Generation 132 //===----------------------------------------------------------------------===// 133 134 //===----------------------------------------------------------------------===// 135 // Generator 136 137 namespace { 138 struct ByteCodeWriter; 139 140 /// This class represents the main generator for the pattern bytecode. 141 class Generator { 142 public: 143 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 144 SmallVectorImpl<ByteCodeField> &matcherByteCode, 145 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 146 SmallVectorImpl<PDLByteCodePattern> &patterns, 147 ByteCodeField &maxValueMemoryIndex, 148 llvm::StringMap<PDLConstraintFunction> &constraintFns, 149 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 150 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 151 rewriterByteCode(rewriterByteCode), patterns(patterns), 152 maxValueMemoryIndex(maxValueMemoryIndex) { 153 for (auto it : llvm::enumerate(constraintFns)) 154 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 155 for (auto it : llvm::enumerate(rewriteFns)) 156 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 157 } 158 159 /// Generate the bytecode for the given PDL interpreter module. 160 void generate(ModuleOp module); 161 162 /// Return the memory index to use for the given value. 163 ByteCodeField &getMemIndex(Value value) { 164 assert(valueToMemIndex.count(value) && 165 "expected memory index to be assigned"); 166 return valueToMemIndex[value]; 167 } 168 169 /// Return an index to use when referring to the given data that is uniqued in 170 /// the MLIR context. 171 template <typename T> 172 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 173 getMemIndex(T val) { 174 const void *opaqueVal = val.getAsOpaquePointer(); 175 176 // Get or insert a reference to this value. 177 auto it = uniquedDataToMemIndex.try_emplace( 178 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 179 if (it.second) 180 uniquedData.push_back(opaqueVal); 181 return it.first->second; 182 } 183 184 private: 185 /// Allocate memory indices for the results of operations within the matcher 186 /// and rewriters. 187 void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 188 189 /// Generate the bytecode for the given operation. 190 void generate(Operation *op, ByteCodeWriter &writer); 191 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 192 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 193 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 194 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 195 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 196 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 197 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 198 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 199 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 200 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 201 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 202 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 203 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 204 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 205 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 206 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 207 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 208 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 209 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 210 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 211 void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); 212 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 213 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 214 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 215 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 216 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 217 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 218 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 219 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 220 221 /// Mapping from value to its corresponding memory index. 222 DenseMap<Value, ByteCodeField> valueToMemIndex; 223 224 /// Mapping from the name of an externally registered rewrite to its index in 225 /// the bytecode registry. 226 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 227 228 /// Mapping from the name of an externally registered constraint to its index 229 /// in the bytecode registry. 230 llvm::StringMap<ByteCodeField> constraintToMemIndex; 231 232 /// Mapping from rewriter function name to the bytecode address of the 233 /// rewriter function in byte. 234 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 235 236 /// Mapping from a uniqued storage object to its memory index within 237 /// `uniquedData`. 238 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 239 240 /// The current MLIR context. 241 MLIRContext *ctx; 242 243 /// Data of the ByteCode class to be populated. 244 std::vector<const void *> &uniquedData; 245 SmallVectorImpl<ByteCodeField> &matcherByteCode; 246 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 247 SmallVectorImpl<PDLByteCodePattern> &patterns; 248 ByteCodeField &maxValueMemoryIndex; 249 }; 250 251 /// This class provides utilities for writing a bytecode stream. 252 struct ByteCodeWriter { 253 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 254 : bytecode(bytecode), generator(generator) {} 255 256 /// Append a field to the bytecode. 257 void append(ByteCodeField field) { bytecode.push_back(field); } 258 void append(OpCode opCode) { bytecode.push_back(opCode); } 259 260 /// Append an address to the bytecode. 261 void append(ByteCodeAddr field) { 262 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 263 "unexpected ByteCode address size"); 264 265 ByteCodeField fieldParts[2]; 266 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 267 bytecode.append({fieldParts[0], fieldParts[1]}); 268 } 269 270 /// Append a successor range to the bytecode, the exact address will need to 271 /// be resolved later. 272 void append(SuccessorRange successors) { 273 // Add back references to the any successors so that the address can be 274 // resolved later. 275 for (Block *successor : successors) { 276 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 277 append(ByteCodeAddr(0)); 278 } 279 } 280 281 /// Append a range of values that will be read as generic PDLValues. 282 void appendPDLValueList(OperandRange values) { 283 bytecode.push_back(values.size()); 284 for (Value value : values) { 285 // Append the type of the value in addition to the value itself. 286 PDLValueKind kind = 287 TypeSwitch<Type, PDLValueKind>(value.getType()) 288 .Case<pdl::AttributeType>( 289 [](Type) { return PDLValueKind::Attribute; }) 290 .Case<pdl::OperationType>( 291 [](Type) { return PDLValueKind::Operation; }) 292 .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; }) 293 .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; }); 294 bytecode.push_back(static_cast<ByteCodeField>(kind)); 295 append(value); 296 } 297 } 298 299 /// Check if the given class `T` has an iterator type. 300 template <typename T, typename... Args> 301 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 302 303 /// Append a value that will be stored in a memory slot and not inline within 304 /// the bytecode. 305 template <typename T> 306 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 307 std::is_pointer<T>::value> 308 append(T value) { 309 bytecode.push_back(generator.getMemIndex(value)); 310 } 311 312 /// Append a range of values. 313 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 314 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 315 append(T range) { 316 bytecode.push_back(llvm::size(range)); 317 for (auto it : range) 318 append(it); 319 } 320 321 /// Append a variadic number of fields to the bytecode. 322 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 323 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 324 append(field); 325 append(field2, fields...); 326 } 327 328 /// Successor references in the bytecode that have yet to be resolved. 329 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 330 331 /// The underlying bytecode buffer. 332 SmallVectorImpl<ByteCodeField> &bytecode; 333 334 /// The main generator producing PDL. 335 Generator &generator; 336 }; 337 } // end anonymous namespace 338 339 void Generator::generate(ModuleOp module) { 340 FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 341 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 342 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 343 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 344 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 345 346 // Allocate memory indices for the results of operations within the matcher 347 // and rewriters. 348 allocateMemoryIndices(matcherFunc, rewriterModule); 349 350 // Generate code for the rewriter functions. 351 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 352 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 353 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 354 for (Operation &op : rewriterFunc.getOps()) 355 generate(&op, rewriterByteCodeWriter); 356 } 357 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 358 "unexpected branches in rewriter function"); 359 360 // Generate code for the matcher function. 361 DenseMap<Block *, ByteCodeAddr> blockToAddr; 362 llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody()); 363 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 364 for (Block *block : rpot) { 365 // Keep track of where this block begins within the matcher function. 366 blockToAddr.try_emplace(block, matcherByteCode.size()); 367 for (Operation &op : *block) 368 generate(&op, matcherByteCodeWriter); 369 } 370 371 // Resolve successor references in the matcher. 372 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 373 ByteCodeAddr addr = blockToAddr[it.first]; 374 for (unsigned offsetToFix : it.second) 375 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 376 } 377 } 378 379 void Generator::allocateMemoryIndices(FuncOp matcherFunc, 380 ModuleOp rewriterModule) { 381 // Rewriters use simplistic allocation scheme that simply assigns an index to 382 // each result. 383 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 384 ByteCodeField index = 0; 385 for (BlockArgument arg : rewriterFunc.getArguments()) 386 valueToMemIndex.try_emplace(arg, index++); 387 rewriterFunc.getBody().walk([&](Operation *op) { 388 for (Value result : op->getResults()) 389 valueToMemIndex.try_emplace(result, index++); 390 }); 391 if (index > maxValueMemoryIndex) 392 maxValueMemoryIndex = index; 393 } 394 395 // The matcher function uses a more sophisticated numbering that tries to 396 // minimize the number of memory indices assigned. This is done by determining 397 // a live range of the values within the matcher, then the allocation is just 398 // finding the minimal number of overlapping live ranges. This is essentially 399 // a simplified form of register allocation where we don't necessarily have a 400 // limited number of registers, but we still want to minimize the number used. 401 DenseMap<Operation *, ByteCodeField> opToIndex; 402 matcherFunc.getBody().walk([&](Operation *op) { 403 opToIndex.insert(std::make_pair(op, opToIndex.size())); 404 }); 405 406 // Liveness info for each of the defs within the matcher. 407 using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>; 408 LivenessSet::Allocator allocator; 409 DenseMap<Value, LivenessSet> valueDefRanges; 410 411 // Assign the root operation being matched to slot 0. 412 BlockArgument rootOpArg = matcherFunc.getArgument(0); 413 valueToMemIndex[rootOpArg] = 0; 414 415 // Walk each of the blocks, computing the def interval that the value is used. 416 Liveness matcherLiveness(matcherFunc); 417 for (Block &block : matcherFunc.getBody()) { 418 const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); 419 assert(info && "expected liveness info for block"); 420 auto processValue = [&](Value value, Operation *firstUseOrDef) { 421 // We don't need to process the root op argument, this value is always 422 // assigned to the first memory slot. 423 if (value == rootOpArg) 424 return; 425 426 // Set indices for the range of this block that the value is used. 427 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 428 defRangeIt->second.insert( 429 opToIndex[firstUseOrDef], 430 opToIndex[info->getEndOperation(value, firstUseOrDef)], 431 /*dummyValue*/ 0); 432 }; 433 434 // Process the live-ins of this block. 435 for (Value liveIn : info->in()) 436 processValue(liveIn, &block.front()); 437 438 // Process any new defs within this block. 439 for (Operation &op : block) 440 for (Value result : op.getResults()) 441 processValue(result, &op); 442 } 443 444 // Greedily allocate memory slots using the computed def live ranges. 445 std::vector<LivenessSet> allocatedIndices; 446 for (auto &defIt : valueDefRanges) { 447 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 448 LivenessSet &defSet = defIt.second; 449 450 // Try to allocate to an existing index. 451 for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { 452 LivenessSet &existingIndex = existingIndexIt.value(); 453 llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps( 454 defIt.second, existingIndex); 455 if (overlaps.valid()) 456 continue; 457 // Union the range of the def within the existing index. 458 for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) 459 existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0); 460 memIndex = existingIndexIt.index() + 1; 461 } 462 463 // If no existing index could be used, add a new one. 464 if (memIndex == 0) { 465 allocatedIndices.emplace_back(allocator); 466 for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) 467 allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0); 468 memIndex = allocatedIndices.size(); 469 } 470 } 471 472 // Update the max number of indices. 473 ByteCodeField numMatcherIndices = allocatedIndices.size() + 1; 474 if (numMatcherIndices > maxValueMemoryIndex) 475 maxValueMemoryIndex = numMatcherIndices; 476 } 477 478 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 479 TypeSwitch<Operation *>(op) 480 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 481 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 482 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 483 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 484 pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, 485 pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, 486 pdl_interp::EraseOp, pdl_interp::FinalizeOp, 487 pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, 488 pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, 489 pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp, 490 pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, 491 pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, 492 pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, 493 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, 494 pdl_interp::SwitchResultCountOp>( 495 [&](auto interpOp) { this->generate(interpOp, writer); }) 496 .Default([](Operation *) { 497 llvm_unreachable("unknown `pdl_interp` operation"); 498 }); 499 } 500 501 void Generator::generate(pdl_interp::ApplyConstraintOp op, 502 ByteCodeWriter &writer) { 503 assert(constraintToMemIndex.count(op.name()) && 504 "expected index for constraint function"); 505 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 506 op.constParamsAttr()); 507 writer.appendPDLValueList(op.args()); 508 writer.append(op.getSuccessors()); 509 } 510 void Generator::generate(pdl_interp::ApplyRewriteOp op, 511 ByteCodeWriter &writer) { 512 assert(externalRewriterToMemIndex.count(op.name()) && 513 "expected index for rewrite function"); 514 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 515 op.constParamsAttr()); 516 writer.appendPDLValueList(op.args()); 517 518 #ifndef NDEBUG 519 // In debug mode we also append the number of results so that we can assert 520 // that the native creation function gave us the correct number of results. 521 writer.append(ByteCodeField(op.results().size())); 522 #endif 523 for (Value result : op.results()) 524 writer.append(result); 525 } 526 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 527 writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); 528 } 529 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 530 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 531 } 532 void Generator::generate(pdl_interp::CheckAttributeOp op, 533 ByteCodeWriter &writer) { 534 writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 535 op.getSuccessors()); 536 } 537 void Generator::generate(pdl_interp::CheckOperandCountOp op, 538 ByteCodeWriter &writer) { 539 writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 540 op.getSuccessors()); 541 } 542 void Generator::generate(pdl_interp::CheckOperationNameOp op, 543 ByteCodeWriter &writer) { 544 writer.append(OpCode::CheckOperationName, op.operation(), 545 OperationName(op.name(), ctx), op.getSuccessors()); 546 } 547 void Generator::generate(pdl_interp::CheckResultCountOp op, 548 ByteCodeWriter &writer) { 549 writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 550 op.getSuccessors()); 551 } 552 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 553 writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 554 } 555 void Generator::generate(pdl_interp::CreateAttributeOp op, 556 ByteCodeWriter &writer) { 557 // Simply repoint the memory index of the result to the constant. 558 getMemIndex(op.attribute()) = getMemIndex(op.value()); 559 } 560 void Generator::generate(pdl_interp::CreateOperationOp op, 561 ByteCodeWriter &writer) { 562 writer.append(OpCode::CreateOperation, op.operation(), 563 OperationName(op.name(), ctx), op.operands()); 564 565 // Add the attributes. 566 OperandRange attributes = op.attributes(); 567 writer.append(static_cast<ByteCodeField>(attributes.size())); 568 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 569 writer.append( 570 Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), 571 std::get<1>(it)); 572 } 573 writer.append(op.types()); 574 } 575 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 576 // Simply repoint the memory index of the result to the constant. 577 getMemIndex(op.result()) = getMemIndex(op.value()); 578 } 579 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 580 writer.append(OpCode::EraseOp, op.operation()); 581 } 582 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 583 writer.append(OpCode::Finalize); 584 } 585 void Generator::generate(pdl_interp::GetAttributeOp op, 586 ByteCodeWriter &writer) { 587 writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 588 Identifier::get(op.name(), ctx)); 589 } 590 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 591 ByteCodeWriter &writer) { 592 writer.append(OpCode::GetAttributeType, op.result(), op.value()); 593 } 594 void Generator::generate(pdl_interp::GetDefiningOpOp op, 595 ByteCodeWriter &writer) { 596 writer.append(OpCode::GetDefiningOp, op.operation(), op.value()); 597 } 598 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 599 uint32_t index = op.index(); 600 if (index < 4) 601 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 602 else 603 writer.append(OpCode::GetOperandN, index); 604 writer.append(op.operation(), op.value()); 605 } 606 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 607 uint32_t index = op.index(); 608 if (index < 4) 609 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 610 else 611 writer.append(OpCode::GetResultN, index); 612 writer.append(op.operation(), op.value()); 613 } 614 void Generator::generate(pdl_interp::GetValueTypeOp op, 615 ByteCodeWriter &writer) { 616 writer.append(OpCode::GetValueType, op.result(), op.value()); 617 } 618 void Generator::generate(pdl_interp::InferredTypesOp op, 619 ByteCodeWriter &writer) { 620 // InferType maps to a null type as a marker for inferring result types. 621 getMemIndex(op.type()) = getMemIndex(Type()); 622 } 623 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 624 writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 625 } 626 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 627 ByteCodeField patternIndex = patterns.size(); 628 patterns.emplace_back(PDLByteCodePattern::create( 629 op, rewriterToAddr[op.rewriter().getLeafReference()])); 630 writer.append(OpCode::RecordMatch, patternIndex, 631 SuccessorRange(op.getOperation()), op.matchedOps(), 632 op.inputs()); 633 } 634 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 635 writer.append(OpCode::ReplaceOp, op.operation(), op.replValues()); 636 } 637 void Generator::generate(pdl_interp::SwitchAttributeOp op, 638 ByteCodeWriter &writer) { 639 writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 640 op.getSuccessors()); 641 } 642 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 643 ByteCodeWriter &writer) { 644 writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 645 op.getSuccessors()); 646 } 647 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 648 ByteCodeWriter &writer) { 649 auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 650 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 651 }); 652 writer.append(OpCode::SwitchOperationName, op.operation(), cases, 653 op.getSuccessors()); 654 } 655 void Generator::generate(pdl_interp::SwitchResultCountOp op, 656 ByteCodeWriter &writer) { 657 writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 658 op.getSuccessors()); 659 } 660 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 661 writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 662 op.getSuccessors()); 663 } 664 665 //===----------------------------------------------------------------------===// 666 // PDLByteCode 667 //===----------------------------------------------------------------------===// 668 669 PDLByteCode::PDLByteCode(ModuleOp module, 670 llvm::StringMap<PDLConstraintFunction> constraintFns, 671 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 672 Generator generator(module.getContext(), uniquedData, matcherByteCode, 673 rewriterByteCode, patterns, maxValueMemoryIndex, 674 constraintFns, rewriteFns); 675 generator.generate(module); 676 677 // Initialize the external functions. 678 for (auto &it : constraintFns) 679 constraintFunctions.push_back(std::move(it.second)); 680 for (auto &it : rewriteFns) 681 rewriteFunctions.push_back(std::move(it.second)); 682 } 683 684 /// Initialize the given state such that it can be used to execute the current 685 /// bytecode. 686 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 687 state.memory.resize(maxValueMemoryIndex, nullptr); 688 state.currentPatternBenefits.reserve(patterns.size()); 689 for (const PDLByteCodePattern &pattern : patterns) 690 state.currentPatternBenefits.push_back(pattern.getBenefit()); 691 } 692 693 //===----------------------------------------------------------------------===// 694 // ByteCode Execution 695 696 namespace { 697 /// This class provides support for executing a bytecode stream. 698 class ByteCodeExecutor { 699 public: 700 ByteCodeExecutor(const ByteCodeField *curCodeIt, 701 MutableArrayRef<const void *> memory, 702 ArrayRef<const void *> uniquedMemory, 703 ArrayRef<ByteCodeField> code, 704 ArrayRef<PatternBenefit> currentPatternBenefits, 705 ArrayRef<PDLByteCodePattern> patterns, 706 ArrayRef<PDLConstraintFunction> constraintFunctions, 707 ArrayRef<PDLRewriteFunction> rewriteFunctions) 708 : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), 709 code(code), currentPatternBenefits(currentPatternBenefits), 710 patterns(patterns), constraintFunctions(constraintFunctions), 711 rewriteFunctions(rewriteFunctions) {} 712 713 /// Start executing the code at the current bytecode index. `matches` is an 714 /// optional field provided when this function is executed in a matching 715 /// context. 716 void execute(PatternRewriter &rewriter, 717 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 718 Optional<Location> mainRewriteLoc = {}); 719 720 private: 721 /// Internal implementation of executing each of the bytecode commands. 722 void executeApplyConstraint(PatternRewriter &rewriter); 723 void executeApplyRewrite(PatternRewriter &rewriter); 724 void executeAreEqual(); 725 void executeBranch(); 726 void executeCheckOperandCount(); 727 void executeCheckOperationName(); 728 void executeCheckResultCount(); 729 void executeCreateOperation(PatternRewriter &rewriter, 730 Location mainRewriteLoc); 731 void executeEraseOp(PatternRewriter &rewriter); 732 void executeGetAttribute(); 733 void executeGetAttributeType(); 734 void executeGetDefiningOp(); 735 void executeGetOperand(unsigned index); 736 void executeGetResult(unsigned index); 737 void executeGetValueType(); 738 void executeIsNotNull(); 739 void executeRecordMatch(PatternRewriter &rewriter, 740 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 741 void executeReplaceOp(PatternRewriter &rewriter); 742 void executeSwitchAttribute(); 743 void executeSwitchOperandCount(); 744 void executeSwitchOperationName(); 745 void executeSwitchResultCount(); 746 void executeSwitchType(); 747 748 /// Read a value from the bytecode buffer, optionally skipping a certain 749 /// number of prefix values. These methods always update the buffer to point 750 /// to the next field after the read data. 751 template <typename T = ByteCodeField> 752 T read(size_t skipN = 0) { 753 curCodeIt += skipN; 754 return readImpl<T>(); 755 } 756 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 757 758 /// Read a list of values from the bytecode buffer. 759 template <typename ValueT, typename T> 760 void readList(SmallVectorImpl<T> &list) { 761 list.clear(); 762 for (unsigned i = 0, e = read(); i != e; ++i) 763 list.push_back(read<ValueT>()); 764 } 765 766 /// Jump to a specific successor based on a predicate value. 767 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 768 /// Jump to a specific successor based on a destination index. 769 void selectJump(size_t destIndex) { 770 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 771 } 772 773 /// Handle a switch operation with the provided value and cases. 774 template <typename T, typename RangeT> 775 void handleSwitch(const T &value, RangeT &&cases) { 776 LLVM_DEBUG({ 777 llvm::dbgs() << " * Value: " << value << "\n" 778 << " * Cases: "; 779 llvm::interleaveComma(cases, llvm::dbgs()); 780 llvm::dbgs() << "\n"; 781 }); 782 783 // Check to see if the attribute value is within the case list. Jump to 784 // the correct successor index based on the result. 785 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 786 if (*it == value) 787 return selectJump(size_t((it - cases.begin()) + 1)); 788 selectJump(size_t(0)); 789 } 790 791 /// Internal implementation of reading various data types from the bytecode 792 /// stream. 793 template <typename T> 794 const void *readFromMemory() { 795 size_t index = *curCodeIt++; 796 797 // If this type is an SSA value, it can only be stored in non-const memory. 798 if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size()) 799 return memory[index]; 800 801 // Otherwise, if this index is not inbounds it is uniqued. 802 return uniquedMemory[index - memory.size()]; 803 } 804 template <typename T> 805 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 806 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 807 } 808 template <typename T> 809 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 810 T> 811 readImpl() { 812 return T(T::getFromOpaquePointer(readFromMemory<T>())); 813 } 814 template <typename T> 815 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 816 switch (static_cast<PDLValueKind>(read())) { 817 case PDLValueKind::Attribute: 818 return read<Attribute>(); 819 case PDLValueKind::Operation: 820 return read<Operation *>(); 821 case PDLValueKind::Type: 822 return read<Type>(); 823 case PDLValueKind::Value: 824 return read<Value>(); 825 } 826 llvm_unreachable("unhandled PDLValueKind"); 827 } 828 template <typename T> 829 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 830 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 831 "unexpected ByteCode address size"); 832 ByteCodeAddr result; 833 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 834 curCodeIt += 2; 835 return result; 836 } 837 template <typename T> 838 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 839 return *curCodeIt++; 840 } 841 842 /// The underlying bytecode buffer. 843 const ByteCodeField *curCodeIt; 844 845 /// The current execution memory. 846 MutableArrayRef<const void *> memory; 847 848 /// References to ByteCode data necessary for execution. 849 ArrayRef<const void *> uniquedMemory; 850 ArrayRef<ByteCodeField> code; 851 ArrayRef<PatternBenefit> currentPatternBenefits; 852 ArrayRef<PDLByteCodePattern> patterns; 853 ArrayRef<PDLConstraintFunction> constraintFunctions; 854 ArrayRef<PDLRewriteFunction> rewriteFunctions; 855 }; 856 857 /// This class is an instantiation of the PDLResultList that provides access to 858 /// the returned results. This API is not on `PDLResultList` to avoid 859 /// overexposing access to information specific solely to the ByteCode. 860 class ByteCodeRewriteResultList : public PDLResultList { 861 public: 862 /// Return the list of PDL results. 863 MutableArrayRef<PDLValue> getResults() { return results; } 864 }; 865 } // end anonymous namespace 866 867 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 868 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 869 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 870 ArrayAttr constParams = read<ArrayAttr>(); 871 SmallVector<PDLValue, 16> args; 872 readList<PDLValue>(args); 873 874 LLVM_DEBUG({ 875 llvm::dbgs() << " * Arguments: "; 876 llvm::interleaveComma(args, llvm::dbgs()); 877 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 878 }); 879 880 // Invoke the constraint and jump to the proper destination. 881 selectJump(succeeded(constraintFn(args, constParams, rewriter))); 882 } 883 884 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 885 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 886 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 887 ArrayAttr constParams = read<ArrayAttr>(); 888 SmallVector<PDLValue, 16> args; 889 readList<PDLValue>(args); 890 891 LLVM_DEBUG({ 892 llvm::dbgs() << " * Arguments: "; 893 llvm::interleaveComma(args, llvm::dbgs()); 894 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 895 }); 896 ByteCodeRewriteResultList results; 897 rewriteFn(args, constParams, rewriter, results); 898 899 // Store the results in the bytecode memory. 900 #ifndef NDEBUG 901 ByteCodeField expectedNumberOfResults = read(); 902 assert(results.getResults().size() == expectedNumberOfResults && 903 "native PDL rewrite function returned unexpected number of results"); 904 #endif 905 906 // Store the results in the bytecode memory. 907 for (PDLValue &result : results.getResults()) { 908 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 909 memory[read()] = result.getAsOpaquePointer(); 910 } 911 } 912 913 void ByteCodeExecutor::executeAreEqual() { 914 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 915 const void *lhs = read<const void *>(); 916 const void *rhs = read<const void *>(); 917 918 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 919 selectJump(lhs == rhs); 920 } 921 922 void ByteCodeExecutor::executeBranch() { 923 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 924 curCodeIt = &code[read<ByteCodeAddr>()]; 925 } 926 927 void ByteCodeExecutor::executeCheckOperandCount() { 928 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 929 Operation *op = read<Operation *>(); 930 uint32_t expectedCount = read<uint32_t>(); 931 932 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 933 << " * Expected: " << expectedCount << "\n"); 934 selectJump(op->getNumOperands() == expectedCount); 935 } 936 937 void ByteCodeExecutor::executeCheckOperationName() { 938 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 939 Operation *op = read<Operation *>(); 940 OperationName expectedName = read<OperationName>(); 941 942 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 943 << " * Expected: \"" << expectedName << "\"\n"); 944 selectJump(op->getName() == expectedName); 945 } 946 947 void ByteCodeExecutor::executeCheckResultCount() { 948 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 949 Operation *op = read<Operation *>(); 950 uint32_t expectedCount = read<uint32_t>(); 951 952 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 953 << " * Expected: " << expectedCount << "\n"); 954 selectJump(op->getNumResults() == expectedCount); 955 } 956 957 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 958 Location mainRewriteLoc) { 959 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 960 961 unsigned memIndex = read(); 962 OperationState state(mainRewriteLoc, read<OperationName>()); 963 readList<Value>(state.operands); 964 for (unsigned i = 0, e = read(); i != e; ++i) { 965 Identifier name = read<Identifier>(); 966 if (Attribute attr = read<Attribute>()) 967 state.addAttribute(name, attr); 968 } 969 970 bool hasInferredTypes = false; 971 for (unsigned i = 0, e = read(); i != e; ++i) { 972 Type resultType = read<Type>(); 973 hasInferredTypes |= !resultType; 974 state.types.push_back(resultType); 975 } 976 977 // Handle the case where the operation has inferred types. 978 if (hasInferredTypes) { 979 InferTypeOpInterface::Concept *concept = 980 state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>(); 981 982 // TODO: Handle failure. 983 state.types.clear(); 984 if (failed(concept->inferReturnTypes( 985 state.getContext(), state.location, state.operands, 986 state.attributes.getDictionary(state.getContext()), state.regions, 987 state.types))) 988 return; 989 } 990 Operation *resultOp = rewriter.createOperation(state); 991 memory[memIndex] = resultOp; 992 993 LLVM_DEBUG({ 994 llvm::dbgs() << " * Attributes: " 995 << state.attributes.getDictionary(state.getContext()) 996 << "\n * Operands: "; 997 llvm::interleaveComma(state.operands, llvm::dbgs()); 998 llvm::dbgs() << "\n * Result Types: "; 999 llvm::interleaveComma(state.types, llvm::dbgs()); 1000 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1001 }); 1002 } 1003 1004 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1005 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1006 Operation *op = read<Operation *>(); 1007 1008 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1009 rewriter.eraseOp(op); 1010 } 1011 1012 void ByteCodeExecutor::executeGetAttribute() { 1013 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1014 unsigned memIndex = read(); 1015 Operation *op = read<Operation *>(); 1016 Identifier attrName = read<Identifier>(); 1017 Attribute attr = op->getAttr(attrName); 1018 1019 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1020 << " * Attribute: " << attrName << "\n" 1021 << " * Result: " << attr << "\n"); 1022 memory[memIndex] = attr.getAsOpaquePointer(); 1023 } 1024 1025 void ByteCodeExecutor::executeGetAttributeType() { 1026 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1027 unsigned memIndex = read(); 1028 Attribute attr = read<Attribute>(); 1029 Type type = attr ? attr.getType() : Type(); 1030 1031 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1032 << " * Result: " << type << "\n"); 1033 memory[memIndex] = type.getAsOpaquePointer(); 1034 } 1035 1036 void ByteCodeExecutor::executeGetDefiningOp() { 1037 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1038 unsigned memIndex = read(); 1039 Value value = read<Value>(); 1040 Operation *op = value ? value.getDefiningOp() : nullptr; 1041 1042 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1043 << " * Result: " << *op << "\n"); 1044 memory[memIndex] = op; 1045 } 1046 1047 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1048 Operation *op = read<Operation *>(); 1049 unsigned memIndex = read(); 1050 Value operand = 1051 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1052 1053 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1054 << " * Index: " << index << "\n" 1055 << " * Result: " << operand << "\n"); 1056 memory[memIndex] = operand.getAsOpaquePointer(); 1057 } 1058 1059 void ByteCodeExecutor::executeGetResult(unsigned index) { 1060 Operation *op = read<Operation *>(); 1061 unsigned memIndex = read(); 1062 OpResult result = 1063 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1064 1065 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1066 << " * Index: " << index << "\n" 1067 << " * Result: " << result << "\n"); 1068 memory[memIndex] = result.getAsOpaquePointer(); 1069 } 1070 1071 void ByteCodeExecutor::executeGetValueType() { 1072 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1073 unsigned memIndex = read(); 1074 Value value = read<Value>(); 1075 Type type = value ? value.getType() : Type(); 1076 1077 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1078 << " * Result: " << type << "\n"); 1079 memory[memIndex] = type.getAsOpaquePointer(); 1080 } 1081 1082 void ByteCodeExecutor::executeIsNotNull() { 1083 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1084 const void *value = read<const void *>(); 1085 1086 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1087 selectJump(value != nullptr); 1088 } 1089 1090 void ByteCodeExecutor::executeRecordMatch( 1091 PatternRewriter &rewriter, 1092 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1093 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1094 unsigned patternIndex = read(); 1095 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1096 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1097 1098 // If the benefit of the pattern is impossible, skip the processing of the 1099 // rest of the pattern. 1100 if (benefit.isImpossibleToMatch()) { 1101 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1102 curCodeIt = dest; 1103 return; 1104 } 1105 1106 // Create a fused location containing the locations of each of the 1107 // operations used in the match. This will be used as the location for 1108 // created operations during the rewrite that don't already have an 1109 // explicit location set. 1110 unsigned numMatchLocs = read(); 1111 SmallVector<Location, 4> matchLocs; 1112 matchLocs.reserve(numMatchLocs); 1113 for (unsigned i = 0; i != numMatchLocs; ++i) 1114 matchLocs.push_back(read<Operation *>()->getLoc()); 1115 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1116 1117 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1118 << " * Location: " << matchLoc << "\n"); 1119 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1120 readList<const void *>(matches.back().values); 1121 curCodeIt = dest; 1122 } 1123 1124 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1125 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1126 Operation *op = read<Operation *>(); 1127 SmallVector<Value, 16> args; 1128 readList<Value>(args); 1129 1130 LLVM_DEBUG({ 1131 llvm::dbgs() << " * Operation: " << *op << "\n" 1132 << " * Values: "; 1133 llvm::interleaveComma(args, llvm::dbgs()); 1134 llvm::dbgs() << "\n"; 1135 }); 1136 rewriter.replaceOp(op, args); 1137 } 1138 1139 void ByteCodeExecutor::executeSwitchAttribute() { 1140 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1141 Attribute value = read<Attribute>(); 1142 ArrayAttr cases = read<ArrayAttr>(); 1143 handleSwitch(value, cases); 1144 } 1145 1146 void ByteCodeExecutor::executeSwitchOperandCount() { 1147 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1148 Operation *op = read<Operation *>(); 1149 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1150 1151 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1152 handleSwitch(op->getNumOperands(), cases); 1153 } 1154 1155 void ByteCodeExecutor::executeSwitchOperationName() { 1156 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1157 OperationName value = read<Operation *>()->getName(); 1158 size_t caseCount = read(); 1159 1160 // The operation names are stored in-line, so to print them out for 1161 // debugging purposes we need to read the array before executing the 1162 // switch so that we can display all of the possible values. 1163 LLVM_DEBUG({ 1164 const ByteCodeField *prevCodeIt = curCodeIt; 1165 llvm::dbgs() << " * Value: " << value << "\n" 1166 << " * Cases: "; 1167 llvm::interleaveComma( 1168 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1169 [&](size_t) { return read<OperationName>(); }), 1170 llvm::dbgs()); 1171 llvm::dbgs() << "\n"; 1172 curCodeIt = prevCodeIt; 1173 }); 1174 1175 // Try to find the switch value within any of the cases. 1176 for (size_t i = 0; i != caseCount; ++i) { 1177 if (read<OperationName>() == value) { 1178 curCodeIt += (caseCount - i - 1); 1179 return selectJump(i + 1); 1180 } 1181 } 1182 selectJump(size_t(0)); 1183 } 1184 1185 void ByteCodeExecutor::executeSwitchResultCount() { 1186 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1187 Operation *op = read<Operation *>(); 1188 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1189 1190 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1191 handleSwitch(op->getNumResults(), cases); 1192 } 1193 1194 void ByteCodeExecutor::executeSwitchType() { 1195 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 1196 Type value = read<Type>(); 1197 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 1198 handleSwitch(value, cases); 1199 } 1200 1201 void ByteCodeExecutor::execute( 1202 PatternRewriter &rewriter, 1203 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 1204 Optional<Location> mainRewriteLoc) { 1205 while (true) { 1206 OpCode opCode = static_cast<OpCode>(read()); 1207 switch (opCode) { 1208 case ApplyConstraint: 1209 executeApplyConstraint(rewriter); 1210 break; 1211 case ApplyRewrite: 1212 executeApplyRewrite(rewriter); 1213 break; 1214 case AreEqual: 1215 executeAreEqual(); 1216 break; 1217 case Branch: 1218 executeBranch(); 1219 break; 1220 case CheckOperandCount: 1221 executeCheckOperandCount(); 1222 break; 1223 case CheckOperationName: 1224 executeCheckOperationName(); 1225 break; 1226 case CheckResultCount: 1227 executeCheckResultCount(); 1228 break; 1229 case CreateOperation: 1230 executeCreateOperation(rewriter, *mainRewriteLoc); 1231 break; 1232 case EraseOp: 1233 executeEraseOp(rewriter); 1234 break; 1235 case Finalize: 1236 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); 1237 return; 1238 case GetAttribute: 1239 executeGetAttribute(); 1240 break; 1241 case GetAttributeType: 1242 executeGetAttributeType(); 1243 break; 1244 case GetDefiningOp: 1245 executeGetDefiningOp(); 1246 break; 1247 case GetOperand0: 1248 case GetOperand1: 1249 case GetOperand2: 1250 case GetOperand3: { 1251 unsigned index = opCode - GetOperand0; 1252 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 1253 executeGetOperand(index); 1254 break; 1255 } 1256 case GetOperandN: 1257 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 1258 executeGetOperand(read<uint32_t>()); 1259 break; 1260 case GetResult0: 1261 case GetResult1: 1262 case GetResult2: 1263 case GetResult3: { 1264 unsigned index = opCode - GetResult0; 1265 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 1266 executeGetResult(index); 1267 break; 1268 } 1269 case GetResultN: 1270 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 1271 executeGetResult(read<uint32_t>()); 1272 break; 1273 case GetValueType: 1274 executeGetValueType(); 1275 break; 1276 case IsNotNull: 1277 executeIsNotNull(); 1278 break; 1279 case RecordMatch: 1280 assert(matches && 1281 "expected matches to be provided when executing the matcher"); 1282 executeRecordMatch(rewriter, *matches); 1283 break; 1284 case ReplaceOp: 1285 executeReplaceOp(rewriter); 1286 break; 1287 case SwitchAttribute: 1288 executeSwitchAttribute(); 1289 break; 1290 case SwitchOperandCount: 1291 executeSwitchOperandCount(); 1292 break; 1293 case SwitchOperationName: 1294 executeSwitchOperationName(); 1295 break; 1296 case SwitchResultCount: 1297 executeSwitchResultCount(); 1298 break; 1299 case SwitchType: 1300 executeSwitchType(); 1301 break; 1302 } 1303 LLVM_DEBUG(llvm::dbgs() << "\n"); 1304 } 1305 } 1306 1307 /// Run the pattern matcher on the given root operation, collecting the matched 1308 /// patterns in `matches`. 1309 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 1310 SmallVectorImpl<MatchResult> &matches, 1311 PDLByteCodeMutableState &state) const { 1312 // The first memory slot is always the root operation. 1313 state.memory[0] = op; 1314 1315 // The matcher function always starts at code address 0. 1316 ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, 1317 matcherByteCode, state.currentPatternBenefits, 1318 patterns, constraintFunctions, rewriteFunctions); 1319 executor.execute(rewriter, &matches); 1320 1321 // Order the found matches by benefit. 1322 std::stable_sort(matches.begin(), matches.end(), 1323 [](const MatchResult &lhs, const MatchResult &rhs) { 1324 return lhs.benefit > rhs.benefit; 1325 }); 1326 } 1327 1328 /// Run the rewriter of the given pattern on the root operation `op`. 1329 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 1330 PDLByteCodeMutableState &state) const { 1331 // The arguments of the rewrite function are stored at the start of the 1332 // memory buffer. 1333 llvm::copy(match.values, state.memory.begin()); 1334 1335 ByteCodeExecutor executor(&rewriterByteCode[match.pattern->getRewriterAddr()], 1336 state.memory, uniquedData, rewriterByteCode, 1337 state.currentPatternBenefits, patterns, 1338 constraintFunctions, rewriteFunctions); 1339 executor.execute(rewriter, /*matches=*/nullptr, match.location); 1340 } 1341