1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 27 #define DEBUG_TYPE "pdl-bytecode" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37 ByteCodeAddr rewriterAddr) { 38 SmallVector<StringRef, 8> generatedOps; 39 if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 40 generatedOps = 41 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42 43 PatternBenefit benefit = matchOp.benefit(); 44 MLIRContext *ctx = matchOp.getContext(); 45 46 // Check to see if this is pattern matches a specific operation type. 47 if (Optional<StringRef> rootKind = matchOp.rootKind()) 48 return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, 49 generatedOps); 50 return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, 51 generatedOps); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // PDLByteCodeMutableState 56 //===----------------------------------------------------------------------===// 57 58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59 /// to the position of the pattern within the range returned by 60 /// `PDLByteCode::getPatterns`. 61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62 PatternBenefit benefit) { 63 currentPatternBenefits[patternIndex] = benefit; 64 } 65 66 /// Cleanup any allocated state after a full match/rewrite has been completed. 67 /// This method should be called irregardless of whether the match+rewrite was a 68 /// success or not. 69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 70 allocatedTypeRangeMemory.clear(); 71 allocatedValueRangeMemory.clear(); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Bytecode OpCodes 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 enum OpCode : ByteCodeField { 80 /// Apply an externally registered constraint. 81 ApplyConstraint, 82 /// Apply an externally registered rewrite. 83 ApplyRewrite, 84 /// Check if two generic values are equal. 85 AreEqual, 86 /// Check if two ranges are equal. 87 AreRangesEqual, 88 /// Unconditional branch. 89 Branch, 90 /// Compare the operand count of an operation with a constant. 91 CheckOperandCount, 92 /// Compare the name of an operation with a constant. 93 CheckOperationName, 94 /// Compare the result count of an operation with a constant. 95 CheckResultCount, 96 /// Compare a range of types to a constant range of types. 97 CheckTypes, 98 /// Continue to the next iteration of a loop. 99 Continue, 100 /// Create an operation. 101 CreateOperation, 102 /// Create a range of types. 103 CreateTypes, 104 /// Erase an operation. 105 EraseOp, 106 /// Extract the op from a range at the specified index. 107 ExtractOp, 108 /// Extract the type from a range at the specified index. 109 ExtractType, 110 /// Extract the value from a range at the specified index. 111 ExtractValue, 112 /// Terminate a matcher or rewrite sequence. 113 Finalize, 114 /// Iterate over a range of values. 115 ForEach, 116 /// Get a specific attribute of an operation. 117 GetAttribute, 118 /// Get the type of an attribute. 119 GetAttributeType, 120 /// Get the defining operation of a value. 121 GetDefiningOp, 122 /// Get a specific operand of an operation. 123 GetOperand0, 124 GetOperand1, 125 GetOperand2, 126 GetOperand3, 127 GetOperandN, 128 /// Get a specific operand group of an operation. 129 GetOperands, 130 /// Get a specific result of an operation. 131 GetResult0, 132 GetResult1, 133 GetResult2, 134 GetResult3, 135 GetResultN, 136 /// Get a specific result group of an operation. 137 GetResults, 138 /// Get the users of a value or a range of values. 139 GetUsers, 140 /// Get the type of a value. 141 GetValueType, 142 /// Get the types of a value range. 143 GetValueRangeTypes, 144 /// Check if a generic value is not null. 145 IsNotNull, 146 /// Record a successful pattern match. 147 RecordMatch, 148 /// Replace an operation. 149 ReplaceOp, 150 /// Compare an attribute with a set of constants. 151 SwitchAttribute, 152 /// Compare the operand count of an operation with a set of constants. 153 SwitchOperandCount, 154 /// Compare the name of an operation with a set of constants. 155 SwitchOperationName, 156 /// Compare the result count of an operation with a set of constants. 157 SwitchResultCount, 158 /// Compare a type with a set of constants. 159 SwitchType, 160 /// Compare a range of types with a set of constants. 161 SwitchTypes, 162 }; 163 } // namespace 164 165 //===----------------------------------------------------------------------===// 166 // ByteCode Generation 167 //===----------------------------------------------------------------------===// 168 169 //===----------------------------------------------------------------------===// 170 // Generator 171 172 namespace { 173 struct ByteCodeLiveRange; 174 struct ByteCodeWriter; 175 176 /// Check if the given class `T` can be converted to an opaque pointer. 177 template <typename T, typename... Args> 178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 179 180 /// This class represents the main generator for the pattern bytecode. 181 class Generator { 182 public: 183 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 184 SmallVectorImpl<ByteCodeField> &matcherByteCode, 185 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 186 SmallVectorImpl<PDLByteCodePattern> &patterns, 187 ByteCodeField &maxValueMemoryIndex, 188 ByteCodeField &maxOpRangeMemoryIndex, 189 ByteCodeField &maxTypeRangeMemoryIndex, 190 ByteCodeField &maxValueRangeMemoryIndex, 191 ByteCodeField &maxLoopLevel, 192 llvm::StringMap<PDLConstraintFunction> &constraintFns, 193 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 194 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 195 rewriterByteCode(rewriterByteCode), patterns(patterns), 196 maxValueMemoryIndex(maxValueMemoryIndex), 197 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), 198 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 199 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), 200 maxLoopLevel(maxLoopLevel) { 201 for (const auto &it : llvm::enumerate(constraintFns)) 202 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 203 for (const auto &it : llvm::enumerate(rewriteFns)) 204 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 205 } 206 207 /// Generate the bytecode for the given PDL interpreter module. 208 void generate(ModuleOp module); 209 210 /// Return the memory index to use for the given value. 211 ByteCodeField &getMemIndex(Value value) { 212 assert(valueToMemIndex.count(value) && 213 "expected memory index to be assigned"); 214 return valueToMemIndex[value]; 215 } 216 217 /// Return the range memory index used to store the given range value. 218 ByteCodeField &getRangeStorageIndex(Value value) { 219 assert(valueToRangeIndex.count(value) && 220 "expected range index to be assigned"); 221 return valueToRangeIndex[value]; 222 } 223 224 /// Return an index to use when referring to the given data that is uniqued in 225 /// the MLIR context. 226 template <typename T> 227 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 228 getMemIndex(T val) { 229 const void *opaqueVal = val.getAsOpaquePointer(); 230 231 // Get or insert a reference to this value. 232 auto it = uniquedDataToMemIndex.try_emplace( 233 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 234 if (it.second) 235 uniquedData.push_back(opaqueVal); 236 return it.first->second; 237 } 238 239 private: 240 /// Allocate memory indices for the results of operations within the matcher 241 /// and rewriters. 242 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 243 ModuleOp rewriterModule); 244 245 /// Generate the bytecode for the given operation. 246 void generate(Region *region, ByteCodeWriter &writer); 247 void generate(Operation *op, ByteCodeWriter &writer); 248 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 249 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 250 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 251 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 252 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 253 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 254 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 255 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 256 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 257 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 258 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); 259 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 260 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 261 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 262 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 263 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 264 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); 265 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 266 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); 267 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 268 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 269 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 270 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 271 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 272 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 273 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 274 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); 275 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 276 void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); 277 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 278 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 279 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 280 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 281 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 282 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 283 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 284 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 285 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 286 287 /// Mapping from value to its corresponding memory index. 288 DenseMap<Value, ByteCodeField> valueToMemIndex; 289 290 /// Mapping from a range value to its corresponding range storage index. 291 DenseMap<Value, ByteCodeField> valueToRangeIndex; 292 293 /// Mapping from the name of an externally registered rewrite to its index in 294 /// the bytecode registry. 295 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 296 297 /// Mapping from the name of an externally registered constraint to its index 298 /// in the bytecode registry. 299 llvm::StringMap<ByteCodeField> constraintToMemIndex; 300 301 /// Mapping from rewriter function name to the bytecode address of the 302 /// rewriter function in byte. 303 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 304 305 /// Mapping from a uniqued storage object to its memory index within 306 /// `uniquedData`. 307 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 308 309 /// The current level of the foreach loop. 310 ByteCodeField curLoopLevel = 0; 311 312 /// The current MLIR context. 313 MLIRContext *ctx; 314 315 /// Mapping from block to its address. 316 DenseMap<Block *, ByteCodeAddr> blockToAddr; 317 318 /// Data of the ByteCode class to be populated. 319 std::vector<const void *> &uniquedData; 320 SmallVectorImpl<ByteCodeField> &matcherByteCode; 321 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 322 SmallVectorImpl<PDLByteCodePattern> &patterns; 323 ByteCodeField &maxValueMemoryIndex; 324 ByteCodeField &maxOpRangeMemoryIndex; 325 ByteCodeField &maxTypeRangeMemoryIndex; 326 ByteCodeField &maxValueRangeMemoryIndex; 327 ByteCodeField &maxLoopLevel; 328 }; 329 330 /// This class provides utilities for writing a bytecode stream. 331 struct ByteCodeWriter { 332 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 333 : bytecode(bytecode), generator(generator) {} 334 335 /// Append a field to the bytecode. 336 void append(ByteCodeField field) { bytecode.push_back(field); } 337 void append(OpCode opCode) { bytecode.push_back(opCode); } 338 339 /// Append an address to the bytecode. 340 void append(ByteCodeAddr field) { 341 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 342 "unexpected ByteCode address size"); 343 344 ByteCodeField fieldParts[2]; 345 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 346 bytecode.append({fieldParts[0], fieldParts[1]}); 347 } 348 349 /// Append a single successor to the bytecode, the exact address will need to 350 /// be resolved later. 351 void append(Block *successor) { 352 // Add back a reference to the successor so that the address can be resolved 353 // later. 354 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 355 append(ByteCodeAddr(0)); 356 } 357 358 /// Append a successor range to the bytecode, the exact address will need to 359 /// be resolved later. 360 void append(SuccessorRange successors) { 361 for (Block *successor : successors) 362 append(successor); 363 } 364 365 /// Append a range of values that will be read as generic PDLValues. 366 void appendPDLValueList(OperandRange values) { 367 bytecode.push_back(values.size()); 368 for (Value value : values) 369 appendPDLValue(value); 370 } 371 372 /// Append a value as a PDLValue. 373 void appendPDLValue(Value value) { 374 appendPDLValueKind(value); 375 append(value); 376 } 377 378 /// Append the PDLValue::Kind of the given value. 379 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } 380 381 /// Append the PDLValue::Kind of the given type. 382 void appendPDLValueKind(Type type) { 383 PDLValue::Kind kind = 384 TypeSwitch<Type, PDLValue::Kind>(type) 385 .Case<pdl::AttributeType>( 386 [](Type) { return PDLValue::Kind::Attribute; }) 387 .Case<pdl::OperationType>( 388 [](Type) { return PDLValue::Kind::Operation; }) 389 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 390 if (rangeTy.getElementType().isa<pdl::TypeType>()) 391 return PDLValue::Kind::TypeRange; 392 return PDLValue::Kind::ValueRange; 393 }) 394 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 395 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 396 bytecode.push_back(static_cast<ByteCodeField>(kind)); 397 } 398 399 /// Append a value that will be stored in a memory slot and not inline within 400 /// the bytecode. 401 template <typename T> 402 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 403 std::is_pointer<T>::value> 404 append(T value) { 405 bytecode.push_back(generator.getMemIndex(value)); 406 } 407 408 /// Append a range of values. 409 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 410 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 411 append(T range) { 412 bytecode.push_back(llvm::size(range)); 413 for (auto it : range) 414 append(it); 415 } 416 417 /// Append a variadic number of fields to the bytecode. 418 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 419 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 420 append(field); 421 append(field2, fields...); 422 } 423 424 /// Appends a value as a pointer, stored inline within the bytecode. 425 template <typename T> 426 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 427 appendInline(T value) { 428 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); 429 const void *pointer = value.getAsOpaquePointer(); 430 ByteCodeField fieldParts[numParts]; 431 std::memcpy(fieldParts, &pointer, sizeof(const void *)); 432 bytecode.append(fieldParts, fieldParts + numParts); 433 } 434 435 /// Successor references in the bytecode that have yet to be resolved. 436 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 437 438 /// The underlying bytecode buffer. 439 SmallVectorImpl<ByteCodeField> &bytecode; 440 441 /// The main generator producing PDL. 442 Generator &generator; 443 }; 444 445 /// This class represents a live range of PDL Interpreter values, containing 446 /// information about when values are live within a match/rewrite. 447 struct ByteCodeLiveRange { 448 using Set = llvm::IntervalMap<uint64_t, char, 16>; 449 using Allocator = Set::Allocator; 450 451 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} 452 453 /// Union this live range with the one provided. 454 void unionWith(const ByteCodeLiveRange &rhs) { 455 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; 456 ++it) 457 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); 458 } 459 460 /// Returns true if this range overlaps with the one provided. 461 bool overlaps(const ByteCodeLiveRange &rhs) const { 462 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness) 463 .valid(); 464 } 465 466 /// A map representing the ranges of the match/rewrite that a value is live in 467 /// the interpreter. 468 /// 469 /// We use std::unique_ptr here, because IntervalMap does not provide a 470 /// correct copy or move constructor. We can eliminate the pointer once 471 /// https://reviews.llvm.org/D113240 lands. 472 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness; 473 474 /// The operation range storage index for this range. 475 Optional<unsigned> opRangeIndex; 476 477 /// The type range storage index for this range. 478 Optional<unsigned> typeRangeIndex; 479 480 /// The value range storage index for this range. 481 Optional<unsigned> valueRangeIndex; 482 }; 483 } // namespace 484 485 void Generator::generate(ModuleOp module) { 486 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>( 487 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 488 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 489 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 490 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 491 492 // Allocate memory indices for the results of operations within the matcher 493 // and rewriters. 494 allocateMemoryIndices(matcherFunc, rewriterModule); 495 496 // Generate code for the rewriter functions. 497 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 498 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 499 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 500 for (Operation &op : rewriterFunc.getOps()) 501 generate(&op, rewriterByteCodeWriter); 502 } 503 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 504 "unexpected branches in rewriter function"); 505 506 // Generate code for the matcher function. 507 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 508 generate(&matcherFunc.getBody(), matcherByteCodeWriter); 509 510 // Resolve successor references in the matcher. 511 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 512 ByteCodeAddr addr = blockToAddr[it.first]; 513 for (unsigned offsetToFix : it.second) 514 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 515 } 516 } 517 518 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 519 ModuleOp rewriterModule) { 520 // Rewriters use simplistic allocation scheme that simply assigns an index to 521 // each result. 522 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 523 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 524 auto processRewriterValue = [&](Value val) { 525 valueToMemIndex.try_emplace(val, index++); 526 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 527 Type elementTy = rangeType.getElementType(); 528 if (elementTy.isa<pdl::TypeType>()) 529 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 530 else if (elementTy.isa<pdl::ValueType>()) 531 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 532 } 533 }; 534 535 for (BlockArgument arg : rewriterFunc.getArguments()) 536 processRewriterValue(arg); 537 rewriterFunc.getBody().walk([&](Operation *op) { 538 for (Value result : op->getResults()) 539 processRewriterValue(result); 540 }); 541 if (index > maxValueMemoryIndex) 542 maxValueMemoryIndex = index; 543 if (typeRangeIndex > maxTypeRangeMemoryIndex) 544 maxTypeRangeMemoryIndex = typeRangeIndex; 545 if (valueRangeIndex > maxValueRangeMemoryIndex) 546 maxValueRangeMemoryIndex = valueRangeIndex; 547 } 548 549 // The matcher function uses a more sophisticated numbering that tries to 550 // minimize the number of memory indices assigned. This is done by determining 551 // a live range of the values within the matcher, then the allocation is just 552 // finding the minimal number of overlapping live ranges. This is essentially 553 // a simplified form of register allocation where we don't necessarily have a 554 // limited number of registers, but we still want to minimize the number used. 555 DenseMap<Operation *, unsigned> opToFirstIndex; 556 DenseMap<Operation *, unsigned> opToLastIndex; 557 558 // A custom walk that marks the first and the last index of each operation. 559 // The entry marks the beginning of the liveness range for this operation, 560 // followed by nested operations, followed by the end of the liveness range. 561 unsigned index = 0; 562 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) { 563 opToFirstIndex.try_emplace(op, index++); 564 for (Region ®ion : op->getRegions()) 565 for (Block &block : region.getBlocks()) 566 for (Operation &nested : block) 567 walk(&nested); 568 opToLastIndex.try_emplace(op, index++); 569 }; 570 walk(matcherFunc); 571 572 // Liveness info for each of the defs within the matcher. 573 ByteCodeLiveRange::Allocator allocator; 574 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 575 576 // Assign the root operation being matched to slot 0. 577 BlockArgument rootOpArg = matcherFunc.getArgument(0); 578 valueToMemIndex[rootOpArg] = 0; 579 580 // Walk each of the blocks, computing the def interval that the value is used. 581 Liveness matcherLiveness(matcherFunc); 582 matcherFunc->walk([&](Block *block) { 583 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); 584 assert(info && "expected liveness info for block"); 585 auto processValue = [&](Value value, Operation *firstUseOrDef) { 586 // We don't need to process the root op argument, this value is always 587 // assigned to the first memory slot. 588 if (value == rootOpArg) 589 return; 590 591 // Set indices for the range of this block that the value is used. 592 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 593 defRangeIt->second.liveness->insert( 594 opToFirstIndex[firstUseOrDef], 595 opToLastIndex[info->getEndOperation(value, firstUseOrDef)], 596 /*dummyValue*/ 0); 597 598 // Check to see if this value is a range type. 599 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 600 Type eleType = rangeTy.getElementType(); 601 if (eleType.isa<pdl::OperationType>()) 602 defRangeIt->second.opRangeIndex = 0; 603 else if (eleType.isa<pdl::TypeType>()) 604 defRangeIt->second.typeRangeIndex = 0; 605 else if (eleType.isa<pdl::ValueType>()) 606 defRangeIt->second.valueRangeIndex = 0; 607 } 608 }; 609 610 // Process the live-ins of this block. 611 for (Value liveIn : info->in()) { 612 // Only process the value if it has been defined in the current region. 613 // Other values that span across pdl_interp.foreach will be added higher 614 // up. This ensures that the we keep them alive for the entire duration 615 // of the loop. 616 if (liveIn.getParentRegion() == block->getParent()) 617 processValue(liveIn, &block->front()); 618 } 619 620 // Process the block arguments for the entry block (those are not live-in). 621 if (block->isEntryBlock()) { 622 for (Value argument : block->getArguments()) 623 processValue(argument, &block->front()); 624 } 625 626 // Process any new defs within this block. 627 for (Operation &op : *block) 628 for (Value result : op.getResults()) 629 processValue(result, &op); 630 }); 631 632 // Greedily allocate memory slots using the computed def live ranges. 633 std::vector<ByteCodeLiveRange> allocatedIndices; 634 635 // The number of memory indices currently allocated (and its next value). 636 // Recall that the root gets allocated memory index 0. 637 ByteCodeField numIndices = 1; 638 639 // The number of memory ranges of various types (and their next values). 640 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; 641 642 for (auto &defIt : valueDefRanges) { 643 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 644 ByteCodeLiveRange &defRange = defIt.second; 645 646 // Try to allocate to an existing index. 647 for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) { 648 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 649 if (!defRange.overlaps(existingRange)) { 650 existingRange.unionWith(defRange); 651 memIndex = existingIndexIt.index() + 1; 652 653 if (defRange.opRangeIndex) { 654 if (!existingRange.opRangeIndex) 655 existingRange.opRangeIndex = numOpRanges++; 656 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; 657 } else if (defRange.typeRangeIndex) { 658 if (!existingRange.typeRangeIndex) 659 existingRange.typeRangeIndex = numTypeRanges++; 660 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 661 } else if (defRange.valueRangeIndex) { 662 if (!existingRange.valueRangeIndex) 663 existingRange.valueRangeIndex = numValueRanges++; 664 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 665 } 666 break; 667 } 668 } 669 670 // If no existing index could be used, add a new one. 671 if (memIndex == 0) { 672 allocatedIndices.emplace_back(allocator); 673 ByteCodeLiveRange &newRange = allocatedIndices.back(); 674 newRange.unionWith(defRange); 675 676 // Allocate an index for op/type/value ranges. 677 if (defRange.opRangeIndex) { 678 newRange.opRangeIndex = numOpRanges; 679 valueToRangeIndex[defIt.first] = numOpRanges++; 680 } else if (defRange.typeRangeIndex) { 681 newRange.typeRangeIndex = numTypeRanges; 682 valueToRangeIndex[defIt.first] = numTypeRanges++; 683 } else if (defRange.valueRangeIndex) { 684 newRange.valueRangeIndex = numValueRanges; 685 valueToRangeIndex[defIt.first] = numValueRanges++; 686 } 687 688 memIndex = allocatedIndices.size(); 689 ++numIndices; 690 } 691 } 692 693 // Print the index usage and ensure that we did not run out of index space. 694 LLVM_DEBUG({ 695 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " 696 << "(down from initial " << valueDefRanges.size() << ").\n"; 697 }); 698 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && 699 "Ran out of memory for allocated indices"); 700 701 // Update the max number of indices. 702 if (numIndices > maxValueMemoryIndex) 703 maxValueMemoryIndex = numIndices; 704 if (numOpRanges > maxOpRangeMemoryIndex) 705 maxOpRangeMemoryIndex = numOpRanges; 706 if (numTypeRanges > maxTypeRangeMemoryIndex) 707 maxTypeRangeMemoryIndex = numTypeRanges; 708 if (numValueRanges > maxValueRangeMemoryIndex) 709 maxValueRangeMemoryIndex = numValueRanges; 710 } 711 712 void Generator::generate(Region *region, ByteCodeWriter &writer) { 713 llvm::ReversePostOrderTraversal<Region *> rpot(region); 714 for (Block *block : rpot) { 715 // Keep track of where this block begins within the matcher function. 716 blockToAddr.try_emplace(block, matcherByteCode.size()); 717 for (Operation &op : *block) 718 generate(&op, writer); 719 } 720 } 721 722 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 723 LLVM_DEBUG({ 724 // The following list must contain all the operations that do not 725 // produce any bytecode. 726 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp, 727 pdl_interp::InferredTypesOp>(op)) 728 writer.appendInline(op->getLoc()); 729 }); 730 TypeSwitch<Operation *>(op) 731 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 732 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 733 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 734 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 735 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 736 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, 737 pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, 738 pdl_interp::CreateTypesOp, pdl_interp::EraseOp, 739 pdl_interp::ExtractOp, pdl_interp::FinalizeOp, 740 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, 741 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 742 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, 743 pdl_interp::GetResultOp, pdl_interp::GetResultsOp, 744 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, 745 pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, 746 pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, 747 pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, 748 pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, 749 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 750 [&](auto interpOp) { this->generate(interpOp, writer); }) 751 .Default([](Operation *) { 752 llvm_unreachable("unknown `pdl_interp` operation"); 753 }); 754 } 755 756 void Generator::generate(pdl_interp::ApplyConstraintOp op, 757 ByteCodeWriter &writer) { 758 assert(constraintToMemIndex.count(op.name()) && 759 "expected index for constraint function"); 760 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 761 op.constParamsAttr()); 762 writer.appendPDLValueList(op.args()); 763 writer.append(op.getSuccessors()); 764 } 765 void Generator::generate(pdl_interp::ApplyRewriteOp op, 766 ByteCodeWriter &writer) { 767 assert(externalRewriterToMemIndex.count(op.name()) && 768 "expected index for rewrite function"); 769 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 770 op.constParamsAttr()); 771 writer.appendPDLValueList(op.args()); 772 773 ResultRange results = op.results(); 774 writer.append(ByteCodeField(results.size())); 775 for (Value result : results) { 776 // In debug mode we also record the expected kind of the result, so that we 777 // can provide extra verification of the native rewrite function. 778 #ifndef NDEBUG 779 writer.appendPDLValueKind(result); 780 #endif 781 782 // Range results also need to append the range storage index. 783 if (result.getType().isa<pdl::RangeType>()) 784 writer.append(getRangeStorageIndex(result)); 785 writer.append(result); 786 } 787 } 788 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 789 Value lhs = op.lhs(); 790 if (lhs.getType().isa<pdl::RangeType>()) { 791 writer.append(OpCode::AreRangesEqual); 792 writer.appendPDLValueKind(lhs); 793 writer.append(op.lhs(), op.rhs(), op.getSuccessors()); 794 return; 795 } 796 797 writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors()); 798 } 799 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 800 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 801 } 802 void Generator::generate(pdl_interp::CheckAttributeOp op, 803 ByteCodeWriter &writer) { 804 writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 805 op.getSuccessors()); 806 } 807 void Generator::generate(pdl_interp::CheckOperandCountOp op, 808 ByteCodeWriter &writer) { 809 writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 810 static_cast<ByteCodeField>(op.compareAtLeast()), 811 op.getSuccessors()); 812 } 813 void Generator::generate(pdl_interp::CheckOperationNameOp op, 814 ByteCodeWriter &writer) { 815 writer.append(OpCode::CheckOperationName, op.operation(), 816 OperationName(op.name(), ctx), op.getSuccessors()); 817 } 818 void Generator::generate(pdl_interp::CheckResultCountOp op, 819 ByteCodeWriter &writer) { 820 writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 821 static_cast<ByteCodeField>(op.compareAtLeast()), 822 op.getSuccessors()); 823 } 824 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 825 writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 826 } 827 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 828 writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); 829 } 830 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { 831 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); 832 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); 833 } 834 void Generator::generate(pdl_interp::CreateAttributeOp op, 835 ByteCodeWriter &writer) { 836 // Simply repoint the memory index of the result to the constant. 837 getMemIndex(op.attribute()) = getMemIndex(op.value()); 838 } 839 void Generator::generate(pdl_interp::CreateOperationOp op, 840 ByteCodeWriter &writer) { 841 writer.append(OpCode::CreateOperation, op.operation(), 842 OperationName(op.name(), ctx)); 843 writer.appendPDLValueList(op.operands()); 844 845 // Add the attributes. 846 OperandRange attributes = op.attributes(); 847 writer.append(static_cast<ByteCodeField>(attributes.size())); 848 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) 849 writer.append(std::get<0>(it), std::get<1>(it)); 850 writer.appendPDLValueList(op.types()); 851 } 852 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 853 // Simply repoint the memory index of the result to the constant. 854 getMemIndex(op.result()) = getMemIndex(op.value()); 855 } 856 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 857 writer.append(OpCode::CreateTypes, op.result(), 858 getRangeStorageIndex(op.result()), op.value()); 859 } 860 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 861 writer.append(OpCode::EraseOp, op.operation()); 862 } 863 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { 864 OpCode opCode = 865 TypeSwitch<Type, OpCode>(op.result().getType()) 866 .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) 867 .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) 868 .Case([](pdl::TypeType) { return OpCode::ExtractType; }) 869 .Default([](Type) -> OpCode { 870 llvm_unreachable("unsupported element type"); 871 }); 872 writer.append(opCode, op.range(), op.index(), op.result()); 873 } 874 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 875 writer.append(OpCode::Finalize); 876 } 877 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { 878 BlockArgument arg = op.getLoopVariable(); 879 writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg); 880 writer.appendPDLValueKind(arg.getType()); 881 writer.append(curLoopLevel, op.successor()); 882 ++curLoopLevel; 883 if (curLoopLevel > maxLoopLevel) 884 maxLoopLevel = curLoopLevel; 885 generate(&op.region(), writer); 886 --curLoopLevel; 887 } 888 void Generator::generate(pdl_interp::GetAttributeOp op, 889 ByteCodeWriter &writer) { 890 writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 891 op.nameAttr()); 892 } 893 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 894 ByteCodeWriter &writer) { 895 writer.append(OpCode::GetAttributeType, op.result(), op.value()); 896 } 897 void Generator::generate(pdl_interp::GetDefiningOpOp op, 898 ByteCodeWriter &writer) { 899 writer.append(OpCode::GetDefiningOp, op.operation()); 900 writer.appendPDLValue(op.value()); 901 } 902 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 903 uint32_t index = op.index(); 904 if (index < 4) 905 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 906 else 907 writer.append(OpCode::GetOperandN, index); 908 writer.append(op.operation(), op.value()); 909 } 910 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 911 Value result = op.value(); 912 Optional<uint32_t> index = op.index(); 913 writer.append(OpCode::GetOperands, 914 index.getValueOr(std::numeric_limits<uint32_t>::max()), 915 op.operation()); 916 if (result.getType().isa<pdl::RangeType>()) 917 writer.append(getRangeStorageIndex(result)); 918 else 919 writer.append(std::numeric_limits<ByteCodeField>::max()); 920 writer.append(result); 921 } 922 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 923 uint32_t index = op.index(); 924 if (index < 4) 925 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 926 else 927 writer.append(OpCode::GetResultN, index); 928 writer.append(op.operation(), op.value()); 929 } 930 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 931 Value result = op.value(); 932 Optional<uint32_t> index = op.index(); 933 writer.append(OpCode::GetResults, 934 index.getValueOr(std::numeric_limits<uint32_t>::max()), 935 op.operation()); 936 if (result.getType().isa<pdl::RangeType>()) 937 writer.append(getRangeStorageIndex(result)); 938 else 939 writer.append(std::numeric_limits<ByteCodeField>::max()); 940 writer.append(result); 941 } 942 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { 943 Value operations = op.operations(); 944 ByteCodeField rangeIndex = getRangeStorageIndex(operations); 945 writer.append(OpCode::GetUsers, operations, rangeIndex); 946 writer.appendPDLValue(op.value()); 947 } 948 void Generator::generate(pdl_interp::GetValueTypeOp op, 949 ByteCodeWriter &writer) { 950 if (op.getType().isa<pdl::RangeType>()) { 951 Value result = op.result(); 952 writer.append(OpCode::GetValueRangeTypes, result, 953 getRangeStorageIndex(result), op.value()); 954 } else { 955 writer.append(OpCode::GetValueType, op.result(), op.value()); 956 } 957 } 958 959 void Generator::generate(pdl_interp::InferredTypesOp op, 960 ByteCodeWriter &writer) { 961 // InferType maps to a null type as a marker for inferring result types. 962 getMemIndex(op.type()) = getMemIndex(Type()); 963 } 964 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 965 writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 966 } 967 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 968 ByteCodeField patternIndex = patterns.size(); 969 patterns.emplace_back(PDLByteCodePattern::create( 970 op, rewriterToAddr[op.rewriter().getLeafReference().getValue()])); 971 writer.append(OpCode::RecordMatch, patternIndex, 972 SuccessorRange(op.getOperation()), op.matchedOps()); 973 writer.appendPDLValueList(op.inputs()); 974 } 975 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 976 writer.append(OpCode::ReplaceOp, op.operation()); 977 writer.appendPDLValueList(op.replValues()); 978 } 979 void Generator::generate(pdl_interp::SwitchAttributeOp op, 980 ByteCodeWriter &writer) { 981 writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 982 op.getSuccessors()); 983 } 984 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 985 ByteCodeWriter &writer) { 986 writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 987 op.getSuccessors()); 988 } 989 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 990 ByteCodeWriter &writer) { 991 auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 992 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 993 }); 994 writer.append(OpCode::SwitchOperationName, op.operation(), cases, 995 op.getSuccessors()); 996 } 997 void Generator::generate(pdl_interp::SwitchResultCountOp op, 998 ByteCodeWriter &writer) { 999 writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 1000 op.getSuccessors()); 1001 } 1002 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 1003 writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 1004 op.getSuccessors()); 1005 } 1006 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 1007 writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(), 1008 op.getSuccessors()); 1009 } 1010 1011 //===----------------------------------------------------------------------===// 1012 // PDLByteCode 1013 //===----------------------------------------------------------------------===// 1014 1015 PDLByteCode::PDLByteCode(ModuleOp module, 1016 llvm::StringMap<PDLConstraintFunction> constraintFns, 1017 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 1018 Generator generator(module.getContext(), uniquedData, matcherByteCode, 1019 rewriterByteCode, patterns, maxValueMemoryIndex, 1020 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, 1021 maxLoopLevel, constraintFns, rewriteFns); 1022 generator.generate(module); 1023 1024 // Initialize the external functions. 1025 for (auto &it : constraintFns) 1026 constraintFunctions.push_back(std::move(it.second)); 1027 for (auto &it : rewriteFns) 1028 rewriteFunctions.push_back(std::move(it.second)); 1029 } 1030 1031 /// Initialize the given state such that it can be used to execute the current 1032 /// bytecode. 1033 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 1034 state.memory.resize(maxValueMemoryIndex, nullptr); 1035 state.opRangeMemory.resize(maxOpRangeCount); 1036 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 1037 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 1038 state.loopIndex.resize(maxLoopLevel, 0); 1039 state.currentPatternBenefits.reserve(patterns.size()); 1040 for (const PDLByteCodePattern &pattern : patterns) 1041 state.currentPatternBenefits.push_back(pattern.getBenefit()); 1042 } 1043 1044 //===----------------------------------------------------------------------===// 1045 // ByteCode Execution 1046 1047 namespace { 1048 /// This class provides support for executing a bytecode stream. 1049 class ByteCodeExecutor { 1050 public: 1051 ByteCodeExecutor( 1052 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 1053 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, 1054 MutableArrayRef<TypeRange> typeRangeMemory, 1055 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 1056 MutableArrayRef<ValueRange> valueRangeMemory, 1057 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 1058 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, 1059 ArrayRef<ByteCodeField> code, 1060 ArrayRef<PatternBenefit> currentPatternBenefits, 1061 ArrayRef<PDLByteCodePattern> patterns, 1062 ArrayRef<PDLConstraintFunction> constraintFunctions, 1063 ArrayRef<PDLRewriteFunction> rewriteFunctions) 1064 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), 1065 typeRangeMemory(typeRangeMemory), 1066 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 1067 valueRangeMemory(valueRangeMemory), 1068 allocatedValueRangeMemory(allocatedValueRangeMemory), 1069 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), 1070 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 1071 constraintFunctions(constraintFunctions), 1072 rewriteFunctions(rewriteFunctions) {} 1073 1074 /// Start executing the code at the current bytecode index. `matches` is an 1075 /// optional field provided when this function is executed in a matching 1076 /// context. 1077 void execute(PatternRewriter &rewriter, 1078 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 1079 Optional<Location> mainRewriteLoc = {}); 1080 1081 private: 1082 /// Internal implementation of executing each of the bytecode commands. 1083 void executeApplyConstraint(PatternRewriter &rewriter); 1084 void executeApplyRewrite(PatternRewriter &rewriter); 1085 void executeAreEqual(); 1086 void executeAreRangesEqual(); 1087 void executeBranch(); 1088 void executeCheckOperandCount(); 1089 void executeCheckOperationName(); 1090 void executeCheckResultCount(); 1091 void executeCheckTypes(); 1092 void executeContinue(); 1093 void executeCreateOperation(PatternRewriter &rewriter, 1094 Location mainRewriteLoc); 1095 void executeCreateTypes(); 1096 void executeEraseOp(PatternRewriter &rewriter); 1097 template <typename T, typename Range, PDLValue::Kind kind> 1098 void executeExtract(); 1099 void executeFinalize(); 1100 void executeForEach(); 1101 void executeGetAttribute(); 1102 void executeGetAttributeType(); 1103 void executeGetDefiningOp(); 1104 void executeGetOperand(unsigned index); 1105 void executeGetOperands(); 1106 void executeGetResult(unsigned index); 1107 void executeGetResults(); 1108 void executeGetUsers(); 1109 void executeGetValueType(); 1110 void executeGetValueRangeTypes(); 1111 void executeIsNotNull(); 1112 void executeRecordMatch(PatternRewriter &rewriter, 1113 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 1114 void executeReplaceOp(PatternRewriter &rewriter); 1115 void executeSwitchAttribute(); 1116 void executeSwitchOperandCount(); 1117 void executeSwitchOperationName(); 1118 void executeSwitchResultCount(); 1119 void executeSwitchType(); 1120 void executeSwitchTypes(); 1121 1122 /// Pushes a code iterator to the stack. 1123 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } 1124 1125 /// Pops a code iterator from the stack, returning true on success. 1126 void popCodeIt() { 1127 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); 1128 curCodeIt = resumeCodeIt.back(); 1129 resumeCodeIt.pop_back(); 1130 } 1131 1132 /// Return the bytecode iterator at the start of the current op code. 1133 const ByteCodeField *getPrevCodeIt() const { 1134 LLVM_DEBUG({ 1135 // Account for the op code and the Location stored inline. 1136 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField); 1137 }); 1138 1139 // Account for the op code only. 1140 return curCodeIt - 1; 1141 } 1142 1143 /// Read a value from the bytecode buffer, optionally skipping a certain 1144 /// number of prefix values. These methods always update the buffer to point 1145 /// to the next field after the read data. 1146 template <typename T = ByteCodeField> 1147 T read(size_t skipN = 0) { 1148 curCodeIt += skipN; 1149 return readImpl<T>(); 1150 } 1151 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 1152 1153 /// Read a list of values from the bytecode buffer. 1154 template <typename ValueT, typename T> 1155 void readList(SmallVectorImpl<T> &list) { 1156 list.clear(); 1157 for (unsigned i = 0, e = read(); i != e; ++i) 1158 list.push_back(read<ValueT>()); 1159 } 1160 1161 /// Read a list of values from the bytecode buffer. The values may be encoded 1162 /// as either Value or ValueRange elements. 1163 void readValueList(SmallVectorImpl<Value> &list) { 1164 for (unsigned i = 0, e = read(); i != e; ++i) { 1165 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1166 list.push_back(read<Value>()); 1167 } else { 1168 ValueRange *values = read<ValueRange *>(); 1169 list.append(values->begin(), values->end()); 1170 } 1171 } 1172 } 1173 1174 /// Read a value stored inline as a pointer. 1175 template <typename T> 1176 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T> 1177 readInline() { 1178 const void *pointer; 1179 std::memcpy(&pointer, curCodeIt, sizeof(const void *)); 1180 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); 1181 return T::getFromOpaquePointer(pointer); 1182 } 1183 1184 /// Jump to a specific successor based on a predicate value. 1185 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 1186 /// Jump to a specific successor based on a destination index. 1187 void selectJump(size_t destIndex) { 1188 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 1189 } 1190 1191 /// Handle a switch operation with the provided value and cases. 1192 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1193 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1194 LLVM_DEBUG({ 1195 llvm::dbgs() << " * Value: " << value << "\n" 1196 << " * Cases: "; 1197 llvm::interleaveComma(cases, llvm::dbgs()); 1198 llvm::dbgs() << "\n"; 1199 }); 1200 1201 // Check to see if the attribute value is within the case list. Jump to 1202 // the correct successor index based on the result. 1203 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1204 if (cmp(*it, value)) 1205 return selectJump(size_t((it - cases.begin()) + 1)); 1206 selectJump(size_t(0)); 1207 } 1208 1209 /// Store a pointer to memory. 1210 void storeToMemory(unsigned index, const void *value) { 1211 memory[index] = value; 1212 } 1213 1214 /// Store a value to memory as an opaque pointer. 1215 template <typename T> 1216 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 1217 storeToMemory(unsigned index, T value) { 1218 memory[index] = value.getAsOpaquePointer(); 1219 } 1220 1221 /// Internal implementation of reading various data types from the bytecode 1222 /// stream. 1223 template <typename T> 1224 const void *readFromMemory() { 1225 size_t index = *curCodeIt++; 1226 1227 // If this type is an SSA value, it can only be stored in non-const memory. 1228 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1229 Value>::value || 1230 index < memory.size()) 1231 return memory[index]; 1232 1233 // Otherwise, if this index is not inbounds it is uniqued. 1234 return uniquedMemory[index - memory.size()]; 1235 } 1236 template <typename T> 1237 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1238 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1239 } 1240 template <typename T> 1241 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1242 T> 1243 readImpl() { 1244 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1245 } 1246 template <typename T> 1247 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1248 switch (read<PDLValue::Kind>()) { 1249 case PDLValue::Kind::Attribute: 1250 return read<Attribute>(); 1251 case PDLValue::Kind::Operation: 1252 return read<Operation *>(); 1253 case PDLValue::Kind::Type: 1254 return read<Type>(); 1255 case PDLValue::Kind::Value: 1256 return read<Value>(); 1257 case PDLValue::Kind::TypeRange: 1258 return read<TypeRange *>(); 1259 case PDLValue::Kind::ValueRange: 1260 return read<ValueRange *>(); 1261 } 1262 llvm_unreachable("unhandled PDLValue::Kind"); 1263 } 1264 template <typename T> 1265 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1266 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1267 "unexpected ByteCode address size"); 1268 ByteCodeAddr result; 1269 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1270 curCodeIt += 2; 1271 return result; 1272 } 1273 template <typename T> 1274 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1275 return *curCodeIt++; 1276 } 1277 template <typename T> 1278 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1279 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1280 } 1281 1282 /// The underlying bytecode buffer. 1283 const ByteCodeField *curCodeIt; 1284 1285 /// The stack of bytecode positions at which to resume operation. 1286 SmallVector<const ByteCodeField *> resumeCodeIt; 1287 1288 /// The current execution memory. 1289 MutableArrayRef<const void *> memory; 1290 MutableArrayRef<OwningOpRange> opRangeMemory; 1291 MutableArrayRef<TypeRange> typeRangeMemory; 1292 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1293 MutableArrayRef<ValueRange> valueRangeMemory; 1294 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1295 1296 /// The current loop indices. 1297 MutableArrayRef<unsigned> loopIndex; 1298 1299 /// References to ByteCode data necessary for execution. 1300 ArrayRef<const void *> uniquedMemory; 1301 ArrayRef<ByteCodeField> code; 1302 ArrayRef<PatternBenefit> currentPatternBenefits; 1303 ArrayRef<PDLByteCodePattern> patterns; 1304 ArrayRef<PDLConstraintFunction> constraintFunctions; 1305 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1306 }; 1307 1308 /// This class is an instantiation of the PDLResultList that provides access to 1309 /// the returned results. This API is not on `PDLResultList` to avoid 1310 /// overexposing access to information specific solely to the ByteCode. 1311 class ByteCodeRewriteResultList : public PDLResultList { 1312 public: 1313 ByteCodeRewriteResultList(unsigned maxNumResults) 1314 : PDLResultList(maxNumResults) {} 1315 1316 /// Return the list of PDL results. 1317 MutableArrayRef<PDLValue> getResults() { return results; } 1318 1319 /// Return the type ranges allocated by this list. 1320 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1321 return allocatedTypeRanges; 1322 } 1323 1324 /// Return the value ranges allocated by this list. 1325 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1326 return allocatedValueRanges; 1327 } 1328 }; 1329 } // namespace 1330 1331 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1332 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1333 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1334 ArrayAttr constParams = read<ArrayAttr>(); 1335 SmallVector<PDLValue, 16> args; 1336 readList<PDLValue>(args); 1337 1338 LLVM_DEBUG({ 1339 llvm::dbgs() << " * Arguments: "; 1340 llvm::interleaveComma(args, llvm::dbgs()); 1341 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1342 }); 1343 1344 // Invoke the constraint and jump to the proper destination. 1345 selectJump(succeeded(constraintFn(args, constParams, rewriter))); 1346 } 1347 1348 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1349 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1350 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1351 ArrayAttr constParams = read<ArrayAttr>(); 1352 SmallVector<PDLValue, 16> args; 1353 readList<PDLValue>(args); 1354 1355 LLVM_DEBUG({ 1356 llvm::dbgs() << " * Arguments: "; 1357 llvm::interleaveComma(args, llvm::dbgs()); 1358 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1359 }); 1360 1361 // Execute the rewrite function. 1362 ByteCodeField numResults = read(); 1363 ByteCodeRewriteResultList results(numResults); 1364 rewriteFn(args, constParams, rewriter, results); 1365 1366 assert(results.getResults().size() == numResults && 1367 "native PDL rewrite function returned unexpected number of results"); 1368 1369 // Store the results in the bytecode memory. 1370 for (PDLValue &result : results.getResults()) { 1371 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1372 1373 // In debug mode we also verify the expected kind of the result. 1374 #ifndef NDEBUG 1375 assert(result.getKind() == read<PDLValue::Kind>() && 1376 "native PDL rewrite function returned an unexpected type of result"); 1377 #endif 1378 1379 // If the result is a range, we need to copy it over to the bytecodes 1380 // range memory. 1381 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1382 unsigned rangeIndex = read(); 1383 typeRangeMemory[rangeIndex] = *typeRange; 1384 memory[read()] = &typeRangeMemory[rangeIndex]; 1385 } else if (Optional<ValueRange> valueRange = 1386 result.dyn_cast<ValueRange>()) { 1387 unsigned rangeIndex = read(); 1388 valueRangeMemory[rangeIndex] = *valueRange; 1389 memory[read()] = &valueRangeMemory[rangeIndex]; 1390 } else { 1391 memory[read()] = result.getAsOpaquePointer(); 1392 } 1393 } 1394 1395 // Copy over any underlying storage allocated for result ranges. 1396 for (auto &it : results.getAllocatedTypeRanges()) 1397 allocatedTypeRangeMemory.push_back(std::move(it)); 1398 for (auto &it : results.getAllocatedValueRanges()) 1399 allocatedValueRangeMemory.push_back(std::move(it)); 1400 } 1401 1402 void ByteCodeExecutor::executeAreEqual() { 1403 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1404 const void *lhs = read<const void *>(); 1405 const void *rhs = read<const void *>(); 1406 1407 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1408 selectJump(lhs == rhs); 1409 } 1410 1411 void ByteCodeExecutor::executeAreRangesEqual() { 1412 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1413 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1414 const void *lhs = read<const void *>(); 1415 const void *rhs = read<const void *>(); 1416 1417 switch (valueKind) { 1418 case PDLValue::Kind::TypeRange: { 1419 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1420 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1421 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1422 selectJump(*lhsRange == *rhsRange); 1423 break; 1424 } 1425 case PDLValue::Kind::ValueRange: { 1426 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1427 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1428 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1429 selectJump(*lhsRange == *rhsRange); 1430 break; 1431 } 1432 default: 1433 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1434 } 1435 } 1436 1437 void ByteCodeExecutor::executeBranch() { 1438 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1439 curCodeIt = &code[read<ByteCodeAddr>()]; 1440 } 1441 1442 void ByteCodeExecutor::executeCheckOperandCount() { 1443 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1444 Operation *op = read<Operation *>(); 1445 uint32_t expectedCount = read<uint32_t>(); 1446 bool compareAtLeast = read(); 1447 1448 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1449 << " * Expected: " << expectedCount << "\n" 1450 << " * Comparator: " 1451 << (compareAtLeast ? ">=" : "==") << "\n"); 1452 if (compareAtLeast) 1453 selectJump(op->getNumOperands() >= expectedCount); 1454 else 1455 selectJump(op->getNumOperands() == expectedCount); 1456 } 1457 1458 void ByteCodeExecutor::executeCheckOperationName() { 1459 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1460 Operation *op = read<Operation *>(); 1461 OperationName expectedName = read<OperationName>(); 1462 1463 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1464 << " * Expected: \"" << expectedName << "\"\n"); 1465 selectJump(op->getName() == expectedName); 1466 } 1467 1468 void ByteCodeExecutor::executeCheckResultCount() { 1469 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1470 Operation *op = read<Operation *>(); 1471 uint32_t expectedCount = read<uint32_t>(); 1472 bool compareAtLeast = read(); 1473 1474 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1475 << " * Expected: " << expectedCount << "\n" 1476 << " * Comparator: " 1477 << (compareAtLeast ? ">=" : "==") << "\n"); 1478 if (compareAtLeast) 1479 selectJump(op->getNumResults() >= expectedCount); 1480 else 1481 selectJump(op->getNumResults() == expectedCount); 1482 } 1483 1484 void ByteCodeExecutor::executeCheckTypes() { 1485 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1486 TypeRange *lhs = read<TypeRange *>(); 1487 Attribute rhs = read<Attribute>(); 1488 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1489 1490 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1491 } 1492 1493 void ByteCodeExecutor::executeContinue() { 1494 ByteCodeField level = read(); 1495 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" 1496 << " * Level: " << level << "\n"); 1497 ++loopIndex[level]; 1498 popCodeIt(); 1499 } 1500 1501 void ByteCodeExecutor::executeCreateTypes() { 1502 LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 1503 unsigned memIndex = read(); 1504 unsigned rangeIndex = read(); 1505 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1506 1507 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1508 1509 // Allocate a buffer for this type range. 1510 llvm::OwningArrayRef<Type> storage(typesAttr.size()); 1511 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 1512 allocatedTypeRangeMemory.emplace_back(std::move(storage)); 1513 1514 // Assign this to the range slot and use the range as the value for the 1515 // memory index. 1516 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 1517 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1518 } 1519 1520 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1521 Location mainRewriteLoc) { 1522 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1523 1524 unsigned memIndex = read(); 1525 OperationState state(mainRewriteLoc, read<OperationName>()); 1526 readValueList(state.operands); 1527 for (unsigned i = 0, e = read(); i != e; ++i) { 1528 StringAttr name = read<StringAttr>(); 1529 if (Attribute attr = read<Attribute>()) 1530 state.addAttribute(name, attr); 1531 } 1532 1533 for (unsigned i = 0, e = read(); i != e; ++i) { 1534 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1535 state.types.push_back(read<Type>()); 1536 continue; 1537 } 1538 1539 // If we find a null range, this signals that the types are infered. 1540 if (TypeRange *resultTypes = read<TypeRange *>()) { 1541 state.types.append(resultTypes->begin(), resultTypes->end()); 1542 continue; 1543 } 1544 1545 // Handle the case where the operation has inferred types. 1546 InferTypeOpInterface::Concept *inferInterface = 1547 state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>(); 1548 1549 // TODO: Handle failure. 1550 state.types.clear(); 1551 if (failed(inferInterface->inferReturnTypes( 1552 state.getContext(), state.location, state.operands, 1553 state.attributes.getDictionary(state.getContext()), state.regions, 1554 state.types))) 1555 return; 1556 break; 1557 } 1558 1559 Operation *resultOp = rewriter.createOperation(state); 1560 memory[memIndex] = resultOp; 1561 1562 LLVM_DEBUG({ 1563 llvm::dbgs() << " * Attributes: " 1564 << state.attributes.getDictionary(state.getContext()) 1565 << "\n * Operands: "; 1566 llvm::interleaveComma(state.operands, llvm::dbgs()); 1567 llvm::dbgs() << "\n * Result Types: "; 1568 llvm::interleaveComma(state.types, llvm::dbgs()); 1569 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1570 }); 1571 } 1572 1573 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1574 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1575 Operation *op = read<Operation *>(); 1576 1577 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1578 rewriter.eraseOp(op); 1579 } 1580 1581 template <typename T, typename Range, PDLValue::Kind kind> 1582 void ByteCodeExecutor::executeExtract() { 1583 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); 1584 Range *range = read<Range *>(); 1585 unsigned index = read<uint32_t>(); 1586 unsigned memIndex = read(); 1587 1588 if (!range) { 1589 memory[memIndex] = nullptr; 1590 return; 1591 } 1592 1593 T result = index < range->size() ? (*range)[index] : T(); 1594 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" 1595 << " * Index: " << index << "\n" 1596 << " * Result: " << result << "\n"); 1597 storeToMemory(memIndex, result); 1598 } 1599 1600 void ByteCodeExecutor::executeFinalize() { 1601 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); 1602 } 1603 1604 void ByteCodeExecutor::executeForEach() { 1605 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); 1606 const ByteCodeField *prevCodeIt = getPrevCodeIt(); 1607 unsigned rangeIndex = read(); 1608 unsigned memIndex = read(); 1609 const void *value = nullptr; 1610 1611 switch (read<PDLValue::Kind>()) { 1612 case PDLValue::Kind::Operation: { 1613 unsigned &index = loopIndex[read()]; 1614 ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; 1615 assert(index <= array.size() && "iterated past the end"); 1616 if (index < array.size()) { 1617 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); 1618 value = array[index]; 1619 break; 1620 } 1621 1622 LLVM_DEBUG(llvm::dbgs() << " * Done\n"); 1623 index = 0; 1624 selectJump(size_t(0)); 1625 return; 1626 } 1627 default: 1628 llvm_unreachable("unexpected `ForEach` value kind"); 1629 } 1630 1631 // Store the iterate value and the stack address. 1632 memory[memIndex] = value; 1633 pushCodeIt(prevCodeIt); 1634 1635 // Skip over the successor (we will enter the body of the loop). 1636 read<ByteCodeAddr>(); 1637 } 1638 1639 void ByteCodeExecutor::executeGetAttribute() { 1640 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1641 unsigned memIndex = read(); 1642 Operation *op = read<Operation *>(); 1643 StringAttr attrName = read<StringAttr>(); 1644 Attribute attr = op->getAttr(attrName); 1645 1646 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1647 << " * Attribute: " << attrName << "\n" 1648 << " * Result: " << attr << "\n"); 1649 memory[memIndex] = attr.getAsOpaquePointer(); 1650 } 1651 1652 void ByteCodeExecutor::executeGetAttributeType() { 1653 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1654 unsigned memIndex = read(); 1655 Attribute attr = read<Attribute>(); 1656 Type type = attr ? attr.getType() : Type(); 1657 1658 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1659 << " * Result: " << type << "\n"); 1660 memory[memIndex] = type.getAsOpaquePointer(); 1661 } 1662 1663 void ByteCodeExecutor::executeGetDefiningOp() { 1664 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1665 unsigned memIndex = read(); 1666 Operation *op = nullptr; 1667 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1668 Value value = read<Value>(); 1669 if (value) 1670 op = value.getDefiningOp(); 1671 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1672 } else { 1673 ValueRange *values = read<ValueRange *>(); 1674 if (values && !values->empty()) { 1675 op = values->front().getDefiningOp(); 1676 } 1677 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1678 } 1679 1680 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1681 memory[memIndex] = op; 1682 } 1683 1684 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1685 Operation *op = read<Operation *>(); 1686 unsigned memIndex = read(); 1687 Value operand = 1688 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1689 1690 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1691 << " * Index: " << index << "\n" 1692 << " * Result: " << operand << "\n"); 1693 memory[memIndex] = operand.getAsOpaquePointer(); 1694 } 1695 1696 /// This function is the internal implementation of `GetResults` and 1697 /// `GetOperands` that provides support for extracting a value range from the 1698 /// given operation. 1699 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1700 static void * 1701 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1702 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1703 MutableArrayRef<ValueRange> valueRangeMemory) { 1704 // Check for the sentinel index that signals that all values should be 1705 // returned. 1706 if (index == std::numeric_limits<uint32_t>::max()) { 1707 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1708 // `values` is already the full value range. 1709 1710 // Otherwise, check to see if this operation uses AttrSizedSegments. 1711 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1712 LLVM_DEBUG(llvm::dbgs() 1713 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1714 1715 auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 1716 if (!segmentAttr || segmentAttr.getNumElements() <= index) 1717 return nullptr; 1718 1719 auto segments = segmentAttr.getValues<int32_t>(); 1720 unsigned startIndex = 1721 std::accumulate(segments.begin(), segments.begin() + index, 0); 1722 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1723 1724 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1725 << *std::next(segments.begin(), index) << "]\n"); 1726 1727 // Otherwise, assume this is the last operand group of the operation. 1728 // FIXME: We currently don't support operations with 1729 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1730 // have a way to detect it's presence. 1731 } else if (values.size() >= index) { 1732 LLVM_DEBUG(llvm::dbgs() 1733 << " * Treating values as trailing variadic range\n"); 1734 values = values.drop_front(index); 1735 1736 // If we couldn't detect a way to compute the values, bail out. 1737 } else { 1738 return nullptr; 1739 } 1740 1741 // If the range index is valid, we are returning a range. 1742 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1743 valueRangeMemory[rangeIndex] = values; 1744 return &valueRangeMemory[rangeIndex]; 1745 } 1746 1747 // If a range index wasn't provided, the range is required to be non-variadic. 1748 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1749 } 1750 1751 void ByteCodeExecutor::executeGetOperands() { 1752 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1753 unsigned index = read<uint32_t>(); 1754 Operation *op = read<Operation *>(); 1755 ByteCodeField rangeIndex = read(); 1756 1757 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1758 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1759 valueRangeMemory); 1760 if (!result) 1761 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1762 memory[read()] = result; 1763 } 1764 1765 void ByteCodeExecutor::executeGetResult(unsigned index) { 1766 Operation *op = read<Operation *>(); 1767 unsigned memIndex = read(); 1768 OpResult result = 1769 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1770 1771 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1772 << " * Index: " << index << "\n" 1773 << " * Result: " << result << "\n"); 1774 memory[memIndex] = result.getAsOpaquePointer(); 1775 } 1776 1777 void ByteCodeExecutor::executeGetResults() { 1778 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1779 unsigned index = read<uint32_t>(); 1780 Operation *op = read<Operation *>(); 1781 ByteCodeField rangeIndex = read(); 1782 1783 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1784 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1785 valueRangeMemory); 1786 if (!result) 1787 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1788 memory[read()] = result; 1789 } 1790 1791 void ByteCodeExecutor::executeGetUsers() { 1792 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); 1793 unsigned memIndex = read(); 1794 unsigned rangeIndex = read(); 1795 OwningOpRange &range = opRangeMemory[rangeIndex]; 1796 memory[memIndex] = ⦥ 1797 1798 range = OwningOpRange(); 1799 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1800 // Read the value. 1801 Value value = read<Value>(); 1802 if (!value) 1803 return; 1804 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1805 1806 // Extract the users of a single value. 1807 range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); 1808 llvm::copy(value.getUsers(), range.begin()); 1809 } else { 1810 // Read a range of values. 1811 ValueRange *values = read<ValueRange *>(); 1812 if (!values) 1813 return; 1814 LLVM_DEBUG({ 1815 llvm::dbgs() << " * Values (" << values->size() << "): "; 1816 llvm::interleaveComma(*values, llvm::dbgs()); 1817 llvm::dbgs() << "\n"; 1818 }); 1819 1820 // Extract all the users of a range of values. 1821 SmallVector<Operation *> users; 1822 for (Value value : *values) 1823 users.append(value.user_begin(), value.user_end()); 1824 range = OwningOpRange(users.size()); 1825 llvm::copy(users, range.begin()); 1826 } 1827 1828 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); 1829 } 1830 1831 void ByteCodeExecutor::executeGetValueType() { 1832 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1833 unsigned memIndex = read(); 1834 Value value = read<Value>(); 1835 Type type = value ? value.getType() : Type(); 1836 1837 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1838 << " * Result: " << type << "\n"); 1839 memory[memIndex] = type.getAsOpaquePointer(); 1840 } 1841 1842 void ByteCodeExecutor::executeGetValueRangeTypes() { 1843 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1844 unsigned memIndex = read(); 1845 unsigned rangeIndex = read(); 1846 ValueRange *values = read<ValueRange *>(); 1847 if (!values) { 1848 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1849 memory[memIndex] = nullptr; 1850 return; 1851 } 1852 1853 LLVM_DEBUG({ 1854 llvm::dbgs() << " * Values (" << values->size() << "): "; 1855 llvm::interleaveComma(*values, llvm::dbgs()); 1856 llvm::dbgs() << "\n * Result: "; 1857 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1858 llvm::dbgs() << "\n"; 1859 }); 1860 typeRangeMemory[rangeIndex] = values->getType(); 1861 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1862 } 1863 1864 void ByteCodeExecutor::executeIsNotNull() { 1865 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1866 const void *value = read<const void *>(); 1867 1868 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1869 selectJump(value != nullptr); 1870 } 1871 1872 void ByteCodeExecutor::executeRecordMatch( 1873 PatternRewriter &rewriter, 1874 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1875 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1876 unsigned patternIndex = read(); 1877 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1878 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1879 1880 // If the benefit of the pattern is impossible, skip the processing of the 1881 // rest of the pattern. 1882 if (benefit.isImpossibleToMatch()) { 1883 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1884 curCodeIt = dest; 1885 return; 1886 } 1887 1888 // Create a fused location containing the locations of each of the 1889 // operations used in the match. This will be used as the location for 1890 // created operations during the rewrite that don't already have an 1891 // explicit location set. 1892 unsigned numMatchLocs = read(); 1893 SmallVector<Location, 4> matchLocs; 1894 matchLocs.reserve(numMatchLocs); 1895 for (unsigned i = 0; i != numMatchLocs; ++i) 1896 matchLocs.push_back(read<Operation *>()->getLoc()); 1897 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1898 1899 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1900 << " * Location: " << matchLoc << "\n"); 1901 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1902 PDLByteCode::MatchResult &match = matches.back(); 1903 1904 // Record all of the inputs to the match. If any of the inputs are ranges, we 1905 // will also need to remap the range pointer to memory stored in the match 1906 // state. 1907 unsigned numInputs = read(); 1908 match.values.reserve(numInputs); 1909 match.typeRangeValues.reserve(numInputs); 1910 match.valueRangeValues.reserve(numInputs); 1911 for (unsigned i = 0; i < numInputs; ++i) { 1912 switch (read<PDLValue::Kind>()) { 1913 case PDLValue::Kind::TypeRange: 1914 match.typeRangeValues.push_back(*read<TypeRange *>()); 1915 match.values.push_back(&match.typeRangeValues.back()); 1916 break; 1917 case PDLValue::Kind::ValueRange: 1918 match.valueRangeValues.push_back(*read<ValueRange *>()); 1919 match.values.push_back(&match.valueRangeValues.back()); 1920 break; 1921 default: 1922 match.values.push_back(read<const void *>()); 1923 break; 1924 } 1925 } 1926 curCodeIt = dest; 1927 } 1928 1929 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1930 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1931 Operation *op = read<Operation *>(); 1932 SmallVector<Value, 16> args; 1933 readValueList(args); 1934 1935 LLVM_DEBUG({ 1936 llvm::dbgs() << " * Operation: " << *op << "\n" 1937 << " * Values: "; 1938 llvm::interleaveComma(args, llvm::dbgs()); 1939 llvm::dbgs() << "\n"; 1940 }); 1941 rewriter.replaceOp(op, args); 1942 } 1943 1944 void ByteCodeExecutor::executeSwitchAttribute() { 1945 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1946 Attribute value = read<Attribute>(); 1947 ArrayAttr cases = read<ArrayAttr>(); 1948 handleSwitch(value, cases); 1949 } 1950 1951 void ByteCodeExecutor::executeSwitchOperandCount() { 1952 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1953 Operation *op = read<Operation *>(); 1954 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1955 1956 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1957 handleSwitch(op->getNumOperands(), cases); 1958 } 1959 1960 void ByteCodeExecutor::executeSwitchOperationName() { 1961 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1962 OperationName value = read<Operation *>()->getName(); 1963 size_t caseCount = read(); 1964 1965 // The operation names are stored in-line, so to print them out for 1966 // debugging purposes we need to read the array before executing the 1967 // switch so that we can display all of the possible values. 1968 LLVM_DEBUG({ 1969 const ByteCodeField *prevCodeIt = curCodeIt; 1970 llvm::dbgs() << " * Value: " << value << "\n" 1971 << " * Cases: "; 1972 llvm::interleaveComma( 1973 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1974 [&](size_t) { return read<OperationName>(); }), 1975 llvm::dbgs()); 1976 llvm::dbgs() << "\n"; 1977 curCodeIt = prevCodeIt; 1978 }); 1979 1980 // Try to find the switch value within any of the cases. 1981 for (size_t i = 0; i != caseCount; ++i) { 1982 if (read<OperationName>() == value) { 1983 curCodeIt += (caseCount - i - 1); 1984 return selectJump(i + 1); 1985 } 1986 } 1987 selectJump(size_t(0)); 1988 } 1989 1990 void ByteCodeExecutor::executeSwitchResultCount() { 1991 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1992 Operation *op = read<Operation *>(); 1993 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1994 1995 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1996 handleSwitch(op->getNumResults(), cases); 1997 } 1998 1999 void ByteCodeExecutor::executeSwitchType() { 2000 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 2001 Type value = read<Type>(); 2002 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 2003 handleSwitch(value, cases); 2004 } 2005 2006 void ByteCodeExecutor::executeSwitchTypes() { 2007 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 2008 TypeRange *value = read<TypeRange *>(); 2009 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 2010 if (!value) { 2011 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 2012 return selectJump(size_t(0)); 2013 } 2014 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 2015 return value == caseValue.getAsValueRange<TypeAttr>(); 2016 }); 2017 } 2018 2019 void ByteCodeExecutor::execute( 2020 PatternRewriter &rewriter, 2021 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 2022 Optional<Location> mainRewriteLoc) { 2023 while (true) { 2024 // Print the location of the operation being executed. 2025 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); 2026 2027 OpCode opCode = static_cast<OpCode>(read()); 2028 switch (opCode) { 2029 case ApplyConstraint: 2030 executeApplyConstraint(rewriter); 2031 break; 2032 case ApplyRewrite: 2033 executeApplyRewrite(rewriter); 2034 break; 2035 case AreEqual: 2036 executeAreEqual(); 2037 break; 2038 case AreRangesEqual: 2039 executeAreRangesEqual(); 2040 break; 2041 case Branch: 2042 executeBranch(); 2043 break; 2044 case CheckOperandCount: 2045 executeCheckOperandCount(); 2046 break; 2047 case CheckOperationName: 2048 executeCheckOperationName(); 2049 break; 2050 case CheckResultCount: 2051 executeCheckResultCount(); 2052 break; 2053 case CheckTypes: 2054 executeCheckTypes(); 2055 break; 2056 case Continue: 2057 executeContinue(); 2058 break; 2059 case CreateOperation: 2060 executeCreateOperation(rewriter, *mainRewriteLoc); 2061 break; 2062 case CreateTypes: 2063 executeCreateTypes(); 2064 break; 2065 case EraseOp: 2066 executeEraseOp(rewriter); 2067 break; 2068 case ExtractOp: 2069 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); 2070 break; 2071 case ExtractType: 2072 executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); 2073 break; 2074 case ExtractValue: 2075 executeExtract<Value, ValueRange, PDLValue::Kind::Value>(); 2076 break; 2077 case Finalize: 2078 executeFinalize(); 2079 LLVM_DEBUG(llvm::dbgs() << "\n"); 2080 return; 2081 case ForEach: 2082 executeForEach(); 2083 break; 2084 case GetAttribute: 2085 executeGetAttribute(); 2086 break; 2087 case GetAttributeType: 2088 executeGetAttributeType(); 2089 break; 2090 case GetDefiningOp: 2091 executeGetDefiningOp(); 2092 break; 2093 case GetOperand0: 2094 case GetOperand1: 2095 case GetOperand2: 2096 case GetOperand3: { 2097 unsigned index = opCode - GetOperand0; 2098 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 2099 executeGetOperand(index); 2100 break; 2101 } 2102 case GetOperandN: 2103 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 2104 executeGetOperand(read<uint32_t>()); 2105 break; 2106 case GetOperands: 2107 executeGetOperands(); 2108 break; 2109 case GetResult0: 2110 case GetResult1: 2111 case GetResult2: 2112 case GetResult3: { 2113 unsigned index = opCode - GetResult0; 2114 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 2115 executeGetResult(index); 2116 break; 2117 } 2118 case GetResultN: 2119 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 2120 executeGetResult(read<uint32_t>()); 2121 break; 2122 case GetResults: 2123 executeGetResults(); 2124 break; 2125 case GetUsers: 2126 executeGetUsers(); 2127 break; 2128 case GetValueType: 2129 executeGetValueType(); 2130 break; 2131 case GetValueRangeTypes: 2132 executeGetValueRangeTypes(); 2133 break; 2134 case IsNotNull: 2135 executeIsNotNull(); 2136 break; 2137 case RecordMatch: 2138 assert(matches && 2139 "expected matches to be provided when executing the matcher"); 2140 executeRecordMatch(rewriter, *matches); 2141 break; 2142 case ReplaceOp: 2143 executeReplaceOp(rewriter); 2144 break; 2145 case SwitchAttribute: 2146 executeSwitchAttribute(); 2147 break; 2148 case SwitchOperandCount: 2149 executeSwitchOperandCount(); 2150 break; 2151 case SwitchOperationName: 2152 executeSwitchOperationName(); 2153 break; 2154 case SwitchResultCount: 2155 executeSwitchResultCount(); 2156 break; 2157 case SwitchType: 2158 executeSwitchType(); 2159 break; 2160 case SwitchTypes: 2161 executeSwitchTypes(); 2162 break; 2163 } 2164 LLVM_DEBUG(llvm::dbgs() << "\n"); 2165 } 2166 } 2167 2168 /// Run the pattern matcher on the given root operation, collecting the matched 2169 /// patterns in `matches`. 2170 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 2171 SmallVectorImpl<MatchResult> &matches, 2172 PDLByteCodeMutableState &state) const { 2173 // The first memory slot is always the root operation. 2174 state.memory[0] = op; 2175 2176 // The matcher function always starts at code address 0. 2177 ByteCodeExecutor executor( 2178 matcherByteCode.data(), state.memory, state.opRangeMemory, 2179 state.typeRangeMemory, state.allocatedTypeRangeMemory, 2180 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, 2181 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, 2182 constraintFunctions, rewriteFunctions); 2183 executor.execute(rewriter, &matches); 2184 2185 // Order the found matches by benefit. 2186 std::stable_sort(matches.begin(), matches.end(), 2187 [](const MatchResult &lhs, const MatchResult &rhs) { 2188 return lhs.benefit > rhs.benefit; 2189 }); 2190 } 2191 2192 /// Run the rewriter of the given pattern on the root operation `op`. 2193 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 2194 PDLByteCodeMutableState &state) const { 2195 // The arguments of the rewrite function are stored at the start of the 2196 // memory buffer. 2197 llvm::copy(match.values, state.memory.begin()); 2198 2199 ByteCodeExecutor executor( 2200 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 2201 state.opRangeMemory, state.typeRangeMemory, 2202 state.allocatedTypeRangeMemory, state.valueRangeMemory, 2203 state.allocatedValueRangeMemory, state.loopIndex, uniquedData, 2204 rewriterByteCode, state.currentPatternBenefits, patterns, 2205 constraintFunctions, rewriteFunctions); 2206 executor.execute(rewriter, /*matches=*/nullptr, match.location); 2207 } 2208