1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 27 #define DEBUG_TYPE "pdl-bytecode" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37 ByteCodeAddr rewriterAddr) { 38 SmallVector<StringRef, 8> generatedOps; 39 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) 40 generatedOps = 41 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42 43 PatternBenefit benefit = matchOp.getBenefit(); 44 MLIRContext *ctx = matchOp.getContext(); 45 46 // Check to see if this is pattern matches a specific operation type. 47 if (Optional<StringRef> rootKind = matchOp.getRootKind()) 48 return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, 49 generatedOps); 50 return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, 51 generatedOps); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // PDLByteCodeMutableState 56 //===----------------------------------------------------------------------===// 57 58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59 /// to the position of the pattern within the range returned by 60 /// `PDLByteCode::getPatterns`. 61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62 PatternBenefit benefit) { 63 currentPatternBenefits[patternIndex] = benefit; 64 } 65 66 /// Cleanup any allocated state after a full match/rewrite has been completed. 67 /// This method should be called irregardless of whether the match+rewrite was a 68 /// success or not. 69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 70 allocatedTypeRangeMemory.clear(); 71 allocatedValueRangeMemory.clear(); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Bytecode OpCodes 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 enum OpCode : ByteCodeField { 80 /// Apply an externally registered constraint. 81 ApplyConstraint, 82 /// Apply an externally registered rewrite. 83 ApplyRewrite, 84 /// Check if two generic values are equal. 85 AreEqual, 86 /// Check if two ranges are equal. 87 AreRangesEqual, 88 /// Unconditional branch. 89 Branch, 90 /// Compare the operand count of an operation with a constant. 91 CheckOperandCount, 92 /// Compare the name of an operation with a constant. 93 CheckOperationName, 94 /// Compare the result count of an operation with a constant. 95 CheckResultCount, 96 /// Compare a range of types to a constant range of types. 97 CheckTypes, 98 /// Continue to the next iteration of a loop. 99 Continue, 100 /// Create an operation. 101 CreateOperation, 102 /// Create a range of types. 103 CreateTypes, 104 /// Erase an operation. 105 EraseOp, 106 /// Extract the op from a range at the specified index. 107 ExtractOp, 108 /// Extract the type from a range at the specified index. 109 ExtractType, 110 /// Extract the value from a range at the specified index. 111 ExtractValue, 112 /// Terminate a matcher or rewrite sequence. 113 Finalize, 114 /// Iterate over a range of values. 115 ForEach, 116 /// Get a specific attribute of an operation. 117 GetAttribute, 118 /// Get the type of an attribute. 119 GetAttributeType, 120 /// Get the defining operation of a value. 121 GetDefiningOp, 122 /// Get a specific operand of an operation. 123 GetOperand0, 124 GetOperand1, 125 GetOperand2, 126 GetOperand3, 127 GetOperandN, 128 /// Get a specific operand group of an operation. 129 GetOperands, 130 /// Get a specific result of an operation. 131 GetResult0, 132 GetResult1, 133 GetResult2, 134 GetResult3, 135 GetResultN, 136 /// Get a specific result group of an operation. 137 GetResults, 138 /// Get the users of a value or a range of values. 139 GetUsers, 140 /// Get the type of a value. 141 GetValueType, 142 /// Get the types of a value range. 143 GetValueRangeTypes, 144 /// Check if a generic value is not null. 145 IsNotNull, 146 /// Record a successful pattern match. 147 RecordMatch, 148 /// Replace an operation. 149 ReplaceOp, 150 /// Compare an attribute with a set of constants. 151 SwitchAttribute, 152 /// Compare the operand count of an operation with a set of constants. 153 SwitchOperandCount, 154 /// Compare the name of an operation with a set of constants. 155 SwitchOperationName, 156 /// Compare the result count of an operation with a set of constants. 157 SwitchResultCount, 158 /// Compare a type with a set of constants. 159 SwitchType, 160 /// Compare a range of types with a set of constants. 161 SwitchTypes, 162 }; 163 } // namespace 164 165 //===----------------------------------------------------------------------===// 166 // ByteCode Generation 167 //===----------------------------------------------------------------------===// 168 169 //===----------------------------------------------------------------------===// 170 // Generator 171 172 namespace { 173 struct ByteCodeLiveRange; 174 struct ByteCodeWriter; 175 176 /// Check if the given class `T` can be converted to an opaque pointer. 177 template <typename T, typename... Args> 178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 179 180 /// This class represents the main generator for the pattern bytecode. 181 class Generator { 182 public: 183 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 184 SmallVectorImpl<ByteCodeField> &matcherByteCode, 185 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 186 SmallVectorImpl<PDLByteCodePattern> &patterns, 187 ByteCodeField &maxValueMemoryIndex, 188 ByteCodeField &maxOpRangeMemoryIndex, 189 ByteCodeField &maxTypeRangeMemoryIndex, 190 ByteCodeField &maxValueRangeMemoryIndex, 191 ByteCodeField &maxLoopLevel, 192 llvm::StringMap<PDLConstraintFunction> &constraintFns, 193 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 194 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 195 rewriterByteCode(rewriterByteCode), patterns(patterns), 196 maxValueMemoryIndex(maxValueMemoryIndex), 197 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), 198 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 199 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), 200 maxLoopLevel(maxLoopLevel) { 201 for (const auto &it : llvm::enumerate(constraintFns)) 202 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 203 for (const auto &it : llvm::enumerate(rewriteFns)) 204 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 205 } 206 207 /// Generate the bytecode for the given PDL interpreter module. 208 void generate(ModuleOp module); 209 210 /// Return the memory index to use for the given value. 211 ByteCodeField &getMemIndex(Value value) { 212 assert(valueToMemIndex.count(value) && 213 "expected memory index to be assigned"); 214 return valueToMemIndex[value]; 215 } 216 217 /// Return the range memory index used to store the given range value. 218 ByteCodeField &getRangeStorageIndex(Value value) { 219 assert(valueToRangeIndex.count(value) && 220 "expected range index to be assigned"); 221 return valueToRangeIndex[value]; 222 } 223 224 /// Return an index to use when referring to the given data that is uniqued in 225 /// the MLIR context. 226 template <typename T> 227 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 228 getMemIndex(T val) { 229 const void *opaqueVal = val.getAsOpaquePointer(); 230 231 // Get or insert a reference to this value. 232 auto it = uniquedDataToMemIndex.try_emplace( 233 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 234 if (it.second) 235 uniquedData.push_back(opaqueVal); 236 return it.first->second; 237 } 238 239 private: 240 /// Allocate memory indices for the results of operations within the matcher 241 /// and rewriters. 242 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 243 ModuleOp rewriterModule); 244 245 /// Generate the bytecode for the given operation. 246 void generate(Region *region, ByteCodeWriter &writer); 247 void generate(Operation *op, ByteCodeWriter &writer); 248 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 249 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 250 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 251 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 252 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 253 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 254 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 255 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 256 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 257 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 258 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); 259 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 260 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 261 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 262 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 263 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 264 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); 265 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 266 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); 267 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 268 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 269 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 270 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 271 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 272 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 273 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 274 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); 275 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 276 void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); 277 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 278 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 279 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 280 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 281 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 282 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 283 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 284 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 285 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 286 287 /// Mapping from value to its corresponding memory index. 288 DenseMap<Value, ByteCodeField> valueToMemIndex; 289 290 /// Mapping from a range value to its corresponding range storage index. 291 DenseMap<Value, ByteCodeField> valueToRangeIndex; 292 293 /// Mapping from the name of an externally registered rewrite to its index in 294 /// the bytecode registry. 295 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 296 297 /// Mapping from the name of an externally registered constraint to its index 298 /// in the bytecode registry. 299 llvm::StringMap<ByteCodeField> constraintToMemIndex; 300 301 /// Mapping from rewriter function name to the bytecode address of the 302 /// rewriter function in byte. 303 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 304 305 /// Mapping from a uniqued storage object to its memory index within 306 /// `uniquedData`. 307 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 308 309 /// The current level of the foreach loop. 310 ByteCodeField curLoopLevel = 0; 311 312 /// The current MLIR context. 313 MLIRContext *ctx; 314 315 /// Mapping from block to its address. 316 DenseMap<Block *, ByteCodeAddr> blockToAddr; 317 318 /// Data of the ByteCode class to be populated. 319 std::vector<const void *> &uniquedData; 320 SmallVectorImpl<ByteCodeField> &matcherByteCode; 321 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 322 SmallVectorImpl<PDLByteCodePattern> &patterns; 323 ByteCodeField &maxValueMemoryIndex; 324 ByteCodeField &maxOpRangeMemoryIndex; 325 ByteCodeField &maxTypeRangeMemoryIndex; 326 ByteCodeField &maxValueRangeMemoryIndex; 327 ByteCodeField &maxLoopLevel; 328 }; 329 330 /// This class provides utilities for writing a bytecode stream. 331 struct ByteCodeWriter { 332 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 333 : bytecode(bytecode), generator(generator) {} 334 335 /// Append a field to the bytecode. 336 void append(ByteCodeField field) { bytecode.push_back(field); } 337 void append(OpCode opCode) { bytecode.push_back(opCode); } 338 339 /// Append an address to the bytecode. 340 void append(ByteCodeAddr field) { 341 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 342 "unexpected ByteCode address size"); 343 344 ByteCodeField fieldParts[2]; 345 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 346 bytecode.append({fieldParts[0], fieldParts[1]}); 347 } 348 349 /// Append a single successor to the bytecode, the exact address will need to 350 /// be resolved later. 351 void append(Block *successor) { 352 // Add back a reference to the successor so that the address can be resolved 353 // later. 354 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 355 append(ByteCodeAddr(0)); 356 } 357 358 /// Append a successor range to the bytecode, the exact address will need to 359 /// be resolved later. 360 void append(SuccessorRange successors) { 361 for (Block *successor : successors) 362 append(successor); 363 } 364 365 /// Append a range of values that will be read as generic PDLValues. 366 void appendPDLValueList(OperandRange values) { 367 bytecode.push_back(values.size()); 368 for (Value value : values) 369 appendPDLValue(value); 370 } 371 372 /// Append a value as a PDLValue. 373 void appendPDLValue(Value value) { 374 appendPDLValueKind(value); 375 append(value); 376 } 377 378 /// Append the PDLValue::Kind of the given value. 379 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } 380 381 /// Append the PDLValue::Kind of the given type. 382 void appendPDLValueKind(Type type) { 383 PDLValue::Kind kind = 384 TypeSwitch<Type, PDLValue::Kind>(type) 385 .Case<pdl::AttributeType>( 386 [](Type) { return PDLValue::Kind::Attribute; }) 387 .Case<pdl::OperationType>( 388 [](Type) { return PDLValue::Kind::Operation; }) 389 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 390 if (rangeTy.getElementType().isa<pdl::TypeType>()) 391 return PDLValue::Kind::TypeRange; 392 return PDLValue::Kind::ValueRange; 393 }) 394 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 395 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 396 bytecode.push_back(static_cast<ByteCodeField>(kind)); 397 } 398 399 /// Append a value that will be stored in a memory slot and not inline within 400 /// the bytecode. 401 template <typename T> 402 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 403 std::is_pointer<T>::value> 404 append(T value) { 405 bytecode.push_back(generator.getMemIndex(value)); 406 } 407 408 /// Append a range of values. 409 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 410 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 411 append(T range) { 412 bytecode.push_back(llvm::size(range)); 413 for (auto it : range) 414 append(it); 415 } 416 417 /// Append a variadic number of fields to the bytecode. 418 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 419 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 420 append(field); 421 append(field2, fields...); 422 } 423 424 /// Appends a value as a pointer, stored inline within the bytecode. 425 template <typename T> 426 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 427 appendInline(T value) { 428 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); 429 const void *pointer = value.getAsOpaquePointer(); 430 ByteCodeField fieldParts[numParts]; 431 std::memcpy(fieldParts, &pointer, sizeof(const void *)); 432 bytecode.append(fieldParts, fieldParts + numParts); 433 } 434 435 /// Successor references in the bytecode that have yet to be resolved. 436 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 437 438 /// The underlying bytecode buffer. 439 SmallVectorImpl<ByteCodeField> &bytecode; 440 441 /// The main generator producing PDL. 442 Generator &generator; 443 }; 444 445 /// This class represents a live range of PDL Interpreter values, containing 446 /// information about when values are live within a match/rewrite. 447 struct ByteCodeLiveRange { 448 using Set = llvm::IntervalMap<uint64_t, char, 16>; 449 using Allocator = Set::Allocator; 450 451 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} 452 453 /// Union this live range with the one provided. 454 void unionWith(const ByteCodeLiveRange &rhs) { 455 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; 456 ++it) 457 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); 458 } 459 460 /// Returns true if this range overlaps with the one provided. 461 bool overlaps(const ByteCodeLiveRange &rhs) const { 462 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness) 463 .valid(); 464 } 465 466 /// A map representing the ranges of the match/rewrite that a value is live in 467 /// the interpreter. 468 /// 469 /// We use std::unique_ptr here, because IntervalMap does not provide a 470 /// correct copy or move constructor. We can eliminate the pointer once 471 /// https://reviews.llvm.org/D113240 lands. 472 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness; 473 474 /// The operation range storage index for this range. 475 Optional<unsigned> opRangeIndex; 476 477 /// The type range storage index for this range. 478 Optional<unsigned> typeRangeIndex; 479 480 /// The value range storage index for this range. 481 Optional<unsigned> valueRangeIndex; 482 }; 483 } // namespace 484 485 void Generator::generate(ModuleOp module) { 486 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>( 487 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 488 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 489 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 490 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 491 492 // Allocate memory indices for the results of operations within the matcher 493 // and rewriters. 494 allocateMemoryIndices(matcherFunc, rewriterModule); 495 496 // Generate code for the rewriter functions. 497 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 498 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 499 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 500 for (Operation &op : rewriterFunc.getOps()) 501 generate(&op, rewriterByteCodeWriter); 502 } 503 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 504 "unexpected branches in rewriter function"); 505 506 // Generate code for the matcher function. 507 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 508 generate(&matcherFunc.getBody(), matcherByteCodeWriter); 509 510 // Resolve successor references in the matcher. 511 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 512 ByteCodeAddr addr = blockToAddr[it.first]; 513 for (unsigned offsetToFix : it.second) 514 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 515 } 516 } 517 518 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, 519 ModuleOp rewriterModule) { 520 // Rewriters use simplistic allocation scheme that simply assigns an index to 521 // each result. 522 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) { 523 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 524 auto processRewriterValue = [&](Value val) { 525 valueToMemIndex.try_emplace(val, index++); 526 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 527 Type elementTy = rangeType.getElementType(); 528 if (elementTy.isa<pdl::TypeType>()) 529 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 530 else if (elementTy.isa<pdl::ValueType>()) 531 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 532 } 533 }; 534 535 for (BlockArgument arg : rewriterFunc.getArguments()) 536 processRewriterValue(arg); 537 rewriterFunc.getBody().walk([&](Operation *op) { 538 for (Value result : op->getResults()) 539 processRewriterValue(result); 540 }); 541 if (index > maxValueMemoryIndex) 542 maxValueMemoryIndex = index; 543 if (typeRangeIndex > maxTypeRangeMemoryIndex) 544 maxTypeRangeMemoryIndex = typeRangeIndex; 545 if (valueRangeIndex > maxValueRangeMemoryIndex) 546 maxValueRangeMemoryIndex = valueRangeIndex; 547 } 548 549 // The matcher function uses a more sophisticated numbering that tries to 550 // minimize the number of memory indices assigned. This is done by determining 551 // a live range of the values within the matcher, then the allocation is just 552 // finding the minimal number of overlapping live ranges. This is essentially 553 // a simplified form of register allocation where we don't necessarily have a 554 // limited number of registers, but we still want to minimize the number used. 555 DenseMap<Operation *, unsigned> opToFirstIndex; 556 DenseMap<Operation *, unsigned> opToLastIndex; 557 558 // A custom walk that marks the first and the last index of each operation. 559 // The entry marks the beginning of the liveness range for this operation, 560 // followed by nested operations, followed by the end of the liveness range. 561 unsigned index = 0; 562 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) { 563 opToFirstIndex.try_emplace(op, index++); 564 for (Region ®ion : op->getRegions()) 565 for (Block &block : region.getBlocks()) 566 for (Operation &nested : block) 567 walk(&nested); 568 opToLastIndex.try_emplace(op, index++); 569 }; 570 walk(matcherFunc); 571 572 // Liveness info for each of the defs within the matcher. 573 ByteCodeLiveRange::Allocator allocator; 574 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 575 576 // Assign the root operation being matched to slot 0. 577 BlockArgument rootOpArg = matcherFunc.getArgument(0); 578 valueToMemIndex[rootOpArg] = 0; 579 580 // Walk each of the blocks, computing the def interval that the value is used. 581 Liveness matcherLiveness(matcherFunc); 582 matcherFunc->walk([&](Block *block) { 583 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); 584 assert(info && "expected liveness info for block"); 585 auto processValue = [&](Value value, Operation *firstUseOrDef) { 586 // We don't need to process the root op argument, this value is always 587 // assigned to the first memory slot. 588 if (value == rootOpArg) 589 return; 590 591 // Set indices for the range of this block that the value is used. 592 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 593 defRangeIt->second.liveness->insert( 594 opToFirstIndex[firstUseOrDef], 595 opToLastIndex[info->getEndOperation(value, firstUseOrDef)], 596 /*dummyValue*/ 0); 597 598 // Check to see if this value is a range type. 599 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 600 Type eleType = rangeTy.getElementType(); 601 if (eleType.isa<pdl::OperationType>()) 602 defRangeIt->second.opRangeIndex = 0; 603 else if (eleType.isa<pdl::TypeType>()) 604 defRangeIt->second.typeRangeIndex = 0; 605 else if (eleType.isa<pdl::ValueType>()) 606 defRangeIt->second.valueRangeIndex = 0; 607 } 608 }; 609 610 // Process the live-ins of this block. 611 for (Value liveIn : info->in()) { 612 // Only process the value if it has been defined in the current region. 613 // Other values that span across pdl_interp.foreach will be added higher 614 // up. This ensures that the we keep them alive for the entire duration 615 // of the loop. 616 if (liveIn.getParentRegion() == block->getParent()) 617 processValue(liveIn, &block->front()); 618 } 619 620 // Process the block arguments for the entry block (those are not live-in). 621 if (block->isEntryBlock()) { 622 for (Value argument : block->getArguments()) 623 processValue(argument, &block->front()); 624 } 625 626 // Process any new defs within this block. 627 for (Operation &op : *block) 628 for (Value result : op.getResults()) 629 processValue(result, &op); 630 }); 631 632 // Greedily allocate memory slots using the computed def live ranges. 633 std::vector<ByteCodeLiveRange> allocatedIndices; 634 635 // The number of memory indices currently allocated (and its next value). 636 // Recall that the root gets allocated memory index 0. 637 ByteCodeField numIndices = 1; 638 639 // The number of memory ranges of various types (and their next values). 640 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; 641 642 for (auto &defIt : valueDefRanges) { 643 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 644 ByteCodeLiveRange &defRange = defIt.second; 645 646 // Try to allocate to an existing index. 647 for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) { 648 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 649 if (!defRange.overlaps(existingRange)) { 650 existingRange.unionWith(defRange); 651 memIndex = existingIndexIt.index() + 1; 652 653 if (defRange.opRangeIndex) { 654 if (!existingRange.opRangeIndex) 655 existingRange.opRangeIndex = numOpRanges++; 656 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; 657 } else if (defRange.typeRangeIndex) { 658 if (!existingRange.typeRangeIndex) 659 existingRange.typeRangeIndex = numTypeRanges++; 660 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 661 } else if (defRange.valueRangeIndex) { 662 if (!existingRange.valueRangeIndex) 663 existingRange.valueRangeIndex = numValueRanges++; 664 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 665 } 666 break; 667 } 668 } 669 670 // If no existing index could be used, add a new one. 671 if (memIndex == 0) { 672 allocatedIndices.emplace_back(allocator); 673 ByteCodeLiveRange &newRange = allocatedIndices.back(); 674 newRange.unionWith(defRange); 675 676 // Allocate an index for op/type/value ranges. 677 if (defRange.opRangeIndex) { 678 newRange.opRangeIndex = numOpRanges; 679 valueToRangeIndex[defIt.first] = numOpRanges++; 680 } else if (defRange.typeRangeIndex) { 681 newRange.typeRangeIndex = numTypeRanges; 682 valueToRangeIndex[defIt.first] = numTypeRanges++; 683 } else if (defRange.valueRangeIndex) { 684 newRange.valueRangeIndex = numValueRanges; 685 valueToRangeIndex[defIt.first] = numValueRanges++; 686 } 687 688 memIndex = allocatedIndices.size(); 689 ++numIndices; 690 } 691 } 692 693 // Print the index usage and ensure that we did not run out of index space. 694 LLVM_DEBUG({ 695 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " 696 << "(down from initial " << valueDefRanges.size() << ").\n"; 697 }); 698 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && 699 "Ran out of memory for allocated indices"); 700 701 // Update the max number of indices. 702 if (numIndices > maxValueMemoryIndex) 703 maxValueMemoryIndex = numIndices; 704 if (numOpRanges > maxOpRangeMemoryIndex) 705 maxOpRangeMemoryIndex = numOpRanges; 706 if (numTypeRanges > maxTypeRangeMemoryIndex) 707 maxTypeRangeMemoryIndex = numTypeRanges; 708 if (numValueRanges > maxValueRangeMemoryIndex) 709 maxValueRangeMemoryIndex = numValueRanges; 710 } 711 712 void Generator::generate(Region *region, ByteCodeWriter &writer) { 713 llvm::ReversePostOrderTraversal<Region *> rpot(region); 714 for (Block *block : rpot) { 715 // Keep track of where this block begins within the matcher function. 716 blockToAddr.try_emplace(block, matcherByteCode.size()); 717 for (Operation &op : *block) 718 generate(&op, writer); 719 } 720 } 721 722 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 723 LLVM_DEBUG({ 724 // The following list must contain all the operations that do not 725 // produce any bytecode. 726 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp, 727 pdl_interp::InferredTypesOp>(op)) 728 writer.appendInline(op->getLoc()); 729 }); 730 TypeSwitch<Operation *>(op) 731 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 732 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 733 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 734 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 735 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 736 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, 737 pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, 738 pdl_interp::CreateTypesOp, pdl_interp::EraseOp, 739 pdl_interp::ExtractOp, pdl_interp::FinalizeOp, 740 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, 741 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 742 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, 743 pdl_interp::GetResultOp, pdl_interp::GetResultsOp, 744 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, 745 pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, 746 pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, 747 pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, 748 pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, 749 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 750 [&](auto interpOp) { this->generate(interpOp, writer); }) 751 .Default([](Operation *) { 752 llvm_unreachable("unknown `pdl_interp` operation"); 753 }); 754 } 755 756 void Generator::generate(pdl_interp::ApplyConstraintOp op, 757 ByteCodeWriter &writer) { 758 assert(constraintToMemIndex.count(op.getName()) && 759 "expected index for constraint function"); 760 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); 761 writer.appendPDLValueList(op.getArgs()); 762 writer.append(op.getSuccessors()); 763 } 764 void Generator::generate(pdl_interp::ApplyRewriteOp op, 765 ByteCodeWriter &writer) { 766 assert(externalRewriterToMemIndex.count(op.getName()) && 767 "expected index for rewrite function"); 768 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]); 769 writer.appendPDLValueList(op.getArgs()); 770 771 ResultRange results = op.getResults(); 772 writer.append(ByteCodeField(results.size())); 773 for (Value result : results) { 774 // In debug mode we also record the expected kind of the result, so that we 775 // can provide extra verification of the native rewrite function. 776 #ifndef NDEBUG 777 writer.appendPDLValueKind(result); 778 #endif 779 780 // Range results also need to append the range storage index. 781 if (result.getType().isa<pdl::RangeType>()) 782 writer.append(getRangeStorageIndex(result)); 783 writer.append(result); 784 } 785 } 786 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 787 Value lhs = op.getLhs(); 788 if (lhs.getType().isa<pdl::RangeType>()) { 789 writer.append(OpCode::AreRangesEqual); 790 writer.appendPDLValueKind(lhs); 791 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors()); 792 return; 793 } 794 795 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors()); 796 } 797 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 798 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 799 } 800 void Generator::generate(pdl_interp::CheckAttributeOp op, 801 ByteCodeWriter &writer) { 802 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(), 803 op.getSuccessors()); 804 } 805 void Generator::generate(pdl_interp::CheckOperandCountOp op, 806 ByteCodeWriter &writer) { 807 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(), 808 static_cast<ByteCodeField>(op.getCompareAtLeast()), 809 op.getSuccessors()); 810 } 811 void Generator::generate(pdl_interp::CheckOperationNameOp op, 812 ByteCodeWriter &writer) { 813 writer.append(OpCode::CheckOperationName, op.getInputOp(), 814 OperationName(op.getName(), ctx), op.getSuccessors()); 815 } 816 void Generator::generate(pdl_interp::CheckResultCountOp op, 817 ByteCodeWriter &writer) { 818 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(), 819 static_cast<ByteCodeField>(op.getCompareAtLeast()), 820 op.getSuccessors()); 821 } 822 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 823 writer.append(OpCode::AreEqual, op.getValue(), op.getType(), 824 op.getSuccessors()); 825 } 826 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 827 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(), 828 op.getSuccessors()); 829 } 830 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { 831 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); 832 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); 833 } 834 void Generator::generate(pdl_interp::CreateAttributeOp op, 835 ByteCodeWriter &writer) { 836 // Simply repoint the memory index of the result to the constant. 837 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue()); 838 } 839 void Generator::generate(pdl_interp::CreateOperationOp op, 840 ByteCodeWriter &writer) { 841 writer.append(OpCode::CreateOperation, op.getResultOp(), 842 OperationName(op.getName(), ctx)); 843 writer.appendPDLValueList(op.getInputOperands()); 844 845 // Add the attributes. 846 OperandRange attributes = op.getInputAttributes(); 847 writer.append(static_cast<ByteCodeField>(attributes.size())); 848 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes)) 849 writer.append(std::get<0>(it), std::get<1>(it)); 850 writer.appendPDLValueList(op.getInputResultTypes()); 851 } 852 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 853 // Simply repoint the memory index of the result to the constant. 854 getMemIndex(op.getResult()) = getMemIndex(op.getValue()); 855 } 856 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 857 writer.append(OpCode::CreateTypes, op.getResult(), 858 getRangeStorageIndex(op.getResult()), op.getValue()); 859 } 860 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 861 writer.append(OpCode::EraseOp, op.getInputOp()); 862 } 863 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { 864 OpCode opCode = 865 TypeSwitch<Type, OpCode>(op.getResult().getType()) 866 .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) 867 .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) 868 .Case([](pdl::TypeType) { return OpCode::ExtractType; }) 869 .Default([](Type) -> OpCode { 870 llvm_unreachable("unsupported element type"); 871 }); 872 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult()); 873 } 874 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 875 writer.append(OpCode::Finalize); 876 } 877 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { 878 BlockArgument arg = op.getLoopVariable(); 879 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg); 880 writer.appendPDLValueKind(arg.getType()); 881 writer.append(curLoopLevel, op.getSuccessor()); 882 ++curLoopLevel; 883 if (curLoopLevel > maxLoopLevel) 884 maxLoopLevel = curLoopLevel; 885 generate(&op.getRegion(), writer); 886 --curLoopLevel; 887 } 888 void Generator::generate(pdl_interp::GetAttributeOp op, 889 ByteCodeWriter &writer) { 890 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(), 891 op.getNameAttr()); 892 } 893 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 894 ByteCodeWriter &writer) { 895 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue()); 896 } 897 void Generator::generate(pdl_interp::GetDefiningOpOp op, 898 ByteCodeWriter &writer) { 899 writer.append(OpCode::GetDefiningOp, op.getInputOp()); 900 writer.appendPDLValue(op.getValue()); 901 } 902 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 903 uint32_t index = op.getIndex(); 904 if (index < 4) 905 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 906 else 907 writer.append(OpCode::GetOperandN, index); 908 writer.append(op.getInputOp(), op.getValue()); 909 } 910 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 911 Value result = op.getValue(); 912 Optional<uint32_t> index = op.getIndex(); 913 writer.append(OpCode::GetOperands, 914 index.getValueOr(std::numeric_limits<uint32_t>::max()), 915 op.getInputOp()); 916 if (result.getType().isa<pdl::RangeType>()) 917 writer.append(getRangeStorageIndex(result)); 918 else 919 writer.append(std::numeric_limits<ByteCodeField>::max()); 920 writer.append(result); 921 } 922 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 923 uint32_t index = op.getIndex(); 924 if (index < 4) 925 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 926 else 927 writer.append(OpCode::GetResultN, index); 928 writer.append(op.getInputOp(), op.getValue()); 929 } 930 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 931 Value result = op.getValue(); 932 Optional<uint32_t> index = op.getIndex(); 933 writer.append(OpCode::GetResults, 934 index.getValueOr(std::numeric_limits<uint32_t>::max()), 935 op.getInputOp()); 936 if (result.getType().isa<pdl::RangeType>()) 937 writer.append(getRangeStorageIndex(result)); 938 else 939 writer.append(std::numeric_limits<ByteCodeField>::max()); 940 writer.append(result); 941 } 942 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { 943 Value operations = op.getOperations(); 944 ByteCodeField rangeIndex = getRangeStorageIndex(operations); 945 writer.append(OpCode::GetUsers, operations, rangeIndex); 946 writer.appendPDLValue(op.getValue()); 947 } 948 void Generator::generate(pdl_interp::GetValueTypeOp op, 949 ByteCodeWriter &writer) { 950 if (op.getType().isa<pdl::RangeType>()) { 951 Value result = op.getResult(); 952 writer.append(OpCode::GetValueRangeTypes, result, 953 getRangeStorageIndex(result), op.getValue()); 954 } else { 955 writer.append(OpCode::GetValueType, op.getResult(), op.getValue()); 956 } 957 } 958 959 void Generator::generate(pdl_interp::InferredTypesOp op, 960 ByteCodeWriter &writer) { 961 // InferType maps to a null type as a marker for inferring result types. 962 getMemIndex(op.getResult()) = getMemIndex(Type()); 963 } 964 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 965 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors()); 966 } 967 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 968 ByteCodeField patternIndex = patterns.size(); 969 patterns.emplace_back(PDLByteCodePattern::create( 970 op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); 971 writer.append(OpCode::RecordMatch, patternIndex, 972 SuccessorRange(op.getOperation()), op.getMatchedOps()); 973 writer.appendPDLValueList(op.getInputs()); 974 } 975 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 976 writer.append(OpCode::ReplaceOp, op.getInputOp()); 977 writer.appendPDLValueList(op.getReplValues()); 978 } 979 void Generator::generate(pdl_interp::SwitchAttributeOp op, 980 ByteCodeWriter &writer) { 981 writer.append(OpCode::SwitchAttribute, op.getAttribute(), 982 op.getCaseValuesAttr(), op.getSuccessors()); 983 } 984 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 985 ByteCodeWriter &writer) { 986 writer.append(OpCode::SwitchOperandCount, op.getInputOp(), 987 op.getCaseValuesAttr(), op.getSuccessors()); 988 } 989 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 990 ByteCodeWriter &writer) { 991 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) { 992 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 993 }); 994 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases, 995 op.getSuccessors()); 996 } 997 void Generator::generate(pdl_interp::SwitchResultCountOp op, 998 ByteCodeWriter &writer) { 999 writer.append(OpCode::SwitchResultCount, op.getInputOp(), 1000 op.getCaseValuesAttr(), op.getSuccessors()); 1001 } 1002 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 1003 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(), 1004 op.getSuccessors()); 1005 } 1006 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 1007 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(), 1008 op.getSuccessors()); 1009 } 1010 1011 //===----------------------------------------------------------------------===// 1012 // PDLByteCode 1013 //===----------------------------------------------------------------------===// 1014 1015 PDLByteCode::PDLByteCode(ModuleOp module, 1016 llvm::StringMap<PDLConstraintFunction> constraintFns, 1017 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 1018 Generator generator(module.getContext(), uniquedData, matcherByteCode, 1019 rewriterByteCode, patterns, maxValueMemoryIndex, 1020 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, 1021 maxLoopLevel, constraintFns, rewriteFns); 1022 generator.generate(module); 1023 1024 // Initialize the external functions. 1025 for (auto &it : constraintFns) 1026 constraintFunctions.push_back(std::move(it.second)); 1027 for (auto &it : rewriteFns) 1028 rewriteFunctions.push_back(std::move(it.second)); 1029 } 1030 1031 /// Initialize the given state such that it can be used to execute the current 1032 /// bytecode. 1033 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 1034 state.memory.resize(maxValueMemoryIndex, nullptr); 1035 state.opRangeMemory.resize(maxOpRangeCount); 1036 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 1037 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 1038 state.loopIndex.resize(maxLoopLevel, 0); 1039 state.currentPatternBenefits.reserve(patterns.size()); 1040 for (const PDLByteCodePattern &pattern : patterns) 1041 state.currentPatternBenefits.push_back(pattern.getBenefit()); 1042 } 1043 1044 //===----------------------------------------------------------------------===// 1045 // ByteCode Execution 1046 1047 namespace { 1048 /// This class provides support for executing a bytecode stream. 1049 class ByteCodeExecutor { 1050 public: 1051 ByteCodeExecutor( 1052 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 1053 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, 1054 MutableArrayRef<TypeRange> typeRangeMemory, 1055 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 1056 MutableArrayRef<ValueRange> valueRangeMemory, 1057 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 1058 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, 1059 ArrayRef<ByteCodeField> code, 1060 ArrayRef<PatternBenefit> currentPatternBenefits, 1061 ArrayRef<PDLByteCodePattern> patterns, 1062 ArrayRef<PDLConstraintFunction> constraintFunctions, 1063 ArrayRef<PDLRewriteFunction> rewriteFunctions) 1064 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), 1065 typeRangeMemory(typeRangeMemory), 1066 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 1067 valueRangeMemory(valueRangeMemory), 1068 allocatedValueRangeMemory(allocatedValueRangeMemory), 1069 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), 1070 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 1071 constraintFunctions(constraintFunctions), 1072 rewriteFunctions(rewriteFunctions) {} 1073 1074 /// Start executing the code at the current bytecode index. `matches` is an 1075 /// optional field provided when this function is executed in a matching 1076 /// context. 1077 void execute(PatternRewriter &rewriter, 1078 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 1079 Optional<Location> mainRewriteLoc = {}); 1080 1081 private: 1082 /// Internal implementation of executing each of the bytecode commands. 1083 void executeApplyConstraint(PatternRewriter &rewriter); 1084 void executeApplyRewrite(PatternRewriter &rewriter); 1085 void executeAreEqual(); 1086 void executeAreRangesEqual(); 1087 void executeBranch(); 1088 void executeCheckOperandCount(); 1089 void executeCheckOperationName(); 1090 void executeCheckResultCount(); 1091 void executeCheckTypes(); 1092 void executeContinue(); 1093 void executeCreateOperation(PatternRewriter &rewriter, 1094 Location mainRewriteLoc); 1095 void executeCreateTypes(); 1096 void executeEraseOp(PatternRewriter &rewriter); 1097 template <typename T, typename Range, PDLValue::Kind kind> 1098 void executeExtract(); 1099 void executeFinalize(); 1100 void executeForEach(); 1101 void executeGetAttribute(); 1102 void executeGetAttributeType(); 1103 void executeGetDefiningOp(); 1104 void executeGetOperand(unsigned index); 1105 void executeGetOperands(); 1106 void executeGetResult(unsigned index); 1107 void executeGetResults(); 1108 void executeGetUsers(); 1109 void executeGetValueType(); 1110 void executeGetValueRangeTypes(); 1111 void executeIsNotNull(); 1112 void executeRecordMatch(PatternRewriter &rewriter, 1113 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 1114 void executeReplaceOp(PatternRewriter &rewriter); 1115 void executeSwitchAttribute(); 1116 void executeSwitchOperandCount(); 1117 void executeSwitchOperationName(); 1118 void executeSwitchResultCount(); 1119 void executeSwitchType(); 1120 void executeSwitchTypes(); 1121 1122 /// Pushes a code iterator to the stack. 1123 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } 1124 1125 /// Pops a code iterator from the stack, returning true on success. 1126 void popCodeIt() { 1127 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); 1128 curCodeIt = resumeCodeIt.back(); 1129 resumeCodeIt.pop_back(); 1130 } 1131 1132 /// Return the bytecode iterator at the start of the current op code. 1133 const ByteCodeField *getPrevCodeIt() const { 1134 LLVM_DEBUG({ 1135 // Account for the op code and the Location stored inline. 1136 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField); 1137 }); 1138 1139 // Account for the op code only. 1140 return curCodeIt - 1; 1141 } 1142 1143 /// Read a value from the bytecode buffer, optionally skipping a certain 1144 /// number of prefix values. These methods always update the buffer to point 1145 /// to the next field after the read data. 1146 template <typename T = ByteCodeField> 1147 T read(size_t skipN = 0) { 1148 curCodeIt += skipN; 1149 return readImpl<T>(); 1150 } 1151 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 1152 1153 /// Read a list of values from the bytecode buffer. 1154 template <typename ValueT, typename T> 1155 void readList(SmallVectorImpl<T> &list) { 1156 list.clear(); 1157 for (unsigned i = 0, e = read(); i != e; ++i) 1158 list.push_back(read<ValueT>()); 1159 } 1160 1161 /// Read a list of values from the bytecode buffer. The values may be encoded 1162 /// as either Value or ValueRange elements. 1163 void readValueList(SmallVectorImpl<Value> &list) { 1164 for (unsigned i = 0, e = read(); i != e; ++i) { 1165 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1166 list.push_back(read<Value>()); 1167 } else { 1168 ValueRange *values = read<ValueRange *>(); 1169 list.append(values->begin(), values->end()); 1170 } 1171 } 1172 } 1173 1174 /// Read a value stored inline as a pointer. 1175 template <typename T> 1176 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T> 1177 readInline() { 1178 const void *pointer; 1179 std::memcpy(&pointer, curCodeIt, sizeof(const void *)); 1180 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); 1181 return T::getFromOpaquePointer(pointer); 1182 } 1183 1184 /// Jump to a specific successor based on a predicate value. 1185 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 1186 /// Jump to a specific successor based on a destination index. 1187 void selectJump(size_t destIndex) { 1188 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 1189 } 1190 1191 /// Handle a switch operation with the provided value and cases. 1192 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1193 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1194 LLVM_DEBUG({ 1195 llvm::dbgs() << " * Value: " << value << "\n" 1196 << " * Cases: "; 1197 llvm::interleaveComma(cases, llvm::dbgs()); 1198 llvm::dbgs() << "\n"; 1199 }); 1200 1201 // Check to see if the attribute value is within the case list. Jump to 1202 // the correct successor index based on the result. 1203 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1204 if (cmp(*it, value)) 1205 return selectJump(size_t((it - cases.begin()) + 1)); 1206 selectJump(size_t(0)); 1207 } 1208 1209 /// Store a pointer to memory. 1210 void storeToMemory(unsigned index, const void *value) { 1211 memory[index] = value; 1212 } 1213 1214 /// Store a value to memory as an opaque pointer. 1215 template <typename T> 1216 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 1217 storeToMemory(unsigned index, T value) { 1218 memory[index] = value.getAsOpaquePointer(); 1219 } 1220 1221 /// Internal implementation of reading various data types from the bytecode 1222 /// stream. 1223 template <typename T> 1224 const void *readFromMemory() { 1225 size_t index = *curCodeIt++; 1226 1227 // If this type is an SSA value, it can only be stored in non-const memory. 1228 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1229 Value>::value || 1230 index < memory.size()) 1231 return memory[index]; 1232 1233 // Otherwise, if this index is not inbounds it is uniqued. 1234 return uniquedMemory[index - memory.size()]; 1235 } 1236 template <typename T> 1237 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1238 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1239 } 1240 template <typename T> 1241 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1242 T> 1243 readImpl() { 1244 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1245 } 1246 template <typename T> 1247 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1248 switch (read<PDLValue::Kind>()) { 1249 case PDLValue::Kind::Attribute: 1250 return read<Attribute>(); 1251 case PDLValue::Kind::Operation: 1252 return read<Operation *>(); 1253 case PDLValue::Kind::Type: 1254 return read<Type>(); 1255 case PDLValue::Kind::Value: 1256 return read<Value>(); 1257 case PDLValue::Kind::TypeRange: 1258 return read<TypeRange *>(); 1259 case PDLValue::Kind::ValueRange: 1260 return read<ValueRange *>(); 1261 } 1262 llvm_unreachable("unhandled PDLValue::Kind"); 1263 } 1264 template <typename T> 1265 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1266 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1267 "unexpected ByteCode address size"); 1268 ByteCodeAddr result; 1269 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1270 curCodeIt += 2; 1271 return result; 1272 } 1273 template <typename T> 1274 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1275 return *curCodeIt++; 1276 } 1277 template <typename T> 1278 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1279 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1280 } 1281 1282 /// The underlying bytecode buffer. 1283 const ByteCodeField *curCodeIt; 1284 1285 /// The stack of bytecode positions at which to resume operation. 1286 SmallVector<const ByteCodeField *> resumeCodeIt; 1287 1288 /// The current execution memory. 1289 MutableArrayRef<const void *> memory; 1290 MutableArrayRef<OwningOpRange> opRangeMemory; 1291 MutableArrayRef<TypeRange> typeRangeMemory; 1292 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1293 MutableArrayRef<ValueRange> valueRangeMemory; 1294 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1295 1296 /// The current loop indices. 1297 MutableArrayRef<unsigned> loopIndex; 1298 1299 /// References to ByteCode data necessary for execution. 1300 ArrayRef<const void *> uniquedMemory; 1301 ArrayRef<ByteCodeField> code; 1302 ArrayRef<PatternBenefit> currentPatternBenefits; 1303 ArrayRef<PDLByteCodePattern> patterns; 1304 ArrayRef<PDLConstraintFunction> constraintFunctions; 1305 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1306 }; 1307 1308 /// This class is an instantiation of the PDLResultList that provides access to 1309 /// the returned results. This API is not on `PDLResultList` to avoid 1310 /// overexposing access to information specific solely to the ByteCode. 1311 class ByteCodeRewriteResultList : public PDLResultList { 1312 public: 1313 ByteCodeRewriteResultList(unsigned maxNumResults) 1314 : PDLResultList(maxNumResults) {} 1315 1316 /// Return the list of PDL results. 1317 MutableArrayRef<PDLValue> getResults() { return results; } 1318 1319 /// Return the type ranges allocated by this list. 1320 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1321 return allocatedTypeRanges; 1322 } 1323 1324 /// Return the value ranges allocated by this list. 1325 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1326 return allocatedValueRanges; 1327 } 1328 }; 1329 } // namespace 1330 1331 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1332 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1333 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1334 SmallVector<PDLValue, 16> args; 1335 readList<PDLValue>(args); 1336 1337 LLVM_DEBUG({ 1338 llvm::dbgs() << " * Arguments: "; 1339 llvm::interleaveComma(args, llvm::dbgs()); 1340 }); 1341 1342 // Invoke the constraint and jump to the proper destination. 1343 selectJump(succeeded(constraintFn(rewriter, args))); 1344 } 1345 1346 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1347 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1348 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1349 SmallVector<PDLValue, 16> args; 1350 readList<PDLValue>(args); 1351 1352 LLVM_DEBUG({ 1353 llvm::dbgs() << " * Arguments: "; 1354 llvm::interleaveComma(args, llvm::dbgs()); 1355 }); 1356 1357 // Execute the rewrite function. 1358 ByteCodeField numResults = read(); 1359 ByteCodeRewriteResultList results(numResults); 1360 rewriteFn(rewriter, results, args); 1361 1362 assert(results.getResults().size() == numResults && 1363 "native PDL rewrite function returned unexpected number of results"); 1364 1365 // Store the results in the bytecode memory. 1366 for (PDLValue &result : results.getResults()) { 1367 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1368 1369 // In debug mode we also verify the expected kind of the result. 1370 #ifndef NDEBUG 1371 assert(result.getKind() == read<PDLValue::Kind>() && 1372 "native PDL rewrite function returned an unexpected type of result"); 1373 #endif 1374 1375 // If the result is a range, we need to copy it over to the bytecodes 1376 // range memory. 1377 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1378 unsigned rangeIndex = read(); 1379 typeRangeMemory[rangeIndex] = *typeRange; 1380 memory[read()] = &typeRangeMemory[rangeIndex]; 1381 } else if (Optional<ValueRange> valueRange = 1382 result.dyn_cast<ValueRange>()) { 1383 unsigned rangeIndex = read(); 1384 valueRangeMemory[rangeIndex] = *valueRange; 1385 memory[read()] = &valueRangeMemory[rangeIndex]; 1386 } else { 1387 memory[read()] = result.getAsOpaquePointer(); 1388 } 1389 } 1390 1391 // Copy over any underlying storage allocated for result ranges. 1392 for (auto &it : results.getAllocatedTypeRanges()) 1393 allocatedTypeRangeMemory.push_back(std::move(it)); 1394 for (auto &it : results.getAllocatedValueRanges()) 1395 allocatedValueRangeMemory.push_back(std::move(it)); 1396 } 1397 1398 void ByteCodeExecutor::executeAreEqual() { 1399 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1400 const void *lhs = read<const void *>(); 1401 const void *rhs = read<const void *>(); 1402 1403 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1404 selectJump(lhs == rhs); 1405 } 1406 1407 void ByteCodeExecutor::executeAreRangesEqual() { 1408 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1409 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1410 const void *lhs = read<const void *>(); 1411 const void *rhs = read<const void *>(); 1412 1413 switch (valueKind) { 1414 case PDLValue::Kind::TypeRange: { 1415 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1416 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1417 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1418 selectJump(*lhsRange == *rhsRange); 1419 break; 1420 } 1421 case PDLValue::Kind::ValueRange: { 1422 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1423 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1424 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1425 selectJump(*lhsRange == *rhsRange); 1426 break; 1427 } 1428 default: 1429 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1430 } 1431 } 1432 1433 void ByteCodeExecutor::executeBranch() { 1434 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1435 curCodeIt = &code[read<ByteCodeAddr>()]; 1436 } 1437 1438 void ByteCodeExecutor::executeCheckOperandCount() { 1439 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1440 Operation *op = read<Operation *>(); 1441 uint32_t expectedCount = read<uint32_t>(); 1442 bool compareAtLeast = read(); 1443 1444 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1445 << " * Expected: " << expectedCount << "\n" 1446 << " * Comparator: " 1447 << (compareAtLeast ? ">=" : "==") << "\n"); 1448 if (compareAtLeast) 1449 selectJump(op->getNumOperands() >= expectedCount); 1450 else 1451 selectJump(op->getNumOperands() == expectedCount); 1452 } 1453 1454 void ByteCodeExecutor::executeCheckOperationName() { 1455 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1456 Operation *op = read<Operation *>(); 1457 OperationName expectedName = read<OperationName>(); 1458 1459 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1460 << " * Expected: \"" << expectedName << "\"\n"); 1461 selectJump(op->getName() == expectedName); 1462 } 1463 1464 void ByteCodeExecutor::executeCheckResultCount() { 1465 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1466 Operation *op = read<Operation *>(); 1467 uint32_t expectedCount = read<uint32_t>(); 1468 bool compareAtLeast = read(); 1469 1470 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1471 << " * Expected: " << expectedCount << "\n" 1472 << " * Comparator: " 1473 << (compareAtLeast ? ">=" : "==") << "\n"); 1474 if (compareAtLeast) 1475 selectJump(op->getNumResults() >= expectedCount); 1476 else 1477 selectJump(op->getNumResults() == expectedCount); 1478 } 1479 1480 void ByteCodeExecutor::executeCheckTypes() { 1481 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1482 TypeRange *lhs = read<TypeRange *>(); 1483 Attribute rhs = read<Attribute>(); 1484 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1485 1486 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1487 } 1488 1489 void ByteCodeExecutor::executeContinue() { 1490 ByteCodeField level = read(); 1491 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" 1492 << " * Level: " << level << "\n"); 1493 ++loopIndex[level]; 1494 popCodeIt(); 1495 } 1496 1497 void ByteCodeExecutor::executeCreateTypes() { 1498 LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 1499 unsigned memIndex = read(); 1500 unsigned rangeIndex = read(); 1501 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1502 1503 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1504 1505 // Allocate a buffer for this type range. 1506 llvm::OwningArrayRef<Type> storage(typesAttr.size()); 1507 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 1508 allocatedTypeRangeMemory.emplace_back(std::move(storage)); 1509 1510 // Assign this to the range slot and use the range as the value for the 1511 // memory index. 1512 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 1513 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1514 } 1515 1516 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1517 Location mainRewriteLoc) { 1518 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1519 1520 unsigned memIndex = read(); 1521 OperationState state(mainRewriteLoc, read<OperationName>()); 1522 readValueList(state.operands); 1523 for (unsigned i = 0, e = read(); i != e; ++i) { 1524 StringAttr name = read<StringAttr>(); 1525 if (Attribute attr = read<Attribute>()) 1526 state.addAttribute(name, attr); 1527 } 1528 1529 for (unsigned i = 0, e = read(); i != e; ++i) { 1530 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1531 state.types.push_back(read<Type>()); 1532 continue; 1533 } 1534 1535 // If we find a null range, this signals that the types are infered. 1536 if (TypeRange *resultTypes = read<TypeRange *>()) { 1537 state.types.append(resultTypes->begin(), resultTypes->end()); 1538 continue; 1539 } 1540 1541 // Handle the case where the operation has inferred types. 1542 InferTypeOpInterface::Concept *inferInterface = 1543 state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>(); 1544 1545 // TODO: Handle failure. 1546 state.types.clear(); 1547 if (failed(inferInterface->inferReturnTypes( 1548 state.getContext(), state.location, state.operands, 1549 state.attributes.getDictionary(state.getContext()), state.regions, 1550 state.types))) 1551 return; 1552 break; 1553 } 1554 1555 Operation *resultOp = rewriter.create(state); 1556 memory[memIndex] = resultOp; 1557 1558 LLVM_DEBUG({ 1559 llvm::dbgs() << " * Attributes: " 1560 << state.attributes.getDictionary(state.getContext()) 1561 << "\n * Operands: "; 1562 llvm::interleaveComma(state.operands, llvm::dbgs()); 1563 llvm::dbgs() << "\n * Result Types: "; 1564 llvm::interleaveComma(state.types, llvm::dbgs()); 1565 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1566 }); 1567 } 1568 1569 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1570 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1571 Operation *op = read<Operation *>(); 1572 1573 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1574 rewriter.eraseOp(op); 1575 } 1576 1577 template <typename T, typename Range, PDLValue::Kind kind> 1578 void ByteCodeExecutor::executeExtract() { 1579 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); 1580 Range *range = read<Range *>(); 1581 unsigned index = read<uint32_t>(); 1582 unsigned memIndex = read(); 1583 1584 if (!range) { 1585 memory[memIndex] = nullptr; 1586 return; 1587 } 1588 1589 T result = index < range->size() ? (*range)[index] : T(); 1590 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" 1591 << " * Index: " << index << "\n" 1592 << " * Result: " << result << "\n"); 1593 storeToMemory(memIndex, result); 1594 } 1595 1596 void ByteCodeExecutor::executeFinalize() { 1597 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); 1598 } 1599 1600 void ByteCodeExecutor::executeForEach() { 1601 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); 1602 const ByteCodeField *prevCodeIt = getPrevCodeIt(); 1603 unsigned rangeIndex = read(); 1604 unsigned memIndex = read(); 1605 const void *value = nullptr; 1606 1607 switch (read<PDLValue::Kind>()) { 1608 case PDLValue::Kind::Operation: { 1609 unsigned &index = loopIndex[read()]; 1610 ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; 1611 assert(index <= array.size() && "iterated past the end"); 1612 if (index < array.size()) { 1613 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); 1614 value = array[index]; 1615 break; 1616 } 1617 1618 LLVM_DEBUG(llvm::dbgs() << " * Done\n"); 1619 index = 0; 1620 selectJump(size_t(0)); 1621 return; 1622 } 1623 default: 1624 llvm_unreachable("unexpected `ForEach` value kind"); 1625 } 1626 1627 // Store the iterate value and the stack address. 1628 memory[memIndex] = value; 1629 pushCodeIt(prevCodeIt); 1630 1631 // Skip over the successor (we will enter the body of the loop). 1632 read<ByteCodeAddr>(); 1633 } 1634 1635 void ByteCodeExecutor::executeGetAttribute() { 1636 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1637 unsigned memIndex = read(); 1638 Operation *op = read<Operation *>(); 1639 StringAttr attrName = read<StringAttr>(); 1640 Attribute attr = op->getAttr(attrName); 1641 1642 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1643 << " * Attribute: " << attrName << "\n" 1644 << " * Result: " << attr << "\n"); 1645 memory[memIndex] = attr.getAsOpaquePointer(); 1646 } 1647 1648 void ByteCodeExecutor::executeGetAttributeType() { 1649 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1650 unsigned memIndex = read(); 1651 Attribute attr = read<Attribute>(); 1652 Type type = attr ? attr.getType() : Type(); 1653 1654 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1655 << " * Result: " << type << "\n"); 1656 memory[memIndex] = type.getAsOpaquePointer(); 1657 } 1658 1659 void ByteCodeExecutor::executeGetDefiningOp() { 1660 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1661 unsigned memIndex = read(); 1662 Operation *op = nullptr; 1663 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1664 Value value = read<Value>(); 1665 if (value) 1666 op = value.getDefiningOp(); 1667 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1668 } else { 1669 ValueRange *values = read<ValueRange *>(); 1670 if (values && !values->empty()) { 1671 op = values->front().getDefiningOp(); 1672 } 1673 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1674 } 1675 1676 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1677 memory[memIndex] = op; 1678 } 1679 1680 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1681 Operation *op = read<Operation *>(); 1682 unsigned memIndex = read(); 1683 Value operand = 1684 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1685 1686 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1687 << " * Index: " << index << "\n" 1688 << " * Result: " << operand << "\n"); 1689 memory[memIndex] = operand.getAsOpaquePointer(); 1690 } 1691 1692 /// This function is the internal implementation of `GetResults` and 1693 /// `GetOperands` that provides support for extracting a value range from the 1694 /// given operation. 1695 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1696 static void * 1697 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1698 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1699 MutableArrayRef<ValueRange> valueRangeMemory) { 1700 // Check for the sentinel index that signals that all values should be 1701 // returned. 1702 if (index == std::numeric_limits<uint32_t>::max()) { 1703 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1704 // `values` is already the full value range. 1705 1706 // Otherwise, check to see if this operation uses AttrSizedSegments. 1707 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1708 LLVM_DEBUG(llvm::dbgs() 1709 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1710 1711 auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 1712 if (!segmentAttr || segmentAttr.getNumElements() <= index) 1713 return nullptr; 1714 1715 auto segments = segmentAttr.getValues<int32_t>(); 1716 unsigned startIndex = 1717 std::accumulate(segments.begin(), segments.begin() + index, 0); 1718 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1719 1720 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1721 << *std::next(segments.begin(), index) << "]\n"); 1722 1723 // Otherwise, assume this is the last operand group of the operation. 1724 // FIXME: We currently don't support operations with 1725 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1726 // have a way to detect it's presence. 1727 } else if (values.size() >= index) { 1728 LLVM_DEBUG(llvm::dbgs() 1729 << " * Treating values as trailing variadic range\n"); 1730 values = values.drop_front(index); 1731 1732 // If we couldn't detect a way to compute the values, bail out. 1733 } else { 1734 return nullptr; 1735 } 1736 1737 // If the range index is valid, we are returning a range. 1738 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1739 valueRangeMemory[rangeIndex] = values; 1740 return &valueRangeMemory[rangeIndex]; 1741 } 1742 1743 // If a range index wasn't provided, the range is required to be non-variadic. 1744 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1745 } 1746 1747 void ByteCodeExecutor::executeGetOperands() { 1748 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1749 unsigned index = read<uint32_t>(); 1750 Operation *op = read<Operation *>(); 1751 ByteCodeField rangeIndex = read(); 1752 1753 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1754 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1755 valueRangeMemory); 1756 if (!result) 1757 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1758 memory[read()] = result; 1759 } 1760 1761 void ByteCodeExecutor::executeGetResult(unsigned index) { 1762 Operation *op = read<Operation *>(); 1763 unsigned memIndex = read(); 1764 OpResult result = 1765 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1766 1767 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1768 << " * Index: " << index << "\n" 1769 << " * Result: " << result << "\n"); 1770 memory[memIndex] = result.getAsOpaquePointer(); 1771 } 1772 1773 void ByteCodeExecutor::executeGetResults() { 1774 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1775 unsigned index = read<uint32_t>(); 1776 Operation *op = read<Operation *>(); 1777 ByteCodeField rangeIndex = read(); 1778 1779 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1780 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1781 valueRangeMemory); 1782 if (!result) 1783 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1784 memory[read()] = result; 1785 } 1786 1787 void ByteCodeExecutor::executeGetUsers() { 1788 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); 1789 unsigned memIndex = read(); 1790 unsigned rangeIndex = read(); 1791 OwningOpRange &range = opRangeMemory[rangeIndex]; 1792 memory[memIndex] = ⦥ 1793 1794 range = OwningOpRange(); 1795 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1796 // Read the value. 1797 Value value = read<Value>(); 1798 if (!value) 1799 return; 1800 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1801 1802 // Extract the users of a single value. 1803 range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); 1804 llvm::copy(value.getUsers(), range.begin()); 1805 } else { 1806 // Read a range of values. 1807 ValueRange *values = read<ValueRange *>(); 1808 if (!values) 1809 return; 1810 LLVM_DEBUG({ 1811 llvm::dbgs() << " * Values (" << values->size() << "): "; 1812 llvm::interleaveComma(*values, llvm::dbgs()); 1813 llvm::dbgs() << "\n"; 1814 }); 1815 1816 // Extract all the users of a range of values. 1817 SmallVector<Operation *> users; 1818 for (Value value : *values) 1819 users.append(value.user_begin(), value.user_end()); 1820 range = OwningOpRange(users.size()); 1821 llvm::copy(users, range.begin()); 1822 } 1823 1824 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); 1825 } 1826 1827 void ByteCodeExecutor::executeGetValueType() { 1828 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1829 unsigned memIndex = read(); 1830 Value value = read<Value>(); 1831 Type type = value ? value.getType() : Type(); 1832 1833 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1834 << " * Result: " << type << "\n"); 1835 memory[memIndex] = type.getAsOpaquePointer(); 1836 } 1837 1838 void ByteCodeExecutor::executeGetValueRangeTypes() { 1839 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1840 unsigned memIndex = read(); 1841 unsigned rangeIndex = read(); 1842 ValueRange *values = read<ValueRange *>(); 1843 if (!values) { 1844 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1845 memory[memIndex] = nullptr; 1846 return; 1847 } 1848 1849 LLVM_DEBUG({ 1850 llvm::dbgs() << " * Values (" << values->size() << "): "; 1851 llvm::interleaveComma(*values, llvm::dbgs()); 1852 llvm::dbgs() << "\n * Result: "; 1853 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1854 llvm::dbgs() << "\n"; 1855 }); 1856 typeRangeMemory[rangeIndex] = values->getType(); 1857 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1858 } 1859 1860 void ByteCodeExecutor::executeIsNotNull() { 1861 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1862 const void *value = read<const void *>(); 1863 1864 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1865 selectJump(value != nullptr); 1866 } 1867 1868 void ByteCodeExecutor::executeRecordMatch( 1869 PatternRewriter &rewriter, 1870 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1871 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1872 unsigned patternIndex = read(); 1873 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1874 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1875 1876 // If the benefit of the pattern is impossible, skip the processing of the 1877 // rest of the pattern. 1878 if (benefit.isImpossibleToMatch()) { 1879 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1880 curCodeIt = dest; 1881 return; 1882 } 1883 1884 // Create a fused location containing the locations of each of the 1885 // operations used in the match. This will be used as the location for 1886 // created operations during the rewrite that don't already have an 1887 // explicit location set. 1888 unsigned numMatchLocs = read(); 1889 SmallVector<Location, 4> matchLocs; 1890 matchLocs.reserve(numMatchLocs); 1891 for (unsigned i = 0; i != numMatchLocs; ++i) 1892 matchLocs.push_back(read<Operation *>()->getLoc()); 1893 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1894 1895 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1896 << " * Location: " << matchLoc << "\n"); 1897 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1898 PDLByteCode::MatchResult &match = matches.back(); 1899 1900 // Record all of the inputs to the match. If any of the inputs are ranges, we 1901 // will also need to remap the range pointer to memory stored in the match 1902 // state. 1903 unsigned numInputs = read(); 1904 match.values.reserve(numInputs); 1905 match.typeRangeValues.reserve(numInputs); 1906 match.valueRangeValues.reserve(numInputs); 1907 for (unsigned i = 0; i < numInputs; ++i) { 1908 switch (read<PDLValue::Kind>()) { 1909 case PDLValue::Kind::TypeRange: 1910 match.typeRangeValues.push_back(*read<TypeRange *>()); 1911 match.values.push_back(&match.typeRangeValues.back()); 1912 break; 1913 case PDLValue::Kind::ValueRange: 1914 match.valueRangeValues.push_back(*read<ValueRange *>()); 1915 match.values.push_back(&match.valueRangeValues.back()); 1916 break; 1917 default: 1918 match.values.push_back(read<const void *>()); 1919 break; 1920 } 1921 } 1922 curCodeIt = dest; 1923 } 1924 1925 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1926 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1927 Operation *op = read<Operation *>(); 1928 SmallVector<Value, 16> args; 1929 readValueList(args); 1930 1931 LLVM_DEBUG({ 1932 llvm::dbgs() << " * Operation: " << *op << "\n" 1933 << " * Values: "; 1934 llvm::interleaveComma(args, llvm::dbgs()); 1935 llvm::dbgs() << "\n"; 1936 }); 1937 rewriter.replaceOp(op, args); 1938 } 1939 1940 void ByteCodeExecutor::executeSwitchAttribute() { 1941 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1942 Attribute value = read<Attribute>(); 1943 ArrayAttr cases = read<ArrayAttr>(); 1944 handleSwitch(value, cases); 1945 } 1946 1947 void ByteCodeExecutor::executeSwitchOperandCount() { 1948 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1949 Operation *op = read<Operation *>(); 1950 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1951 1952 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1953 handleSwitch(op->getNumOperands(), cases); 1954 } 1955 1956 void ByteCodeExecutor::executeSwitchOperationName() { 1957 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1958 OperationName value = read<Operation *>()->getName(); 1959 size_t caseCount = read(); 1960 1961 // The operation names are stored in-line, so to print them out for 1962 // debugging purposes we need to read the array before executing the 1963 // switch so that we can display all of the possible values. 1964 LLVM_DEBUG({ 1965 const ByteCodeField *prevCodeIt = curCodeIt; 1966 llvm::dbgs() << " * Value: " << value << "\n" 1967 << " * Cases: "; 1968 llvm::interleaveComma( 1969 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1970 [&](size_t) { return read<OperationName>(); }), 1971 llvm::dbgs()); 1972 llvm::dbgs() << "\n"; 1973 curCodeIt = prevCodeIt; 1974 }); 1975 1976 // Try to find the switch value within any of the cases. 1977 for (size_t i = 0; i != caseCount; ++i) { 1978 if (read<OperationName>() == value) { 1979 curCodeIt += (caseCount - i - 1); 1980 return selectJump(i + 1); 1981 } 1982 } 1983 selectJump(size_t(0)); 1984 } 1985 1986 void ByteCodeExecutor::executeSwitchResultCount() { 1987 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1988 Operation *op = read<Operation *>(); 1989 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1990 1991 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1992 handleSwitch(op->getNumResults(), cases); 1993 } 1994 1995 void ByteCodeExecutor::executeSwitchType() { 1996 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 1997 Type value = read<Type>(); 1998 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 1999 handleSwitch(value, cases); 2000 } 2001 2002 void ByteCodeExecutor::executeSwitchTypes() { 2003 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 2004 TypeRange *value = read<TypeRange *>(); 2005 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 2006 if (!value) { 2007 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 2008 return selectJump(size_t(0)); 2009 } 2010 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 2011 return value == caseValue.getAsValueRange<TypeAttr>(); 2012 }); 2013 } 2014 2015 void ByteCodeExecutor::execute( 2016 PatternRewriter &rewriter, 2017 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 2018 Optional<Location> mainRewriteLoc) { 2019 while (true) { 2020 // Print the location of the operation being executed. 2021 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); 2022 2023 OpCode opCode = static_cast<OpCode>(read()); 2024 switch (opCode) { 2025 case ApplyConstraint: 2026 executeApplyConstraint(rewriter); 2027 break; 2028 case ApplyRewrite: 2029 executeApplyRewrite(rewriter); 2030 break; 2031 case AreEqual: 2032 executeAreEqual(); 2033 break; 2034 case AreRangesEqual: 2035 executeAreRangesEqual(); 2036 break; 2037 case Branch: 2038 executeBranch(); 2039 break; 2040 case CheckOperandCount: 2041 executeCheckOperandCount(); 2042 break; 2043 case CheckOperationName: 2044 executeCheckOperationName(); 2045 break; 2046 case CheckResultCount: 2047 executeCheckResultCount(); 2048 break; 2049 case CheckTypes: 2050 executeCheckTypes(); 2051 break; 2052 case Continue: 2053 executeContinue(); 2054 break; 2055 case CreateOperation: 2056 executeCreateOperation(rewriter, *mainRewriteLoc); 2057 break; 2058 case CreateTypes: 2059 executeCreateTypes(); 2060 break; 2061 case EraseOp: 2062 executeEraseOp(rewriter); 2063 break; 2064 case ExtractOp: 2065 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); 2066 break; 2067 case ExtractType: 2068 executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); 2069 break; 2070 case ExtractValue: 2071 executeExtract<Value, ValueRange, PDLValue::Kind::Value>(); 2072 break; 2073 case Finalize: 2074 executeFinalize(); 2075 LLVM_DEBUG(llvm::dbgs() << "\n"); 2076 return; 2077 case ForEach: 2078 executeForEach(); 2079 break; 2080 case GetAttribute: 2081 executeGetAttribute(); 2082 break; 2083 case GetAttributeType: 2084 executeGetAttributeType(); 2085 break; 2086 case GetDefiningOp: 2087 executeGetDefiningOp(); 2088 break; 2089 case GetOperand0: 2090 case GetOperand1: 2091 case GetOperand2: 2092 case GetOperand3: { 2093 unsigned index = opCode - GetOperand0; 2094 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 2095 executeGetOperand(index); 2096 break; 2097 } 2098 case GetOperandN: 2099 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 2100 executeGetOperand(read<uint32_t>()); 2101 break; 2102 case GetOperands: 2103 executeGetOperands(); 2104 break; 2105 case GetResult0: 2106 case GetResult1: 2107 case GetResult2: 2108 case GetResult3: { 2109 unsigned index = opCode - GetResult0; 2110 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 2111 executeGetResult(index); 2112 break; 2113 } 2114 case GetResultN: 2115 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 2116 executeGetResult(read<uint32_t>()); 2117 break; 2118 case GetResults: 2119 executeGetResults(); 2120 break; 2121 case GetUsers: 2122 executeGetUsers(); 2123 break; 2124 case GetValueType: 2125 executeGetValueType(); 2126 break; 2127 case GetValueRangeTypes: 2128 executeGetValueRangeTypes(); 2129 break; 2130 case IsNotNull: 2131 executeIsNotNull(); 2132 break; 2133 case RecordMatch: 2134 assert(matches && 2135 "expected matches to be provided when executing the matcher"); 2136 executeRecordMatch(rewriter, *matches); 2137 break; 2138 case ReplaceOp: 2139 executeReplaceOp(rewriter); 2140 break; 2141 case SwitchAttribute: 2142 executeSwitchAttribute(); 2143 break; 2144 case SwitchOperandCount: 2145 executeSwitchOperandCount(); 2146 break; 2147 case SwitchOperationName: 2148 executeSwitchOperationName(); 2149 break; 2150 case SwitchResultCount: 2151 executeSwitchResultCount(); 2152 break; 2153 case SwitchType: 2154 executeSwitchType(); 2155 break; 2156 case SwitchTypes: 2157 executeSwitchTypes(); 2158 break; 2159 } 2160 LLVM_DEBUG(llvm::dbgs() << "\n"); 2161 } 2162 } 2163 2164 /// Run the pattern matcher on the given root operation, collecting the matched 2165 /// patterns in `matches`. 2166 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 2167 SmallVectorImpl<MatchResult> &matches, 2168 PDLByteCodeMutableState &state) const { 2169 // The first memory slot is always the root operation. 2170 state.memory[0] = op; 2171 2172 // The matcher function always starts at code address 0. 2173 ByteCodeExecutor executor( 2174 matcherByteCode.data(), state.memory, state.opRangeMemory, 2175 state.typeRangeMemory, state.allocatedTypeRangeMemory, 2176 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, 2177 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, 2178 constraintFunctions, rewriteFunctions); 2179 executor.execute(rewriter, &matches); 2180 2181 // Order the found matches by benefit. 2182 std::stable_sort(matches.begin(), matches.end(), 2183 [](const MatchResult &lhs, const MatchResult &rhs) { 2184 return lhs.benefit > rhs.benefit; 2185 }); 2186 } 2187 2188 /// Run the rewriter of the given pattern on the root operation `op`. 2189 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 2190 PDLByteCodeMutableState &state) const { 2191 // The arguments of the rewrite function are stored at the start of the 2192 // memory buffer. 2193 llvm::copy(match.values, state.memory.begin()); 2194 2195 ByteCodeExecutor executor( 2196 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 2197 state.opRangeMemory, state.typeRangeMemory, 2198 state.allocatedTypeRangeMemory, state.valueRangeMemory, 2199 state.allocatedValueRangeMemory, state.loopIndex, uniquedData, 2200 rewriterByteCode, state.currentPatternBenefits, patterns, 2201 constraintFunctions, rewriteFunctions); 2202 executor.execute(rewriter, /*matches=*/nullptr, match.location); 2203 } 2204