1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 27 #define DEBUG_TYPE "pdl-bytecode" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37 ByteCodeAddr rewriterAddr) { 38 SmallVector<StringRef, 8> generatedOps; 39 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) 40 generatedOps = 41 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42 43 PatternBenefit benefit = matchOp.getBenefit(); 44 MLIRContext *ctx = matchOp.getContext(); 45 46 // Check to see if this is pattern matches a specific operation type. 47 if (Optional<StringRef> rootKind = matchOp.getRootKind()) 48 return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, 49 generatedOps); 50 return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, 51 generatedOps); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // PDLByteCodeMutableState 56 //===----------------------------------------------------------------------===// 57 58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59 /// to the position of the pattern within the range returned by 60 /// `PDLByteCode::getPatterns`. 61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62 PatternBenefit benefit) { 63 currentPatternBenefits[patternIndex] = benefit; 64 } 65 66 /// Cleanup any allocated state after a full match/rewrite has been completed. 67 /// This method should be called irregardless of whether the match+rewrite was a 68 /// success or not. 69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 70 allocatedTypeRangeMemory.clear(); 71 allocatedValueRangeMemory.clear(); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Bytecode OpCodes 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 enum OpCode : ByteCodeField { 80 /// Apply an externally registered constraint. 81 ApplyConstraint, 82 /// Apply an externally registered rewrite. 83 ApplyRewrite, 84 /// Check if two generic values are equal. 85 AreEqual, 86 /// Check if two ranges are equal. 87 AreRangesEqual, 88 /// Unconditional branch. 89 Branch, 90 /// Compare the operand count of an operation with a constant. 91 CheckOperandCount, 92 /// Compare the name of an operation with a constant. 93 CheckOperationName, 94 /// Compare the result count of an operation with a constant. 95 CheckResultCount, 96 /// Compare a range of types to a constant range of types. 97 CheckTypes, 98 /// Continue to the next iteration of a loop. 99 Continue, 100 /// Create an operation. 101 CreateOperation, 102 /// Create a range of types. 103 CreateTypes, 104 /// Erase an operation. 105 EraseOp, 106 /// Extract the op from a range at the specified index. 107 ExtractOp, 108 /// Extract the type from a range at the specified index. 109 ExtractType, 110 /// Extract the value from a range at the specified index. 111 ExtractValue, 112 /// Terminate a matcher or rewrite sequence. 113 Finalize, 114 /// Iterate over a range of values. 115 ForEach, 116 /// Get a specific attribute of an operation. 117 GetAttribute, 118 /// Get the type of an attribute. 119 GetAttributeType, 120 /// Get the defining operation of a value. 121 GetDefiningOp, 122 /// Get a specific operand of an operation. 123 GetOperand0, 124 GetOperand1, 125 GetOperand2, 126 GetOperand3, 127 GetOperandN, 128 /// Get a specific operand group of an operation. 129 GetOperands, 130 /// Get a specific result of an operation. 131 GetResult0, 132 GetResult1, 133 GetResult2, 134 GetResult3, 135 GetResultN, 136 /// Get a specific result group of an operation. 137 GetResults, 138 /// Get the users of a value or a range of values. 139 GetUsers, 140 /// Get the type of a value. 141 GetValueType, 142 /// Get the types of a value range. 143 GetValueRangeTypes, 144 /// Check if a generic value is not null. 145 IsNotNull, 146 /// Record a successful pattern match. 147 RecordMatch, 148 /// Replace an operation. 149 ReplaceOp, 150 /// Compare an attribute with a set of constants. 151 SwitchAttribute, 152 /// Compare the operand count of an operation with a set of constants. 153 SwitchOperandCount, 154 /// Compare the name of an operation with a set of constants. 155 SwitchOperationName, 156 /// Compare the result count of an operation with a set of constants. 157 SwitchResultCount, 158 /// Compare a type with a set of constants. 159 SwitchType, 160 /// Compare a range of types with a set of constants. 161 SwitchTypes, 162 }; 163 } // namespace 164 165 //===----------------------------------------------------------------------===// 166 // ByteCode Generation 167 //===----------------------------------------------------------------------===// 168 169 //===----------------------------------------------------------------------===// 170 // Generator 171 172 namespace { 173 struct ByteCodeLiveRange; 174 struct ByteCodeWriter; 175 176 /// Check if the given class `T` can be converted to an opaque pointer. 177 template <typename T, typename... Args> 178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 179 180 /// This class represents the main generator for the pattern bytecode. 181 class Generator { 182 public: 183 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 184 SmallVectorImpl<ByteCodeField> &matcherByteCode, 185 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 186 SmallVectorImpl<PDLByteCodePattern> &patterns, 187 ByteCodeField &maxValueMemoryIndex, 188 ByteCodeField &maxOpRangeMemoryIndex, 189 ByteCodeField &maxTypeRangeMemoryIndex, 190 ByteCodeField &maxValueRangeMemoryIndex, 191 ByteCodeField &maxLoopLevel, 192 llvm::StringMap<PDLConstraintFunction> &constraintFns, 193 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 194 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 195 rewriterByteCode(rewriterByteCode), patterns(patterns), 196 maxValueMemoryIndex(maxValueMemoryIndex), 197 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), 198 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 199 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), 200 maxLoopLevel(maxLoopLevel) { 201 for (const auto &it : llvm::enumerate(constraintFns)) 202 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 203 for (const auto &it : llvm::enumerate(rewriteFns)) 204 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 205 } 206 207 /// Generate the bytecode for the given PDL interpreter module. 208 void generate(ModuleOp module); 209 210 /// Return the memory index to use for the given value. 211 ByteCodeField &getMemIndex(Value value) { 212 assert(valueToMemIndex.count(value) && 213 "expected memory index to be assigned"); 214 return valueToMemIndex[value]; 215 } 216 217 /// Return the range memory index used to store the given range value. 218 ByteCodeField &getRangeStorageIndex(Value value) { 219 assert(valueToRangeIndex.count(value) && 220 "expected range index to be assigned"); 221 return valueToRangeIndex[value]; 222 } 223 224 /// Return an index to use when referring to the given data that is uniqued in 225 /// the MLIR context. 226 template <typename T> 227 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 228 getMemIndex(T val) { 229 const void *opaqueVal = val.getAsOpaquePointer(); 230 231 // Get or insert a reference to this value. 232 auto it = uniquedDataToMemIndex.try_emplace( 233 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 234 if (it.second) 235 uniquedData.push_back(opaqueVal); 236 return it.first->second; 237 } 238 239 private: 240 /// Allocate memory indices for the results of operations within the matcher 241 /// and rewriters. 242 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 243 ModuleOp rewriterModule); 244 245 /// Generate the bytecode for the given operation. 246 void generate(Region *region, ByteCodeWriter &writer); 247 void generate(Operation *op, ByteCodeWriter &writer); 248 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 249 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 250 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 251 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 252 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 253 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 254 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 255 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 256 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 257 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 258 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); 259 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 260 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 261 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 262 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 263 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 264 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); 265 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 266 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); 267 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 268 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 269 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 270 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 271 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 272 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 273 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 274 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); 275 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 276 void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); 277 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 278 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 279 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 280 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 281 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 282 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 283 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 284 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 285 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 286 287 /// Mapping from value to its corresponding memory index. 288 DenseMap<Value, ByteCodeField> valueToMemIndex; 289 290 /// Mapping from a range value to its corresponding range storage index. 291 DenseMap<Value, ByteCodeField> valueToRangeIndex; 292 293 /// Mapping from the name of an externally registered rewrite to its index in 294 /// the bytecode registry. 295 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 296 297 /// Mapping from the name of an externally registered constraint to its index 298 /// in the bytecode registry. 299 llvm::StringMap<ByteCodeField> constraintToMemIndex; 300 301 /// Mapping from rewriter function name to the bytecode address of the 302 /// rewriter function in byte. 303 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 304 305 /// Mapping from a uniqued storage object to its memory index within 306 /// `uniquedData`. 307 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 308 309 /// The current level of the foreach loop. 310 ByteCodeField curLoopLevel = 0; 311 312 /// The current MLIR context. 313 MLIRContext *ctx; 314 315 /// Mapping from block to its address. 316 DenseMap<Block *, ByteCodeAddr> blockToAddr; 317 318 /// Data of the ByteCode class to be populated. 319 std::vector<const void *> &uniquedData; 320 SmallVectorImpl<ByteCodeField> &matcherByteCode; 321 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 322 SmallVectorImpl<PDLByteCodePattern> &patterns; 323 ByteCodeField &maxValueMemoryIndex; 324 ByteCodeField &maxOpRangeMemoryIndex; 325 ByteCodeField &maxTypeRangeMemoryIndex; 326 ByteCodeField &maxValueRangeMemoryIndex; 327 ByteCodeField &maxLoopLevel; 328 }; 329 330 /// This class provides utilities for writing a bytecode stream. 331 struct ByteCodeWriter { 332 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 333 : bytecode(bytecode), generator(generator) {} 334 335 /// Append a field to the bytecode. 336 void append(ByteCodeField field) { bytecode.push_back(field); } 337 void append(OpCode opCode) { bytecode.push_back(opCode); } 338 339 /// Append an address to the bytecode. 340 void append(ByteCodeAddr field) { 341 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 342 "unexpected ByteCode address size"); 343 344 ByteCodeField fieldParts[2]; 345 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 346 bytecode.append({fieldParts[0], fieldParts[1]}); 347 } 348 349 /// Append a single successor to the bytecode, the exact address will need to 350 /// be resolved later. 351 void append(Block *successor) { 352 // Add back a reference to the successor so that the address can be resolved 353 // later. 354 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 355 append(ByteCodeAddr(0)); 356 } 357 358 /// Append a successor range to the bytecode, the exact address will need to 359 /// be resolved later. 360 void append(SuccessorRange successors) { 361 for (Block *successor : successors) 362 append(successor); 363 } 364 365 /// Append a range of values that will be read as generic PDLValues. 366 void appendPDLValueList(OperandRange values) { 367 bytecode.push_back(values.size()); 368 for (Value value : values) 369 appendPDLValue(value); 370 } 371 372 /// Append a value as a PDLValue. 373 void appendPDLValue(Value value) { 374 appendPDLValueKind(value); 375 append(value); 376 } 377 378 /// Append the PDLValue::Kind of the given value. 379 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } 380 381 /// Append the PDLValue::Kind of the given type. 382 void appendPDLValueKind(Type type) { 383 PDLValue::Kind kind = 384 TypeSwitch<Type, PDLValue::Kind>(type) 385 .Case<pdl::AttributeType>( 386 [](Type) { return PDLValue::Kind::Attribute; }) 387 .Case<pdl::OperationType>( 388 [](Type) { return PDLValue::Kind::Operation; }) 389 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 390 if (rangeTy.getElementType().isa<pdl::TypeType>()) 391 return PDLValue::Kind::TypeRange; 392 return PDLValue::Kind::ValueRange; 393 }) 394 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 395 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 396 bytecode.push_back(static_cast<ByteCodeField>(kind)); 397 } 398 399 /// Append a value that will be stored in a memory slot and not inline within 400 /// the bytecode. 401 template <typename T> 402 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 403 std::is_pointer<T>::value> 404 append(T value) { 405 bytecode.push_back(generator.getMemIndex(value)); 406 } 407 408 /// Append a range of values. 409 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 410 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 411 append(T range) { 412 bytecode.push_back(llvm::size(range)); 413 for (auto it : range) 414 append(it); 415 } 416 417 /// Append a variadic number of fields to the bytecode. 418 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 419 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 420 append(field); 421 append(field2, fields...); 422 } 423 424 /// Appends a value as a pointer, stored inline within the bytecode. 425 template <typename T> 426 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 427 appendInline(T value) { 428 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); 429 const void *pointer = value.getAsOpaquePointer(); 430 ByteCodeField fieldParts[numParts]; 431 std::memcpy(fieldParts, &pointer, sizeof(const void *)); 432 bytecode.append(fieldParts, fieldParts + numParts); 433 } 434 435 /// Successor references in the bytecode that have yet to be resolved. 436 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 437 438 /// The underlying bytecode buffer. 439 SmallVectorImpl<ByteCodeField> &bytecode; 440 441 /// The main generator producing PDL. 442 Generator &generator; 443 }; 444 445 /// This class represents a live range of PDL Interpreter values, containing 446 /// information about when values are live within a match/rewrite. 447 struct ByteCodeLiveRange { 448 using Set = llvm::IntervalMap<uint64_t, char, 16>; 449 using Allocator = Set::Allocator; 450 451 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} 452 453 /// Union this live range with the one provided. 454 void unionWith(const ByteCodeLiveRange &rhs) { 455 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; 456 ++it) 457 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); 458 } 459 460 /// Returns true if this range overlaps with the one provided. 461 bool overlaps(const ByteCodeLiveRange &rhs) const { 462 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness) 463 .valid(); 464 } 465 466 /// A map representing the ranges of the match/rewrite that a value is live in 467 /// the interpreter. 468 /// 469 /// We use std::unique_ptr here, because IntervalMap does not provide a 470 /// correct copy or move constructor. We can eliminate the pointer once 471 /// https://reviews.llvm.org/D113240 lands. 472 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness; 473 474 /// The operation range storage index for this range. 475 Optional<unsigned> opRangeIndex; 476 477 /// The type range storage index for this range. 478 Optional<unsigned> typeRangeIndex; 479 480 /// The value range storage index for this range. 481 Optional<unsigned> valueRangeIndex; 482 }; 483 } // namespace 484 485 void Generator::generate(ModuleOp module) { 486 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>( 487 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 488 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 489 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 490 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 491 492 // Allocate memory indices for the results of operations within the matcher 493 // and rewriters. 494 allocateMemoryIndices(matcherFunc, rewriterModule); 495 496 // Generate code for the rewriter functions. 497 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 498 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 499 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 500 for (Operation &op : rewriterFunc.getOps()) 501 generate(&op, rewriterByteCodeWriter); 502 } 503 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 504 "unexpected branches in rewriter function"); 505 506 // Generate code for the matcher function. 507 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 508 generate(&matcherFunc.getBody(), matcherByteCodeWriter); 509 510 // Resolve successor references in the matcher. 511 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 512 ByteCodeAddr addr = blockToAddr[it.first]; 513 for (unsigned offsetToFix : it.second) 514 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 515 } 516 } 517 518 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 519 ModuleOp rewriterModule) { 520 // Rewriters use simplistic allocation scheme that simply assigns an index to 521 // each result. 522 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 523 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 524 auto processRewriterValue = [&](Value val) { 525 valueToMemIndex.try_emplace(val, index++); 526 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 527 Type elementTy = rangeType.getElementType(); 528 if (elementTy.isa<pdl::TypeType>()) 529 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 530 else if (elementTy.isa<pdl::ValueType>()) 531 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 532 } 533 }; 534 535 for (BlockArgument arg : rewriterFunc.getArguments()) 536 processRewriterValue(arg); 537 rewriterFunc.getBody().walk([&](Operation *op) { 538 for (Value result : op->getResults()) 539 processRewriterValue(result); 540 }); 541 if (index > maxValueMemoryIndex) 542 maxValueMemoryIndex = index; 543 if (typeRangeIndex > maxTypeRangeMemoryIndex) 544 maxTypeRangeMemoryIndex = typeRangeIndex; 545 if (valueRangeIndex > maxValueRangeMemoryIndex) 546 maxValueRangeMemoryIndex = valueRangeIndex; 547 } 548 549 // The matcher function uses a more sophisticated numbering that tries to 550 // minimize the number of memory indices assigned. This is done by determining 551 // a live range of the values within the matcher, then the allocation is just 552 // finding the minimal number of overlapping live ranges. This is essentially 553 // a simplified form of register allocation where we don't necessarily have a 554 // limited number of registers, but we still want to minimize the number used. 555 DenseMap<Operation *, unsigned> opToFirstIndex; 556 DenseMap<Operation *, unsigned> opToLastIndex; 557 558 // A custom walk that marks the first and the last index of each operation. 559 // The entry marks the beginning of the liveness range for this operation, 560 // followed by nested operations, followed by the end of the liveness range. 561 unsigned index = 0; 562 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) { 563 opToFirstIndex.try_emplace(op, index++); 564 for (Region ®ion : op->getRegions()) 565 for (Block &block : region.getBlocks()) 566 for (Operation &nested : block) 567 walk(&nested); 568 opToLastIndex.try_emplace(op, index++); 569 }; 570 walk(matcherFunc); 571 572 // Liveness info for each of the defs within the matcher. 573 ByteCodeLiveRange::Allocator allocator; 574 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 575 576 // Assign the root operation being matched to slot 0. 577 BlockArgument rootOpArg = matcherFunc.getArgument(0); 578 valueToMemIndex[rootOpArg] = 0; 579 580 // Walk each of the blocks, computing the def interval that the value is used. 581 Liveness matcherLiveness(matcherFunc); 582 matcherFunc->walk([&](Block *block) { 583 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); 584 assert(info && "expected liveness info for block"); 585 auto processValue = [&](Value value, Operation *firstUseOrDef) { 586 // We don't need to process the root op argument, this value is always 587 // assigned to the first memory slot. 588 if (value == rootOpArg) 589 return; 590 591 // Set indices for the range of this block that the value is used. 592 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 593 defRangeIt->second.liveness->insert( 594 opToFirstIndex[firstUseOrDef], 595 opToLastIndex[info->getEndOperation(value, firstUseOrDef)], 596 /*dummyValue*/ 0); 597 598 // Check to see if this value is a range type. 599 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 600 Type eleType = rangeTy.getElementType(); 601 if (eleType.isa<pdl::OperationType>()) 602 defRangeIt->second.opRangeIndex = 0; 603 else if (eleType.isa<pdl::TypeType>()) 604 defRangeIt->second.typeRangeIndex = 0; 605 else if (eleType.isa<pdl::ValueType>()) 606 defRangeIt->second.valueRangeIndex = 0; 607 } 608 }; 609 610 // Process the live-ins of this block. 611 for (Value liveIn : info->in()) { 612 // Only process the value if it has been defined in the current region. 613 // Other values that span across pdl_interp.foreach will be added higher 614 // up. This ensures that the we keep them alive for the entire duration 615 // of the loop. 616 if (liveIn.getParentRegion() == block->getParent()) 617 processValue(liveIn, &block->front()); 618 } 619 620 // Process the block arguments for the entry block (those are not live-in). 621 if (block->isEntryBlock()) { 622 for (Value argument : block->getArguments()) 623 processValue(argument, &block->front()); 624 } 625 626 // Process any new defs within this block. 627 for (Operation &op : *block) 628 for (Value result : op.getResults()) 629 processValue(result, &op); 630 }); 631 632 // Greedily allocate memory slots using the computed def live ranges. 633 std::vector<ByteCodeLiveRange> allocatedIndices; 634 635 // The number of memory indices currently allocated (and its next value). 636 // Recall that the root gets allocated memory index 0. 637 ByteCodeField numIndices = 1; 638 639 // The number of memory ranges of various types (and their next values). 640 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; 641 642 for (auto &defIt : valueDefRanges) { 643 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 644 ByteCodeLiveRange &defRange = defIt.second; 645 646 // Try to allocate to an existing index. 647 for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) { 648 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 649 if (!defRange.overlaps(existingRange)) { 650 existingRange.unionWith(defRange); 651 memIndex = existingIndexIt.index() + 1; 652 653 if (defRange.opRangeIndex) { 654 if (!existingRange.opRangeIndex) 655 existingRange.opRangeIndex = numOpRanges++; 656 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; 657 } else if (defRange.typeRangeIndex) { 658 if (!existingRange.typeRangeIndex) 659 existingRange.typeRangeIndex = numTypeRanges++; 660 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 661 } else if (defRange.valueRangeIndex) { 662 if (!existingRange.valueRangeIndex) 663 existingRange.valueRangeIndex = numValueRanges++; 664 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 665 } 666 break; 667 } 668 } 669 670 // If no existing index could be used, add a new one. 671 if (memIndex == 0) { 672 allocatedIndices.emplace_back(allocator); 673 ByteCodeLiveRange &newRange = allocatedIndices.back(); 674 newRange.unionWith(defRange); 675 676 // Allocate an index for op/type/value ranges. 677 if (defRange.opRangeIndex) { 678 newRange.opRangeIndex = numOpRanges; 679 valueToRangeIndex[defIt.first] = numOpRanges++; 680 } else if (defRange.typeRangeIndex) { 681 newRange.typeRangeIndex = numTypeRanges; 682 valueToRangeIndex[defIt.first] = numTypeRanges++; 683 } else if (defRange.valueRangeIndex) { 684 newRange.valueRangeIndex = numValueRanges; 685 valueToRangeIndex[defIt.first] = numValueRanges++; 686 } 687 688 memIndex = allocatedIndices.size(); 689 ++numIndices; 690 } 691 } 692 693 // Print the index usage and ensure that we did not run out of index space. 694 LLVM_DEBUG({ 695 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " 696 << "(down from initial " << valueDefRanges.size() << ").\n"; 697 }); 698 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && 699 "Ran out of memory for allocated indices"); 700 701 // Update the max number of indices. 702 if (numIndices > maxValueMemoryIndex) 703 maxValueMemoryIndex = numIndices; 704 if (numOpRanges > maxOpRangeMemoryIndex) 705 maxOpRangeMemoryIndex = numOpRanges; 706 if (numTypeRanges > maxTypeRangeMemoryIndex) 707 maxTypeRangeMemoryIndex = numTypeRanges; 708 if (numValueRanges > maxValueRangeMemoryIndex) 709 maxValueRangeMemoryIndex = numValueRanges; 710 } 711 712 void Generator::generate(Region *region, ByteCodeWriter &writer) { 713 llvm::ReversePostOrderTraversal<Region *> rpot(region); 714 for (Block *block : rpot) { 715 // Keep track of where this block begins within the matcher function. 716 blockToAddr.try_emplace(block, matcherByteCode.size()); 717 for (Operation &op : *block) 718 generate(&op, writer); 719 } 720 } 721 722 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 723 LLVM_DEBUG({ 724 // The following list must contain all the operations that do not 725 // produce any bytecode. 726 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp, 727 pdl_interp::InferredTypesOp>(op)) 728 writer.appendInline(op->getLoc()); 729 }); 730 TypeSwitch<Operation *>(op) 731 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 732 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 733 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 734 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 735 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 736 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, 737 pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, 738 pdl_interp::CreateTypesOp, pdl_interp::EraseOp, 739 pdl_interp::ExtractOp, pdl_interp::FinalizeOp, 740 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, 741 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 742 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, 743 pdl_interp::GetResultOp, pdl_interp::GetResultsOp, 744 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, 745 pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, 746 pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, 747 pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, 748 pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, 749 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 750 [&](auto interpOp) { this->generate(interpOp, writer); }) 751 .Default([](Operation *) { 752 llvm_unreachable("unknown `pdl_interp` operation"); 753 }); 754 } 755 756 void Generator::generate(pdl_interp::ApplyConstraintOp op, 757 ByteCodeWriter &writer) { 758 assert(constraintToMemIndex.count(op.getName()) && 759 "expected index for constraint function"); 760 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()], 761 op.getConstParamsAttr()); 762 writer.appendPDLValueList(op.getArgs()); 763 writer.append(op.getSuccessors()); 764 } 765 void Generator::generate(pdl_interp::ApplyRewriteOp op, 766 ByteCodeWriter &writer) { 767 assert(externalRewriterToMemIndex.count(op.getName()) && 768 "expected index for rewrite function"); 769 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()], 770 op.getConstParamsAttr()); 771 writer.appendPDLValueList(op.getArgs()); 772 773 ResultRange results = op.getResults(); 774 writer.append(ByteCodeField(results.size())); 775 for (Value result : results) { 776 // In debug mode we also record the expected kind of the result, so that we 777 // can provide extra verification of the native rewrite function. 778 #ifndef NDEBUG 779 writer.appendPDLValueKind(result); 780 #endif 781 782 // Range results also need to append the range storage index. 783 if (result.getType().isa<pdl::RangeType>()) 784 writer.append(getRangeStorageIndex(result)); 785 writer.append(result); 786 } 787 } 788 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 789 Value lhs = op.getLhs(); 790 if (lhs.getType().isa<pdl::RangeType>()) { 791 writer.append(OpCode::AreRangesEqual); 792 writer.appendPDLValueKind(lhs); 793 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors()); 794 return; 795 } 796 797 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors()); 798 } 799 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 800 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 801 } 802 void Generator::generate(pdl_interp::CheckAttributeOp op, 803 ByteCodeWriter &writer) { 804 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(), 805 op.getSuccessors()); 806 } 807 void Generator::generate(pdl_interp::CheckOperandCountOp op, 808 ByteCodeWriter &writer) { 809 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(), 810 static_cast<ByteCodeField>(op.getCompareAtLeast()), 811 op.getSuccessors()); 812 } 813 void Generator::generate(pdl_interp::CheckOperationNameOp op, 814 ByteCodeWriter &writer) { 815 writer.append(OpCode::CheckOperationName, op.getInputOp(), 816 OperationName(op.getName(), ctx), op.getSuccessors()); 817 } 818 void Generator::generate(pdl_interp::CheckResultCountOp op, 819 ByteCodeWriter &writer) { 820 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(), 821 static_cast<ByteCodeField>(op.getCompareAtLeast()), 822 op.getSuccessors()); 823 } 824 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 825 writer.append(OpCode::AreEqual, op.getValue(), op.getType(), 826 op.getSuccessors()); 827 } 828 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 829 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(), 830 op.getSuccessors()); 831 } 832 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { 833 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); 834 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); 835 } 836 void Generator::generate(pdl_interp::CreateAttributeOp op, 837 ByteCodeWriter &writer) { 838 // Simply repoint the memory index of the result to the constant. 839 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue()); 840 } 841 void Generator::generate(pdl_interp::CreateOperationOp op, 842 ByteCodeWriter &writer) { 843 writer.append(OpCode::CreateOperation, op.getResultOp(), 844 OperationName(op.getName(), ctx)); 845 writer.appendPDLValueList(op.getInputOperands()); 846 847 // Add the attributes. 848 OperandRange attributes = op.getInputAttributes(); 849 writer.append(static_cast<ByteCodeField>(attributes.size())); 850 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes)) 851 writer.append(std::get<0>(it), std::get<1>(it)); 852 writer.appendPDLValueList(op.getInputResultTypes()); 853 } 854 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 855 // Simply repoint the memory index of the result to the constant. 856 getMemIndex(op.getResult()) = getMemIndex(op.getValue()); 857 } 858 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 859 writer.append(OpCode::CreateTypes, op.getResult(), 860 getRangeStorageIndex(op.getResult()), op.getValue()); 861 } 862 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 863 writer.append(OpCode::EraseOp, op.getInputOp()); 864 } 865 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { 866 OpCode opCode = 867 TypeSwitch<Type, OpCode>(op.getResult().getType()) 868 .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) 869 .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) 870 .Case([](pdl::TypeType) { return OpCode::ExtractType; }) 871 .Default([](Type) -> OpCode { 872 llvm_unreachable("unsupported element type"); 873 }); 874 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult()); 875 } 876 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 877 writer.append(OpCode::Finalize); 878 } 879 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { 880 BlockArgument arg = op.getLoopVariable(); 881 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg); 882 writer.appendPDLValueKind(arg.getType()); 883 writer.append(curLoopLevel, op.getSuccessor()); 884 ++curLoopLevel; 885 if (curLoopLevel > maxLoopLevel) 886 maxLoopLevel = curLoopLevel; 887 generate(&op.getRegion(), writer); 888 --curLoopLevel; 889 } 890 void Generator::generate(pdl_interp::GetAttributeOp op, 891 ByteCodeWriter &writer) { 892 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(), 893 op.getNameAttr()); 894 } 895 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 896 ByteCodeWriter &writer) { 897 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue()); 898 } 899 void Generator::generate(pdl_interp::GetDefiningOpOp op, 900 ByteCodeWriter &writer) { 901 writer.append(OpCode::GetDefiningOp, op.getInputOp()); 902 writer.appendPDLValue(op.getValue()); 903 } 904 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 905 uint32_t index = op.getIndex(); 906 if (index < 4) 907 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 908 else 909 writer.append(OpCode::GetOperandN, index); 910 writer.append(op.getInputOp(), op.getValue()); 911 } 912 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 913 Value result = op.getValue(); 914 Optional<uint32_t> index = op.getIndex(); 915 writer.append(OpCode::GetOperands, 916 index.getValueOr(std::numeric_limits<uint32_t>::max()), 917 op.getInputOp()); 918 if (result.getType().isa<pdl::RangeType>()) 919 writer.append(getRangeStorageIndex(result)); 920 else 921 writer.append(std::numeric_limits<ByteCodeField>::max()); 922 writer.append(result); 923 } 924 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 925 uint32_t index = op.getIndex(); 926 if (index < 4) 927 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 928 else 929 writer.append(OpCode::GetResultN, index); 930 writer.append(op.getInputOp(), op.getValue()); 931 } 932 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 933 Value result = op.getValue(); 934 Optional<uint32_t> index = op.getIndex(); 935 writer.append(OpCode::GetResults, 936 index.getValueOr(std::numeric_limits<uint32_t>::max()), 937 op.getInputOp()); 938 if (result.getType().isa<pdl::RangeType>()) 939 writer.append(getRangeStorageIndex(result)); 940 else 941 writer.append(std::numeric_limits<ByteCodeField>::max()); 942 writer.append(result); 943 } 944 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { 945 Value operations = op.getOperations(); 946 ByteCodeField rangeIndex = getRangeStorageIndex(operations); 947 writer.append(OpCode::GetUsers, operations, rangeIndex); 948 writer.appendPDLValue(op.getValue()); 949 } 950 void Generator::generate(pdl_interp::GetValueTypeOp op, 951 ByteCodeWriter &writer) { 952 if (op.getType().isa<pdl::RangeType>()) { 953 Value result = op.getResult(); 954 writer.append(OpCode::GetValueRangeTypes, result, 955 getRangeStorageIndex(result), op.getValue()); 956 } else { 957 writer.append(OpCode::GetValueType, op.getResult(), op.getValue()); 958 } 959 } 960 961 void Generator::generate(pdl_interp::InferredTypesOp op, 962 ByteCodeWriter &writer) { 963 // InferType maps to a null type as a marker for inferring result types. 964 getMemIndex(op.getResult()) = getMemIndex(Type()); 965 } 966 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 967 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors()); 968 } 969 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 970 ByteCodeField patternIndex = patterns.size(); 971 patterns.emplace_back(PDLByteCodePattern::create( 972 op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); 973 writer.append(OpCode::RecordMatch, patternIndex, 974 SuccessorRange(op.getOperation()), op.getMatchedOps()); 975 writer.appendPDLValueList(op.getInputs()); 976 } 977 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 978 writer.append(OpCode::ReplaceOp, op.getInputOp()); 979 writer.appendPDLValueList(op.getReplValues()); 980 } 981 void Generator::generate(pdl_interp::SwitchAttributeOp op, 982 ByteCodeWriter &writer) { 983 writer.append(OpCode::SwitchAttribute, op.getAttribute(), 984 op.getCaseValuesAttr(), op.getSuccessors()); 985 } 986 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 987 ByteCodeWriter &writer) { 988 writer.append(OpCode::SwitchOperandCount, op.getInputOp(), 989 op.getCaseValuesAttr(), op.getSuccessors()); 990 } 991 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 992 ByteCodeWriter &writer) { 993 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) { 994 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 995 }); 996 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases, 997 op.getSuccessors()); 998 } 999 void Generator::generate(pdl_interp::SwitchResultCountOp op, 1000 ByteCodeWriter &writer) { 1001 writer.append(OpCode::SwitchResultCount, op.getInputOp(), 1002 op.getCaseValuesAttr(), op.getSuccessors()); 1003 } 1004 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 1005 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(), 1006 op.getSuccessors()); 1007 } 1008 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 1009 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(), 1010 op.getSuccessors()); 1011 } 1012 1013 //===----------------------------------------------------------------------===// 1014 // PDLByteCode 1015 //===----------------------------------------------------------------------===// 1016 1017 PDLByteCode::PDLByteCode(ModuleOp module, 1018 llvm::StringMap<PDLConstraintFunction> constraintFns, 1019 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 1020 Generator generator(module.getContext(), uniquedData, matcherByteCode, 1021 rewriterByteCode, patterns, maxValueMemoryIndex, 1022 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, 1023 maxLoopLevel, constraintFns, rewriteFns); 1024 generator.generate(module); 1025 1026 // Initialize the external functions. 1027 for (auto &it : constraintFns) 1028 constraintFunctions.push_back(std::move(it.second)); 1029 for (auto &it : rewriteFns) 1030 rewriteFunctions.push_back(std::move(it.second)); 1031 } 1032 1033 /// Initialize the given state such that it can be used to execute the current 1034 /// bytecode. 1035 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 1036 state.memory.resize(maxValueMemoryIndex, nullptr); 1037 state.opRangeMemory.resize(maxOpRangeCount); 1038 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 1039 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 1040 state.loopIndex.resize(maxLoopLevel, 0); 1041 state.currentPatternBenefits.reserve(patterns.size()); 1042 for (const PDLByteCodePattern &pattern : patterns) 1043 state.currentPatternBenefits.push_back(pattern.getBenefit()); 1044 } 1045 1046 //===----------------------------------------------------------------------===// 1047 // ByteCode Execution 1048 1049 namespace { 1050 /// This class provides support for executing a bytecode stream. 1051 class ByteCodeExecutor { 1052 public: 1053 ByteCodeExecutor( 1054 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 1055 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, 1056 MutableArrayRef<TypeRange> typeRangeMemory, 1057 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 1058 MutableArrayRef<ValueRange> valueRangeMemory, 1059 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 1060 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, 1061 ArrayRef<ByteCodeField> code, 1062 ArrayRef<PatternBenefit> currentPatternBenefits, 1063 ArrayRef<PDLByteCodePattern> patterns, 1064 ArrayRef<PDLConstraintFunction> constraintFunctions, 1065 ArrayRef<PDLRewriteFunction> rewriteFunctions) 1066 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), 1067 typeRangeMemory(typeRangeMemory), 1068 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 1069 valueRangeMemory(valueRangeMemory), 1070 allocatedValueRangeMemory(allocatedValueRangeMemory), 1071 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), 1072 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 1073 constraintFunctions(constraintFunctions), 1074 rewriteFunctions(rewriteFunctions) {} 1075 1076 /// Start executing the code at the current bytecode index. `matches` is an 1077 /// optional field provided when this function is executed in a matching 1078 /// context. 1079 void execute(PatternRewriter &rewriter, 1080 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 1081 Optional<Location> mainRewriteLoc = {}); 1082 1083 private: 1084 /// Internal implementation of executing each of the bytecode commands. 1085 void executeApplyConstraint(PatternRewriter &rewriter); 1086 void executeApplyRewrite(PatternRewriter &rewriter); 1087 void executeAreEqual(); 1088 void executeAreRangesEqual(); 1089 void executeBranch(); 1090 void executeCheckOperandCount(); 1091 void executeCheckOperationName(); 1092 void executeCheckResultCount(); 1093 void executeCheckTypes(); 1094 void executeContinue(); 1095 void executeCreateOperation(PatternRewriter &rewriter, 1096 Location mainRewriteLoc); 1097 void executeCreateTypes(); 1098 void executeEraseOp(PatternRewriter &rewriter); 1099 template <typename T, typename Range, PDLValue::Kind kind> 1100 void executeExtract(); 1101 void executeFinalize(); 1102 void executeForEach(); 1103 void executeGetAttribute(); 1104 void executeGetAttributeType(); 1105 void executeGetDefiningOp(); 1106 void executeGetOperand(unsigned index); 1107 void executeGetOperands(); 1108 void executeGetResult(unsigned index); 1109 void executeGetResults(); 1110 void executeGetUsers(); 1111 void executeGetValueType(); 1112 void executeGetValueRangeTypes(); 1113 void executeIsNotNull(); 1114 void executeRecordMatch(PatternRewriter &rewriter, 1115 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 1116 void executeReplaceOp(PatternRewriter &rewriter); 1117 void executeSwitchAttribute(); 1118 void executeSwitchOperandCount(); 1119 void executeSwitchOperationName(); 1120 void executeSwitchResultCount(); 1121 void executeSwitchType(); 1122 void executeSwitchTypes(); 1123 1124 /// Pushes a code iterator to the stack. 1125 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } 1126 1127 /// Pops a code iterator from the stack, returning true on success. 1128 void popCodeIt() { 1129 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); 1130 curCodeIt = resumeCodeIt.back(); 1131 resumeCodeIt.pop_back(); 1132 } 1133 1134 /// Return the bytecode iterator at the start of the current op code. 1135 const ByteCodeField *getPrevCodeIt() const { 1136 LLVM_DEBUG({ 1137 // Account for the op code and the Location stored inline. 1138 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField); 1139 }); 1140 1141 // Account for the op code only. 1142 return curCodeIt - 1; 1143 } 1144 1145 /// Read a value from the bytecode buffer, optionally skipping a certain 1146 /// number of prefix values. These methods always update the buffer to point 1147 /// to the next field after the read data. 1148 template <typename T = ByteCodeField> 1149 T read(size_t skipN = 0) { 1150 curCodeIt += skipN; 1151 return readImpl<T>(); 1152 } 1153 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 1154 1155 /// Read a list of values from the bytecode buffer. 1156 template <typename ValueT, typename T> 1157 void readList(SmallVectorImpl<T> &list) { 1158 list.clear(); 1159 for (unsigned i = 0, e = read(); i != e; ++i) 1160 list.push_back(read<ValueT>()); 1161 } 1162 1163 /// Read a list of values from the bytecode buffer. The values may be encoded 1164 /// as either Value or ValueRange elements. 1165 void readValueList(SmallVectorImpl<Value> &list) { 1166 for (unsigned i = 0, e = read(); i != e; ++i) { 1167 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1168 list.push_back(read<Value>()); 1169 } else { 1170 ValueRange *values = read<ValueRange *>(); 1171 list.append(values->begin(), values->end()); 1172 } 1173 } 1174 } 1175 1176 /// Read a value stored inline as a pointer. 1177 template <typename T> 1178 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T> 1179 readInline() { 1180 const void *pointer; 1181 std::memcpy(&pointer, curCodeIt, sizeof(const void *)); 1182 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); 1183 return T::getFromOpaquePointer(pointer); 1184 } 1185 1186 /// Jump to a specific successor based on a predicate value. 1187 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 1188 /// Jump to a specific successor based on a destination index. 1189 void selectJump(size_t destIndex) { 1190 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 1191 } 1192 1193 /// Handle a switch operation with the provided value and cases. 1194 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1195 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1196 LLVM_DEBUG({ 1197 llvm::dbgs() << " * Value: " << value << "\n" 1198 << " * Cases: "; 1199 llvm::interleaveComma(cases, llvm::dbgs()); 1200 llvm::dbgs() << "\n"; 1201 }); 1202 1203 // Check to see if the attribute value is within the case list. Jump to 1204 // the correct successor index based on the result. 1205 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1206 if (cmp(*it, value)) 1207 return selectJump(size_t((it - cases.begin()) + 1)); 1208 selectJump(size_t(0)); 1209 } 1210 1211 /// Store a pointer to memory. 1212 void storeToMemory(unsigned index, const void *value) { 1213 memory[index] = value; 1214 } 1215 1216 /// Store a value to memory as an opaque pointer. 1217 template <typename T> 1218 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 1219 storeToMemory(unsigned index, T value) { 1220 memory[index] = value.getAsOpaquePointer(); 1221 } 1222 1223 /// Internal implementation of reading various data types from the bytecode 1224 /// stream. 1225 template <typename T> 1226 const void *readFromMemory() { 1227 size_t index = *curCodeIt++; 1228 1229 // If this type is an SSA value, it can only be stored in non-const memory. 1230 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1231 Value>::value || 1232 index < memory.size()) 1233 return memory[index]; 1234 1235 // Otherwise, if this index is not inbounds it is uniqued. 1236 return uniquedMemory[index - memory.size()]; 1237 } 1238 template <typename T> 1239 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1240 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1241 } 1242 template <typename T> 1243 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1244 T> 1245 readImpl() { 1246 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1247 } 1248 template <typename T> 1249 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1250 switch (read<PDLValue::Kind>()) { 1251 case PDLValue::Kind::Attribute: 1252 return read<Attribute>(); 1253 case PDLValue::Kind::Operation: 1254 return read<Operation *>(); 1255 case PDLValue::Kind::Type: 1256 return read<Type>(); 1257 case PDLValue::Kind::Value: 1258 return read<Value>(); 1259 case PDLValue::Kind::TypeRange: 1260 return read<TypeRange *>(); 1261 case PDLValue::Kind::ValueRange: 1262 return read<ValueRange *>(); 1263 } 1264 llvm_unreachable("unhandled PDLValue::Kind"); 1265 } 1266 template <typename T> 1267 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1268 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1269 "unexpected ByteCode address size"); 1270 ByteCodeAddr result; 1271 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1272 curCodeIt += 2; 1273 return result; 1274 } 1275 template <typename T> 1276 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1277 return *curCodeIt++; 1278 } 1279 template <typename T> 1280 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1281 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1282 } 1283 1284 /// The underlying bytecode buffer. 1285 const ByteCodeField *curCodeIt; 1286 1287 /// The stack of bytecode positions at which to resume operation. 1288 SmallVector<const ByteCodeField *> resumeCodeIt; 1289 1290 /// The current execution memory. 1291 MutableArrayRef<const void *> memory; 1292 MutableArrayRef<OwningOpRange> opRangeMemory; 1293 MutableArrayRef<TypeRange> typeRangeMemory; 1294 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1295 MutableArrayRef<ValueRange> valueRangeMemory; 1296 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1297 1298 /// The current loop indices. 1299 MutableArrayRef<unsigned> loopIndex; 1300 1301 /// References to ByteCode data necessary for execution. 1302 ArrayRef<const void *> uniquedMemory; 1303 ArrayRef<ByteCodeField> code; 1304 ArrayRef<PatternBenefit> currentPatternBenefits; 1305 ArrayRef<PDLByteCodePattern> patterns; 1306 ArrayRef<PDLConstraintFunction> constraintFunctions; 1307 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1308 }; 1309 1310 /// This class is an instantiation of the PDLResultList that provides access to 1311 /// the returned results. This API is not on `PDLResultList` to avoid 1312 /// overexposing access to information specific solely to the ByteCode. 1313 class ByteCodeRewriteResultList : public PDLResultList { 1314 public: 1315 ByteCodeRewriteResultList(unsigned maxNumResults) 1316 : PDLResultList(maxNumResults) {} 1317 1318 /// Return the list of PDL results. 1319 MutableArrayRef<PDLValue> getResults() { return results; } 1320 1321 /// Return the type ranges allocated by this list. 1322 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1323 return allocatedTypeRanges; 1324 } 1325 1326 /// Return the value ranges allocated by this list. 1327 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1328 return allocatedValueRanges; 1329 } 1330 }; 1331 } // namespace 1332 1333 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1334 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1335 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1336 ArrayAttr constParams = read<ArrayAttr>(); 1337 SmallVector<PDLValue, 16> args; 1338 readList<PDLValue>(args); 1339 1340 LLVM_DEBUG({ 1341 llvm::dbgs() << " * Arguments: "; 1342 llvm::interleaveComma(args, llvm::dbgs()); 1343 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1344 }); 1345 1346 // Invoke the constraint and jump to the proper destination. 1347 selectJump(succeeded(constraintFn(args, constParams, rewriter))); 1348 } 1349 1350 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1351 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1352 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1353 ArrayAttr constParams = read<ArrayAttr>(); 1354 SmallVector<PDLValue, 16> args; 1355 readList<PDLValue>(args); 1356 1357 LLVM_DEBUG({ 1358 llvm::dbgs() << " * Arguments: "; 1359 llvm::interleaveComma(args, llvm::dbgs()); 1360 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1361 }); 1362 1363 // Execute the rewrite function. 1364 ByteCodeField numResults = read(); 1365 ByteCodeRewriteResultList results(numResults); 1366 rewriteFn(args, constParams, rewriter, results); 1367 1368 assert(results.getResults().size() == numResults && 1369 "native PDL rewrite function returned unexpected number of results"); 1370 1371 // Store the results in the bytecode memory. 1372 for (PDLValue &result : results.getResults()) { 1373 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1374 1375 // In debug mode we also verify the expected kind of the result. 1376 #ifndef NDEBUG 1377 assert(result.getKind() == read<PDLValue::Kind>() && 1378 "native PDL rewrite function returned an unexpected type of result"); 1379 #endif 1380 1381 // If the result is a range, we need to copy it over to the bytecodes 1382 // range memory. 1383 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1384 unsigned rangeIndex = read(); 1385 typeRangeMemory[rangeIndex] = *typeRange; 1386 memory[read()] = &typeRangeMemory[rangeIndex]; 1387 } else if (Optional<ValueRange> valueRange = 1388 result.dyn_cast<ValueRange>()) { 1389 unsigned rangeIndex = read(); 1390 valueRangeMemory[rangeIndex] = *valueRange; 1391 memory[read()] = &valueRangeMemory[rangeIndex]; 1392 } else { 1393 memory[read()] = result.getAsOpaquePointer(); 1394 } 1395 } 1396 1397 // Copy over any underlying storage allocated for result ranges. 1398 for (auto &it : results.getAllocatedTypeRanges()) 1399 allocatedTypeRangeMemory.push_back(std::move(it)); 1400 for (auto &it : results.getAllocatedValueRanges()) 1401 allocatedValueRangeMemory.push_back(std::move(it)); 1402 } 1403 1404 void ByteCodeExecutor::executeAreEqual() { 1405 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1406 const void *lhs = read<const void *>(); 1407 const void *rhs = read<const void *>(); 1408 1409 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1410 selectJump(lhs == rhs); 1411 } 1412 1413 void ByteCodeExecutor::executeAreRangesEqual() { 1414 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1415 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1416 const void *lhs = read<const void *>(); 1417 const void *rhs = read<const void *>(); 1418 1419 switch (valueKind) { 1420 case PDLValue::Kind::TypeRange: { 1421 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1422 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1423 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1424 selectJump(*lhsRange == *rhsRange); 1425 break; 1426 } 1427 case PDLValue::Kind::ValueRange: { 1428 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1429 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1430 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1431 selectJump(*lhsRange == *rhsRange); 1432 break; 1433 } 1434 default: 1435 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1436 } 1437 } 1438 1439 void ByteCodeExecutor::executeBranch() { 1440 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1441 curCodeIt = &code[read<ByteCodeAddr>()]; 1442 } 1443 1444 void ByteCodeExecutor::executeCheckOperandCount() { 1445 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1446 Operation *op = read<Operation *>(); 1447 uint32_t expectedCount = read<uint32_t>(); 1448 bool compareAtLeast = read(); 1449 1450 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1451 << " * Expected: " << expectedCount << "\n" 1452 << " * Comparator: " 1453 << (compareAtLeast ? ">=" : "==") << "\n"); 1454 if (compareAtLeast) 1455 selectJump(op->getNumOperands() >= expectedCount); 1456 else 1457 selectJump(op->getNumOperands() == expectedCount); 1458 } 1459 1460 void ByteCodeExecutor::executeCheckOperationName() { 1461 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1462 Operation *op = read<Operation *>(); 1463 OperationName expectedName = read<OperationName>(); 1464 1465 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1466 << " * Expected: \"" << expectedName << "\"\n"); 1467 selectJump(op->getName() == expectedName); 1468 } 1469 1470 void ByteCodeExecutor::executeCheckResultCount() { 1471 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1472 Operation *op = read<Operation *>(); 1473 uint32_t expectedCount = read<uint32_t>(); 1474 bool compareAtLeast = read(); 1475 1476 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1477 << " * Expected: " << expectedCount << "\n" 1478 << " * Comparator: " 1479 << (compareAtLeast ? ">=" : "==") << "\n"); 1480 if (compareAtLeast) 1481 selectJump(op->getNumResults() >= expectedCount); 1482 else 1483 selectJump(op->getNumResults() == expectedCount); 1484 } 1485 1486 void ByteCodeExecutor::executeCheckTypes() { 1487 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1488 TypeRange *lhs = read<TypeRange *>(); 1489 Attribute rhs = read<Attribute>(); 1490 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1491 1492 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1493 } 1494 1495 void ByteCodeExecutor::executeContinue() { 1496 ByteCodeField level = read(); 1497 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" 1498 << " * Level: " << level << "\n"); 1499 ++loopIndex[level]; 1500 popCodeIt(); 1501 } 1502 1503 void ByteCodeExecutor::executeCreateTypes() { 1504 LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 1505 unsigned memIndex = read(); 1506 unsigned rangeIndex = read(); 1507 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1508 1509 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1510 1511 // Allocate a buffer for this type range. 1512 llvm::OwningArrayRef<Type> storage(typesAttr.size()); 1513 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 1514 allocatedTypeRangeMemory.emplace_back(std::move(storage)); 1515 1516 // Assign this to the range slot and use the range as the value for the 1517 // memory index. 1518 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 1519 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1520 } 1521 1522 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1523 Location mainRewriteLoc) { 1524 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1525 1526 unsigned memIndex = read(); 1527 OperationState state(mainRewriteLoc, read<OperationName>()); 1528 readValueList(state.operands); 1529 for (unsigned i = 0, e = read(); i != e; ++i) { 1530 StringAttr name = read<StringAttr>(); 1531 if (Attribute attr = read<Attribute>()) 1532 state.addAttribute(name, attr); 1533 } 1534 1535 for (unsigned i = 0, e = read(); i != e; ++i) { 1536 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1537 state.types.push_back(read<Type>()); 1538 continue; 1539 } 1540 1541 // If we find a null range, this signals that the types are infered. 1542 if (TypeRange *resultTypes = read<TypeRange *>()) { 1543 state.types.append(resultTypes->begin(), resultTypes->end()); 1544 continue; 1545 } 1546 1547 // Handle the case where the operation has inferred types. 1548 InferTypeOpInterface::Concept *inferInterface = 1549 state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>(); 1550 1551 // TODO: Handle failure. 1552 state.types.clear(); 1553 if (failed(inferInterface->inferReturnTypes( 1554 state.getContext(), state.location, state.operands, 1555 state.attributes.getDictionary(state.getContext()), state.regions, 1556 state.types))) 1557 return; 1558 break; 1559 } 1560 1561 Operation *resultOp = rewriter.createOperation(state); 1562 memory[memIndex] = resultOp; 1563 1564 LLVM_DEBUG({ 1565 llvm::dbgs() << " * Attributes: " 1566 << state.attributes.getDictionary(state.getContext()) 1567 << "\n * Operands: "; 1568 llvm::interleaveComma(state.operands, llvm::dbgs()); 1569 llvm::dbgs() << "\n * Result Types: "; 1570 llvm::interleaveComma(state.types, llvm::dbgs()); 1571 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1572 }); 1573 } 1574 1575 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1576 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1577 Operation *op = read<Operation *>(); 1578 1579 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1580 rewriter.eraseOp(op); 1581 } 1582 1583 template <typename T, typename Range, PDLValue::Kind kind> 1584 void ByteCodeExecutor::executeExtract() { 1585 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); 1586 Range *range = read<Range *>(); 1587 unsigned index = read<uint32_t>(); 1588 unsigned memIndex = read(); 1589 1590 if (!range) { 1591 memory[memIndex] = nullptr; 1592 return; 1593 } 1594 1595 T result = index < range->size() ? (*range)[index] : T(); 1596 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" 1597 << " * Index: " << index << "\n" 1598 << " * Result: " << result << "\n"); 1599 storeToMemory(memIndex, result); 1600 } 1601 1602 void ByteCodeExecutor::executeFinalize() { 1603 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); 1604 } 1605 1606 void ByteCodeExecutor::executeForEach() { 1607 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); 1608 const ByteCodeField *prevCodeIt = getPrevCodeIt(); 1609 unsigned rangeIndex = read(); 1610 unsigned memIndex = read(); 1611 const void *value = nullptr; 1612 1613 switch (read<PDLValue::Kind>()) { 1614 case PDLValue::Kind::Operation: { 1615 unsigned &index = loopIndex[read()]; 1616 ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; 1617 assert(index <= array.size() && "iterated past the end"); 1618 if (index < array.size()) { 1619 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); 1620 value = array[index]; 1621 break; 1622 } 1623 1624 LLVM_DEBUG(llvm::dbgs() << " * Done\n"); 1625 index = 0; 1626 selectJump(size_t(0)); 1627 return; 1628 } 1629 default: 1630 llvm_unreachable("unexpected `ForEach` value kind"); 1631 } 1632 1633 // Store the iterate value and the stack address. 1634 memory[memIndex] = value; 1635 pushCodeIt(prevCodeIt); 1636 1637 // Skip over the successor (we will enter the body of the loop). 1638 read<ByteCodeAddr>(); 1639 } 1640 1641 void ByteCodeExecutor::executeGetAttribute() { 1642 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1643 unsigned memIndex = read(); 1644 Operation *op = read<Operation *>(); 1645 StringAttr attrName = read<StringAttr>(); 1646 Attribute attr = op->getAttr(attrName); 1647 1648 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1649 << " * Attribute: " << attrName << "\n" 1650 << " * Result: " << attr << "\n"); 1651 memory[memIndex] = attr.getAsOpaquePointer(); 1652 } 1653 1654 void ByteCodeExecutor::executeGetAttributeType() { 1655 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1656 unsigned memIndex = read(); 1657 Attribute attr = read<Attribute>(); 1658 Type type = attr ? attr.getType() : Type(); 1659 1660 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1661 << " * Result: " << type << "\n"); 1662 memory[memIndex] = type.getAsOpaquePointer(); 1663 } 1664 1665 void ByteCodeExecutor::executeGetDefiningOp() { 1666 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1667 unsigned memIndex = read(); 1668 Operation *op = nullptr; 1669 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1670 Value value = read<Value>(); 1671 if (value) 1672 op = value.getDefiningOp(); 1673 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1674 } else { 1675 ValueRange *values = read<ValueRange *>(); 1676 if (values && !values->empty()) { 1677 op = values->front().getDefiningOp(); 1678 } 1679 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1680 } 1681 1682 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1683 memory[memIndex] = op; 1684 } 1685 1686 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1687 Operation *op = read<Operation *>(); 1688 unsigned memIndex = read(); 1689 Value operand = 1690 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1691 1692 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1693 << " * Index: " << index << "\n" 1694 << " * Result: " << operand << "\n"); 1695 memory[memIndex] = operand.getAsOpaquePointer(); 1696 } 1697 1698 /// This function is the internal implementation of `GetResults` and 1699 /// `GetOperands` that provides support for extracting a value range from the 1700 /// given operation. 1701 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1702 static void * 1703 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1704 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1705 MutableArrayRef<ValueRange> valueRangeMemory) { 1706 // Check for the sentinel index that signals that all values should be 1707 // returned. 1708 if (index == std::numeric_limits<uint32_t>::max()) { 1709 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1710 // `values` is already the full value range. 1711 1712 // Otherwise, check to see if this operation uses AttrSizedSegments. 1713 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1714 LLVM_DEBUG(llvm::dbgs() 1715 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1716 1717 auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 1718 if (!segmentAttr || segmentAttr.getNumElements() <= index) 1719 return nullptr; 1720 1721 auto segments = segmentAttr.getValues<int32_t>(); 1722 unsigned startIndex = 1723 std::accumulate(segments.begin(), segments.begin() + index, 0); 1724 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1725 1726 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1727 << *std::next(segments.begin(), index) << "]\n"); 1728 1729 // Otherwise, assume this is the last operand group of the operation. 1730 // FIXME: We currently don't support operations with 1731 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1732 // have a way to detect it's presence. 1733 } else if (values.size() >= index) { 1734 LLVM_DEBUG(llvm::dbgs() 1735 << " * Treating values as trailing variadic range\n"); 1736 values = values.drop_front(index); 1737 1738 // If we couldn't detect a way to compute the values, bail out. 1739 } else { 1740 return nullptr; 1741 } 1742 1743 // If the range index is valid, we are returning a range. 1744 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1745 valueRangeMemory[rangeIndex] = values; 1746 return &valueRangeMemory[rangeIndex]; 1747 } 1748 1749 // If a range index wasn't provided, the range is required to be non-variadic. 1750 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1751 } 1752 1753 void ByteCodeExecutor::executeGetOperands() { 1754 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1755 unsigned index = read<uint32_t>(); 1756 Operation *op = read<Operation *>(); 1757 ByteCodeField rangeIndex = read(); 1758 1759 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1760 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1761 valueRangeMemory); 1762 if (!result) 1763 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1764 memory[read()] = result; 1765 } 1766 1767 void ByteCodeExecutor::executeGetResult(unsigned index) { 1768 Operation *op = read<Operation *>(); 1769 unsigned memIndex = read(); 1770 OpResult result = 1771 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1772 1773 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1774 << " * Index: " << index << "\n" 1775 << " * Result: " << result << "\n"); 1776 memory[memIndex] = result.getAsOpaquePointer(); 1777 } 1778 1779 void ByteCodeExecutor::executeGetResults() { 1780 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1781 unsigned index = read<uint32_t>(); 1782 Operation *op = read<Operation *>(); 1783 ByteCodeField rangeIndex = read(); 1784 1785 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1786 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1787 valueRangeMemory); 1788 if (!result) 1789 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1790 memory[read()] = result; 1791 } 1792 1793 void ByteCodeExecutor::executeGetUsers() { 1794 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); 1795 unsigned memIndex = read(); 1796 unsigned rangeIndex = read(); 1797 OwningOpRange &range = opRangeMemory[rangeIndex]; 1798 memory[memIndex] = ⦥ 1799 1800 range = OwningOpRange(); 1801 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1802 // Read the value. 1803 Value value = read<Value>(); 1804 if (!value) 1805 return; 1806 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1807 1808 // Extract the users of a single value. 1809 range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); 1810 llvm::copy(value.getUsers(), range.begin()); 1811 } else { 1812 // Read a range of values. 1813 ValueRange *values = read<ValueRange *>(); 1814 if (!values) 1815 return; 1816 LLVM_DEBUG({ 1817 llvm::dbgs() << " * Values (" << values->size() << "): "; 1818 llvm::interleaveComma(*values, llvm::dbgs()); 1819 llvm::dbgs() << "\n"; 1820 }); 1821 1822 // Extract all the users of a range of values. 1823 SmallVector<Operation *> users; 1824 for (Value value : *values) 1825 users.append(value.user_begin(), value.user_end()); 1826 range = OwningOpRange(users.size()); 1827 llvm::copy(users, range.begin()); 1828 } 1829 1830 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); 1831 } 1832 1833 void ByteCodeExecutor::executeGetValueType() { 1834 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1835 unsigned memIndex = read(); 1836 Value value = read<Value>(); 1837 Type type = value ? value.getType() : Type(); 1838 1839 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1840 << " * Result: " << type << "\n"); 1841 memory[memIndex] = type.getAsOpaquePointer(); 1842 } 1843 1844 void ByteCodeExecutor::executeGetValueRangeTypes() { 1845 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1846 unsigned memIndex = read(); 1847 unsigned rangeIndex = read(); 1848 ValueRange *values = read<ValueRange *>(); 1849 if (!values) { 1850 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1851 memory[memIndex] = nullptr; 1852 return; 1853 } 1854 1855 LLVM_DEBUG({ 1856 llvm::dbgs() << " * Values (" << values->size() << "): "; 1857 llvm::interleaveComma(*values, llvm::dbgs()); 1858 llvm::dbgs() << "\n * Result: "; 1859 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1860 llvm::dbgs() << "\n"; 1861 }); 1862 typeRangeMemory[rangeIndex] = values->getType(); 1863 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1864 } 1865 1866 void ByteCodeExecutor::executeIsNotNull() { 1867 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1868 const void *value = read<const void *>(); 1869 1870 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1871 selectJump(value != nullptr); 1872 } 1873 1874 void ByteCodeExecutor::executeRecordMatch( 1875 PatternRewriter &rewriter, 1876 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1877 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1878 unsigned patternIndex = read(); 1879 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1880 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1881 1882 // If the benefit of the pattern is impossible, skip the processing of the 1883 // rest of the pattern. 1884 if (benefit.isImpossibleToMatch()) { 1885 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1886 curCodeIt = dest; 1887 return; 1888 } 1889 1890 // Create a fused location containing the locations of each of the 1891 // operations used in the match. This will be used as the location for 1892 // created operations during the rewrite that don't already have an 1893 // explicit location set. 1894 unsigned numMatchLocs = read(); 1895 SmallVector<Location, 4> matchLocs; 1896 matchLocs.reserve(numMatchLocs); 1897 for (unsigned i = 0; i != numMatchLocs; ++i) 1898 matchLocs.push_back(read<Operation *>()->getLoc()); 1899 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1900 1901 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1902 << " * Location: " << matchLoc << "\n"); 1903 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1904 PDLByteCode::MatchResult &match = matches.back(); 1905 1906 // Record all of the inputs to the match. If any of the inputs are ranges, we 1907 // will also need to remap the range pointer to memory stored in the match 1908 // state. 1909 unsigned numInputs = read(); 1910 match.values.reserve(numInputs); 1911 match.typeRangeValues.reserve(numInputs); 1912 match.valueRangeValues.reserve(numInputs); 1913 for (unsigned i = 0; i < numInputs; ++i) { 1914 switch (read<PDLValue::Kind>()) { 1915 case PDLValue::Kind::TypeRange: 1916 match.typeRangeValues.push_back(*read<TypeRange *>()); 1917 match.values.push_back(&match.typeRangeValues.back()); 1918 break; 1919 case PDLValue::Kind::ValueRange: 1920 match.valueRangeValues.push_back(*read<ValueRange *>()); 1921 match.values.push_back(&match.valueRangeValues.back()); 1922 break; 1923 default: 1924 match.values.push_back(read<const void *>()); 1925 break; 1926 } 1927 } 1928 curCodeIt = dest; 1929 } 1930 1931 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1932 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1933 Operation *op = read<Operation *>(); 1934 SmallVector<Value, 16> args; 1935 readValueList(args); 1936 1937 LLVM_DEBUG({ 1938 llvm::dbgs() << " * Operation: " << *op << "\n" 1939 << " * Values: "; 1940 llvm::interleaveComma(args, llvm::dbgs()); 1941 llvm::dbgs() << "\n"; 1942 }); 1943 rewriter.replaceOp(op, args); 1944 } 1945 1946 void ByteCodeExecutor::executeSwitchAttribute() { 1947 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1948 Attribute value = read<Attribute>(); 1949 ArrayAttr cases = read<ArrayAttr>(); 1950 handleSwitch(value, cases); 1951 } 1952 1953 void ByteCodeExecutor::executeSwitchOperandCount() { 1954 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1955 Operation *op = read<Operation *>(); 1956 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1957 1958 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1959 handleSwitch(op->getNumOperands(), cases); 1960 } 1961 1962 void ByteCodeExecutor::executeSwitchOperationName() { 1963 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1964 OperationName value = read<Operation *>()->getName(); 1965 size_t caseCount = read(); 1966 1967 // The operation names are stored in-line, so to print them out for 1968 // debugging purposes we need to read the array before executing the 1969 // switch so that we can display all of the possible values. 1970 LLVM_DEBUG({ 1971 const ByteCodeField *prevCodeIt = curCodeIt; 1972 llvm::dbgs() << " * Value: " << value << "\n" 1973 << " * Cases: "; 1974 llvm::interleaveComma( 1975 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1976 [&](size_t) { return read<OperationName>(); }), 1977 llvm::dbgs()); 1978 llvm::dbgs() << "\n"; 1979 curCodeIt = prevCodeIt; 1980 }); 1981 1982 // Try to find the switch value within any of the cases. 1983 for (size_t i = 0; i != caseCount; ++i) { 1984 if (read<OperationName>() == value) { 1985 curCodeIt += (caseCount - i - 1); 1986 return selectJump(i + 1); 1987 } 1988 } 1989 selectJump(size_t(0)); 1990 } 1991 1992 void ByteCodeExecutor::executeSwitchResultCount() { 1993 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1994 Operation *op = read<Operation *>(); 1995 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1996 1997 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1998 handleSwitch(op->getNumResults(), cases); 1999 } 2000 2001 void ByteCodeExecutor::executeSwitchType() { 2002 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 2003 Type value = read<Type>(); 2004 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 2005 handleSwitch(value, cases); 2006 } 2007 2008 void ByteCodeExecutor::executeSwitchTypes() { 2009 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 2010 TypeRange *value = read<TypeRange *>(); 2011 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 2012 if (!value) { 2013 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 2014 return selectJump(size_t(0)); 2015 } 2016 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 2017 return value == caseValue.getAsValueRange<TypeAttr>(); 2018 }); 2019 } 2020 2021 void ByteCodeExecutor::execute( 2022 PatternRewriter &rewriter, 2023 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 2024 Optional<Location> mainRewriteLoc) { 2025 while (true) { 2026 // Print the location of the operation being executed. 2027 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); 2028 2029 OpCode opCode = static_cast<OpCode>(read()); 2030 switch (opCode) { 2031 case ApplyConstraint: 2032 executeApplyConstraint(rewriter); 2033 break; 2034 case ApplyRewrite: 2035 executeApplyRewrite(rewriter); 2036 break; 2037 case AreEqual: 2038 executeAreEqual(); 2039 break; 2040 case AreRangesEqual: 2041 executeAreRangesEqual(); 2042 break; 2043 case Branch: 2044 executeBranch(); 2045 break; 2046 case CheckOperandCount: 2047 executeCheckOperandCount(); 2048 break; 2049 case CheckOperationName: 2050 executeCheckOperationName(); 2051 break; 2052 case CheckResultCount: 2053 executeCheckResultCount(); 2054 break; 2055 case CheckTypes: 2056 executeCheckTypes(); 2057 break; 2058 case Continue: 2059 executeContinue(); 2060 break; 2061 case CreateOperation: 2062 executeCreateOperation(rewriter, *mainRewriteLoc); 2063 break; 2064 case CreateTypes: 2065 executeCreateTypes(); 2066 break; 2067 case EraseOp: 2068 executeEraseOp(rewriter); 2069 break; 2070 case ExtractOp: 2071 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); 2072 break; 2073 case ExtractType: 2074 executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); 2075 break; 2076 case ExtractValue: 2077 executeExtract<Value, ValueRange, PDLValue::Kind::Value>(); 2078 break; 2079 case Finalize: 2080 executeFinalize(); 2081 LLVM_DEBUG(llvm::dbgs() << "\n"); 2082 return; 2083 case ForEach: 2084 executeForEach(); 2085 break; 2086 case GetAttribute: 2087 executeGetAttribute(); 2088 break; 2089 case GetAttributeType: 2090 executeGetAttributeType(); 2091 break; 2092 case GetDefiningOp: 2093 executeGetDefiningOp(); 2094 break; 2095 case GetOperand0: 2096 case GetOperand1: 2097 case GetOperand2: 2098 case GetOperand3: { 2099 unsigned index = opCode - GetOperand0; 2100 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 2101 executeGetOperand(index); 2102 break; 2103 } 2104 case GetOperandN: 2105 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 2106 executeGetOperand(read<uint32_t>()); 2107 break; 2108 case GetOperands: 2109 executeGetOperands(); 2110 break; 2111 case GetResult0: 2112 case GetResult1: 2113 case GetResult2: 2114 case GetResult3: { 2115 unsigned index = opCode - GetResult0; 2116 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 2117 executeGetResult(index); 2118 break; 2119 } 2120 case GetResultN: 2121 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 2122 executeGetResult(read<uint32_t>()); 2123 break; 2124 case GetResults: 2125 executeGetResults(); 2126 break; 2127 case GetUsers: 2128 executeGetUsers(); 2129 break; 2130 case GetValueType: 2131 executeGetValueType(); 2132 break; 2133 case GetValueRangeTypes: 2134 executeGetValueRangeTypes(); 2135 break; 2136 case IsNotNull: 2137 executeIsNotNull(); 2138 break; 2139 case RecordMatch: 2140 assert(matches && 2141 "expected matches to be provided when executing the matcher"); 2142 executeRecordMatch(rewriter, *matches); 2143 break; 2144 case ReplaceOp: 2145 executeReplaceOp(rewriter); 2146 break; 2147 case SwitchAttribute: 2148 executeSwitchAttribute(); 2149 break; 2150 case SwitchOperandCount: 2151 executeSwitchOperandCount(); 2152 break; 2153 case SwitchOperationName: 2154 executeSwitchOperationName(); 2155 break; 2156 case SwitchResultCount: 2157 executeSwitchResultCount(); 2158 break; 2159 case SwitchType: 2160 executeSwitchType(); 2161 break; 2162 case SwitchTypes: 2163 executeSwitchTypes(); 2164 break; 2165 } 2166 LLVM_DEBUG(llvm::dbgs() << "\n"); 2167 } 2168 } 2169 2170 /// Run the pattern matcher on the given root operation, collecting the matched 2171 /// patterns in `matches`. 2172 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 2173 SmallVectorImpl<MatchResult> &matches, 2174 PDLByteCodeMutableState &state) const { 2175 // The first memory slot is always the root operation. 2176 state.memory[0] = op; 2177 2178 // The matcher function always starts at code address 0. 2179 ByteCodeExecutor executor( 2180 matcherByteCode.data(), state.memory, state.opRangeMemory, 2181 state.typeRangeMemory, state.allocatedTypeRangeMemory, 2182 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, 2183 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, 2184 constraintFunctions, rewriteFunctions); 2185 executor.execute(rewriter, &matches); 2186 2187 // Order the found matches by benefit. 2188 std::stable_sort(matches.begin(), matches.end(), 2189 [](const MatchResult &lhs, const MatchResult &rhs) { 2190 return lhs.benefit > rhs.benefit; 2191 }); 2192 } 2193 2194 /// Run the rewriter of the given pattern on the root operation `op`. 2195 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 2196 PDLByteCodeMutableState &state) const { 2197 // The arguments of the rewrite function are stored at the start of the 2198 // memory buffer. 2199 llvm::copy(match.values, state.memory.begin()); 2200 2201 ByteCodeExecutor executor( 2202 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 2203 state.opRangeMemory, state.typeRangeMemory, 2204 state.allocatedTypeRangeMemory, state.valueRangeMemory, 2205 state.allocatedValueRangeMemory, state.loopIndex, uniquedData, 2206 rewriterByteCode, state.currentPatternBenefits, patterns, 2207 constraintFunctions, rewriteFunctions); 2208 executor.execute(rewriter, /*matches=*/nullptr, match.location); 2209 } 2210