1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 27 #define DEBUG_TYPE "pdl-bytecode" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37 ByteCodeAddr rewriterAddr) { 38 SmallVector<StringRef, 8> generatedOps; 39 if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 40 generatedOps = 41 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42 43 PatternBenefit benefit = matchOp.benefit(); 44 MLIRContext *ctx = matchOp.getContext(); 45 46 // Check to see if this is pattern matches a specific operation type. 47 if (Optional<StringRef> rootKind = matchOp.rootKind()) 48 return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, 49 generatedOps); 50 return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, 51 generatedOps); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // PDLByteCodeMutableState 56 //===----------------------------------------------------------------------===// 57 58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59 /// to the position of the pattern within the range returned by 60 /// `PDLByteCode::getPatterns`. 61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62 PatternBenefit benefit) { 63 currentPatternBenefits[patternIndex] = benefit; 64 } 65 66 /// Cleanup any allocated state after a full match/rewrite has been completed. 67 /// This method should be called irregardless of whether the match+rewrite was a 68 /// success or not. 69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 70 allocatedTypeRangeMemory.clear(); 71 allocatedValueRangeMemory.clear(); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Bytecode OpCodes 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 enum OpCode : ByteCodeField { 80 /// Apply an externally registered constraint. 81 ApplyConstraint, 82 /// Apply an externally registered rewrite. 83 ApplyRewrite, 84 /// Check if two generic values are equal. 85 AreEqual, 86 /// Check if two ranges are equal. 87 AreRangesEqual, 88 /// Unconditional branch. 89 Branch, 90 /// Compare the operand count of an operation with a constant. 91 CheckOperandCount, 92 /// Compare the name of an operation with a constant. 93 CheckOperationName, 94 /// Compare the result count of an operation with a constant. 95 CheckResultCount, 96 /// Compare a range of types to a constant range of types. 97 CheckTypes, 98 /// Continue to the next iteration of a loop. 99 Continue, 100 /// Create an operation. 101 CreateOperation, 102 /// Create a range of types. 103 CreateTypes, 104 /// Erase an operation. 105 EraseOp, 106 /// Extract the op from a range at the specified index. 107 ExtractOp, 108 /// Extract the type from a range at the specified index. 109 ExtractType, 110 /// Extract the value from a range at the specified index. 111 ExtractValue, 112 /// Terminate a matcher or rewrite sequence. 113 Finalize, 114 /// Iterate over a range of values. 115 ForEach, 116 /// Get a specific attribute of an operation. 117 GetAttribute, 118 /// Get the type of an attribute. 119 GetAttributeType, 120 /// Get the defining operation of a value. 121 GetDefiningOp, 122 /// Get a specific operand of an operation. 123 GetOperand0, 124 GetOperand1, 125 GetOperand2, 126 GetOperand3, 127 GetOperandN, 128 /// Get a specific operand group of an operation. 129 GetOperands, 130 /// Get a specific result of an operation. 131 GetResult0, 132 GetResult1, 133 GetResult2, 134 GetResult3, 135 GetResultN, 136 /// Get a specific result group of an operation. 137 GetResults, 138 /// Get the users of a value or a range of values. 139 GetUsers, 140 /// Get the type of a value. 141 GetValueType, 142 /// Get the types of a value range. 143 GetValueRangeTypes, 144 /// Check if a generic value is not null. 145 IsNotNull, 146 /// Record a successful pattern match. 147 RecordMatch, 148 /// Replace an operation. 149 ReplaceOp, 150 /// Compare an attribute with a set of constants. 151 SwitchAttribute, 152 /// Compare the operand count of an operation with a set of constants. 153 SwitchOperandCount, 154 /// Compare the name of an operation with a set of constants. 155 SwitchOperationName, 156 /// Compare the result count of an operation with a set of constants. 157 SwitchResultCount, 158 /// Compare a type with a set of constants. 159 SwitchType, 160 /// Compare a range of types with a set of constants. 161 SwitchTypes, 162 }; 163 } // namespace 164 165 //===----------------------------------------------------------------------===// 166 // ByteCode Generation 167 //===----------------------------------------------------------------------===// 168 169 //===----------------------------------------------------------------------===// 170 // Generator 171 172 namespace { 173 struct ByteCodeLiveRange; 174 struct ByteCodeWriter; 175 176 /// Check if the given class `T` can be converted to an opaque pointer. 177 template <typename T, typename... Args> 178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 179 180 /// This class represents the main generator for the pattern bytecode. 181 class Generator { 182 public: 183 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 184 SmallVectorImpl<ByteCodeField> &matcherByteCode, 185 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 186 SmallVectorImpl<PDLByteCodePattern> &patterns, 187 ByteCodeField &maxValueMemoryIndex, 188 ByteCodeField &maxOpRangeMemoryIndex, 189 ByteCodeField &maxTypeRangeMemoryIndex, 190 ByteCodeField &maxValueRangeMemoryIndex, 191 ByteCodeField &maxLoopLevel, 192 llvm::StringMap<PDLConstraintFunction> &constraintFns, 193 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 194 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 195 rewriterByteCode(rewriterByteCode), patterns(patterns), 196 maxValueMemoryIndex(maxValueMemoryIndex), 197 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), 198 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 199 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), 200 maxLoopLevel(maxLoopLevel) { 201 for (const auto &it : llvm::enumerate(constraintFns)) 202 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 203 for (const auto &it : llvm::enumerate(rewriteFns)) 204 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 205 } 206 207 /// Generate the bytecode for the given PDL interpreter module. 208 void generate(ModuleOp module); 209 210 /// Return the memory index to use for the given value. 211 ByteCodeField &getMemIndex(Value value) { 212 assert(valueToMemIndex.count(value) && 213 "expected memory index to be assigned"); 214 return valueToMemIndex[value]; 215 } 216 217 /// Return the range memory index used to store the given range value. 218 ByteCodeField &getRangeStorageIndex(Value value) { 219 assert(valueToRangeIndex.count(value) && 220 "expected range index to be assigned"); 221 return valueToRangeIndex[value]; 222 } 223 224 /// Return an index to use when referring to the given data that is uniqued in 225 /// the MLIR context. 226 template <typename T> 227 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 228 getMemIndex(T val) { 229 const void *opaqueVal = val.getAsOpaquePointer(); 230 231 // Get or insert a reference to this value. 232 auto it = uniquedDataToMemIndex.try_emplace( 233 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 234 if (it.second) 235 uniquedData.push_back(opaqueVal); 236 return it.first->second; 237 } 238 239 private: 240 /// Allocate memory indices for the results of operations within the matcher 241 /// and rewriters. 242 void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 243 244 /// Generate the bytecode for the given operation. 245 void generate(Region *region, ByteCodeWriter &writer); 246 void generate(Operation *op, ByteCodeWriter &writer); 247 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 248 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 249 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 250 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 251 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 252 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 253 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 254 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 255 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 256 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 257 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); 258 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 259 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 260 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 261 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 262 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 263 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); 264 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 265 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); 266 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 267 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 268 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 269 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 270 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 271 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 272 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 273 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); 274 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 275 void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); 276 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 277 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 278 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 279 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 280 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 281 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 282 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 283 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 284 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 285 286 /// Mapping from value to its corresponding memory index. 287 DenseMap<Value, ByteCodeField> valueToMemIndex; 288 289 /// Mapping from a range value to its corresponding range storage index. 290 DenseMap<Value, ByteCodeField> valueToRangeIndex; 291 292 /// Mapping from the name of an externally registered rewrite to its index in 293 /// the bytecode registry. 294 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 295 296 /// Mapping from the name of an externally registered constraint to its index 297 /// in the bytecode registry. 298 llvm::StringMap<ByteCodeField> constraintToMemIndex; 299 300 /// Mapping from rewriter function name to the bytecode address of the 301 /// rewriter function in byte. 302 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 303 304 /// Mapping from a uniqued storage object to its memory index within 305 /// `uniquedData`. 306 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 307 308 /// The current level of the foreach loop. 309 ByteCodeField curLoopLevel = 0; 310 311 /// The current MLIR context. 312 MLIRContext *ctx; 313 314 /// Mapping from block to its address. 315 DenseMap<Block *, ByteCodeAddr> blockToAddr; 316 317 /// Data of the ByteCode class to be populated. 318 std::vector<const void *> &uniquedData; 319 SmallVectorImpl<ByteCodeField> &matcherByteCode; 320 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 321 SmallVectorImpl<PDLByteCodePattern> &patterns; 322 ByteCodeField &maxValueMemoryIndex; 323 ByteCodeField &maxOpRangeMemoryIndex; 324 ByteCodeField &maxTypeRangeMemoryIndex; 325 ByteCodeField &maxValueRangeMemoryIndex; 326 ByteCodeField &maxLoopLevel; 327 }; 328 329 /// This class provides utilities for writing a bytecode stream. 330 struct ByteCodeWriter { 331 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 332 : bytecode(bytecode), generator(generator) {} 333 334 /// Append a field to the bytecode. 335 void append(ByteCodeField field) { bytecode.push_back(field); } 336 void append(OpCode opCode) { bytecode.push_back(opCode); } 337 338 /// Append an address to the bytecode. 339 void append(ByteCodeAddr field) { 340 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 341 "unexpected ByteCode address size"); 342 343 ByteCodeField fieldParts[2]; 344 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 345 bytecode.append({fieldParts[0], fieldParts[1]}); 346 } 347 348 /// Append a single successor to the bytecode, the exact address will need to 349 /// be resolved later. 350 void append(Block *successor) { 351 // Add back a reference to the successor so that the address can be resolved 352 // later. 353 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 354 append(ByteCodeAddr(0)); 355 } 356 357 /// Append a successor range to the bytecode, the exact address will need to 358 /// be resolved later. 359 void append(SuccessorRange successors) { 360 for (Block *successor : successors) 361 append(successor); 362 } 363 364 /// Append a range of values that will be read as generic PDLValues. 365 void appendPDLValueList(OperandRange values) { 366 bytecode.push_back(values.size()); 367 for (Value value : values) 368 appendPDLValue(value); 369 } 370 371 /// Append a value as a PDLValue. 372 void appendPDLValue(Value value) { 373 appendPDLValueKind(value); 374 append(value); 375 } 376 377 /// Append the PDLValue::Kind of the given value. 378 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } 379 380 /// Append the PDLValue::Kind of the given type. 381 void appendPDLValueKind(Type type) { 382 PDLValue::Kind kind = 383 TypeSwitch<Type, PDLValue::Kind>(type) 384 .Case<pdl::AttributeType>( 385 [](Type) { return PDLValue::Kind::Attribute; }) 386 .Case<pdl::OperationType>( 387 [](Type) { return PDLValue::Kind::Operation; }) 388 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 389 if (rangeTy.getElementType().isa<pdl::TypeType>()) 390 return PDLValue::Kind::TypeRange; 391 return PDLValue::Kind::ValueRange; 392 }) 393 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 394 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 395 bytecode.push_back(static_cast<ByteCodeField>(kind)); 396 } 397 398 /// Append a value that will be stored in a memory slot and not inline within 399 /// the bytecode. 400 template <typename T> 401 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 402 std::is_pointer<T>::value> 403 append(T value) { 404 bytecode.push_back(generator.getMemIndex(value)); 405 } 406 407 /// Append a range of values. 408 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 409 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 410 append(T range) { 411 bytecode.push_back(llvm::size(range)); 412 for (auto it : range) 413 append(it); 414 } 415 416 /// Append a variadic number of fields to the bytecode. 417 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 418 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 419 append(field); 420 append(field2, fields...); 421 } 422 423 /// Appends a value as a pointer, stored inline within the bytecode. 424 template <typename T> 425 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 426 appendInline(T value) { 427 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); 428 const void *pointer = value.getAsOpaquePointer(); 429 ByteCodeField fieldParts[numParts]; 430 std::memcpy(fieldParts, &pointer, sizeof(const void *)); 431 bytecode.append(fieldParts, fieldParts + numParts); 432 } 433 434 /// Successor references in the bytecode that have yet to be resolved. 435 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 436 437 /// The underlying bytecode buffer. 438 SmallVectorImpl<ByteCodeField> &bytecode; 439 440 /// The main generator producing PDL. 441 Generator &generator; 442 }; 443 444 /// This class represents a live range of PDL Interpreter values, containing 445 /// information about when values are live within a match/rewrite. 446 struct ByteCodeLiveRange { 447 using Set = llvm::IntervalMap<uint64_t, char, 16>; 448 using Allocator = Set::Allocator; 449 450 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} 451 452 /// Union this live range with the one provided. 453 void unionWith(const ByteCodeLiveRange &rhs) { 454 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; 455 ++it) 456 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); 457 } 458 459 /// Returns true if this range overlaps with the one provided. 460 bool overlaps(const ByteCodeLiveRange &rhs) const { 461 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness) 462 .valid(); 463 } 464 465 /// A map representing the ranges of the match/rewrite that a value is live in 466 /// the interpreter. 467 /// 468 /// We use std::unique_ptr here, because IntervalMap does not provide a 469 /// correct copy or move constructor. We can eliminate the pointer once 470 /// https://reviews.llvm.org/D113240 lands. 471 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness; 472 473 /// The operation range storage index for this range. 474 Optional<unsigned> opRangeIndex; 475 476 /// The type range storage index for this range. 477 Optional<unsigned> typeRangeIndex; 478 479 /// The value range storage index for this range. 480 Optional<unsigned> valueRangeIndex; 481 }; 482 } // namespace 483 484 void Generator::generate(ModuleOp module) { 485 FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 486 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 487 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 488 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 489 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 490 491 // Allocate memory indices for the results of operations within the matcher 492 // and rewriters. 493 allocateMemoryIndices(matcherFunc, rewriterModule); 494 495 // Generate code for the rewriter functions. 496 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 497 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 498 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 499 for (Operation &op : rewriterFunc.getOps()) 500 generate(&op, rewriterByteCodeWriter); 501 } 502 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 503 "unexpected branches in rewriter function"); 504 505 // Generate code for the matcher function. 506 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 507 generate(&matcherFunc.getBody(), matcherByteCodeWriter); 508 509 // Resolve successor references in the matcher. 510 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 511 ByteCodeAddr addr = blockToAddr[it.first]; 512 for (unsigned offsetToFix : it.second) 513 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 514 } 515 } 516 517 void Generator::allocateMemoryIndices(FuncOp matcherFunc, 518 ModuleOp rewriterModule) { 519 // Rewriters use simplistic allocation scheme that simply assigns an index to 520 // each result. 521 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 522 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 523 auto processRewriterValue = [&](Value val) { 524 valueToMemIndex.try_emplace(val, index++); 525 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 526 Type elementTy = rangeType.getElementType(); 527 if (elementTy.isa<pdl::TypeType>()) 528 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 529 else if (elementTy.isa<pdl::ValueType>()) 530 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 531 } 532 }; 533 534 for (BlockArgument arg : rewriterFunc.getArguments()) 535 processRewriterValue(arg); 536 rewriterFunc.getBody().walk([&](Operation *op) { 537 for (Value result : op->getResults()) 538 processRewriterValue(result); 539 }); 540 if (index > maxValueMemoryIndex) 541 maxValueMemoryIndex = index; 542 if (typeRangeIndex > maxTypeRangeMemoryIndex) 543 maxTypeRangeMemoryIndex = typeRangeIndex; 544 if (valueRangeIndex > maxValueRangeMemoryIndex) 545 maxValueRangeMemoryIndex = valueRangeIndex; 546 } 547 548 // The matcher function uses a more sophisticated numbering that tries to 549 // minimize the number of memory indices assigned. This is done by determining 550 // a live range of the values within the matcher, then the allocation is just 551 // finding the minimal number of overlapping live ranges. This is essentially 552 // a simplified form of register allocation where we don't necessarily have a 553 // limited number of registers, but we still want to minimize the number used. 554 DenseMap<Operation *, unsigned> opToFirstIndex; 555 DenseMap<Operation *, unsigned> opToLastIndex; 556 557 // A custom walk that marks the first and the last index of each operation. 558 // The entry marks the beginning of the liveness range for this operation, 559 // followed by nested operations, followed by the end of the liveness range. 560 unsigned index = 0; 561 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) { 562 opToFirstIndex.try_emplace(op, index++); 563 for (Region ®ion : op->getRegions()) 564 for (Block &block : region.getBlocks()) 565 for (Operation &nested : block) 566 walk(&nested); 567 opToLastIndex.try_emplace(op, index++); 568 }; 569 walk(matcherFunc); 570 571 // Liveness info for each of the defs within the matcher. 572 ByteCodeLiveRange::Allocator allocator; 573 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 574 575 // Assign the root operation being matched to slot 0. 576 BlockArgument rootOpArg = matcherFunc.getArgument(0); 577 valueToMemIndex[rootOpArg] = 0; 578 579 // Walk each of the blocks, computing the def interval that the value is used. 580 Liveness matcherLiveness(matcherFunc); 581 matcherFunc->walk([&](Block *block) { 582 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); 583 assert(info && "expected liveness info for block"); 584 auto processValue = [&](Value value, Operation *firstUseOrDef) { 585 // We don't need to process the root op argument, this value is always 586 // assigned to the first memory slot. 587 if (value == rootOpArg) 588 return; 589 590 // Set indices for the range of this block that the value is used. 591 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 592 defRangeIt->second.liveness->insert( 593 opToFirstIndex[firstUseOrDef], 594 opToLastIndex[info->getEndOperation(value, firstUseOrDef)], 595 /*dummyValue*/ 0); 596 597 // Check to see if this value is a range type. 598 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 599 Type eleType = rangeTy.getElementType(); 600 if (eleType.isa<pdl::OperationType>()) 601 defRangeIt->second.opRangeIndex = 0; 602 else if (eleType.isa<pdl::TypeType>()) 603 defRangeIt->second.typeRangeIndex = 0; 604 else if (eleType.isa<pdl::ValueType>()) 605 defRangeIt->second.valueRangeIndex = 0; 606 } 607 }; 608 609 // Process the live-ins of this block. 610 for (Value liveIn : info->in()) { 611 // Only process the value if it has been defined in the current region. 612 // Other values that span across pdl_interp.foreach will be added higher 613 // up. This ensures that the we keep them alive for the entire duration 614 // of the loop. 615 if (liveIn.getParentRegion() == block->getParent()) 616 processValue(liveIn, &block->front()); 617 } 618 619 // Process the block arguments for the entry block (those are not live-in). 620 if (block->isEntryBlock()) { 621 for (Value argument : block->getArguments()) 622 processValue(argument, &block->front()); 623 } 624 625 // Process any new defs within this block. 626 for (Operation &op : *block) 627 for (Value result : op.getResults()) 628 processValue(result, &op); 629 }); 630 631 // Greedily allocate memory slots using the computed def live ranges. 632 std::vector<ByteCodeLiveRange> allocatedIndices; 633 634 // The number of memory indices currently allocated (and its next value). 635 // Recall that the root gets allocated memory index 0. 636 ByteCodeField numIndices = 1; 637 638 // The number of memory ranges of various types (and their next values). 639 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; 640 641 for (auto &defIt : valueDefRanges) { 642 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 643 ByteCodeLiveRange &defRange = defIt.second; 644 645 // Try to allocate to an existing index. 646 for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) { 647 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 648 if (!defRange.overlaps(existingRange)) { 649 existingRange.unionWith(defRange); 650 memIndex = existingIndexIt.index() + 1; 651 652 if (defRange.opRangeIndex) { 653 if (!existingRange.opRangeIndex) 654 existingRange.opRangeIndex = numOpRanges++; 655 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; 656 } else if (defRange.typeRangeIndex) { 657 if (!existingRange.typeRangeIndex) 658 existingRange.typeRangeIndex = numTypeRanges++; 659 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 660 } else if (defRange.valueRangeIndex) { 661 if (!existingRange.valueRangeIndex) 662 existingRange.valueRangeIndex = numValueRanges++; 663 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 664 } 665 break; 666 } 667 } 668 669 // If no existing index could be used, add a new one. 670 if (memIndex == 0) { 671 allocatedIndices.emplace_back(allocator); 672 ByteCodeLiveRange &newRange = allocatedIndices.back(); 673 newRange.unionWith(defRange); 674 675 // Allocate an index for op/type/value ranges. 676 if (defRange.opRangeIndex) { 677 newRange.opRangeIndex = numOpRanges; 678 valueToRangeIndex[defIt.first] = numOpRanges++; 679 } else if (defRange.typeRangeIndex) { 680 newRange.typeRangeIndex = numTypeRanges; 681 valueToRangeIndex[defIt.first] = numTypeRanges++; 682 } else if (defRange.valueRangeIndex) { 683 newRange.valueRangeIndex = numValueRanges; 684 valueToRangeIndex[defIt.first] = numValueRanges++; 685 } 686 687 memIndex = allocatedIndices.size(); 688 ++numIndices; 689 } 690 } 691 692 // Print the index usage and ensure that we did not run out of index space. 693 LLVM_DEBUG({ 694 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " 695 << "(down from initial " << valueDefRanges.size() << ").\n"; 696 }); 697 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && 698 "Ran out of memory for allocated indices"); 699 700 // Update the max number of indices. 701 if (numIndices > maxValueMemoryIndex) 702 maxValueMemoryIndex = numIndices; 703 if (numOpRanges > maxOpRangeMemoryIndex) 704 maxOpRangeMemoryIndex = numOpRanges; 705 if (numTypeRanges > maxTypeRangeMemoryIndex) 706 maxTypeRangeMemoryIndex = numTypeRanges; 707 if (numValueRanges > maxValueRangeMemoryIndex) 708 maxValueRangeMemoryIndex = numValueRanges; 709 } 710 711 void Generator::generate(Region *region, ByteCodeWriter &writer) { 712 llvm::ReversePostOrderTraversal<Region *> rpot(region); 713 for (Block *block : rpot) { 714 // Keep track of where this block begins within the matcher function. 715 blockToAddr.try_emplace(block, matcherByteCode.size()); 716 for (Operation &op : *block) 717 generate(&op, writer); 718 } 719 } 720 721 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 722 LLVM_DEBUG({ 723 // The following list must contain all the operations that do not 724 // produce any bytecode. 725 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp, 726 pdl_interp::InferredTypesOp>(op)) 727 writer.appendInline(op->getLoc()); 728 }); 729 TypeSwitch<Operation *>(op) 730 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 731 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 732 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 733 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 734 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 735 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, 736 pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, 737 pdl_interp::CreateTypesOp, pdl_interp::EraseOp, 738 pdl_interp::ExtractOp, pdl_interp::FinalizeOp, 739 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, 740 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 741 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, 742 pdl_interp::GetResultOp, pdl_interp::GetResultsOp, 743 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, 744 pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, 745 pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, 746 pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, 747 pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, 748 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 749 [&](auto interpOp) { this->generate(interpOp, writer); }) 750 .Default([](Operation *) { 751 llvm_unreachable("unknown `pdl_interp` operation"); 752 }); 753 } 754 755 void Generator::generate(pdl_interp::ApplyConstraintOp op, 756 ByteCodeWriter &writer) { 757 assert(constraintToMemIndex.count(op.name()) && 758 "expected index for constraint function"); 759 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 760 op.constParamsAttr()); 761 writer.appendPDLValueList(op.args()); 762 writer.append(op.getSuccessors()); 763 } 764 void Generator::generate(pdl_interp::ApplyRewriteOp op, 765 ByteCodeWriter &writer) { 766 assert(externalRewriterToMemIndex.count(op.name()) && 767 "expected index for rewrite function"); 768 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 769 op.constParamsAttr()); 770 writer.appendPDLValueList(op.args()); 771 772 ResultRange results = op.results(); 773 writer.append(ByteCodeField(results.size())); 774 for (Value result : results) { 775 // In debug mode we also record the expected kind of the result, so that we 776 // can provide extra verification of the native rewrite function. 777 #ifndef NDEBUG 778 writer.appendPDLValueKind(result); 779 #endif 780 781 // Range results also need to append the range storage index. 782 if (result.getType().isa<pdl::RangeType>()) 783 writer.append(getRangeStorageIndex(result)); 784 writer.append(result); 785 } 786 } 787 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 788 Value lhs = op.lhs(); 789 if (lhs.getType().isa<pdl::RangeType>()) { 790 writer.append(OpCode::AreRangesEqual); 791 writer.appendPDLValueKind(lhs); 792 writer.append(op.lhs(), op.rhs(), op.getSuccessors()); 793 return; 794 } 795 796 writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors()); 797 } 798 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 799 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 800 } 801 void Generator::generate(pdl_interp::CheckAttributeOp op, 802 ByteCodeWriter &writer) { 803 writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 804 op.getSuccessors()); 805 } 806 void Generator::generate(pdl_interp::CheckOperandCountOp op, 807 ByteCodeWriter &writer) { 808 writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 809 static_cast<ByteCodeField>(op.compareAtLeast()), 810 op.getSuccessors()); 811 } 812 void Generator::generate(pdl_interp::CheckOperationNameOp op, 813 ByteCodeWriter &writer) { 814 writer.append(OpCode::CheckOperationName, op.operation(), 815 OperationName(op.name(), ctx), op.getSuccessors()); 816 } 817 void Generator::generate(pdl_interp::CheckResultCountOp op, 818 ByteCodeWriter &writer) { 819 writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 820 static_cast<ByteCodeField>(op.compareAtLeast()), 821 op.getSuccessors()); 822 } 823 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 824 writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 825 } 826 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 827 writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); 828 } 829 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { 830 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); 831 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); 832 } 833 void Generator::generate(pdl_interp::CreateAttributeOp op, 834 ByteCodeWriter &writer) { 835 // Simply repoint the memory index of the result to the constant. 836 getMemIndex(op.attribute()) = getMemIndex(op.value()); 837 } 838 void Generator::generate(pdl_interp::CreateOperationOp op, 839 ByteCodeWriter &writer) { 840 writer.append(OpCode::CreateOperation, op.operation(), 841 OperationName(op.name(), ctx)); 842 writer.appendPDLValueList(op.operands()); 843 844 // Add the attributes. 845 OperandRange attributes = op.attributes(); 846 writer.append(static_cast<ByteCodeField>(attributes.size())); 847 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) 848 writer.append(std::get<0>(it), std::get<1>(it)); 849 writer.appendPDLValueList(op.types()); 850 } 851 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 852 // Simply repoint the memory index of the result to the constant. 853 getMemIndex(op.result()) = getMemIndex(op.value()); 854 } 855 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 856 writer.append(OpCode::CreateTypes, op.result(), 857 getRangeStorageIndex(op.result()), op.value()); 858 } 859 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 860 writer.append(OpCode::EraseOp, op.operation()); 861 } 862 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { 863 OpCode opCode = 864 TypeSwitch<Type, OpCode>(op.result().getType()) 865 .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) 866 .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) 867 .Case([](pdl::TypeType) { return OpCode::ExtractType; }) 868 .Default([](Type) -> OpCode { 869 llvm_unreachable("unsupported element type"); 870 }); 871 writer.append(opCode, op.range(), op.index(), op.result()); 872 } 873 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 874 writer.append(OpCode::Finalize); 875 } 876 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { 877 BlockArgument arg = op.getLoopVariable(); 878 writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg); 879 writer.appendPDLValueKind(arg.getType()); 880 writer.append(curLoopLevel, op.successor()); 881 ++curLoopLevel; 882 if (curLoopLevel > maxLoopLevel) 883 maxLoopLevel = curLoopLevel; 884 generate(&op.region(), writer); 885 --curLoopLevel; 886 } 887 void Generator::generate(pdl_interp::GetAttributeOp op, 888 ByteCodeWriter &writer) { 889 writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 890 op.nameAttr()); 891 } 892 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 893 ByteCodeWriter &writer) { 894 writer.append(OpCode::GetAttributeType, op.result(), op.value()); 895 } 896 void Generator::generate(pdl_interp::GetDefiningOpOp op, 897 ByteCodeWriter &writer) { 898 writer.append(OpCode::GetDefiningOp, op.operation()); 899 writer.appendPDLValue(op.value()); 900 } 901 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 902 uint32_t index = op.index(); 903 if (index < 4) 904 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 905 else 906 writer.append(OpCode::GetOperandN, index); 907 writer.append(op.operation(), op.value()); 908 } 909 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 910 Value result = op.value(); 911 Optional<uint32_t> index = op.index(); 912 writer.append(OpCode::GetOperands, 913 index.getValueOr(std::numeric_limits<uint32_t>::max()), 914 op.operation()); 915 if (result.getType().isa<pdl::RangeType>()) 916 writer.append(getRangeStorageIndex(result)); 917 else 918 writer.append(std::numeric_limits<ByteCodeField>::max()); 919 writer.append(result); 920 } 921 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 922 uint32_t index = op.index(); 923 if (index < 4) 924 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 925 else 926 writer.append(OpCode::GetResultN, index); 927 writer.append(op.operation(), op.value()); 928 } 929 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 930 Value result = op.value(); 931 Optional<uint32_t> index = op.index(); 932 writer.append(OpCode::GetResults, 933 index.getValueOr(std::numeric_limits<uint32_t>::max()), 934 op.operation()); 935 if (result.getType().isa<pdl::RangeType>()) 936 writer.append(getRangeStorageIndex(result)); 937 else 938 writer.append(std::numeric_limits<ByteCodeField>::max()); 939 writer.append(result); 940 } 941 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { 942 Value operations = op.operations(); 943 ByteCodeField rangeIndex = getRangeStorageIndex(operations); 944 writer.append(OpCode::GetUsers, operations, rangeIndex); 945 writer.appendPDLValue(op.value()); 946 } 947 void Generator::generate(pdl_interp::GetValueTypeOp op, 948 ByteCodeWriter &writer) { 949 if (op.getType().isa<pdl::RangeType>()) { 950 Value result = op.result(); 951 writer.append(OpCode::GetValueRangeTypes, result, 952 getRangeStorageIndex(result), op.value()); 953 } else { 954 writer.append(OpCode::GetValueType, op.result(), op.value()); 955 } 956 } 957 958 void Generator::generate(pdl_interp::InferredTypesOp op, 959 ByteCodeWriter &writer) { 960 // InferType maps to a null type as a marker for inferring result types. 961 getMemIndex(op.type()) = getMemIndex(Type()); 962 } 963 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 964 writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 965 } 966 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 967 ByteCodeField patternIndex = patterns.size(); 968 patterns.emplace_back(PDLByteCodePattern::create( 969 op, rewriterToAddr[op.rewriter().getLeafReference().getValue()])); 970 writer.append(OpCode::RecordMatch, patternIndex, 971 SuccessorRange(op.getOperation()), op.matchedOps()); 972 writer.appendPDLValueList(op.inputs()); 973 } 974 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 975 writer.append(OpCode::ReplaceOp, op.operation()); 976 writer.appendPDLValueList(op.replValues()); 977 } 978 void Generator::generate(pdl_interp::SwitchAttributeOp op, 979 ByteCodeWriter &writer) { 980 writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 981 op.getSuccessors()); 982 } 983 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 984 ByteCodeWriter &writer) { 985 writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 986 op.getSuccessors()); 987 } 988 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 989 ByteCodeWriter &writer) { 990 auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 991 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 992 }); 993 writer.append(OpCode::SwitchOperationName, op.operation(), cases, 994 op.getSuccessors()); 995 } 996 void Generator::generate(pdl_interp::SwitchResultCountOp op, 997 ByteCodeWriter &writer) { 998 writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 999 op.getSuccessors()); 1000 } 1001 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 1002 writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 1003 op.getSuccessors()); 1004 } 1005 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 1006 writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(), 1007 op.getSuccessors()); 1008 } 1009 1010 //===----------------------------------------------------------------------===// 1011 // PDLByteCode 1012 //===----------------------------------------------------------------------===// 1013 1014 PDLByteCode::PDLByteCode(ModuleOp module, 1015 llvm::StringMap<PDLConstraintFunction> constraintFns, 1016 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 1017 Generator generator(module.getContext(), uniquedData, matcherByteCode, 1018 rewriterByteCode, patterns, maxValueMemoryIndex, 1019 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, 1020 maxLoopLevel, constraintFns, rewriteFns); 1021 generator.generate(module); 1022 1023 // Initialize the external functions. 1024 for (auto &it : constraintFns) 1025 constraintFunctions.push_back(std::move(it.second)); 1026 for (auto &it : rewriteFns) 1027 rewriteFunctions.push_back(std::move(it.second)); 1028 } 1029 1030 /// Initialize the given state such that it can be used to execute the current 1031 /// bytecode. 1032 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 1033 state.memory.resize(maxValueMemoryIndex, nullptr); 1034 state.opRangeMemory.resize(maxOpRangeCount); 1035 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 1036 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 1037 state.loopIndex.resize(maxLoopLevel, 0); 1038 state.currentPatternBenefits.reserve(patterns.size()); 1039 for (const PDLByteCodePattern &pattern : patterns) 1040 state.currentPatternBenefits.push_back(pattern.getBenefit()); 1041 } 1042 1043 //===----------------------------------------------------------------------===// 1044 // ByteCode Execution 1045 1046 namespace { 1047 /// This class provides support for executing a bytecode stream. 1048 class ByteCodeExecutor { 1049 public: 1050 ByteCodeExecutor( 1051 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 1052 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, 1053 MutableArrayRef<TypeRange> typeRangeMemory, 1054 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 1055 MutableArrayRef<ValueRange> valueRangeMemory, 1056 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 1057 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, 1058 ArrayRef<ByteCodeField> code, 1059 ArrayRef<PatternBenefit> currentPatternBenefits, 1060 ArrayRef<PDLByteCodePattern> patterns, 1061 ArrayRef<PDLConstraintFunction> constraintFunctions, 1062 ArrayRef<PDLRewriteFunction> rewriteFunctions) 1063 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), 1064 typeRangeMemory(typeRangeMemory), 1065 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 1066 valueRangeMemory(valueRangeMemory), 1067 allocatedValueRangeMemory(allocatedValueRangeMemory), 1068 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), 1069 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 1070 constraintFunctions(constraintFunctions), 1071 rewriteFunctions(rewriteFunctions) {} 1072 1073 /// Start executing the code at the current bytecode index. `matches` is an 1074 /// optional field provided when this function is executed in a matching 1075 /// context. 1076 void execute(PatternRewriter &rewriter, 1077 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 1078 Optional<Location> mainRewriteLoc = {}); 1079 1080 private: 1081 /// Internal implementation of executing each of the bytecode commands. 1082 void executeApplyConstraint(PatternRewriter &rewriter); 1083 void executeApplyRewrite(PatternRewriter &rewriter); 1084 void executeAreEqual(); 1085 void executeAreRangesEqual(); 1086 void executeBranch(); 1087 void executeCheckOperandCount(); 1088 void executeCheckOperationName(); 1089 void executeCheckResultCount(); 1090 void executeCheckTypes(); 1091 void executeContinue(); 1092 void executeCreateOperation(PatternRewriter &rewriter, 1093 Location mainRewriteLoc); 1094 void executeCreateTypes(); 1095 void executeEraseOp(PatternRewriter &rewriter); 1096 template <typename T, typename Range, PDLValue::Kind kind> 1097 void executeExtract(); 1098 void executeFinalize(); 1099 void executeForEach(); 1100 void executeGetAttribute(); 1101 void executeGetAttributeType(); 1102 void executeGetDefiningOp(); 1103 void executeGetOperand(unsigned index); 1104 void executeGetOperands(); 1105 void executeGetResult(unsigned index); 1106 void executeGetResults(); 1107 void executeGetUsers(); 1108 void executeGetValueType(); 1109 void executeGetValueRangeTypes(); 1110 void executeIsNotNull(); 1111 void executeRecordMatch(PatternRewriter &rewriter, 1112 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 1113 void executeReplaceOp(PatternRewriter &rewriter); 1114 void executeSwitchAttribute(); 1115 void executeSwitchOperandCount(); 1116 void executeSwitchOperationName(); 1117 void executeSwitchResultCount(); 1118 void executeSwitchType(); 1119 void executeSwitchTypes(); 1120 1121 /// Pushes a code iterator to the stack. 1122 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } 1123 1124 /// Pops a code iterator from the stack, returning true on success. 1125 void popCodeIt() { 1126 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); 1127 curCodeIt = resumeCodeIt.back(); 1128 resumeCodeIt.pop_back(); 1129 } 1130 1131 /// Return the bytecode iterator at the start of the current op code. 1132 const ByteCodeField *getPrevCodeIt() const { 1133 LLVM_DEBUG({ 1134 // Account for the op code and the Location stored inline. 1135 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField); 1136 }); 1137 1138 // Account for the op code only. 1139 return curCodeIt - 1; 1140 } 1141 1142 /// Read a value from the bytecode buffer, optionally skipping a certain 1143 /// number of prefix values. These methods always update the buffer to point 1144 /// to the next field after the read data. 1145 template <typename T = ByteCodeField> 1146 T read(size_t skipN = 0) { 1147 curCodeIt += skipN; 1148 return readImpl<T>(); 1149 } 1150 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 1151 1152 /// Read a list of values from the bytecode buffer. 1153 template <typename ValueT, typename T> 1154 void readList(SmallVectorImpl<T> &list) { 1155 list.clear(); 1156 for (unsigned i = 0, e = read(); i != e; ++i) 1157 list.push_back(read<ValueT>()); 1158 } 1159 1160 /// Read a list of values from the bytecode buffer. The values may be encoded 1161 /// as either Value or ValueRange elements. 1162 void readValueList(SmallVectorImpl<Value> &list) { 1163 for (unsigned i = 0, e = read(); i != e; ++i) { 1164 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1165 list.push_back(read<Value>()); 1166 } else { 1167 ValueRange *values = read<ValueRange *>(); 1168 list.append(values->begin(), values->end()); 1169 } 1170 } 1171 } 1172 1173 /// Read a value stored inline as a pointer. 1174 template <typename T> 1175 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T> 1176 readInline() { 1177 const void *pointer; 1178 std::memcpy(&pointer, curCodeIt, sizeof(const void *)); 1179 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); 1180 return T::getFromOpaquePointer(pointer); 1181 } 1182 1183 /// Jump to a specific successor based on a predicate value. 1184 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 1185 /// Jump to a specific successor based on a destination index. 1186 void selectJump(size_t destIndex) { 1187 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 1188 } 1189 1190 /// Handle a switch operation with the provided value and cases. 1191 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1192 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1193 LLVM_DEBUG({ 1194 llvm::dbgs() << " * Value: " << value << "\n" 1195 << " * Cases: "; 1196 llvm::interleaveComma(cases, llvm::dbgs()); 1197 llvm::dbgs() << "\n"; 1198 }); 1199 1200 // Check to see if the attribute value is within the case list. Jump to 1201 // the correct successor index based on the result. 1202 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1203 if (cmp(*it, value)) 1204 return selectJump(size_t((it - cases.begin()) + 1)); 1205 selectJump(size_t(0)); 1206 } 1207 1208 /// Store a pointer to memory. 1209 void storeToMemory(unsigned index, const void *value) { 1210 memory[index] = value; 1211 } 1212 1213 /// Store a value to memory as an opaque pointer. 1214 template <typename T> 1215 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 1216 storeToMemory(unsigned index, T value) { 1217 memory[index] = value.getAsOpaquePointer(); 1218 } 1219 1220 /// Internal implementation of reading various data types from the bytecode 1221 /// stream. 1222 template <typename T> 1223 const void *readFromMemory() { 1224 size_t index = *curCodeIt++; 1225 1226 // If this type is an SSA value, it can only be stored in non-const memory. 1227 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1228 Value>::value || 1229 index < memory.size()) 1230 return memory[index]; 1231 1232 // Otherwise, if this index is not inbounds it is uniqued. 1233 return uniquedMemory[index - memory.size()]; 1234 } 1235 template <typename T> 1236 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1237 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1238 } 1239 template <typename T> 1240 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1241 T> 1242 readImpl() { 1243 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1244 } 1245 template <typename T> 1246 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1247 switch (read<PDLValue::Kind>()) { 1248 case PDLValue::Kind::Attribute: 1249 return read<Attribute>(); 1250 case PDLValue::Kind::Operation: 1251 return read<Operation *>(); 1252 case PDLValue::Kind::Type: 1253 return read<Type>(); 1254 case PDLValue::Kind::Value: 1255 return read<Value>(); 1256 case PDLValue::Kind::TypeRange: 1257 return read<TypeRange *>(); 1258 case PDLValue::Kind::ValueRange: 1259 return read<ValueRange *>(); 1260 } 1261 llvm_unreachable("unhandled PDLValue::Kind"); 1262 } 1263 template <typename T> 1264 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1265 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1266 "unexpected ByteCode address size"); 1267 ByteCodeAddr result; 1268 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1269 curCodeIt += 2; 1270 return result; 1271 } 1272 template <typename T> 1273 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1274 return *curCodeIt++; 1275 } 1276 template <typename T> 1277 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1278 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1279 } 1280 1281 /// The underlying bytecode buffer. 1282 const ByteCodeField *curCodeIt; 1283 1284 /// The stack of bytecode positions at which to resume operation. 1285 SmallVector<const ByteCodeField *> resumeCodeIt; 1286 1287 /// The current execution memory. 1288 MutableArrayRef<const void *> memory; 1289 MutableArrayRef<OwningOpRange> opRangeMemory; 1290 MutableArrayRef<TypeRange> typeRangeMemory; 1291 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1292 MutableArrayRef<ValueRange> valueRangeMemory; 1293 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1294 1295 /// The current loop indices. 1296 MutableArrayRef<unsigned> loopIndex; 1297 1298 /// References to ByteCode data necessary for execution. 1299 ArrayRef<const void *> uniquedMemory; 1300 ArrayRef<ByteCodeField> code; 1301 ArrayRef<PatternBenefit> currentPatternBenefits; 1302 ArrayRef<PDLByteCodePattern> patterns; 1303 ArrayRef<PDLConstraintFunction> constraintFunctions; 1304 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1305 }; 1306 1307 /// This class is an instantiation of the PDLResultList that provides access to 1308 /// the returned results. This API is not on `PDLResultList` to avoid 1309 /// overexposing access to information specific solely to the ByteCode. 1310 class ByteCodeRewriteResultList : public PDLResultList { 1311 public: 1312 ByteCodeRewriteResultList(unsigned maxNumResults) 1313 : PDLResultList(maxNumResults) {} 1314 1315 /// Return the list of PDL results. 1316 MutableArrayRef<PDLValue> getResults() { return results; } 1317 1318 /// Return the type ranges allocated by this list. 1319 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1320 return allocatedTypeRanges; 1321 } 1322 1323 /// Return the value ranges allocated by this list. 1324 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1325 return allocatedValueRanges; 1326 } 1327 }; 1328 } // namespace 1329 1330 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1331 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1332 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1333 ArrayAttr constParams = read<ArrayAttr>(); 1334 SmallVector<PDLValue, 16> args; 1335 readList<PDLValue>(args); 1336 1337 LLVM_DEBUG({ 1338 llvm::dbgs() << " * Arguments: "; 1339 llvm::interleaveComma(args, llvm::dbgs()); 1340 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1341 }); 1342 1343 // Invoke the constraint and jump to the proper destination. 1344 selectJump(succeeded(constraintFn(args, constParams, rewriter))); 1345 } 1346 1347 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1348 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1349 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1350 ArrayAttr constParams = read<ArrayAttr>(); 1351 SmallVector<PDLValue, 16> args; 1352 readList<PDLValue>(args); 1353 1354 LLVM_DEBUG({ 1355 llvm::dbgs() << " * Arguments: "; 1356 llvm::interleaveComma(args, llvm::dbgs()); 1357 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1358 }); 1359 1360 // Execute the rewrite function. 1361 ByteCodeField numResults = read(); 1362 ByteCodeRewriteResultList results(numResults); 1363 rewriteFn(args, constParams, rewriter, results); 1364 1365 assert(results.getResults().size() == numResults && 1366 "native PDL rewrite function returned unexpected number of results"); 1367 1368 // Store the results in the bytecode memory. 1369 for (PDLValue &result : results.getResults()) { 1370 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1371 1372 // In debug mode we also verify the expected kind of the result. 1373 #ifndef NDEBUG 1374 assert(result.getKind() == read<PDLValue::Kind>() && 1375 "native PDL rewrite function returned an unexpected type of result"); 1376 #endif 1377 1378 // If the result is a range, we need to copy it over to the bytecodes 1379 // range memory. 1380 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1381 unsigned rangeIndex = read(); 1382 typeRangeMemory[rangeIndex] = *typeRange; 1383 memory[read()] = &typeRangeMemory[rangeIndex]; 1384 } else if (Optional<ValueRange> valueRange = 1385 result.dyn_cast<ValueRange>()) { 1386 unsigned rangeIndex = read(); 1387 valueRangeMemory[rangeIndex] = *valueRange; 1388 memory[read()] = &valueRangeMemory[rangeIndex]; 1389 } else { 1390 memory[read()] = result.getAsOpaquePointer(); 1391 } 1392 } 1393 1394 // Copy over any underlying storage allocated for result ranges. 1395 for (auto &it : results.getAllocatedTypeRanges()) 1396 allocatedTypeRangeMemory.push_back(std::move(it)); 1397 for (auto &it : results.getAllocatedValueRanges()) 1398 allocatedValueRangeMemory.push_back(std::move(it)); 1399 } 1400 1401 void ByteCodeExecutor::executeAreEqual() { 1402 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1403 const void *lhs = read<const void *>(); 1404 const void *rhs = read<const void *>(); 1405 1406 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1407 selectJump(lhs == rhs); 1408 } 1409 1410 void ByteCodeExecutor::executeAreRangesEqual() { 1411 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1412 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1413 const void *lhs = read<const void *>(); 1414 const void *rhs = read<const void *>(); 1415 1416 switch (valueKind) { 1417 case PDLValue::Kind::TypeRange: { 1418 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1419 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1420 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1421 selectJump(*lhsRange == *rhsRange); 1422 break; 1423 } 1424 case PDLValue::Kind::ValueRange: { 1425 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1426 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1427 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1428 selectJump(*lhsRange == *rhsRange); 1429 break; 1430 } 1431 default: 1432 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1433 } 1434 } 1435 1436 void ByteCodeExecutor::executeBranch() { 1437 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1438 curCodeIt = &code[read<ByteCodeAddr>()]; 1439 } 1440 1441 void ByteCodeExecutor::executeCheckOperandCount() { 1442 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1443 Operation *op = read<Operation *>(); 1444 uint32_t expectedCount = read<uint32_t>(); 1445 bool compareAtLeast = read(); 1446 1447 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1448 << " * Expected: " << expectedCount << "\n" 1449 << " * Comparator: " 1450 << (compareAtLeast ? ">=" : "==") << "\n"); 1451 if (compareAtLeast) 1452 selectJump(op->getNumOperands() >= expectedCount); 1453 else 1454 selectJump(op->getNumOperands() == expectedCount); 1455 } 1456 1457 void ByteCodeExecutor::executeCheckOperationName() { 1458 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1459 Operation *op = read<Operation *>(); 1460 OperationName expectedName = read<OperationName>(); 1461 1462 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1463 << " * Expected: \"" << expectedName << "\"\n"); 1464 selectJump(op->getName() == expectedName); 1465 } 1466 1467 void ByteCodeExecutor::executeCheckResultCount() { 1468 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1469 Operation *op = read<Operation *>(); 1470 uint32_t expectedCount = read<uint32_t>(); 1471 bool compareAtLeast = read(); 1472 1473 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1474 << " * Expected: " << expectedCount << "\n" 1475 << " * Comparator: " 1476 << (compareAtLeast ? ">=" : "==") << "\n"); 1477 if (compareAtLeast) 1478 selectJump(op->getNumResults() >= expectedCount); 1479 else 1480 selectJump(op->getNumResults() == expectedCount); 1481 } 1482 1483 void ByteCodeExecutor::executeCheckTypes() { 1484 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1485 TypeRange *lhs = read<TypeRange *>(); 1486 Attribute rhs = read<Attribute>(); 1487 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1488 1489 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1490 } 1491 1492 void ByteCodeExecutor::executeContinue() { 1493 ByteCodeField level = read(); 1494 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" 1495 << " * Level: " << level << "\n"); 1496 ++loopIndex[level]; 1497 popCodeIt(); 1498 } 1499 1500 void ByteCodeExecutor::executeCreateTypes() { 1501 LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 1502 unsigned memIndex = read(); 1503 unsigned rangeIndex = read(); 1504 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1505 1506 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1507 1508 // Allocate a buffer for this type range. 1509 llvm::OwningArrayRef<Type> storage(typesAttr.size()); 1510 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 1511 allocatedTypeRangeMemory.emplace_back(std::move(storage)); 1512 1513 // Assign this to the range slot and use the range as the value for the 1514 // memory index. 1515 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 1516 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1517 } 1518 1519 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1520 Location mainRewriteLoc) { 1521 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1522 1523 unsigned memIndex = read(); 1524 OperationState state(mainRewriteLoc, read<OperationName>()); 1525 readValueList(state.operands); 1526 for (unsigned i = 0, e = read(); i != e; ++i) { 1527 StringAttr name = read<StringAttr>(); 1528 if (Attribute attr = read<Attribute>()) 1529 state.addAttribute(name, attr); 1530 } 1531 1532 for (unsigned i = 0, e = read(); i != e; ++i) { 1533 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1534 state.types.push_back(read<Type>()); 1535 continue; 1536 } 1537 1538 // If we find a null range, this signals that the types are infered. 1539 if (TypeRange *resultTypes = read<TypeRange *>()) { 1540 state.types.append(resultTypes->begin(), resultTypes->end()); 1541 continue; 1542 } 1543 1544 // Handle the case where the operation has inferred types. 1545 InferTypeOpInterface::Concept *concept = 1546 state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>(); 1547 1548 // TODO: Handle failure. 1549 state.types.clear(); 1550 if (failed(concept->inferReturnTypes( 1551 state.getContext(), state.location, state.operands, 1552 state.attributes.getDictionary(state.getContext()), state.regions, 1553 state.types))) 1554 return; 1555 break; 1556 } 1557 1558 Operation *resultOp = rewriter.createOperation(state); 1559 memory[memIndex] = resultOp; 1560 1561 LLVM_DEBUG({ 1562 llvm::dbgs() << " * Attributes: " 1563 << state.attributes.getDictionary(state.getContext()) 1564 << "\n * Operands: "; 1565 llvm::interleaveComma(state.operands, llvm::dbgs()); 1566 llvm::dbgs() << "\n * Result Types: "; 1567 llvm::interleaveComma(state.types, llvm::dbgs()); 1568 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1569 }); 1570 } 1571 1572 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1573 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1574 Operation *op = read<Operation *>(); 1575 1576 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1577 rewriter.eraseOp(op); 1578 } 1579 1580 template <typename T, typename Range, PDLValue::Kind kind> 1581 void ByteCodeExecutor::executeExtract() { 1582 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); 1583 Range *range = read<Range *>(); 1584 unsigned index = read<uint32_t>(); 1585 unsigned memIndex = read(); 1586 1587 if (!range) { 1588 memory[memIndex] = nullptr; 1589 return; 1590 } 1591 1592 T result = index < range->size() ? (*range)[index] : T(); 1593 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" 1594 << " * Index: " << index << "\n" 1595 << " * Result: " << result << "\n"); 1596 storeToMemory(memIndex, result); 1597 } 1598 1599 void ByteCodeExecutor::executeFinalize() { 1600 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); 1601 } 1602 1603 void ByteCodeExecutor::executeForEach() { 1604 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); 1605 const ByteCodeField *prevCodeIt = getPrevCodeIt(); 1606 unsigned rangeIndex = read(); 1607 unsigned memIndex = read(); 1608 const void *value = nullptr; 1609 1610 switch (read<PDLValue::Kind>()) { 1611 case PDLValue::Kind::Operation: { 1612 unsigned &index = loopIndex[read()]; 1613 ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; 1614 assert(index <= array.size() && "iterated past the end"); 1615 if (index < array.size()) { 1616 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); 1617 value = array[index]; 1618 break; 1619 } 1620 1621 LLVM_DEBUG(llvm::dbgs() << " * Done\n"); 1622 index = 0; 1623 selectJump(size_t(0)); 1624 return; 1625 } 1626 default: 1627 llvm_unreachable("unexpected `ForEach` value kind"); 1628 } 1629 1630 // Store the iterate value and the stack address. 1631 memory[memIndex] = value; 1632 pushCodeIt(prevCodeIt); 1633 1634 // Skip over the successor (we will enter the body of the loop). 1635 read<ByteCodeAddr>(); 1636 } 1637 1638 void ByteCodeExecutor::executeGetAttribute() { 1639 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1640 unsigned memIndex = read(); 1641 Operation *op = read<Operation *>(); 1642 StringAttr attrName = read<StringAttr>(); 1643 Attribute attr = op->getAttr(attrName); 1644 1645 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1646 << " * Attribute: " << attrName << "\n" 1647 << " * Result: " << attr << "\n"); 1648 memory[memIndex] = attr.getAsOpaquePointer(); 1649 } 1650 1651 void ByteCodeExecutor::executeGetAttributeType() { 1652 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1653 unsigned memIndex = read(); 1654 Attribute attr = read<Attribute>(); 1655 Type type = attr ? attr.getType() : Type(); 1656 1657 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1658 << " * Result: " << type << "\n"); 1659 memory[memIndex] = type.getAsOpaquePointer(); 1660 } 1661 1662 void ByteCodeExecutor::executeGetDefiningOp() { 1663 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1664 unsigned memIndex = read(); 1665 Operation *op = nullptr; 1666 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1667 Value value = read<Value>(); 1668 if (value) 1669 op = value.getDefiningOp(); 1670 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1671 } else { 1672 ValueRange *values = read<ValueRange *>(); 1673 if (values && !values->empty()) { 1674 op = values->front().getDefiningOp(); 1675 } 1676 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1677 } 1678 1679 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1680 memory[memIndex] = op; 1681 } 1682 1683 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1684 Operation *op = read<Operation *>(); 1685 unsigned memIndex = read(); 1686 Value operand = 1687 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1688 1689 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1690 << " * Index: " << index << "\n" 1691 << " * Result: " << operand << "\n"); 1692 memory[memIndex] = operand.getAsOpaquePointer(); 1693 } 1694 1695 /// This function is the internal implementation of `GetResults` and 1696 /// `GetOperands` that provides support for extracting a value range from the 1697 /// given operation. 1698 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1699 static void * 1700 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1701 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1702 MutableArrayRef<ValueRange> valueRangeMemory) { 1703 // Check for the sentinel index that signals that all values should be 1704 // returned. 1705 if (index == std::numeric_limits<uint32_t>::max()) { 1706 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1707 // `values` is already the full value range. 1708 1709 // Otherwise, check to see if this operation uses AttrSizedSegments. 1710 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1711 LLVM_DEBUG(llvm::dbgs() 1712 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1713 1714 auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 1715 if (!segmentAttr || segmentAttr.getNumElements() <= index) 1716 return nullptr; 1717 1718 auto segments = segmentAttr.getValues<int32_t>(); 1719 unsigned startIndex = 1720 std::accumulate(segments.begin(), segments.begin() + index, 0); 1721 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1722 1723 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1724 << *std::next(segments.begin(), index) << "]\n"); 1725 1726 // Otherwise, assume this is the last operand group of the operation. 1727 // FIXME: We currently don't support operations with 1728 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1729 // have a way to detect it's presence. 1730 } else if (values.size() >= index) { 1731 LLVM_DEBUG(llvm::dbgs() 1732 << " * Treating values as trailing variadic range\n"); 1733 values = values.drop_front(index); 1734 1735 // If we couldn't detect a way to compute the values, bail out. 1736 } else { 1737 return nullptr; 1738 } 1739 1740 // If the range index is valid, we are returning a range. 1741 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1742 valueRangeMemory[rangeIndex] = values; 1743 return &valueRangeMemory[rangeIndex]; 1744 } 1745 1746 // If a range index wasn't provided, the range is required to be non-variadic. 1747 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1748 } 1749 1750 void ByteCodeExecutor::executeGetOperands() { 1751 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1752 unsigned index = read<uint32_t>(); 1753 Operation *op = read<Operation *>(); 1754 ByteCodeField rangeIndex = read(); 1755 1756 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1757 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1758 valueRangeMemory); 1759 if (!result) 1760 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1761 memory[read()] = result; 1762 } 1763 1764 void ByteCodeExecutor::executeGetResult(unsigned index) { 1765 Operation *op = read<Operation *>(); 1766 unsigned memIndex = read(); 1767 OpResult result = 1768 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1769 1770 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1771 << " * Index: " << index << "\n" 1772 << " * Result: " << result << "\n"); 1773 memory[memIndex] = result.getAsOpaquePointer(); 1774 } 1775 1776 void ByteCodeExecutor::executeGetResults() { 1777 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1778 unsigned index = read<uint32_t>(); 1779 Operation *op = read<Operation *>(); 1780 ByteCodeField rangeIndex = read(); 1781 1782 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1783 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1784 valueRangeMemory); 1785 if (!result) 1786 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1787 memory[read()] = result; 1788 } 1789 1790 void ByteCodeExecutor::executeGetUsers() { 1791 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); 1792 unsigned memIndex = read(); 1793 unsigned rangeIndex = read(); 1794 OwningOpRange &range = opRangeMemory[rangeIndex]; 1795 memory[memIndex] = ⦥ 1796 1797 range = OwningOpRange(); 1798 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1799 // Read the value. 1800 Value value = read<Value>(); 1801 if (!value) 1802 return; 1803 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1804 1805 // Extract the users of a single value. 1806 range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); 1807 llvm::copy(value.getUsers(), range.begin()); 1808 } else { 1809 // Read a range of values. 1810 ValueRange *values = read<ValueRange *>(); 1811 if (!values) 1812 return; 1813 LLVM_DEBUG({ 1814 llvm::dbgs() << " * Values (" << values->size() << "): "; 1815 llvm::interleaveComma(*values, llvm::dbgs()); 1816 llvm::dbgs() << "\n"; 1817 }); 1818 1819 // Extract all the users of a range of values. 1820 SmallVector<Operation *> users; 1821 for (Value value : *values) 1822 users.append(value.user_begin(), value.user_end()); 1823 range = OwningOpRange(users.size()); 1824 llvm::copy(users, range.begin()); 1825 } 1826 1827 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); 1828 } 1829 1830 void ByteCodeExecutor::executeGetValueType() { 1831 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1832 unsigned memIndex = read(); 1833 Value value = read<Value>(); 1834 Type type = value ? value.getType() : Type(); 1835 1836 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1837 << " * Result: " << type << "\n"); 1838 memory[memIndex] = type.getAsOpaquePointer(); 1839 } 1840 1841 void ByteCodeExecutor::executeGetValueRangeTypes() { 1842 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1843 unsigned memIndex = read(); 1844 unsigned rangeIndex = read(); 1845 ValueRange *values = read<ValueRange *>(); 1846 if (!values) { 1847 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1848 memory[memIndex] = nullptr; 1849 return; 1850 } 1851 1852 LLVM_DEBUG({ 1853 llvm::dbgs() << " * Values (" << values->size() << "): "; 1854 llvm::interleaveComma(*values, llvm::dbgs()); 1855 llvm::dbgs() << "\n * Result: "; 1856 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1857 llvm::dbgs() << "\n"; 1858 }); 1859 typeRangeMemory[rangeIndex] = values->getType(); 1860 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1861 } 1862 1863 void ByteCodeExecutor::executeIsNotNull() { 1864 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1865 const void *value = read<const void *>(); 1866 1867 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1868 selectJump(value != nullptr); 1869 } 1870 1871 void ByteCodeExecutor::executeRecordMatch( 1872 PatternRewriter &rewriter, 1873 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1874 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1875 unsigned patternIndex = read(); 1876 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1877 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1878 1879 // If the benefit of the pattern is impossible, skip the processing of the 1880 // rest of the pattern. 1881 if (benefit.isImpossibleToMatch()) { 1882 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1883 curCodeIt = dest; 1884 return; 1885 } 1886 1887 // Create a fused location containing the locations of each of the 1888 // operations used in the match. This will be used as the location for 1889 // created operations during the rewrite that don't already have an 1890 // explicit location set. 1891 unsigned numMatchLocs = read(); 1892 SmallVector<Location, 4> matchLocs; 1893 matchLocs.reserve(numMatchLocs); 1894 for (unsigned i = 0; i != numMatchLocs; ++i) 1895 matchLocs.push_back(read<Operation *>()->getLoc()); 1896 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1897 1898 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1899 << " * Location: " << matchLoc << "\n"); 1900 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1901 PDLByteCode::MatchResult &match = matches.back(); 1902 1903 // Record all of the inputs to the match. If any of the inputs are ranges, we 1904 // will also need to remap the range pointer to memory stored in the match 1905 // state. 1906 unsigned numInputs = read(); 1907 match.values.reserve(numInputs); 1908 match.typeRangeValues.reserve(numInputs); 1909 match.valueRangeValues.reserve(numInputs); 1910 for (unsigned i = 0; i < numInputs; ++i) { 1911 switch (read<PDLValue::Kind>()) { 1912 case PDLValue::Kind::TypeRange: 1913 match.typeRangeValues.push_back(*read<TypeRange *>()); 1914 match.values.push_back(&match.typeRangeValues.back()); 1915 break; 1916 case PDLValue::Kind::ValueRange: 1917 match.valueRangeValues.push_back(*read<ValueRange *>()); 1918 match.values.push_back(&match.valueRangeValues.back()); 1919 break; 1920 default: 1921 match.values.push_back(read<const void *>()); 1922 break; 1923 } 1924 } 1925 curCodeIt = dest; 1926 } 1927 1928 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1929 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1930 Operation *op = read<Operation *>(); 1931 SmallVector<Value, 16> args; 1932 readValueList(args); 1933 1934 LLVM_DEBUG({ 1935 llvm::dbgs() << " * Operation: " << *op << "\n" 1936 << " * Values: "; 1937 llvm::interleaveComma(args, llvm::dbgs()); 1938 llvm::dbgs() << "\n"; 1939 }); 1940 rewriter.replaceOp(op, args); 1941 } 1942 1943 void ByteCodeExecutor::executeSwitchAttribute() { 1944 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1945 Attribute value = read<Attribute>(); 1946 ArrayAttr cases = read<ArrayAttr>(); 1947 handleSwitch(value, cases); 1948 } 1949 1950 void ByteCodeExecutor::executeSwitchOperandCount() { 1951 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1952 Operation *op = read<Operation *>(); 1953 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1954 1955 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1956 handleSwitch(op->getNumOperands(), cases); 1957 } 1958 1959 void ByteCodeExecutor::executeSwitchOperationName() { 1960 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1961 OperationName value = read<Operation *>()->getName(); 1962 size_t caseCount = read(); 1963 1964 // The operation names are stored in-line, so to print them out for 1965 // debugging purposes we need to read the array before executing the 1966 // switch so that we can display all of the possible values. 1967 LLVM_DEBUG({ 1968 const ByteCodeField *prevCodeIt = curCodeIt; 1969 llvm::dbgs() << " * Value: " << value << "\n" 1970 << " * Cases: "; 1971 llvm::interleaveComma( 1972 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1973 [&](size_t) { return read<OperationName>(); }), 1974 llvm::dbgs()); 1975 llvm::dbgs() << "\n"; 1976 curCodeIt = prevCodeIt; 1977 }); 1978 1979 // Try to find the switch value within any of the cases. 1980 for (size_t i = 0; i != caseCount; ++i) { 1981 if (read<OperationName>() == value) { 1982 curCodeIt += (caseCount - i - 1); 1983 return selectJump(i + 1); 1984 } 1985 } 1986 selectJump(size_t(0)); 1987 } 1988 1989 void ByteCodeExecutor::executeSwitchResultCount() { 1990 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1991 Operation *op = read<Operation *>(); 1992 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1993 1994 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1995 handleSwitch(op->getNumResults(), cases); 1996 } 1997 1998 void ByteCodeExecutor::executeSwitchType() { 1999 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 2000 Type value = read<Type>(); 2001 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 2002 handleSwitch(value, cases); 2003 } 2004 2005 void ByteCodeExecutor::executeSwitchTypes() { 2006 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 2007 TypeRange *value = read<TypeRange *>(); 2008 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 2009 if (!value) { 2010 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 2011 return selectJump(size_t(0)); 2012 } 2013 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 2014 return value == caseValue.getAsValueRange<TypeAttr>(); 2015 }); 2016 } 2017 2018 void ByteCodeExecutor::execute( 2019 PatternRewriter &rewriter, 2020 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 2021 Optional<Location> mainRewriteLoc) { 2022 while (true) { 2023 // Print the location of the operation being executed. 2024 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); 2025 2026 OpCode opCode = static_cast<OpCode>(read()); 2027 switch (opCode) { 2028 case ApplyConstraint: 2029 executeApplyConstraint(rewriter); 2030 break; 2031 case ApplyRewrite: 2032 executeApplyRewrite(rewriter); 2033 break; 2034 case AreEqual: 2035 executeAreEqual(); 2036 break; 2037 case AreRangesEqual: 2038 executeAreRangesEqual(); 2039 break; 2040 case Branch: 2041 executeBranch(); 2042 break; 2043 case CheckOperandCount: 2044 executeCheckOperandCount(); 2045 break; 2046 case CheckOperationName: 2047 executeCheckOperationName(); 2048 break; 2049 case CheckResultCount: 2050 executeCheckResultCount(); 2051 break; 2052 case CheckTypes: 2053 executeCheckTypes(); 2054 break; 2055 case Continue: 2056 executeContinue(); 2057 break; 2058 case CreateOperation: 2059 executeCreateOperation(rewriter, *mainRewriteLoc); 2060 break; 2061 case CreateTypes: 2062 executeCreateTypes(); 2063 break; 2064 case EraseOp: 2065 executeEraseOp(rewriter); 2066 break; 2067 case ExtractOp: 2068 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); 2069 break; 2070 case ExtractType: 2071 executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); 2072 break; 2073 case ExtractValue: 2074 executeExtract<Value, ValueRange, PDLValue::Kind::Value>(); 2075 break; 2076 case Finalize: 2077 executeFinalize(); 2078 LLVM_DEBUG(llvm::dbgs() << "\n"); 2079 return; 2080 case ForEach: 2081 executeForEach(); 2082 break; 2083 case GetAttribute: 2084 executeGetAttribute(); 2085 break; 2086 case GetAttributeType: 2087 executeGetAttributeType(); 2088 break; 2089 case GetDefiningOp: 2090 executeGetDefiningOp(); 2091 break; 2092 case GetOperand0: 2093 case GetOperand1: 2094 case GetOperand2: 2095 case GetOperand3: { 2096 unsigned index = opCode - GetOperand0; 2097 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 2098 executeGetOperand(index); 2099 break; 2100 } 2101 case GetOperandN: 2102 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 2103 executeGetOperand(read<uint32_t>()); 2104 break; 2105 case GetOperands: 2106 executeGetOperands(); 2107 break; 2108 case GetResult0: 2109 case GetResult1: 2110 case GetResult2: 2111 case GetResult3: { 2112 unsigned index = opCode - GetResult0; 2113 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 2114 executeGetResult(index); 2115 break; 2116 } 2117 case GetResultN: 2118 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 2119 executeGetResult(read<uint32_t>()); 2120 break; 2121 case GetResults: 2122 executeGetResults(); 2123 break; 2124 case GetUsers: 2125 executeGetUsers(); 2126 break; 2127 case GetValueType: 2128 executeGetValueType(); 2129 break; 2130 case GetValueRangeTypes: 2131 executeGetValueRangeTypes(); 2132 break; 2133 case IsNotNull: 2134 executeIsNotNull(); 2135 break; 2136 case RecordMatch: 2137 assert(matches && 2138 "expected matches to be provided when executing the matcher"); 2139 executeRecordMatch(rewriter, *matches); 2140 break; 2141 case ReplaceOp: 2142 executeReplaceOp(rewriter); 2143 break; 2144 case SwitchAttribute: 2145 executeSwitchAttribute(); 2146 break; 2147 case SwitchOperandCount: 2148 executeSwitchOperandCount(); 2149 break; 2150 case SwitchOperationName: 2151 executeSwitchOperationName(); 2152 break; 2153 case SwitchResultCount: 2154 executeSwitchResultCount(); 2155 break; 2156 case SwitchType: 2157 executeSwitchType(); 2158 break; 2159 case SwitchTypes: 2160 executeSwitchTypes(); 2161 break; 2162 } 2163 LLVM_DEBUG(llvm::dbgs() << "\n"); 2164 } 2165 } 2166 2167 /// Run the pattern matcher on the given root operation, collecting the matched 2168 /// patterns in `matches`. 2169 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 2170 SmallVectorImpl<MatchResult> &matches, 2171 PDLByteCodeMutableState &state) const { 2172 // The first memory slot is always the root operation. 2173 state.memory[0] = op; 2174 2175 // The matcher function always starts at code address 0. 2176 ByteCodeExecutor executor( 2177 matcherByteCode.data(), state.memory, state.opRangeMemory, 2178 state.typeRangeMemory, state.allocatedTypeRangeMemory, 2179 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, 2180 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, 2181 constraintFunctions, rewriteFunctions); 2182 executor.execute(rewriter, &matches); 2183 2184 // Order the found matches by benefit. 2185 std::stable_sort(matches.begin(), matches.end(), 2186 [](const MatchResult &lhs, const MatchResult &rhs) { 2187 return lhs.benefit > rhs.benefit; 2188 }); 2189 } 2190 2191 /// Run the rewriter of the given pattern on the root operation `op`. 2192 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 2193 PDLByteCodeMutableState &state) const { 2194 // The arguments of the rewrite function are stored at the start of the 2195 // memory buffer. 2196 llvm::copy(match.values, state.memory.begin()); 2197 2198 ByteCodeExecutor executor( 2199 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 2200 state.opRangeMemory, state.typeRangeMemory, 2201 state.allocatedTypeRangeMemory, state.valueRangeMemory, 2202 state.allocatedValueRangeMemory, state.loopIndex, uniquedData, 2203 rewriterByteCode, state.currentPatternBenefits, patterns, 2204 constraintFunctions, rewriteFunctions); 2205 executor.execute(rewriter, /*matches=*/nullptr, match.location); 2206 } 2207