1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 27 #define DEBUG_TYPE "pdl-bytecode" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37 ByteCodeAddr rewriterAddr) { 38 SmallVector<StringRef, 8> generatedOps; 39 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) 40 generatedOps = 41 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42 43 PatternBenefit benefit = matchOp.getBenefit(); 44 MLIRContext *ctx = matchOp.getContext(); 45 46 // Check to see if this is pattern matches a specific operation type. 47 if (Optional<StringRef> rootKind = matchOp.getRootKind()) 48 return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, 49 generatedOps); 50 return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, 51 generatedOps); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // PDLByteCodeMutableState 56 //===----------------------------------------------------------------------===// 57 58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59 /// to the position of the pattern within the range returned by 60 /// `PDLByteCode::getPatterns`. 61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62 PatternBenefit benefit) { 63 currentPatternBenefits[patternIndex] = benefit; 64 } 65 66 /// Cleanup any allocated state after a full match/rewrite has been completed. 67 /// This method should be called irregardless of whether the match+rewrite was a 68 /// success or not. 69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 70 allocatedTypeRangeMemory.clear(); 71 allocatedValueRangeMemory.clear(); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Bytecode OpCodes 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 enum OpCode : ByteCodeField { 80 /// Apply an externally registered constraint. 81 ApplyConstraint, 82 /// Apply an externally registered rewrite. 83 ApplyRewrite, 84 /// Check if two generic values are equal. 85 AreEqual, 86 /// Check if two ranges are equal. 87 AreRangesEqual, 88 /// Unconditional branch. 89 Branch, 90 /// Compare the operand count of an operation with a constant. 91 CheckOperandCount, 92 /// Compare the name of an operation with a constant. 93 CheckOperationName, 94 /// Compare the result count of an operation with a constant. 95 CheckResultCount, 96 /// Compare a range of types to a constant range of types. 97 CheckTypes, 98 /// Continue to the next iteration of a loop. 99 Continue, 100 /// Create an operation. 101 CreateOperation, 102 /// Create a range of types. 103 CreateTypes, 104 /// Erase an operation. 105 EraseOp, 106 /// Extract the op from a range at the specified index. 107 ExtractOp, 108 /// Extract the type from a range at the specified index. 109 ExtractType, 110 /// Extract the value from a range at the specified index. 111 ExtractValue, 112 /// Terminate a matcher or rewrite sequence. 113 Finalize, 114 /// Iterate over a range of values. 115 ForEach, 116 /// Get a specific attribute of an operation. 117 GetAttribute, 118 /// Get the type of an attribute. 119 GetAttributeType, 120 /// Get the defining operation of a value. 121 GetDefiningOp, 122 /// Get a specific operand of an operation. 123 GetOperand0, 124 GetOperand1, 125 GetOperand2, 126 GetOperand3, 127 GetOperandN, 128 /// Get a specific operand group of an operation. 129 GetOperands, 130 /// Get a specific result of an operation. 131 GetResult0, 132 GetResult1, 133 GetResult2, 134 GetResult3, 135 GetResultN, 136 /// Get a specific result group of an operation. 137 GetResults, 138 /// Get the users of a value or a range of values. 139 GetUsers, 140 /// Get the type of a value. 141 GetValueType, 142 /// Get the types of a value range. 143 GetValueRangeTypes, 144 /// Check if a generic value is not null. 145 IsNotNull, 146 /// Record a successful pattern match. 147 RecordMatch, 148 /// Replace an operation. 149 ReplaceOp, 150 /// Compare an attribute with a set of constants. 151 SwitchAttribute, 152 /// Compare the operand count of an operation with a set of constants. 153 SwitchOperandCount, 154 /// Compare the name of an operation with a set of constants. 155 SwitchOperationName, 156 /// Compare the result count of an operation with a set of constants. 157 SwitchResultCount, 158 /// Compare a type with a set of constants. 159 SwitchType, 160 /// Compare a range of types with a set of constants. 161 SwitchTypes, 162 }; 163 } // namespace 164 165 /// A marker used to indicate if an operation should infer types. 166 static constexpr ByteCodeField kInferTypesMarker = 167 std::numeric_limits<ByteCodeField>::max(); 168 169 //===----------------------------------------------------------------------===// 170 // ByteCode Generation 171 //===----------------------------------------------------------------------===// 172 173 //===----------------------------------------------------------------------===// 174 // Generator 175 176 namespace { 177 struct ByteCodeLiveRange; 178 struct ByteCodeWriter; 179 180 /// Check if the given class `T` can be converted to an opaque pointer. 181 template <typename T, typename... Args> 182 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 183 184 /// This class represents the main generator for the pattern bytecode. 185 class Generator { 186 public: 187 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 188 SmallVectorImpl<ByteCodeField> &matcherByteCode, 189 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 190 SmallVectorImpl<PDLByteCodePattern> &patterns, 191 ByteCodeField &maxValueMemoryIndex, 192 ByteCodeField &maxOpRangeMemoryIndex, 193 ByteCodeField &maxTypeRangeMemoryIndex, 194 ByteCodeField &maxValueRangeMemoryIndex, 195 ByteCodeField &maxLoopLevel, 196 llvm::StringMap<PDLConstraintFunction> &constraintFns, 197 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 198 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 199 rewriterByteCode(rewriterByteCode), patterns(patterns), 200 maxValueMemoryIndex(maxValueMemoryIndex), 201 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), 202 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 203 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), 204 maxLoopLevel(maxLoopLevel) { 205 for (const auto &it : llvm::enumerate(constraintFns)) 206 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 207 for (const auto &it : llvm::enumerate(rewriteFns)) 208 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 209 } 210 211 /// Generate the bytecode for the given PDL interpreter module. 212 void generate(ModuleOp module); 213 214 /// Return the memory index to use for the given value. 215 ByteCodeField &getMemIndex(Value value) { 216 assert(valueToMemIndex.count(value) && 217 "expected memory index to be assigned"); 218 return valueToMemIndex[value]; 219 } 220 221 /// Return the range memory index used to store the given range value. 222 ByteCodeField &getRangeStorageIndex(Value value) { 223 assert(valueToRangeIndex.count(value) && 224 "expected range index to be assigned"); 225 return valueToRangeIndex[value]; 226 } 227 228 /// Return an index to use when referring to the given data that is uniqued in 229 /// the MLIR context. 230 template <typename T> 231 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 232 getMemIndex(T val) { 233 const void *opaqueVal = val.getAsOpaquePointer(); 234 235 // Get or insert a reference to this value. 236 auto it = uniquedDataToMemIndex.try_emplace( 237 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 238 if (it.second) 239 uniquedData.push_back(opaqueVal); 240 return it.first->second; 241 } 242 243 private: 244 /// Allocate memory indices for the results of operations within the matcher 245 /// and rewriters. 246 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 247 ModuleOp rewriterModule); 248 249 /// Generate the bytecode for the given operation. 250 void generate(Region *region, ByteCodeWriter &writer); 251 void generate(Operation *op, ByteCodeWriter &writer); 252 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 253 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 254 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 255 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 256 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 257 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 258 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 259 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 260 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 261 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 262 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); 263 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 264 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 265 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 266 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 267 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 268 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); 269 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 270 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); 271 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 272 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 273 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 274 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 275 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 276 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 277 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 278 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); 279 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 280 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 281 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 282 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 283 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 284 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 285 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 286 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 287 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 288 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 289 290 /// Mapping from value to its corresponding memory index. 291 DenseMap<Value, ByteCodeField> valueToMemIndex; 292 293 /// Mapping from a range value to its corresponding range storage index. 294 DenseMap<Value, ByteCodeField> valueToRangeIndex; 295 296 /// Mapping from the name of an externally registered rewrite to its index in 297 /// the bytecode registry. 298 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 299 300 /// Mapping from the name of an externally registered constraint to its index 301 /// in the bytecode registry. 302 llvm::StringMap<ByteCodeField> constraintToMemIndex; 303 304 /// Mapping from rewriter function name to the bytecode address of the 305 /// rewriter function in byte. 306 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 307 308 /// Mapping from a uniqued storage object to its memory index within 309 /// `uniquedData`. 310 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 311 312 /// The current level of the foreach loop. 313 ByteCodeField curLoopLevel = 0; 314 315 /// The current MLIR context. 316 MLIRContext *ctx; 317 318 /// Mapping from block to its address. 319 DenseMap<Block *, ByteCodeAddr> blockToAddr; 320 321 /// Data of the ByteCode class to be populated. 322 std::vector<const void *> &uniquedData; 323 SmallVectorImpl<ByteCodeField> &matcherByteCode; 324 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 325 SmallVectorImpl<PDLByteCodePattern> &patterns; 326 ByteCodeField &maxValueMemoryIndex; 327 ByteCodeField &maxOpRangeMemoryIndex; 328 ByteCodeField &maxTypeRangeMemoryIndex; 329 ByteCodeField &maxValueRangeMemoryIndex; 330 ByteCodeField &maxLoopLevel; 331 }; 332 333 /// This class provides utilities for writing a bytecode stream. 334 struct ByteCodeWriter { 335 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 336 : bytecode(bytecode), generator(generator) {} 337 338 /// Append a field to the bytecode. 339 void append(ByteCodeField field) { bytecode.push_back(field); } 340 void append(OpCode opCode) { bytecode.push_back(opCode); } 341 342 /// Append an address to the bytecode. 343 void append(ByteCodeAddr field) { 344 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 345 "unexpected ByteCode address size"); 346 347 ByteCodeField fieldParts[2]; 348 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 349 bytecode.append({fieldParts[0], fieldParts[1]}); 350 } 351 352 /// Append a single successor to the bytecode, the exact address will need to 353 /// be resolved later. 354 void append(Block *successor) { 355 // Add back a reference to the successor so that the address can be resolved 356 // later. 357 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 358 append(ByteCodeAddr(0)); 359 } 360 361 /// Append a successor range to the bytecode, the exact address will need to 362 /// be resolved later. 363 void append(SuccessorRange successors) { 364 for (Block *successor : successors) 365 append(successor); 366 } 367 368 /// Append a range of values that will be read as generic PDLValues. 369 void appendPDLValueList(OperandRange values) { 370 bytecode.push_back(values.size()); 371 for (Value value : values) 372 appendPDLValue(value); 373 } 374 375 /// Append a value as a PDLValue. 376 void appendPDLValue(Value value) { 377 appendPDLValueKind(value); 378 append(value); 379 } 380 381 /// Append the PDLValue::Kind of the given value. 382 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } 383 384 /// Append the PDLValue::Kind of the given type. 385 void appendPDLValueKind(Type type) { 386 PDLValue::Kind kind = 387 TypeSwitch<Type, PDLValue::Kind>(type) 388 .Case<pdl::AttributeType>( 389 [](Type) { return PDLValue::Kind::Attribute; }) 390 .Case<pdl::OperationType>( 391 [](Type) { return PDLValue::Kind::Operation; }) 392 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 393 if (rangeTy.getElementType().isa<pdl::TypeType>()) 394 return PDLValue::Kind::TypeRange; 395 return PDLValue::Kind::ValueRange; 396 }) 397 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 398 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 399 bytecode.push_back(static_cast<ByteCodeField>(kind)); 400 } 401 402 /// Append a value that will be stored in a memory slot and not inline within 403 /// the bytecode. 404 template <typename T> 405 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 406 std::is_pointer<T>::value> 407 append(T value) { 408 bytecode.push_back(generator.getMemIndex(value)); 409 } 410 411 /// Append a range of values. 412 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 413 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 414 append(T range) { 415 bytecode.push_back(llvm::size(range)); 416 for (auto it : range) 417 append(it); 418 } 419 420 /// Append a variadic number of fields to the bytecode. 421 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 422 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 423 append(field); 424 append(field2, fields...); 425 } 426 427 /// Appends a value as a pointer, stored inline within the bytecode. 428 template <typename T> 429 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 430 appendInline(T value) { 431 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); 432 const void *pointer = value.getAsOpaquePointer(); 433 ByteCodeField fieldParts[numParts]; 434 std::memcpy(fieldParts, &pointer, sizeof(const void *)); 435 bytecode.append(fieldParts, fieldParts + numParts); 436 } 437 438 /// Successor references in the bytecode that have yet to be resolved. 439 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 440 441 /// The underlying bytecode buffer. 442 SmallVectorImpl<ByteCodeField> &bytecode; 443 444 /// The main generator producing PDL. 445 Generator &generator; 446 }; 447 448 /// This class represents a live range of PDL Interpreter values, containing 449 /// information about when values are live within a match/rewrite. 450 struct ByteCodeLiveRange { 451 using Set = llvm::IntervalMap<uint64_t, char, 16>; 452 using Allocator = Set::Allocator; 453 454 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} 455 456 /// Union this live range with the one provided. 457 void unionWith(const ByteCodeLiveRange &rhs) { 458 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; 459 ++it) 460 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); 461 } 462 463 /// Returns true if this range overlaps with the one provided. 464 bool overlaps(const ByteCodeLiveRange &rhs) const { 465 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness) 466 .valid(); 467 } 468 469 /// A map representing the ranges of the match/rewrite that a value is live in 470 /// the interpreter. 471 /// 472 /// We use std::unique_ptr here, because IntervalMap does not provide a 473 /// correct copy or move constructor. We can eliminate the pointer once 474 /// https://reviews.llvm.org/D113240 lands. 475 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness; 476 477 /// The operation range storage index for this range. 478 Optional<unsigned> opRangeIndex; 479 480 /// The type range storage index for this range. 481 Optional<unsigned> typeRangeIndex; 482 483 /// The value range storage index for this range. 484 Optional<unsigned> valueRangeIndex; 485 }; 486 } // namespace 487 488 void Generator::generate(ModuleOp module) { 489 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>( 490 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 491 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 492 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 493 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 494 495 // Allocate memory indices for the results of operations within the matcher 496 // and rewriters. 497 allocateMemoryIndices(matcherFunc, rewriterModule); 498 499 // Generate code for the rewriter functions. 500 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 501 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 502 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 503 for (Operation &op : rewriterFunc.getOps()) 504 generate(&op, rewriterByteCodeWriter); 505 } 506 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 507 "unexpected branches in rewriter function"); 508 509 // Generate code for the matcher function. 510 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 511 generate(&matcherFunc.getBody(), matcherByteCodeWriter); 512 513 // Resolve successor references in the matcher. 514 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 515 ByteCodeAddr addr = blockToAddr[it.first]; 516 for (unsigned offsetToFix : it.second) 517 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 518 } 519 } 520 521 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 522 ModuleOp rewriterModule) { 523 // Rewriters use simplistic allocation scheme that simply assigns an index to 524 // each result. 525 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 526 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 527 auto processRewriterValue = [&](Value val) { 528 valueToMemIndex.try_emplace(val, index++); 529 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 530 Type elementTy = rangeType.getElementType(); 531 if (elementTy.isa<pdl::TypeType>()) 532 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 533 else if (elementTy.isa<pdl::ValueType>()) 534 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 535 } 536 }; 537 538 for (BlockArgument arg : rewriterFunc.getArguments()) 539 processRewriterValue(arg); 540 rewriterFunc.getBody().walk([&](Operation *op) { 541 for (Value result : op->getResults()) 542 processRewriterValue(result); 543 }); 544 if (index > maxValueMemoryIndex) 545 maxValueMemoryIndex = index; 546 if (typeRangeIndex > maxTypeRangeMemoryIndex) 547 maxTypeRangeMemoryIndex = typeRangeIndex; 548 if (valueRangeIndex > maxValueRangeMemoryIndex) 549 maxValueRangeMemoryIndex = valueRangeIndex; 550 } 551 552 // The matcher function uses a more sophisticated numbering that tries to 553 // minimize the number of memory indices assigned. This is done by determining 554 // a live range of the values within the matcher, then the allocation is just 555 // finding the minimal number of overlapping live ranges. This is essentially 556 // a simplified form of register allocation where we don't necessarily have a 557 // limited number of registers, but we still want to minimize the number used. 558 DenseMap<Operation *, unsigned> opToFirstIndex; 559 DenseMap<Operation *, unsigned> opToLastIndex; 560 561 // A custom walk that marks the first and the last index of each operation. 562 // The entry marks the beginning of the liveness range for this operation, 563 // followed by nested operations, followed by the end of the liveness range. 564 unsigned index = 0; 565 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) { 566 opToFirstIndex.try_emplace(op, index++); 567 for (Region ®ion : op->getRegions()) 568 for (Block &block : region.getBlocks()) 569 for (Operation &nested : block) 570 walk(&nested); 571 opToLastIndex.try_emplace(op, index++); 572 }; 573 walk(matcherFunc); 574 575 // Liveness info for each of the defs within the matcher. 576 ByteCodeLiveRange::Allocator allocator; 577 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 578 579 // Assign the root operation being matched to slot 0. 580 BlockArgument rootOpArg = matcherFunc.getArgument(0); 581 valueToMemIndex[rootOpArg] = 0; 582 583 // Walk each of the blocks, computing the def interval that the value is used. 584 Liveness matcherLiveness(matcherFunc); 585 matcherFunc->walk([&](Block *block) { 586 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); 587 assert(info && "expected liveness info for block"); 588 auto processValue = [&](Value value, Operation *firstUseOrDef) { 589 // We don't need to process the root op argument, this value is always 590 // assigned to the first memory slot. 591 if (value == rootOpArg) 592 return; 593 594 // Set indices for the range of this block that the value is used. 595 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 596 defRangeIt->second.liveness->insert( 597 opToFirstIndex[firstUseOrDef], 598 opToLastIndex[info->getEndOperation(value, firstUseOrDef)], 599 /*dummyValue*/ 0); 600 601 // Check to see if this value is a range type. 602 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 603 Type eleType = rangeTy.getElementType(); 604 if (eleType.isa<pdl::OperationType>()) 605 defRangeIt->second.opRangeIndex = 0; 606 else if (eleType.isa<pdl::TypeType>()) 607 defRangeIt->second.typeRangeIndex = 0; 608 else if (eleType.isa<pdl::ValueType>()) 609 defRangeIt->second.valueRangeIndex = 0; 610 } 611 }; 612 613 // Process the live-ins of this block. 614 for (Value liveIn : info->in()) { 615 // Only process the value if it has been defined in the current region. 616 // Other values that span across pdl_interp.foreach will be added higher 617 // up. This ensures that the we keep them alive for the entire duration 618 // of the loop. 619 if (liveIn.getParentRegion() == block->getParent()) 620 processValue(liveIn, &block->front()); 621 } 622 623 // Process the block arguments for the entry block (those are not live-in). 624 if (block->isEntryBlock()) { 625 for (Value argument : block->getArguments()) 626 processValue(argument, &block->front()); 627 } 628 629 // Process any new defs within this block. 630 for (Operation &op : *block) 631 for (Value result : op.getResults()) 632 processValue(result, &op); 633 }); 634 635 // Greedily allocate memory slots using the computed def live ranges. 636 std::vector<ByteCodeLiveRange> allocatedIndices; 637 638 // The number of memory indices currently allocated (and its next value). 639 // Recall that the root gets allocated memory index 0. 640 ByteCodeField numIndices = 1; 641 642 // The number of memory ranges of various types (and their next values). 643 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; 644 645 for (auto &defIt : valueDefRanges) { 646 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 647 ByteCodeLiveRange &defRange = defIt.second; 648 649 // Try to allocate to an existing index. 650 for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) { 651 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 652 if (!defRange.overlaps(existingRange)) { 653 existingRange.unionWith(defRange); 654 memIndex = existingIndexIt.index() + 1; 655 656 if (defRange.opRangeIndex) { 657 if (!existingRange.opRangeIndex) 658 existingRange.opRangeIndex = numOpRanges++; 659 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; 660 } else if (defRange.typeRangeIndex) { 661 if (!existingRange.typeRangeIndex) 662 existingRange.typeRangeIndex = numTypeRanges++; 663 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 664 } else if (defRange.valueRangeIndex) { 665 if (!existingRange.valueRangeIndex) 666 existingRange.valueRangeIndex = numValueRanges++; 667 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 668 } 669 break; 670 } 671 } 672 673 // If no existing index could be used, add a new one. 674 if (memIndex == 0) { 675 allocatedIndices.emplace_back(allocator); 676 ByteCodeLiveRange &newRange = allocatedIndices.back(); 677 newRange.unionWith(defRange); 678 679 // Allocate an index for op/type/value ranges. 680 if (defRange.opRangeIndex) { 681 newRange.opRangeIndex = numOpRanges; 682 valueToRangeIndex[defIt.first] = numOpRanges++; 683 } else if (defRange.typeRangeIndex) { 684 newRange.typeRangeIndex = numTypeRanges; 685 valueToRangeIndex[defIt.first] = numTypeRanges++; 686 } else if (defRange.valueRangeIndex) { 687 newRange.valueRangeIndex = numValueRanges; 688 valueToRangeIndex[defIt.first] = numValueRanges++; 689 } 690 691 memIndex = allocatedIndices.size(); 692 ++numIndices; 693 } 694 } 695 696 // Print the index usage and ensure that we did not run out of index space. 697 LLVM_DEBUG({ 698 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " 699 << "(down from initial " << valueDefRanges.size() << ").\n"; 700 }); 701 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && 702 "Ran out of memory for allocated indices"); 703 704 // Update the max number of indices. 705 if (numIndices > maxValueMemoryIndex) 706 maxValueMemoryIndex = numIndices; 707 if (numOpRanges > maxOpRangeMemoryIndex) 708 maxOpRangeMemoryIndex = numOpRanges; 709 if (numTypeRanges > maxTypeRangeMemoryIndex) 710 maxTypeRangeMemoryIndex = numTypeRanges; 711 if (numValueRanges > maxValueRangeMemoryIndex) 712 maxValueRangeMemoryIndex = numValueRanges; 713 } 714 715 void Generator::generate(Region *region, ByteCodeWriter &writer) { 716 llvm::ReversePostOrderTraversal<Region *> rpot(region); 717 for (Block *block : rpot) { 718 // Keep track of where this block begins within the matcher function. 719 blockToAddr.try_emplace(block, matcherByteCode.size()); 720 for (Operation &op : *block) 721 generate(&op, writer); 722 } 723 } 724 725 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 726 LLVM_DEBUG({ 727 // The following list must contain all the operations that do not 728 // produce any bytecode. 729 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op)) 730 writer.appendInline(op->getLoc()); 731 }); 732 TypeSwitch<Operation *>(op) 733 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 734 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 735 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 736 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 737 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 738 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, 739 pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, 740 pdl_interp::CreateTypesOp, pdl_interp::EraseOp, 741 pdl_interp::ExtractOp, pdl_interp::FinalizeOp, 742 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, 743 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 744 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, 745 pdl_interp::GetResultOp, pdl_interp::GetResultsOp, 746 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, 747 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, 748 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, 749 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp, 750 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, 751 pdl_interp::SwitchResultCountOp>( 752 [&](auto interpOp) { this->generate(interpOp, writer); }) 753 .Default([](Operation *) { 754 llvm_unreachable("unknown `pdl_interp` operation"); 755 }); 756 } 757 758 void Generator::generate(pdl_interp::ApplyConstraintOp op, 759 ByteCodeWriter &writer) { 760 assert(constraintToMemIndex.count(op.getName()) && 761 "expected index for constraint function"); 762 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); 763 writer.appendPDLValueList(op.getArgs()); 764 writer.append(op.getSuccessors()); 765 } 766 void Generator::generate(pdl_interp::ApplyRewriteOp op, 767 ByteCodeWriter &writer) { 768 assert(externalRewriterToMemIndex.count(op.getName()) && 769 "expected index for rewrite function"); 770 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]); 771 writer.appendPDLValueList(op.getArgs()); 772 773 ResultRange results = op.getResults(); 774 writer.append(ByteCodeField(results.size())); 775 for (Value result : results) { 776 // In debug mode we also record the expected kind of the result, so that we 777 // can provide extra verification of the native rewrite function. 778 #ifndef NDEBUG 779 writer.appendPDLValueKind(result); 780 #endif 781 782 // Range results also need to append the range storage index. 783 if (result.getType().isa<pdl::RangeType>()) 784 writer.append(getRangeStorageIndex(result)); 785 writer.append(result); 786 } 787 } 788 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 789 Value lhs = op.getLhs(); 790 if (lhs.getType().isa<pdl::RangeType>()) { 791 writer.append(OpCode::AreRangesEqual); 792 writer.appendPDLValueKind(lhs); 793 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors()); 794 return; 795 } 796 797 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors()); 798 } 799 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 800 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 801 } 802 void Generator::generate(pdl_interp::CheckAttributeOp op, 803 ByteCodeWriter &writer) { 804 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(), 805 op.getSuccessors()); 806 } 807 void Generator::generate(pdl_interp::CheckOperandCountOp op, 808 ByteCodeWriter &writer) { 809 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(), 810 static_cast<ByteCodeField>(op.getCompareAtLeast()), 811 op.getSuccessors()); 812 } 813 void Generator::generate(pdl_interp::CheckOperationNameOp op, 814 ByteCodeWriter &writer) { 815 writer.append(OpCode::CheckOperationName, op.getInputOp(), 816 OperationName(op.getName(), ctx), op.getSuccessors()); 817 } 818 void Generator::generate(pdl_interp::CheckResultCountOp op, 819 ByteCodeWriter &writer) { 820 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(), 821 static_cast<ByteCodeField>(op.getCompareAtLeast()), 822 op.getSuccessors()); 823 } 824 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 825 writer.append(OpCode::AreEqual, op.getValue(), op.getType(), 826 op.getSuccessors()); 827 } 828 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 829 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(), 830 op.getSuccessors()); 831 } 832 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { 833 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); 834 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); 835 } 836 void Generator::generate(pdl_interp::CreateAttributeOp op, 837 ByteCodeWriter &writer) { 838 // Simply repoint the memory index of the result to the constant. 839 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue()); 840 } 841 void Generator::generate(pdl_interp::CreateOperationOp op, 842 ByteCodeWriter &writer) { 843 writer.append(OpCode::CreateOperation, op.getResultOp(), 844 OperationName(op.getName(), ctx)); 845 writer.appendPDLValueList(op.getInputOperands()); 846 847 // Add the attributes. 848 OperandRange attributes = op.getInputAttributes(); 849 writer.append(static_cast<ByteCodeField>(attributes.size())); 850 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes)) 851 writer.append(std::get<0>(it), std::get<1>(it)); 852 853 // Add the result types. If the operation has inferred results, we use a 854 // marker "size" value. Otherwise, we add the list of explicit result types. 855 if (op.getInferredResultTypes()) 856 writer.append(kInferTypesMarker); 857 else 858 writer.appendPDLValueList(op.getInputResultTypes()); 859 } 860 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 861 // Simply repoint the memory index of the result to the constant. 862 getMemIndex(op.getResult()) = getMemIndex(op.getValue()); 863 } 864 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 865 writer.append(OpCode::CreateTypes, op.getResult(), 866 getRangeStorageIndex(op.getResult()), op.getValue()); 867 } 868 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 869 writer.append(OpCode::EraseOp, op.getInputOp()); 870 } 871 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { 872 OpCode opCode = 873 TypeSwitch<Type, OpCode>(op.getResult().getType()) 874 .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) 875 .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) 876 .Case([](pdl::TypeType) { return OpCode::ExtractType; }) 877 .Default([](Type) -> OpCode { 878 llvm_unreachable("unsupported element type"); 879 }); 880 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult()); 881 } 882 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 883 writer.append(OpCode::Finalize); 884 } 885 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { 886 BlockArgument arg = op.getLoopVariable(); 887 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg); 888 writer.appendPDLValueKind(arg.getType()); 889 writer.append(curLoopLevel, op.getSuccessor()); 890 ++curLoopLevel; 891 if (curLoopLevel > maxLoopLevel) 892 maxLoopLevel = curLoopLevel; 893 generate(&op.getRegion(), writer); 894 --curLoopLevel; 895 } 896 void Generator::generate(pdl_interp::GetAttributeOp op, 897 ByteCodeWriter &writer) { 898 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(), 899 op.getNameAttr()); 900 } 901 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 902 ByteCodeWriter &writer) { 903 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue()); 904 } 905 void Generator::generate(pdl_interp::GetDefiningOpOp op, 906 ByteCodeWriter &writer) { 907 writer.append(OpCode::GetDefiningOp, op.getInputOp()); 908 writer.appendPDLValue(op.getValue()); 909 } 910 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 911 uint32_t index = op.getIndex(); 912 if (index < 4) 913 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 914 else 915 writer.append(OpCode::GetOperandN, index); 916 writer.append(op.getInputOp(), op.getValue()); 917 } 918 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 919 Value result = op.getValue(); 920 Optional<uint32_t> index = op.getIndex(); 921 writer.append(OpCode::GetOperands, 922 index.value_or(std::numeric_limits<uint32_t>::max()), 923 op.getInputOp()); 924 if (result.getType().isa<pdl::RangeType>()) 925 writer.append(getRangeStorageIndex(result)); 926 else 927 writer.append(std::numeric_limits<ByteCodeField>::max()); 928 writer.append(result); 929 } 930 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 931 uint32_t index = op.getIndex(); 932 if (index < 4) 933 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 934 else 935 writer.append(OpCode::GetResultN, index); 936 writer.append(op.getInputOp(), op.getValue()); 937 } 938 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 939 Value result = op.getValue(); 940 Optional<uint32_t> index = op.getIndex(); 941 writer.append(OpCode::GetResults, 942 index.value_or(std::numeric_limits<uint32_t>::max()), 943 op.getInputOp()); 944 if (result.getType().isa<pdl::RangeType>()) 945 writer.append(getRangeStorageIndex(result)); 946 else 947 writer.append(std::numeric_limits<ByteCodeField>::max()); 948 writer.append(result); 949 } 950 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { 951 Value operations = op.getOperations(); 952 ByteCodeField rangeIndex = getRangeStorageIndex(operations); 953 writer.append(OpCode::GetUsers, operations, rangeIndex); 954 writer.appendPDLValue(op.getValue()); 955 } 956 void Generator::generate(pdl_interp::GetValueTypeOp op, 957 ByteCodeWriter &writer) { 958 if (op.getType().isa<pdl::RangeType>()) { 959 Value result = op.getResult(); 960 writer.append(OpCode::GetValueRangeTypes, result, 961 getRangeStorageIndex(result), op.getValue()); 962 } else { 963 writer.append(OpCode::GetValueType, op.getResult(), op.getValue()); 964 } 965 } 966 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 967 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors()); 968 } 969 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 970 ByteCodeField patternIndex = patterns.size(); 971 patterns.emplace_back(PDLByteCodePattern::create( 972 op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); 973 writer.append(OpCode::RecordMatch, patternIndex, 974 SuccessorRange(op.getOperation()), op.getMatchedOps()); 975 writer.appendPDLValueList(op.getInputs()); 976 } 977 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 978 writer.append(OpCode::ReplaceOp, op.getInputOp()); 979 writer.appendPDLValueList(op.getReplValues()); 980 } 981 void Generator::generate(pdl_interp::SwitchAttributeOp op, 982 ByteCodeWriter &writer) { 983 writer.append(OpCode::SwitchAttribute, op.getAttribute(), 984 op.getCaseValuesAttr(), op.getSuccessors()); 985 } 986 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 987 ByteCodeWriter &writer) { 988 writer.append(OpCode::SwitchOperandCount, op.getInputOp(), 989 op.getCaseValuesAttr(), op.getSuccessors()); 990 } 991 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 992 ByteCodeWriter &writer) { 993 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) { 994 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 995 }); 996 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases, 997 op.getSuccessors()); 998 } 999 void Generator::generate(pdl_interp::SwitchResultCountOp op, 1000 ByteCodeWriter &writer) { 1001 writer.append(OpCode::SwitchResultCount, op.getInputOp(), 1002 op.getCaseValuesAttr(), op.getSuccessors()); 1003 } 1004 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 1005 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(), 1006 op.getSuccessors()); 1007 } 1008 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 1009 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(), 1010 op.getSuccessors()); 1011 } 1012 1013 //===----------------------------------------------------------------------===// 1014 // PDLByteCode 1015 //===----------------------------------------------------------------------===// 1016 1017 PDLByteCode::PDLByteCode(ModuleOp module, 1018 llvm::StringMap<PDLConstraintFunction> constraintFns, 1019 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 1020 Generator generator(module.getContext(), uniquedData, matcherByteCode, 1021 rewriterByteCode, patterns, maxValueMemoryIndex, 1022 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, 1023 maxLoopLevel, constraintFns, rewriteFns); 1024 generator.generate(module); 1025 1026 // Initialize the external functions. 1027 for (auto &it : constraintFns) 1028 constraintFunctions.push_back(std::move(it.second)); 1029 for (auto &it : rewriteFns) 1030 rewriteFunctions.push_back(std::move(it.second)); 1031 } 1032 1033 /// Initialize the given state such that it can be used to execute the current 1034 /// bytecode. 1035 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 1036 state.memory.resize(maxValueMemoryIndex, nullptr); 1037 state.opRangeMemory.resize(maxOpRangeCount); 1038 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 1039 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 1040 state.loopIndex.resize(maxLoopLevel, 0); 1041 state.currentPatternBenefits.reserve(patterns.size()); 1042 for (const PDLByteCodePattern &pattern : patterns) 1043 state.currentPatternBenefits.push_back(pattern.getBenefit()); 1044 } 1045 1046 //===----------------------------------------------------------------------===// 1047 // ByteCode Execution 1048 1049 namespace { 1050 /// This class provides support for executing a bytecode stream. 1051 class ByteCodeExecutor { 1052 public: 1053 ByteCodeExecutor( 1054 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 1055 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, 1056 MutableArrayRef<TypeRange> typeRangeMemory, 1057 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 1058 MutableArrayRef<ValueRange> valueRangeMemory, 1059 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 1060 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, 1061 ArrayRef<ByteCodeField> code, 1062 ArrayRef<PatternBenefit> currentPatternBenefits, 1063 ArrayRef<PDLByteCodePattern> patterns, 1064 ArrayRef<PDLConstraintFunction> constraintFunctions, 1065 ArrayRef<PDLRewriteFunction> rewriteFunctions) 1066 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), 1067 typeRangeMemory(typeRangeMemory), 1068 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 1069 valueRangeMemory(valueRangeMemory), 1070 allocatedValueRangeMemory(allocatedValueRangeMemory), 1071 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), 1072 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 1073 constraintFunctions(constraintFunctions), 1074 rewriteFunctions(rewriteFunctions) {} 1075 1076 /// Start executing the code at the current bytecode index. `matches` is an 1077 /// optional field provided when this function is executed in a matching 1078 /// context. 1079 void execute(PatternRewriter &rewriter, 1080 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 1081 Optional<Location> mainRewriteLoc = {}); 1082 1083 private: 1084 /// Internal implementation of executing each of the bytecode commands. 1085 void executeApplyConstraint(PatternRewriter &rewriter); 1086 void executeApplyRewrite(PatternRewriter &rewriter); 1087 void executeAreEqual(); 1088 void executeAreRangesEqual(); 1089 void executeBranch(); 1090 void executeCheckOperandCount(); 1091 void executeCheckOperationName(); 1092 void executeCheckResultCount(); 1093 void executeCheckTypes(); 1094 void executeContinue(); 1095 void executeCreateOperation(PatternRewriter &rewriter, 1096 Location mainRewriteLoc); 1097 void executeCreateTypes(); 1098 void executeEraseOp(PatternRewriter &rewriter); 1099 template <typename T, typename Range, PDLValue::Kind kind> 1100 void executeExtract(); 1101 void executeFinalize(); 1102 void executeForEach(); 1103 void executeGetAttribute(); 1104 void executeGetAttributeType(); 1105 void executeGetDefiningOp(); 1106 void executeGetOperand(unsigned index); 1107 void executeGetOperands(); 1108 void executeGetResult(unsigned index); 1109 void executeGetResults(); 1110 void executeGetUsers(); 1111 void executeGetValueType(); 1112 void executeGetValueRangeTypes(); 1113 void executeIsNotNull(); 1114 void executeRecordMatch(PatternRewriter &rewriter, 1115 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 1116 void executeReplaceOp(PatternRewriter &rewriter); 1117 void executeSwitchAttribute(); 1118 void executeSwitchOperandCount(); 1119 void executeSwitchOperationName(); 1120 void executeSwitchResultCount(); 1121 void executeSwitchType(); 1122 void executeSwitchTypes(); 1123 1124 /// Pushes a code iterator to the stack. 1125 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } 1126 1127 /// Pops a code iterator from the stack, returning true on success. 1128 void popCodeIt() { 1129 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); 1130 curCodeIt = resumeCodeIt.back(); 1131 resumeCodeIt.pop_back(); 1132 } 1133 1134 /// Return the bytecode iterator at the start of the current op code. 1135 const ByteCodeField *getPrevCodeIt() const { 1136 LLVM_DEBUG({ 1137 // Account for the op code and the Location stored inline. 1138 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField); 1139 }); 1140 1141 // Account for the op code only. 1142 return curCodeIt - 1; 1143 } 1144 1145 /// Read a value from the bytecode buffer, optionally skipping a certain 1146 /// number of prefix values. These methods always update the buffer to point 1147 /// to the next field after the read data. 1148 template <typename T = ByteCodeField> 1149 T read(size_t skipN = 0) { 1150 curCodeIt += skipN; 1151 return readImpl<T>(); 1152 } 1153 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 1154 1155 /// Read a list of values from the bytecode buffer. 1156 template <typename ValueT, typename T> 1157 void readList(SmallVectorImpl<T> &list) { 1158 list.clear(); 1159 for (unsigned i = 0, e = read(); i != e; ++i) 1160 list.push_back(read<ValueT>()); 1161 } 1162 1163 /// Read a list of values from the bytecode buffer. The values may be encoded 1164 /// as either Value or ValueRange elements. 1165 void readValueList(SmallVectorImpl<Value> &list) { 1166 for (unsigned i = 0, e = read(); i != e; ++i) { 1167 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1168 list.push_back(read<Value>()); 1169 } else { 1170 ValueRange *values = read<ValueRange *>(); 1171 list.append(values->begin(), values->end()); 1172 } 1173 } 1174 } 1175 1176 /// Read a value stored inline as a pointer. 1177 template <typename T> 1178 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T> 1179 readInline() { 1180 const void *pointer; 1181 std::memcpy(&pointer, curCodeIt, sizeof(const void *)); 1182 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); 1183 return T::getFromOpaquePointer(pointer); 1184 } 1185 1186 /// Jump to a specific successor based on a predicate value. 1187 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 1188 /// Jump to a specific successor based on a destination index. 1189 void selectJump(size_t destIndex) { 1190 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 1191 } 1192 1193 /// Handle a switch operation with the provided value and cases. 1194 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1195 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1196 LLVM_DEBUG({ 1197 llvm::dbgs() << " * Value: " << value << "\n" 1198 << " * Cases: "; 1199 llvm::interleaveComma(cases, llvm::dbgs()); 1200 llvm::dbgs() << "\n"; 1201 }); 1202 1203 // Check to see if the attribute value is within the case list. Jump to 1204 // the correct successor index based on the result. 1205 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1206 if (cmp(*it, value)) 1207 return selectJump(size_t((it - cases.begin()) + 1)); 1208 selectJump(size_t(0)); 1209 } 1210 1211 /// Store a pointer to memory. 1212 void storeToMemory(unsigned index, const void *value) { 1213 memory[index] = value; 1214 } 1215 1216 /// Store a value to memory as an opaque pointer. 1217 template <typename T> 1218 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 1219 storeToMemory(unsigned index, T value) { 1220 memory[index] = value.getAsOpaquePointer(); 1221 } 1222 1223 /// Internal implementation of reading various data types from the bytecode 1224 /// stream. 1225 template <typename T> 1226 const void *readFromMemory() { 1227 size_t index = *curCodeIt++; 1228 1229 // If this type is an SSA value, it can only be stored in non-const memory. 1230 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1231 Value>::value || 1232 index < memory.size()) 1233 return memory[index]; 1234 1235 // Otherwise, if this index is not inbounds it is uniqued. 1236 return uniquedMemory[index - memory.size()]; 1237 } 1238 template <typename T> 1239 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1240 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1241 } 1242 template <typename T> 1243 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1244 T> 1245 readImpl() { 1246 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1247 } 1248 template <typename T> 1249 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1250 switch (read<PDLValue::Kind>()) { 1251 case PDLValue::Kind::Attribute: 1252 return read<Attribute>(); 1253 case PDLValue::Kind::Operation: 1254 return read<Operation *>(); 1255 case PDLValue::Kind::Type: 1256 return read<Type>(); 1257 case PDLValue::Kind::Value: 1258 return read<Value>(); 1259 case PDLValue::Kind::TypeRange: 1260 return read<TypeRange *>(); 1261 case PDLValue::Kind::ValueRange: 1262 return read<ValueRange *>(); 1263 } 1264 llvm_unreachable("unhandled PDLValue::Kind"); 1265 } 1266 template <typename T> 1267 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1268 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1269 "unexpected ByteCode address size"); 1270 ByteCodeAddr result; 1271 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1272 curCodeIt += 2; 1273 return result; 1274 } 1275 template <typename T> 1276 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1277 return *curCodeIt++; 1278 } 1279 template <typename T> 1280 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1281 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1282 } 1283 1284 /// The underlying bytecode buffer. 1285 const ByteCodeField *curCodeIt; 1286 1287 /// The stack of bytecode positions at which to resume operation. 1288 SmallVector<const ByteCodeField *> resumeCodeIt; 1289 1290 /// The current execution memory. 1291 MutableArrayRef<const void *> memory; 1292 MutableArrayRef<OwningOpRange> opRangeMemory; 1293 MutableArrayRef<TypeRange> typeRangeMemory; 1294 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1295 MutableArrayRef<ValueRange> valueRangeMemory; 1296 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1297 1298 /// The current loop indices. 1299 MutableArrayRef<unsigned> loopIndex; 1300 1301 /// References to ByteCode data necessary for execution. 1302 ArrayRef<const void *> uniquedMemory; 1303 ArrayRef<ByteCodeField> code; 1304 ArrayRef<PatternBenefit> currentPatternBenefits; 1305 ArrayRef<PDLByteCodePattern> patterns; 1306 ArrayRef<PDLConstraintFunction> constraintFunctions; 1307 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1308 }; 1309 1310 /// This class is an instantiation of the PDLResultList that provides access to 1311 /// the returned results. This API is not on `PDLResultList` to avoid 1312 /// overexposing access to information specific solely to the ByteCode. 1313 class ByteCodeRewriteResultList : public PDLResultList { 1314 public: 1315 ByteCodeRewriteResultList(unsigned maxNumResults) 1316 : PDLResultList(maxNumResults) {} 1317 1318 /// Return the list of PDL results. 1319 MutableArrayRef<PDLValue> getResults() { return results; } 1320 1321 /// Return the type ranges allocated by this list. 1322 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1323 return allocatedTypeRanges; 1324 } 1325 1326 /// Return the value ranges allocated by this list. 1327 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1328 return allocatedValueRanges; 1329 } 1330 }; 1331 } // namespace 1332 1333 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1334 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1335 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1336 SmallVector<PDLValue, 16> args; 1337 readList<PDLValue>(args); 1338 1339 LLVM_DEBUG({ 1340 llvm::dbgs() << " * Arguments: "; 1341 llvm::interleaveComma(args, llvm::dbgs()); 1342 }); 1343 1344 // Invoke the constraint and jump to the proper destination. 1345 selectJump(succeeded(constraintFn(rewriter, args))); 1346 } 1347 1348 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1349 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1350 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1351 SmallVector<PDLValue, 16> args; 1352 readList<PDLValue>(args); 1353 1354 LLVM_DEBUG({ 1355 llvm::dbgs() << " * Arguments: "; 1356 llvm::interleaveComma(args, llvm::dbgs()); 1357 }); 1358 1359 // Execute the rewrite function. 1360 ByteCodeField numResults = read(); 1361 ByteCodeRewriteResultList results(numResults); 1362 rewriteFn(rewriter, results, args); 1363 1364 assert(results.getResults().size() == numResults && 1365 "native PDL rewrite function returned unexpected number of results"); 1366 1367 // Store the results in the bytecode memory. 1368 for (PDLValue &result : results.getResults()) { 1369 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1370 1371 // In debug mode we also verify the expected kind of the result. 1372 #ifndef NDEBUG 1373 assert(result.getKind() == read<PDLValue::Kind>() && 1374 "native PDL rewrite function returned an unexpected type of result"); 1375 #endif 1376 1377 // If the result is a range, we need to copy it over to the bytecodes 1378 // range memory. 1379 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1380 unsigned rangeIndex = read(); 1381 typeRangeMemory[rangeIndex] = *typeRange; 1382 memory[read()] = &typeRangeMemory[rangeIndex]; 1383 } else if (Optional<ValueRange> valueRange = 1384 result.dyn_cast<ValueRange>()) { 1385 unsigned rangeIndex = read(); 1386 valueRangeMemory[rangeIndex] = *valueRange; 1387 memory[read()] = &valueRangeMemory[rangeIndex]; 1388 } else { 1389 memory[read()] = result.getAsOpaquePointer(); 1390 } 1391 } 1392 1393 // Copy over any underlying storage allocated for result ranges. 1394 for (auto &it : results.getAllocatedTypeRanges()) 1395 allocatedTypeRangeMemory.push_back(std::move(it)); 1396 for (auto &it : results.getAllocatedValueRanges()) 1397 allocatedValueRangeMemory.push_back(std::move(it)); 1398 } 1399 1400 void ByteCodeExecutor::executeAreEqual() { 1401 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1402 const void *lhs = read<const void *>(); 1403 const void *rhs = read<const void *>(); 1404 1405 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1406 selectJump(lhs == rhs); 1407 } 1408 1409 void ByteCodeExecutor::executeAreRangesEqual() { 1410 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1411 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1412 const void *lhs = read<const void *>(); 1413 const void *rhs = read<const void *>(); 1414 1415 switch (valueKind) { 1416 case PDLValue::Kind::TypeRange: { 1417 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1418 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1419 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1420 selectJump(*lhsRange == *rhsRange); 1421 break; 1422 } 1423 case PDLValue::Kind::ValueRange: { 1424 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1425 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1426 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1427 selectJump(*lhsRange == *rhsRange); 1428 break; 1429 } 1430 default: 1431 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1432 } 1433 } 1434 1435 void ByteCodeExecutor::executeBranch() { 1436 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1437 curCodeIt = &code[read<ByteCodeAddr>()]; 1438 } 1439 1440 void ByteCodeExecutor::executeCheckOperandCount() { 1441 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1442 Operation *op = read<Operation *>(); 1443 uint32_t expectedCount = read<uint32_t>(); 1444 bool compareAtLeast = read(); 1445 1446 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1447 << " * Expected: " << expectedCount << "\n" 1448 << " * Comparator: " 1449 << (compareAtLeast ? ">=" : "==") << "\n"); 1450 if (compareAtLeast) 1451 selectJump(op->getNumOperands() >= expectedCount); 1452 else 1453 selectJump(op->getNumOperands() == expectedCount); 1454 } 1455 1456 void ByteCodeExecutor::executeCheckOperationName() { 1457 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1458 Operation *op = read<Operation *>(); 1459 OperationName expectedName = read<OperationName>(); 1460 1461 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1462 << " * Expected: \"" << expectedName << "\"\n"); 1463 selectJump(op->getName() == expectedName); 1464 } 1465 1466 void ByteCodeExecutor::executeCheckResultCount() { 1467 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1468 Operation *op = read<Operation *>(); 1469 uint32_t expectedCount = read<uint32_t>(); 1470 bool compareAtLeast = read(); 1471 1472 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1473 << " * Expected: " << expectedCount << "\n" 1474 << " * Comparator: " 1475 << (compareAtLeast ? ">=" : "==") << "\n"); 1476 if (compareAtLeast) 1477 selectJump(op->getNumResults() >= expectedCount); 1478 else 1479 selectJump(op->getNumResults() == expectedCount); 1480 } 1481 1482 void ByteCodeExecutor::executeCheckTypes() { 1483 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1484 TypeRange *lhs = read<TypeRange *>(); 1485 Attribute rhs = read<Attribute>(); 1486 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1487 1488 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1489 } 1490 1491 void ByteCodeExecutor::executeContinue() { 1492 ByteCodeField level = read(); 1493 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" 1494 << " * Level: " << level << "\n"); 1495 ++loopIndex[level]; 1496 popCodeIt(); 1497 } 1498 1499 void ByteCodeExecutor::executeCreateTypes() { 1500 LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 1501 unsigned memIndex = read(); 1502 unsigned rangeIndex = read(); 1503 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1504 1505 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1506 1507 // Allocate a buffer for this type range. 1508 llvm::OwningArrayRef<Type> storage(typesAttr.size()); 1509 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 1510 allocatedTypeRangeMemory.emplace_back(std::move(storage)); 1511 1512 // Assign this to the range slot and use the range as the value for the 1513 // memory index. 1514 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 1515 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1516 } 1517 1518 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1519 Location mainRewriteLoc) { 1520 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1521 1522 unsigned memIndex = read(); 1523 OperationState state(mainRewriteLoc, read<OperationName>()); 1524 readValueList(state.operands); 1525 for (unsigned i = 0, e = read(); i != e; ++i) { 1526 StringAttr name = read<StringAttr>(); 1527 if (Attribute attr = read<Attribute>()) 1528 state.addAttribute(name, attr); 1529 } 1530 1531 // Read in the result types. If the "size" is the sentinel value, this 1532 // indicates that the result types should be inferred. 1533 unsigned numResults = read(); 1534 if (numResults == kInferTypesMarker) { 1535 InferTypeOpInterface::Concept *inferInterface = 1536 state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>(); 1537 assert(inferInterface && 1538 "expected operation to provide InferTypeOpInterface"); 1539 1540 // TODO: Handle failure. 1541 if (failed(inferInterface->inferReturnTypes( 1542 state.getContext(), state.location, state.operands, 1543 state.attributes.getDictionary(state.getContext()), state.regions, 1544 state.types))) 1545 return; 1546 } else { 1547 // Otherwise, this is a fixed number of results. 1548 for (unsigned i = 0; i != numResults; ++i) { 1549 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1550 state.types.push_back(read<Type>()); 1551 } else { 1552 TypeRange *resultTypes = read<TypeRange *>(); 1553 state.types.append(resultTypes->begin(), resultTypes->end()); 1554 } 1555 } 1556 } 1557 1558 Operation *resultOp = rewriter.create(state); 1559 memory[memIndex] = resultOp; 1560 1561 LLVM_DEBUG({ 1562 llvm::dbgs() << " * Attributes: " 1563 << state.attributes.getDictionary(state.getContext()) 1564 << "\n * Operands: "; 1565 llvm::interleaveComma(state.operands, llvm::dbgs()); 1566 llvm::dbgs() << "\n * Result Types: "; 1567 llvm::interleaveComma(state.types, llvm::dbgs()); 1568 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1569 }); 1570 } 1571 1572 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1573 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1574 Operation *op = read<Operation *>(); 1575 1576 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1577 rewriter.eraseOp(op); 1578 } 1579 1580 template <typename T, typename Range, PDLValue::Kind kind> 1581 void ByteCodeExecutor::executeExtract() { 1582 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); 1583 Range *range = read<Range *>(); 1584 unsigned index = read<uint32_t>(); 1585 unsigned memIndex = read(); 1586 1587 if (!range) { 1588 memory[memIndex] = nullptr; 1589 return; 1590 } 1591 1592 T result = index < range->size() ? (*range)[index] : T(); 1593 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" 1594 << " * Index: " << index << "\n" 1595 << " * Result: " << result << "\n"); 1596 storeToMemory(memIndex, result); 1597 } 1598 1599 void ByteCodeExecutor::executeFinalize() { 1600 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); 1601 } 1602 1603 void ByteCodeExecutor::executeForEach() { 1604 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); 1605 const ByteCodeField *prevCodeIt = getPrevCodeIt(); 1606 unsigned rangeIndex = read(); 1607 unsigned memIndex = read(); 1608 const void *value = nullptr; 1609 1610 switch (read<PDLValue::Kind>()) { 1611 case PDLValue::Kind::Operation: { 1612 unsigned &index = loopIndex[read()]; 1613 ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; 1614 assert(index <= array.size() && "iterated past the end"); 1615 if (index < array.size()) { 1616 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); 1617 value = array[index]; 1618 break; 1619 } 1620 1621 LLVM_DEBUG(llvm::dbgs() << " * Done\n"); 1622 index = 0; 1623 selectJump(size_t(0)); 1624 return; 1625 } 1626 default: 1627 llvm_unreachable("unexpected `ForEach` value kind"); 1628 } 1629 1630 // Store the iterate value and the stack address. 1631 memory[memIndex] = value; 1632 pushCodeIt(prevCodeIt); 1633 1634 // Skip over the successor (we will enter the body of the loop). 1635 read<ByteCodeAddr>(); 1636 } 1637 1638 void ByteCodeExecutor::executeGetAttribute() { 1639 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1640 unsigned memIndex = read(); 1641 Operation *op = read<Operation *>(); 1642 StringAttr attrName = read<StringAttr>(); 1643 Attribute attr = op->getAttr(attrName); 1644 1645 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1646 << " * Attribute: " << attrName << "\n" 1647 << " * Result: " << attr << "\n"); 1648 memory[memIndex] = attr.getAsOpaquePointer(); 1649 } 1650 1651 void ByteCodeExecutor::executeGetAttributeType() { 1652 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1653 unsigned memIndex = read(); 1654 Attribute attr = read<Attribute>(); 1655 Type type = attr ? attr.getType() : Type(); 1656 1657 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1658 << " * Result: " << type << "\n"); 1659 memory[memIndex] = type.getAsOpaquePointer(); 1660 } 1661 1662 void ByteCodeExecutor::executeGetDefiningOp() { 1663 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1664 unsigned memIndex = read(); 1665 Operation *op = nullptr; 1666 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1667 Value value = read<Value>(); 1668 if (value) 1669 op = value.getDefiningOp(); 1670 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1671 } else { 1672 ValueRange *values = read<ValueRange *>(); 1673 if (values && !values->empty()) { 1674 op = values->front().getDefiningOp(); 1675 } 1676 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1677 } 1678 1679 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1680 memory[memIndex] = op; 1681 } 1682 1683 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1684 Operation *op = read<Operation *>(); 1685 unsigned memIndex = read(); 1686 Value operand = 1687 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1688 1689 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1690 << " * Index: " << index << "\n" 1691 << " * Result: " << operand << "\n"); 1692 memory[memIndex] = operand.getAsOpaquePointer(); 1693 } 1694 1695 /// This function is the internal implementation of `GetResults` and 1696 /// `GetOperands` that provides support for extracting a value range from the 1697 /// given operation. 1698 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1699 static void * 1700 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1701 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1702 MutableArrayRef<ValueRange> valueRangeMemory) { 1703 // Check for the sentinel index that signals that all values should be 1704 // returned. 1705 if (index == std::numeric_limits<uint32_t>::max()) { 1706 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1707 // `values` is already the full value range. 1708 1709 // Otherwise, check to see if this operation uses AttrSizedSegments. 1710 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1711 LLVM_DEBUG(llvm::dbgs() 1712 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1713 1714 auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 1715 if (!segmentAttr || segmentAttr.getNumElements() <= index) 1716 return nullptr; 1717 1718 auto segments = segmentAttr.getValues<int32_t>(); 1719 unsigned startIndex = 1720 std::accumulate(segments.begin(), segments.begin() + index, 0); 1721 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1722 1723 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1724 << *std::next(segments.begin(), index) << "]\n"); 1725 1726 // Otherwise, assume this is the last operand group of the operation. 1727 // FIXME: We currently don't support operations with 1728 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1729 // have a way to detect it's presence. 1730 } else if (values.size() >= index) { 1731 LLVM_DEBUG(llvm::dbgs() 1732 << " * Treating values as trailing variadic range\n"); 1733 values = values.drop_front(index); 1734 1735 // If we couldn't detect a way to compute the values, bail out. 1736 } else { 1737 return nullptr; 1738 } 1739 1740 // If the range index is valid, we are returning a range. 1741 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1742 valueRangeMemory[rangeIndex] = values; 1743 return &valueRangeMemory[rangeIndex]; 1744 } 1745 1746 // If a range index wasn't provided, the range is required to be non-variadic. 1747 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1748 } 1749 1750 void ByteCodeExecutor::executeGetOperands() { 1751 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1752 unsigned index = read<uint32_t>(); 1753 Operation *op = read<Operation *>(); 1754 ByteCodeField rangeIndex = read(); 1755 1756 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1757 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1758 valueRangeMemory); 1759 if (!result) 1760 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1761 memory[read()] = result; 1762 } 1763 1764 void ByteCodeExecutor::executeGetResult(unsigned index) { 1765 Operation *op = read<Operation *>(); 1766 unsigned memIndex = read(); 1767 OpResult result = 1768 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1769 1770 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1771 << " * Index: " << index << "\n" 1772 << " * Result: " << result << "\n"); 1773 memory[memIndex] = result.getAsOpaquePointer(); 1774 } 1775 1776 void ByteCodeExecutor::executeGetResults() { 1777 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1778 unsigned index = read<uint32_t>(); 1779 Operation *op = read<Operation *>(); 1780 ByteCodeField rangeIndex = read(); 1781 1782 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1783 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1784 valueRangeMemory); 1785 if (!result) 1786 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1787 memory[read()] = result; 1788 } 1789 1790 void ByteCodeExecutor::executeGetUsers() { 1791 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); 1792 unsigned memIndex = read(); 1793 unsigned rangeIndex = read(); 1794 OwningOpRange &range = opRangeMemory[rangeIndex]; 1795 memory[memIndex] = ⦥ 1796 1797 range = OwningOpRange(); 1798 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1799 // Read the value. 1800 Value value = read<Value>(); 1801 if (!value) 1802 return; 1803 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1804 1805 // Extract the users of a single value. 1806 range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); 1807 llvm::copy(value.getUsers(), range.begin()); 1808 } else { 1809 // Read a range of values. 1810 ValueRange *values = read<ValueRange *>(); 1811 if (!values) 1812 return; 1813 LLVM_DEBUG({ 1814 llvm::dbgs() << " * Values (" << values->size() << "): "; 1815 llvm::interleaveComma(*values, llvm::dbgs()); 1816 llvm::dbgs() << "\n"; 1817 }); 1818 1819 // Extract all the users of a range of values. 1820 SmallVector<Operation *> users; 1821 for (Value value : *values) 1822 users.append(value.user_begin(), value.user_end()); 1823 range = OwningOpRange(users.size()); 1824 llvm::copy(users, range.begin()); 1825 } 1826 1827 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); 1828 } 1829 1830 void ByteCodeExecutor::executeGetValueType() { 1831 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1832 unsigned memIndex = read(); 1833 Value value = read<Value>(); 1834 Type type = value ? value.getType() : Type(); 1835 1836 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1837 << " * Result: " << type << "\n"); 1838 memory[memIndex] = type.getAsOpaquePointer(); 1839 } 1840 1841 void ByteCodeExecutor::executeGetValueRangeTypes() { 1842 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1843 unsigned memIndex = read(); 1844 unsigned rangeIndex = read(); 1845 ValueRange *values = read<ValueRange *>(); 1846 if (!values) { 1847 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1848 memory[memIndex] = nullptr; 1849 return; 1850 } 1851 1852 LLVM_DEBUG({ 1853 llvm::dbgs() << " * Values (" << values->size() << "): "; 1854 llvm::interleaveComma(*values, llvm::dbgs()); 1855 llvm::dbgs() << "\n * Result: "; 1856 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1857 llvm::dbgs() << "\n"; 1858 }); 1859 typeRangeMemory[rangeIndex] = values->getType(); 1860 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1861 } 1862 1863 void ByteCodeExecutor::executeIsNotNull() { 1864 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1865 const void *value = read<const void *>(); 1866 1867 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1868 selectJump(value != nullptr); 1869 } 1870 1871 void ByteCodeExecutor::executeRecordMatch( 1872 PatternRewriter &rewriter, 1873 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1874 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1875 unsigned patternIndex = read(); 1876 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1877 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1878 1879 // If the benefit of the pattern is impossible, skip the processing of the 1880 // rest of the pattern. 1881 if (benefit.isImpossibleToMatch()) { 1882 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1883 curCodeIt = dest; 1884 return; 1885 } 1886 1887 // Create a fused location containing the locations of each of the 1888 // operations used in the match. This will be used as the location for 1889 // created operations during the rewrite that don't already have an 1890 // explicit location set. 1891 unsigned numMatchLocs = read(); 1892 SmallVector<Location, 4> matchLocs; 1893 matchLocs.reserve(numMatchLocs); 1894 for (unsigned i = 0; i != numMatchLocs; ++i) 1895 matchLocs.push_back(read<Operation *>()->getLoc()); 1896 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1897 1898 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1899 << " * Location: " << matchLoc << "\n"); 1900 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1901 PDLByteCode::MatchResult &match = matches.back(); 1902 1903 // Record all of the inputs to the match. If any of the inputs are ranges, we 1904 // will also need to remap the range pointer to memory stored in the match 1905 // state. 1906 unsigned numInputs = read(); 1907 match.values.reserve(numInputs); 1908 match.typeRangeValues.reserve(numInputs); 1909 match.valueRangeValues.reserve(numInputs); 1910 for (unsigned i = 0; i < numInputs; ++i) { 1911 switch (read<PDLValue::Kind>()) { 1912 case PDLValue::Kind::TypeRange: 1913 match.typeRangeValues.push_back(*read<TypeRange *>()); 1914 match.values.push_back(&match.typeRangeValues.back()); 1915 break; 1916 case PDLValue::Kind::ValueRange: 1917 match.valueRangeValues.push_back(*read<ValueRange *>()); 1918 match.values.push_back(&match.valueRangeValues.back()); 1919 break; 1920 default: 1921 match.values.push_back(read<const void *>()); 1922 break; 1923 } 1924 } 1925 curCodeIt = dest; 1926 } 1927 1928 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1929 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1930 Operation *op = read<Operation *>(); 1931 SmallVector<Value, 16> args; 1932 readValueList(args); 1933 1934 LLVM_DEBUG({ 1935 llvm::dbgs() << " * Operation: " << *op << "\n" 1936 << " * Values: "; 1937 llvm::interleaveComma(args, llvm::dbgs()); 1938 llvm::dbgs() << "\n"; 1939 }); 1940 rewriter.replaceOp(op, args); 1941 } 1942 1943 void ByteCodeExecutor::executeSwitchAttribute() { 1944 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1945 Attribute value = read<Attribute>(); 1946 ArrayAttr cases = read<ArrayAttr>(); 1947 handleSwitch(value, cases); 1948 } 1949 1950 void ByteCodeExecutor::executeSwitchOperandCount() { 1951 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1952 Operation *op = read<Operation *>(); 1953 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1954 1955 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1956 handleSwitch(op->getNumOperands(), cases); 1957 } 1958 1959 void ByteCodeExecutor::executeSwitchOperationName() { 1960 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1961 OperationName value = read<Operation *>()->getName(); 1962 size_t caseCount = read(); 1963 1964 // The operation names are stored in-line, so to print them out for 1965 // debugging purposes we need to read the array before executing the 1966 // switch so that we can display all of the possible values. 1967 LLVM_DEBUG({ 1968 const ByteCodeField *prevCodeIt = curCodeIt; 1969 llvm::dbgs() << " * Value: " << value << "\n" 1970 << " * Cases: "; 1971 llvm::interleaveComma( 1972 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1973 [&](size_t) { return read<OperationName>(); }), 1974 llvm::dbgs()); 1975 llvm::dbgs() << "\n"; 1976 curCodeIt = prevCodeIt; 1977 }); 1978 1979 // Try to find the switch value within any of the cases. 1980 for (size_t i = 0; i != caseCount; ++i) { 1981 if (read<OperationName>() == value) { 1982 curCodeIt += (caseCount - i - 1); 1983 return selectJump(i + 1); 1984 } 1985 } 1986 selectJump(size_t(0)); 1987 } 1988 1989 void ByteCodeExecutor::executeSwitchResultCount() { 1990 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1991 Operation *op = read<Operation *>(); 1992 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1993 1994 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1995 handleSwitch(op->getNumResults(), cases); 1996 } 1997 1998 void ByteCodeExecutor::executeSwitchType() { 1999 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 2000 Type value = read<Type>(); 2001 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 2002 handleSwitch(value, cases); 2003 } 2004 2005 void ByteCodeExecutor::executeSwitchTypes() { 2006 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 2007 TypeRange *value = read<TypeRange *>(); 2008 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 2009 if (!value) { 2010 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 2011 return selectJump(size_t(0)); 2012 } 2013 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 2014 return value == caseValue.getAsValueRange<TypeAttr>(); 2015 }); 2016 } 2017 2018 void ByteCodeExecutor::execute( 2019 PatternRewriter &rewriter, 2020 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 2021 Optional<Location> mainRewriteLoc) { 2022 while (true) { 2023 // Print the location of the operation being executed. 2024 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); 2025 2026 OpCode opCode = static_cast<OpCode>(read()); 2027 switch (opCode) { 2028 case ApplyConstraint: 2029 executeApplyConstraint(rewriter); 2030 break; 2031 case ApplyRewrite: 2032 executeApplyRewrite(rewriter); 2033 break; 2034 case AreEqual: 2035 executeAreEqual(); 2036 break; 2037 case AreRangesEqual: 2038 executeAreRangesEqual(); 2039 break; 2040 case Branch: 2041 executeBranch(); 2042 break; 2043 case CheckOperandCount: 2044 executeCheckOperandCount(); 2045 break; 2046 case CheckOperationName: 2047 executeCheckOperationName(); 2048 break; 2049 case CheckResultCount: 2050 executeCheckResultCount(); 2051 break; 2052 case CheckTypes: 2053 executeCheckTypes(); 2054 break; 2055 case Continue: 2056 executeContinue(); 2057 break; 2058 case CreateOperation: 2059 executeCreateOperation(rewriter, *mainRewriteLoc); 2060 break; 2061 case CreateTypes: 2062 executeCreateTypes(); 2063 break; 2064 case EraseOp: 2065 executeEraseOp(rewriter); 2066 break; 2067 case ExtractOp: 2068 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); 2069 break; 2070 case ExtractType: 2071 executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); 2072 break; 2073 case ExtractValue: 2074 executeExtract<Value, ValueRange, PDLValue::Kind::Value>(); 2075 break; 2076 case Finalize: 2077 executeFinalize(); 2078 LLVM_DEBUG(llvm::dbgs() << "\n"); 2079 return; 2080 case ForEach: 2081 executeForEach(); 2082 break; 2083 case GetAttribute: 2084 executeGetAttribute(); 2085 break; 2086 case GetAttributeType: 2087 executeGetAttributeType(); 2088 break; 2089 case GetDefiningOp: 2090 executeGetDefiningOp(); 2091 break; 2092 case GetOperand0: 2093 case GetOperand1: 2094 case GetOperand2: 2095 case GetOperand3: { 2096 unsigned index = opCode - GetOperand0; 2097 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 2098 executeGetOperand(index); 2099 break; 2100 } 2101 case GetOperandN: 2102 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 2103 executeGetOperand(read<uint32_t>()); 2104 break; 2105 case GetOperands: 2106 executeGetOperands(); 2107 break; 2108 case GetResult0: 2109 case GetResult1: 2110 case GetResult2: 2111 case GetResult3: { 2112 unsigned index = opCode - GetResult0; 2113 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 2114 executeGetResult(index); 2115 break; 2116 } 2117 case GetResultN: 2118 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 2119 executeGetResult(read<uint32_t>()); 2120 break; 2121 case GetResults: 2122 executeGetResults(); 2123 break; 2124 case GetUsers: 2125 executeGetUsers(); 2126 break; 2127 case GetValueType: 2128 executeGetValueType(); 2129 break; 2130 case GetValueRangeTypes: 2131 executeGetValueRangeTypes(); 2132 break; 2133 case IsNotNull: 2134 executeIsNotNull(); 2135 break; 2136 case RecordMatch: 2137 assert(matches && 2138 "expected matches to be provided when executing the matcher"); 2139 executeRecordMatch(rewriter, *matches); 2140 break; 2141 case ReplaceOp: 2142 executeReplaceOp(rewriter); 2143 break; 2144 case SwitchAttribute: 2145 executeSwitchAttribute(); 2146 break; 2147 case SwitchOperandCount: 2148 executeSwitchOperandCount(); 2149 break; 2150 case SwitchOperationName: 2151 executeSwitchOperationName(); 2152 break; 2153 case SwitchResultCount: 2154 executeSwitchResultCount(); 2155 break; 2156 case SwitchType: 2157 executeSwitchType(); 2158 break; 2159 case SwitchTypes: 2160 executeSwitchTypes(); 2161 break; 2162 } 2163 LLVM_DEBUG(llvm::dbgs() << "\n"); 2164 } 2165 } 2166 2167 /// Run the pattern matcher on the given root operation, collecting the matched 2168 /// patterns in `matches`. 2169 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 2170 SmallVectorImpl<MatchResult> &matches, 2171 PDLByteCodeMutableState &state) const { 2172 // The first memory slot is always the root operation. 2173 state.memory[0] = op; 2174 2175 // The matcher function always starts at code address 0. 2176 ByteCodeExecutor executor( 2177 matcherByteCode.data(), state.memory, state.opRangeMemory, 2178 state.typeRangeMemory, state.allocatedTypeRangeMemory, 2179 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, 2180 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, 2181 constraintFunctions, rewriteFunctions); 2182 executor.execute(rewriter, &matches); 2183 2184 // Order the found matches by benefit. 2185 std::stable_sort(matches.begin(), matches.end(), 2186 [](const MatchResult &lhs, const MatchResult &rhs) { 2187 return lhs.benefit > rhs.benefit; 2188 }); 2189 } 2190 2191 /// Run the rewriter of the given pattern on the root operation `op`. 2192 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 2193 PDLByteCodeMutableState &state) const { 2194 // The arguments of the rewrite function are stored at the start of the 2195 // memory buffer. 2196 llvm::copy(match.values, state.memory.begin()); 2197 2198 ByteCodeExecutor executor( 2199 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 2200 state.opRangeMemory, state.typeRangeMemory, 2201 state.allocatedTypeRangeMemory, state.valueRangeMemory, 2202 state.allocatedValueRangeMemory, state.loopIndex, uniquedData, 2203 rewriterByteCode, state.currentPatternBenefits, patterns, 2204 constraintFunctions, rewriteFunctions); 2205 executor.execute(rewriter, /*matches=*/nullptr, match.location); 2206 } 2207