1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 27 #define DEBUG_TYPE "pdl-bytecode" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37 ByteCodeAddr rewriterAddr) { 38 SmallVector<StringRef, 8> generatedOps; 39 if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 40 generatedOps = 41 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42 43 PatternBenefit benefit = matchOp.benefit(); 44 MLIRContext *ctx = matchOp.getContext(); 45 46 // Check to see if this is pattern matches a specific operation type. 47 if (Optional<StringRef> rootKind = matchOp.rootKind()) 48 return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, 49 generatedOps); 50 return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, 51 generatedOps); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // PDLByteCodeMutableState 56 //===----------------------------------------------------------------------===// 57 58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59 /// to the position of the pattern within the range returned by 60 /// `PDLByteCode::getPatterns`. 61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62 PatternBenefit benefit) { 63 currentPatternBenefits[patternIndex] = benefit; 64 } 65 66 /// Cleanup any allocated state after a full match/rewrite has been completed. 67 /// This method should be called irregardless of whether the match+rewrite was a 68 /// success or not. 69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 70 allocatedTypeRangeMemory.clear(); 71 allocatedValueRangeMemory.clear(); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Bytecode OpCodes 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 enum OpCode : ByteCodeField { 80 /// Apply an externally registered constraint. 81 ApplyConstraint, 82 /// Apply an externally registered rewrite. 83 ApplyRewrite, 84 /// Check if two generic values are equal. 85 AreEqual, 86 /// Check if two ranges are equal. 87 AreRangesEqual, 88 /// Unconditional branch. 89 Branch, 90 /// Compare the operand count of an operation with a constant. 91 CheckOperandCount, 92 /// Compare the name of an operation with a constant. 93 CheckOperationName, 94 /// Compare the result count of an operation with a constant. 95 CheckResultCount, 96 /// Compare a range of types to a constant range of types. 97 CheckTypes, 98 /// Continue to the next iteration of a loop. 99 Continue, 100 /// Create an operation. 101 CreateOperation, 102 /// Create a range of types. 103 CreateTypes, 104 /// Erase an operation. 105 EraseOp, 106 /// Extract the op from a range at the specified index. 107 ExtractOp, 108 /// Extract the type from a range at the specified index. 109 ExtractType, 110 /// Extract the value from a range at the specified index. 111 ExtractValue, 112 /// Terminate a matcher or rewrite sequence. 113 Finalize, 114 /// Iterate over a range of values. 115 ForEach, 116 /// Get a specific attribute of an operation. 117 GetAttribute, 118 /// Get the type of an attribute. 119 GetAttributeType, 120 /// Get the defining operation of a value. 121 GetDefiningOp, 122 /// Get a specific operand of an operation. 123 GetOperand0, 124 GetOperand1, 125 GetOperand2, 126 GetOperand3, 127 GetOperandN, 128 /// Get a specific operand group of an operation. 129 GetOperands, 130 /// Get a specific result of an operation. 131 GetResult0, 132 GetResult1, 133 GetResult2, 134 GetResult3, 135 GetResultN, 136 /// Get a specific result group of an operation. 137 GetResults, 138 /// Get the users of a value or a range of values. 139 GetUsers, 140 /// Get the type of a value. 141 GetValueType, 142 /// Get the types of a value range. 143 GetValueRangeTypes, 144 /// Check if a generic value is not null. 145 IsNotNull, 146 /// Record a successful pattern match. 147 RecordMatch, 148 /// Replace an operation. 149 ReplaceOp, 150 /// Compare an attribute with a set of constants. 151 SwitchAttribute, 152 /// Compare the operand count of an operation with a set of constants. 153 SwitchOperandCount, 154 /// Compare the name of an operation with a set of constants. 155 SwitchOperationName, 156 /// Compare the result count of an operation with a set of constants. 157 SwitchResultCount, 158 /// Compare a type with a set of constants. 159 SwitchType, 160 /// Compare a range of types with a set of constants. 161 SwitchTypes, 162 }; 163 } // namespace 164 165 //===----------------------------------------------------------------------===// 166 // ByteCode Generation 167 //===----------------------------------------------------------------------===// 168 169 //===----------------------------------------------------------------------===// 170 // Generator 171 172 namespace { 173 struct ByteCodeLiveRange; 174 struct ByteCodeWriter; 175 176 /// Check if the given class `T` can be converted to an opaque pointer. 177 template <typename T, typename... Args> 178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 179 180 /// This class represents the main generator for the pattern bytecode. 181 class Generator { 182 public: 183 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 184 SmallVectorImpl<ByteCodeField> &matcherByteCode, 185 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 186 SmallVectorImpl<PDLByteCodePattern> &patterns, 187 ByteCodeField &maxValueMemoryIndex, 188 ByteCodeField &maxOpRangeMemoryIndex, 189 ByteCodeField &maxTypeRangeMemoryIndex, 190 ByteCodeField &maxValueRangeMemoryIndex, 191 ByteCodeField &maxLoopLevel, 192 llvm::StringMap<PDLConstraintFunction> &constraintFns, 193 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 194 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 195 rewriterByteCode(rewriterByteCode), patterns(patterns), 196 maxValueMemoryIndex(maxValueMemoryIndex), 197 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), 198 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 199 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), 200 maxLoopLevel(maxLoopLevel) { 201 for (auto it : llvm::enumerate(constraintFns)) 202 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 203 for (auto it : llvm::enumerate(rewriteFns)) 204 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 205 } 206 207 /// Generate the bytecode for the given PDL interpreter module. 208 void generate(ModuleOp module); 209 210 /// Return the memory index to use for the given value. 211 ByteCodeField &getMemIndex(Value value) { 212 assert(valueToMemIndex.count(value) && 213 "expected memory index to be assigned"); 214 return valueToMemIndex[value]; 215 } 216 217 /// Return the range memory index used to store the given range value. 218 ByteCodeField &getRangeStorageIndex(Value value) { 219 assert(valueToRangeIndex.count(value) && 220 "expected range index to be assigned"); 221 return valueToRangeIndex[value]; 222 } 223 224 /// Return an index to use when referring to the given data that is uniqued in 225 /// the MLIR context. 226 template <typename T> 227 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 228 getMemIndex(T val) { 229 const void *opaqueVal = val.getAsOpaquePointer(); 230 231 // Get or insert a reference to this value. 232 auto it = uniquedDataToMemIndex.try_emplace( 233 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 234 if (it.second) 235 uniquedData.push_back(opaqueVal); 236 return it.first->second; 237 } 238 239 private: 240 /// Allocate memory indices for the results of operations within the matcher 241 /// and rewriters. 242 void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 243 244 /// Generate the bytecode for the given operation. 245 void generate(Region *region, ByteCodeWriter &writer); 246 void generate(Operation *op, ByteCodeWriter &writer); 247 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 248 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 249 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 250 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 251 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 252 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 253 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 254 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 255 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 256 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 257 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); 258 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 259 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 260 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 261 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 262 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 263 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); 264 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 265 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); 266 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 267 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 268 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 269 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 270 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 271 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 272 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 273 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); 274 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 275 void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); 276 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 277 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 278 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 279 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 280 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 281 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 282 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 283 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 284 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 285 286 /// Mapping from value to its corresponding memory index. 287 DenseMap<Value, ByteCodeField> valueToMemIndex; 288 289 /// Mapping from a range value to its corresponding range storage index. 290 DenseMap<Value, ByteCodeField> valueToRangeIndex; 291 292 /// Mapping from the name of an externally registered rewrite to its index in 293 /// the bytecode registry. 294 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 295 296 /// Mapping from the name of an externally registered constraint to its index 297 /// in the bytecode registry. 298 llvm::StringMap<ByteCodeField> constraintToMemIndex; 299 300 /// Mapping from rewriter function name to the bytecode address of the 301 /// rewriter function in byte. 302 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 303 304 /// Mapping from a uniqued storage object to its memory index within 305 /// `uniquedData`. 306 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 307 308 /// The current level of the foreach loop. 309 ByteCodeField curLoopLevel = 0; 310 311 /// The current MLIR context. 312 MLIRContext *ctx; 313 314 /// Mapping from block to its address. 315 DenseMap<Block *, ByteCodeAddr> blockToAddr; 316 317 /// Data of the ByteCode class to be populated. 318 std::vector<const void *> &uniquedData; 319 SmallVectorImpl<ByteCodeField> &matcherByteCode; 320 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 321 SmallVectorImpl<PDLByteCodePattern> &patterns; 322 ByteCodeField &maxValueMemoryIndex; 323 ByteCodeField &maxOpRangeMemoryIndex; 324 ByteCodeField &maxTypeRangeMemoryIndex; 325 ByteCodeField &maxValueRangeMemoryIndex; 326 ByteCodeField &maxLoopLevel; 327 }; 328 329 /// This class provides utilities for writing a bytecode stream. 330 struct ByteCodeWriter { 331 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 332 : bytecode(bytecode), generator(generator) {} 333 334 /// Append a field to the bytecode. 335 void append(ByteCodeField field) { bytecode.push_back(field); } 336 void append(OpCode opCode) { bytecode.push_back(opCode); } 337 338 /// Append an address to the bytecode. 339 void append(ByteCodeAddr field) { 340 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 341 "unexpected ByteCode address size"); 342 343 ByteCodeField fieldParts[2]; 344 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 345 bytecode.append({fieldParts[0], fieldParts[1]}); 346 } 347 348 /// Append a single successor to the bytecode, the exact address will need to 349 /// be resolved later. 350 void append(Block *successor) { 351 // Add back a reference to the successor so that the address can be resolved 352 // later. 353 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 354 append(ByteCodeAddr(0)); 355 } 356 357 /// Append a successor range to the bytecode, the exact address will need to 358 /// be resolved later. 359 void append(SuccessorRange successors) { 360 for (Block *successor : successors) 361 append(successor); 362 } 363 364 /// Append a range of values that will be read as generic PDLValues. 365 void appendPDLValueList(OperandRange values) { 366 bytecode.push_back(values.size()); 367 for (Value value : values) 368 appendPDLValue(value); 369 } 370 371 /// Append a value as a PDLValue. 372 void appendPDLValue(Value value) { 373 appendPDLValueKind(value); 374 append(value); 375 } 376 377 /// Append the PDLValue::Kind of the given value. 378 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } 379 380 /// Append the PDLValue::Kind of the given type. 381 void appendPDLValueKind(Type type) { 382 PDLValue::Kind kind = 383 TypeSwitch<Type, PDLValue::Kind>(type) 384 .Case<pdl::AttributeType>( 385 [](Type) { return PDLValue::Kind::Attribute; }) 386 .Case<pdl::OperationType>( 387 [](Type) { return PDLValue::Kind::Operation; }) 388 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 389 if (rangeTy.getElementType().isa<pdl::TypeType>()) 390 return PDLValue::Kind::TypeRange; 391 return PDLValue::Kind::ValueRange; 392 }) 393 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 394 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 395 bytecode.push_back(static_cast<ByteCodeField>(kind)); 396 } 397 398 /// Append a value that will be stored in a memory slot and not inline within 399 /// the bytecode. 400 template <typename T> 401 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 402 std::is_pointer<T>::value> 403 append(T value) { 404 bytecode.push_back(generator.getMemIndex(value)); 405 } 406 407 /// Append a range of values. 408 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 409 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 410 append(T range) { 411 bytecode.push_back(llvm::size(range)); 412 for (auto it : range) 413 append(it); 414 } 415 416 /// Append a variadic number of fields to the bytecode. 417 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 418 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 419 append(field); 420 append(field2, fields...); 421 } 422 423 /// Appends a value as a pointer, stored inline within the bytecode. 424 template <typename T> 425 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 426 appendInline(T value) { 427 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); 428 const void *pointer = value.getAsOpaquePointer(); 429 ByteCodeField fieldParts[numParts]; 430 std::memcpy(fieldParts, &pointer, sizeof(const void *)); 431 bytecode.append(fieldParts, fieldParts + numParts); 432 } 433 434 /// Successor references in the bytecode that have yet to be resolved. 435 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 436 437 /// The underlying bytecode buffer. 438 SmallVectorImpl<ByteCodeField> &bytecode; 439 440 /// The main generator producing PDL. 441 Generator &generator; 442 }; 443 444 /// This class represents a live range of PDL Interpreter values, containing 445 /// information about when values are live within a match/rewrite. 446 struct ByteCodeLiveRange { 447 using Set = llvm::IntervalMap<uint64_t, char, 16>; 448 using Allocator = Set::Allocator; 449 450 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} 451 452 /// Union this live range with the one provided. 453 void unionWith(const ByteCodeLiveRange &rhs) { 454 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; 455 ++it) 456 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); 457 } 458 459 /// Returns true if this range overlaps with the one provided. 460 bool overlaps(const ByteCodeLiveRange &rhs) const { 461 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness) 462 .valid(); 463 } 464 465 /// A map representing the ranges of the match/rewrite that a value is live in 466 /// the interpreter. 467 /// 468 /// We use std::unique_ptr here, because IntervalMap does not provide a 469 /// correct copy or move constructor. We can eliminate the pointer once 470 /// https://reviews.llvm.org/D113240 lands. 471 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness; 472 473 /// The operation range storage index for this range. 474 Optional<unsigned> opRangeIndex; 475 476 /// The type range storage index for this range. 477 Optional<unsigned> typeRangeIndex; 478 479 /// The value range storage index for this range. 480 Optional<unsigned> valueRangeIndex; 481 }; 482 } // namespace 483 484 void Generator::generate(ModuleOp module) { 485 FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 486 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 487 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 488 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 489 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 490 491 // Allocate memory indices for the results of operations within the matcher 492 // and rewriters. 493 allocateMemoryIndices(matcherFunc, rewriterModule); 494 495 // Generate code for the rewriter functions. 496 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 497 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 498 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 499 for (Operation &op : rewriterFunc.getOps()) 500 generate(&op, rewriterByteCodeWriter); 501 } 502 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 503 "unexpected branches in rewriter function"); 504 505 // Generate code for the matcher function. 506 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 507 generate(&matcherFunc.getBody(), matcherByteCodeWriter); 508 509 // Resolve successor references in the matcher. 510 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 511 ByteCodeAddr addr = blockToAddr[it.first]; 512 for (unsigned offsetToFix : it.second) 513 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 514 } 515 } 516 517 void Generator::allocateMemoryIndices(FuncOp matcherFunc, 518 ModuleOp rewriterModule) { 519 // Rewriters use simplistic allocation scheme that simply assigns an index to 520 // each result. 521 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 522 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 523 auto processRewriterValue = [&](Value val) { 524 valueToMemIndex.try_emplace(val, index++); 525 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 526 Type elementTy = rangeType.getElementType(); 527 if (elementTy.isa<pdl::TypeType>()) 528 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 529 else if (elementTy.isa<pdl::ValueType>()) 530 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 531 } 532 }; 533 534 for (BlockArgument arg : rewriterFunc.getArguments()) 535 processRewriterValue(arg); 536 rewriterFunc.getBody().walk([&](Operation *op) { 537 for (Value result : op->getResults()) 538 processRewriterValue(result); 539 }); 540 if (index > maxValueMemoryIndex) 541 maxValueMemoryIndex = index; 542 if (typeRangeIndex > maxTypeRangeMemoryIndex) 543 maxTypeRangeMemoryIndex = typeRangeIndex; 544 if (valueRangeIndex > maxValueRangeMemoryIndex) 545 maxValueRangeMemoryIndex = valueRangeIndex; 546 } 547 548 // The matcher function uses a more sophisticated numbering that tries to 549 // minimize the number of memory indices assigned. This is done by determining 550 // a live range of the values within the matcher, then the allocation is just 551 // finding the minimal number of overlapping live ranges. This is essentially 552 // a simplified form of register allocation where we don't necessarily have a 553 // limited number of registers, but we still want to minimize the number used. 554 DenseMap<Operation *, unsigned> opToIndex; 555 matcherFunc.getBody().walk([&](Operation *op) { 556 opToIndex.insert(std::make_pair(op, opToIndex.size())); 557 }); 558 559 // Liveness info for each of the defs within the matcher. 560 ByteCodeLiveRange::Allocator allocator; 561 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 562 563 // Assign the root operation being matched to slot 0. 564 BlockArgument rootOpArg = matcherFunc.getArgument(0); 565 valueToMemIndex[rootOpArg] = 0; 566 567 // Walk each of the blocks, computing the def interval that the value is used. 568 Liveness matcherLiveness(matcherFunc); 569 matcherFunc->walk([&](Block *block) { 570 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); 571 assert(info && "expected liveness info for block"); 572 auto processValue = [&](Value value, Operation *firstUseOrDef) { 573 // We don't need to process the root op argument, this value is always 574 // assigned to the first memory slot. 575 if (value == rootOpArg) 576 return; 577 578 // Set indices for the range of this block that the value is used. 579 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 580 defRangeIt->second.liveness->insert( 581 opToIndex[firstUseOrDef], 582 opToIndex[info->getEndOperation(value, firstUseOrDef)], 583 /*dummyValue*/ 0); 584 585 // Check to see if this value is a range type. 586 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 587 Type eleType = rangeTy.getElementType(); 588 if (eleType.isa<pdl::OperationType>()) 589 defRangeIt->second.opRangeIndex = 0; 590 else if (eleType.isa<pdl::TypeType>()) 591 defRangeIt->second.typeRangeIndex = 0; 592 else if (eleType.isa<pdl::ValueType>()) 593 defRangeIt->second.valueRangeIndex = 0; 594 } 595 }; 596 597 // Process the live-ins of this block. 598 for (Value liveIn : info->in()) { 599 // Only process the value if it has been defined in the current region. 600 // Other values that span across pdl_interp.foreach will be added higher 601 // up. This ensures that the we keep them alive for the entire duration 602 // of the loop. 603 if (liveIn.getParentRegion() == block->getParent()) 604 processValue(liveIn, &block->front()); 605 } 606 607 // Process the block arguments for the entry block (those are not live-in). 608 if (block->isEntryBlock()) { 609 for (Value argument : block->getArguments()) 610 processValue(argument, &block->front()); 611 } 612 613 // Process any new defs within this block. 614 for (Operation &op : *block) 615 for (Value result : op.getResults()) 616 processValue(result, &op); 617 }); 618 619 // Greedily allocate memory slots using the computed def live ranges. 620 std::vector<ByteCodeLiveRange> allocatedIndices; 621 622 // The number of memory indices currently allocated (and its next value). 623 // Recall that the root gets allocated memory index 0. 624 ByteCodeField numIndices = 1; 625 626 // The number of memory ranges of various types (and their next values). 627 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; 628 629 for (auto &defIt : valueDefRanges) { 630 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 631 ByteCodeLiveRange &defRange = defIt.second; 632 633 // Try to allocate to an existing index. 634 for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { 635 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 636 if (!defRange.overlaps(existingRange)) { 637 existingRange.unionWith(defRange); 638 memIndex = existingIndexIt.index() + 1; 639 640 if (defRange.opRangeIndex) { 641 if (!existingRange.opRangeIndex) 642 existingRange.opRangeIndex = numOpRanges++; 643 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; 644 } else if (defRange.typeRangeIndex) { 645 if (!existingRange.typeRangeIndex) 646 existingRange.typeRangeIndex = numTypeRanges++; 647 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 648 } else if (defRange.valueRangeIndex) { 649 if (!existingRange.valueRangeIndex) 650 existingRange.valueRangeIndex = numValueRanges++; 651 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 652 } 653 break; 654 } 655 } 656 657 // If no existing index could be used, add a new one. 658 if (memIndex == 0) { 659 allocatedIndices.emplace_back(allocator); 660 ByteCodeLiveRange &newRange = allocatedIndices.back(); 661 newRange.unionWith(defRange); 662 663 // Allocate an index for op/type/value ranges. 664 if (defRange.opRangeIndex) { 665 newRange.opRangeIndex = numOpRanges; 666 valueToRangeIndex[defIt.first] = numOpRanges++; 667 } else if (defRange.typeRangeIndex) { 668 newRange.typeRangeIndex = numTypeRanges; 669 valueToRangeIndex[defIt.first] = numTypeRanges++; 670 } else if (defRange.valueRangeIndex) { 671 newRange.valueRangeIndex = numValueRanges; 672 valueToRangeIndex[defIt.first] = numValueRanges++; 673 } 674 675 memIndex = allocatedIndices.size(); 676 ++numIndices; 677 } 678 } 679 680 // Print the index usage and ensure that we did not run out of index space. 681 LLVM_DEBUG({ 682 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " 683 << "(down from initial " << valueDefRanges.size() << ").\n"; 684 }); 685 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && 686 "Ran out of memory for allocated indices"); 687 688 // Update the max number of indices. 689 if (numIndices > maxValueMemoryIndex) 690 maxValueMemoryIndex = numIndices; 691 if (numOpRanges > maxOpRangeMemoryIndex) 692 maxOpRangeMemoryIndex = numOpRanges; 693 if (numTypeRanges > maxTypeRangeMemoryIndex) 694 maxTypeRangeMemoryIndex = numTypeRanges; 695 if (numValueRanges > maxValueRangeMemoryIndex) 696 maxValueRangeMemoryIndex = numValueRanges; 697 } 698 699 void Generator::generate(Region *region, ByteCodeWriter &writer) { 700 llvm::ReversePostOrderTraversal<Region *> rpot(region); 701 for (Block *block : rpot) { 702 // Keep track of where this block begins within the matcher function. 703 blockToAddr.try_emplace(block, matcherByteCode.size()); 704 for (Operation &op : *block) 705 generate(&op, writer); 706 } 707 } 708 709 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 710 LLVM_DEBUG({ 711 // The following list must contain all the operations that do not 712 // produce any bytecode. 713 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp, 714 pdl_interp::InferredTypesOp>(op)) 715 writer.appendInline(op->getLoc()); 716 }); 717 TypeSwitch<Operation *>(op) 718 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 719 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 720 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 721 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 722 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 723 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, 724 pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, 725 pdl_interp::CreateTypesOp, pdl_interp::EraseOp, 726 pdl_interp::ExtractOp, pdl_interp::FinalizeOp, 727 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, 728 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 729 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, 730 pdl_interp::GetResultOp, pdl_interp::GetResultsOp, 731 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, 732 pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, 733 pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, 734 pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, 735 pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, 736 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 737 [&](auto interpOp) { this->generate(interpOp, writer); }) 738 .Default([](Operation *) { 739 llvm_unreachable("unknown `pdl_interp` operation"); 740 }); 741 } 742 743 void Generator::generate(pdl_interp::ApplyConstraintOp op, 744 ByteCodeWriter &writer) { 745 assert(constraintToMemIndex.count(op.name()) && 746 "expected index for constraint function"); 747 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 748 op.constParamsAttr()); 749 writer.appendPDLValueList(op.args()); 750 writer.append(op.getSuccessors()); 751 } 752 void Generator::generate(pdl_interp::ApplyRewriteOp op, 753 ByteCodeWriter &writer) { 754 assert(externalRewriterToMemIndex.count(op.name()) && 755 "expected index for rewrite function"); 756 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 757 op.constParamsAttr()); 758 writer.appendPDLValueList(op.args()); 759 760 ResultRange results = op.results(); 761 writer.append(ByteCodeField(results.size())); 762 for (Value result : results) { 763 // In debug mode we also record the expected kind of the result, so that we 764 // can provide extra verification of the native rewrite function. 765 #ifndef NDEBUG 766 writer.appendPDLValueKind(result); 767 #endif 768 769 // Range results also need to append the range storage index. 770 if (result.getType().isa<pdl::RangeType>()) 771 writer.append(getRangeStorageIndex(result)); 772 writer.append(result); 773 } 774 } 775 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 776 Value lhs = op.lhs(); 777 if (lhs.getType().isa<pdl::RangeType>()) { 778 writer.append(OpCode::AreRangesEqual); 779 writer.appendPDLValueKind(lhs); 780 writer.append(op.lhs(), op.rhs(), op.getSuccessors()); 781 return; 782 } 783 784 writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors()); 785 } 786 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 787 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 788 } 789 void Generator::generate(pdl_interp::CheckAttributeOp op, 790 ByteCodeWriter &writer) { 791 writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 792 op.getSuccessors()); 793 } 794 void Generator::generate(pdl_interp::CheckOperandCountOp op, 795 ByteCodeWriter &writer) { 796 writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 797 static_cast<ByteCodeField>(op.compareAtLeast()), 798 op.getSuccessors()); 799 } 800 void Generator::generate(pdl_interp::CheckOperationNameOp op, 801 ByteCodeWriter &writer) { 802 writer.append(OpCode::CheckOperationName, op.operation(), 803 OperationName(op.name(), ctx), op.getSuccessors()); 804 } 805 void Generator::generate(pdl_interp::CheckResultCountOp op, 806 ByteCodeWriter &writer) { 807 writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 808 static_cast<ByteCodeField>(op.compareAtLeast()), 809 op.getSuccessors()); 810 } 811 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 812 writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 813 } 814 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 815 writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); 816 } 817 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { 818 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); 819 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); 820 } 821 void Generator::generate(pdl_interp::CreateAttributeOp op, 822 ByteCodeWriter &writer) { 823 // Simply repoint the memory index of the result to the constant. 824 getMemIndex(op.attribute()) = getMemIndex(op.value()); 825 } 826 void Generator::generate(pdl_interp::CreateOperationOp op, 827 ByteCodeWriter &writer) { 828 writer.append(OpCode::CreateOperation, op.operation(), 829 OperationName(op.name(), ctx)); 830 writer.appendPDLValueList(op.operands()); 831 832 // Add the attributes. 833 OperandRange attributes = op.attributes(); 834 writer.append(static_cast<ByteCodeField>(attributes.size())); 835 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) 836 writer.append(std::get<0>(it), std::get<1>(it)); 837 writer.appendPDLValueList(op.types()); 838 } 839 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 840 // Simply repoint the memory index of the result to the constant. 841 getMemIndex(op.result()) = getMemIndex(op.value()); 842 } 843 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 844 writer.append(OpCode::CreateTypes, op.result(), 845 getRangeStorageIndex(op.result()), op.value()); 846 } 847 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 848 writer.append(OpCode::EraseOp, op.operation()); 849 } 850 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { 851 OpCode opCode = 852 TypeSwitch<Type, OpCode>(op.result().getType()) 853 .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) 854 .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) 855 .Case([](pdl::TypeType) { return OpCode::ExtractType; }) 856 .Default([](Type) -> OpCode { 857 llvm_unreachable("unsupported element type"); 858 }); 859 writer.append(opCode, op.range(), op.index(), op.result()); 860 } 861 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 862 writer.append(OpCode::Finalize); 863 } 864 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { 865 BlockArgument arg = op.getLoopVariable(); 866 writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg); 867 writer.appendPDLValueKind(arg.getType()); 868 writer.append(curLoopLevel, op.successor()); 869 ++curLoopLevel; 870 if (curLoopLevel > maxLoopLevel) 871 maxLoopLevel = curLoopLevel; 872 generate(&op.region(), writer); 873 --curLoopLevel; 874 } 875 void Generator::generate(pdl_interp::GetAttributeOp op, 876 ByteCodeWriter &writer) { 877 writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 878 op.nameAttr()); 879 } 880 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 881 ByteCodeWriter &writer) { 882 writer.append(OpCode::GetAttributeType, op.result(), op.value()); 883 } 884 void Generator::generate(pdl_interp::GetDefiningOpOp op, 885 ByteCodeWriter &writer) { 886 writer.append(OpCode::GetDefiningOp, op.operation()); 887 writer.appendPDLValue(op.value()); 888 } 889 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 890 uint32_t index = op.index(); 891 if (index < 4) 892 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 893 else 894 writer.append(OpCode::GetOperandN, index); 895 writer.append(op.operation(), op.value()); 896 } 897 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 898 Value result = op.value(); 899 Optional<uint32_t> index = op.index(); 900 writer.append(OpCode::GetOperands, 901 index.getValueOr(std::numeric_limits<uint32_t>::max()), 902 op.operation()); 903 if (result.getType().isa<pdl::RangeType>()) 904 writer.append(getRangeStorageIndex(result)); 905 else 906 writer.append(std::numeric_limits<ByteCodeField>::max()); 907 writer.append(result); 908 } 909 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 910 uint32_t index = op.index(); 911 if (index < 4) 912 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 913 else 914 writer.append(OpCode::GetResultN, index); 915 writer.append(op.operation(), op.value()); 916 } 917 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 918 Value result = op.value(); 919 Optional<uint32_t> index = op.index(); 920 writer.append(OpCode::GetResults, 921 index.getValueOr(std::numeric_limits<uint32_t>::max()), 922 op.operation()); 923 if (result.getType().isa<pdl::RangeType>()) 924 writer.append(getRangeStorageIndex(result)); 925 else 926 writer.append(std::numeric_limits<ByteCodeField>::max()); 927 writer.append(result); 928 } 929 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { 930 Value operations = op.operations(); 931 ByteCodeField rangeIndex = getRangeStorageIndex(operations); 932 writer.append(OpCode::GetUsers, operations, rangeIndex); 933 writer.appendPDLValue(op.value()); 934 } 935 void Generator::generate(pdl_interp::GetValueTypeOp op, 936 ByteCodeWriter &writer) { 937 if (op.getType().isa<pdl::RangeType>()) { 938 Value result = op.result(); 939 writer.append(OpCode::GetValueRangeTypes, result, 940 getRangeStorageIndex(result), op.value()); 941 } else { 942 writer.append(OpCode::GetValueType, op.result(), op.value()); 943 } 944 } 945 946 void Generator::generate(pdl_interp::InferredTypesOp op, 947 ByteCodeWriter &writer) { 948 // InferType maps to a null type as a marker for inferring result types. 949 getMemIndex(op.type()) = getMemIndex(Type()); 950 } 951 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 952 writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 953 } 954 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 955 ByteCodeField patternIndex = patterns.size(); 956 patterns.emplace_back(PDLByteCodePattern::create( 957 op, rewriterToAddr[op.rewriter().getLeafReference().getValue()])); 958 writer.append(OpCode::RecordMatch, patternIndex, 959 SuccessorRange(op.getOperation()), op.matchedOps()); 960 writer.appendPDLValueList(op.inputs()); 961 } 962 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 963 writer.append(OpCode::ReplaceOp, op.operation()); 964 writer.appendPDLValueList(op.replValues()); 965 } 966 void Generator::generate(pdl_interp::SwitchAttributeOp op, 967 ByteCodeWriter &writer) { 968 writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 969 op.getSuccessors()); 970 } 971 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 972 ByteCodeWriter &writer) { 973 writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 974 op.getSuccessors()); 975 } 976 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 977 ByteCodeWriter &writer) { 978 auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 979 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 980 }); 981 writer.append(OpCode::SwitchOperationName, op.operation(), cases, 982 op.getSuccessors()); 983 } 984 void Generator::generate(pdl_interp::SwitchResultCountOp op, 985 ByteCodeWriter &writer) { 986 writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 987 op.getSuccessors()); 988 } 989 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 990 writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 991 op.getSuccessors()); 992 } 993 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 994 writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(), 995 op.getSuccessors()); 996 } 997 998 //===----------------------------------------------------------------------===// 999 // PDLByteCode 1000 //===----------------------------------------------------------------------===// 1001 1002 PDLByteCode::PDLByteCode(ModuleOp module, 1003 llvm::StringMap<PDLConstraintFunction> constraintFns, 1004 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 1005 Generator generator(module.getContext(), uniquedData, matcherByteCode, 1006 rewriterByteCode, patterns, maxValueMemoryIndex, 1007 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, 1008 maxLoopLevel, constraintFns, rewriteFns); 1009 generator.generate(module); 1010 1011 // Initialize the external functions. 1012 for (auto &it : constraintFns) 1013 constraintFunctions.push_back(std::move(it.second)); 1014 for (auto &it : rewriteFns) 1015 rewriteFunctions.push_back(std::move(it.second)); 1016 } 1017 1018 /// Initialize the given state such that it can be used to execute the current 1019 /// bytecode. 1020 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 1021 state.memory.resize(maxValueMemoryIndex, nullptr); 1022 state.opRangeMemory.resize(maxOpRangeCount); 1023 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 1024 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 1025 state.loopIndex.resize(maxLoopLevel, 0); 1026 state.currentPatternBenefits.reserve(patterns.size()); 1027 for (const PDLByteCodePattern &pattern : patterns) 1028 state.currentPatternBenefits.push_back(pattern.getBenefit()); 1029 } 1030 1031 //===----------------------------------------------------------------------===// 1032 // ByteCode Execution 1033 1034 namespace { 1035 /// This class provides support for executing a bytecode stream. 1036 class ByteCodeExecutor { 1037 public: 1038 ByteCodeExecutor( 1039 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 1040 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, 1041 MutableArrayRef<TypeRange> typeRangeMemory, 1042 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 1043 MutableArrayRef<ValueRange> valueRangeMemory, 1044 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 1045 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, 1046 ArrayRef<ByteCodeField> code, 1047 ArrayRef<PatternBenefit> currentPatternBenefits, 1048 ArrayRef<PDLByteCodePattern> patterns, 1049 ArrayRef<PDLConstraintFunction> constraintFunctions, 1050 ArrayRef<PDLRewriteFunction> rewriteFunctions) 1051 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), 1052 typeRangeMemory(typeRangeMemory), 1053 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 1054 valueRangeMemory(valueRangeMemory), 1055 allocatedValueRangeMemory(allocatedValueRangeMemory), 1056 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), 1057 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 1058 constraintFunctions(constraintFunctions), 1059 rewriteFunctions(rewriteFunctions) {} 1060 1061 /// Start executing the code at the current bytecode index. `matches` is an 1062 /// optional field provided when this function is executed in a matching 1063 /// context. 1064 void execute(PatternRewriter &rewriter, 1065 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 1066 Optional<Location> mainRewriteLoc = {}); 1067 1068 private: 1069 /// Internal implementation of executing each of the bytecode commands. 1070 void executeApplyConstraint(PatternRewriter &rewriter); 1071 void executeApplyRewrite(PatternRewriter &rewriter); 1072 void executeAreEqual(); 1073 void executeAreRangesEqual(); 1074 void executeBranch(); 1075 void executeCheckOperandCount(); 1076 void executeCheckOperationName(); 1077 void executeCheckResultCount(); 1078 void executeCheckTypes(); 1079 void executeContinue(); 1080 void executeCreateOperation(PatternRewriter &rewriter, 1081 Location mainRewriteLoc); 1082 void executeCreateTypes(); 1083 void executeEraseOp(PatternRewriter &rewriter); 1084 template <typename T, typename Range, PDLValue::Kind kind> 1085 void executeExtract(); 1086 void executeFinalize(); 1087 void executeForEach(); 1088 void executeGetAttribute(); 1089 void executeGetAttributeType(); 1090 void executeGetDefiningOp(); 1091 void executeGetOperand(unsigned index); 1092 void executeGetOperands(); 1093 void executeGetResult(unsigned index); 1094 void executeGetResults(); 1095 void executeGetUsers(); 1096 void executeGetValueType(); 1097 void executeGetValueRangeTypes(); 1098 void executeIsNotNull(); 1099 void executeRecordMatch(PatternRewriter &rewriter, 1100 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 1101 void executeReplaceOp(PatternRewriter &rewriter); 1102 void executeSwitchAttribute(); 1103 void executeSwitchOperandCount(); 1104 void executeSwitchOperationName(); 1105 void executeSwitchResultCount(); 1106 void executeSwitchType(); 1107 void executeSwitchTypes(); 1108 1109 /// Pushes a code iterator to the stack. 1110 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } 1111 1112 /// Pops a code iterator from the stack, returning true on success. 1113 void popCodeIt() { 1114 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); 1115 curCodeIt = resumeCodeIt.back(); 1116 resumeCodeIt.pop_back(); 1117 } 1118 1119 /// Return the bytecode iterator at the start of the current op code. 1120 const ByteCodeField *getPrevCodeIt() const { 1121 LLVM_DEBUG({ 1122 // Account for the op code and the Location stored inline. 1123 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField); 1124 }); 1125 1126 // Account for the op code only. 1127 return curCodeIt - 1; 1128 } 1129 1130 /// Read a value from the bytecode buffer, optionally skipping a certain 1131 /// number of prefix values. These methods always update the buffer to point 1132 /// to the next field after the read data. 1133 template <typename T = ByteCodeField> 1134 T read(size_t skipN = 0) { 1135 curCodeIt += skipN; 1136 return readImpl<T>(); 1137 } 1138 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 1139 1140 /// Read a list of values from the bytecode buffer. 1141 template <typename ValueT, typename T> 1142 void readList(SmallVectorImpl<T> &list) { 1143 list.clear(); 1144 for (unsigned i = 0, e = read(); i != e; ++i) 1145 list.push_back(read<ValueT>()); 1146 } 1147 1148 /// Read a list of values from the bytecode buffer. The values may be encoded 1149 /// as either Value or ValueRange elements. 1150 void readValueList(SmallVectorImpl<Value> &list) { 1151 for (unsigned i = 0, e = read(); i != e; ++i) { 1152 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1153 list.push_back(read<Value>()); 1154 } else { 1155 ValueRange *values = read<ValueRange *>(); 1156 list.append(values->begin(), values->end()); 1157 } 1158 } 1159 } 1160 1161 /// Read a value stored inline as a pointer. 1162 template <typename T> 1163 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T> 1164 readInline() { 1165 const void *pointer; 1166 std::memcpy(&pointer, curCodeIt, sizeof(const void *)); 1167 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); 1168 return T::getFromOpaquePointer(pointer); 1169 } 1170 1171 /// Jump to a specific successor based on a predicate value. 1172 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 1173 /// Jump to a specific successor based on a destination index. 1174 void selectJump(size_t destIndex) { 1175 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 1176 } 1177 1178 /// Handle a switch operation with the provided value and cases. 1179 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1180 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1181 LLVM_DEBUG({ 1182 llvm::dbgs() << " * Value: " << value << "\n" 1183 << " * Cases: "; 1184 llvm::interleaveComma(cases, llvm::dbgs()); 1185 llvm::dbgs() << "\n"; 1186 }); 1187 1188 // Check to see if the attribute value is within the case list. Jump to 1189 // the correct successor index based on the result. 1190 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1191 if (cmp(*it, value)) 1192 return selectJump(size_t((it - cases.begin()) + 1)); 1193 selectJump(size_t(0)); 1194 } 1195 1196 /// Store a pointer to memory. 1197 void storeToMemory(unsigned index, const void *value) { 1198 memory[index] = value; 1199 } 1200 1201 /// Store a value to memory as an opaque pointer. 1202 template <typename T> 1203 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 1204 storeToMemory(unsigned index, T value) { 1205 memory[index] = value.getAsOpaquePointer(); 1206 } 1207 1208 /// Internal implementation of reading various data types from the bytecode 1209 /// stream. 1210 template <typename T> 1211 const void *readFromMemory() { 1212 size_t index = *curCodeIt++; 1213 1214 // If this type is an SSA value, it can only be stored in non-const memory. 1215 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1216 Value>::value || 1217 index < memory.size()) 1218 return memory[index]; 1219 1220 // Otherwise, if this index is not inbounds it is uniqued. 1221 return uniquedMemory[index - memory.size()]; 1222 } 1223 template <typename T> 1224 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1225 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1226 } 1227 template <typename T> 1228 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1229 T> 1230 readImpl() { 1231 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1232 } 1233 template <typename T> 1234 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1235 switch (read<PDLValue::Kind>()) { 1236 case PDLValue::Kind::Attribute: 1237 return read<Attribute>(); 1238 case PDLValue::Kind::Operation: 1239 return read<Operation *>(); 1240 case PDLValue::Kind::Type: 1241 return read<Type>(); 1242 case PDLValue::Kind::Value: 1243 return read<Value>(); 1244 case PDLValue::Kind::TypeRange: 1245 return read<TypeRange *>(); 1246 case PDLValue::Kind::ValueRange: 1247 return read<ValueRange *>(); 1248 } 1249 llvm_unreachable("unhandled PDLValue::Kind"); 1250 } 1251 template <typename T> 1252 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1253 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1254 "unexpected ByteCode address size"); 1255 ByteCodeAddr result; 1256 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1257 curCodeIt += 2; 1258 return result; 1259 } 1260 template <typename T> 1261 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1262 return *curCodeIt++; 1263 } 1264 template <typename T> 1265 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1266 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1267 } 1268 1269 /// The underlying bytecode buffer. 1270 const ByteCodeField *curCodeIt; 1271 1272 /// The stack of bytecode positions at which to resume operation. 1273 SmallVector<const ByteCodeField *> resumeCodeIt; 1274 1275 /// The current execution memory. 1276 MutableArrayRef<const void *> memory; 1277 MutableArrayRef<OwningOpRange> opRangeMemory; 1278 MutableArrayRef<TypeRange> typeRangeMemory; 1279 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1280 MutableArrayRef<ValueRange> valueRangeMemory; 1281 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1282 1283 /// The current loop indices. 1284 MutableArrayRef<unsigned> loopIndex; 1285 1286 /// References to ByteCode data necessary for execution. 1287 ArrayRef<const void *> uniquedMemory; 1288 ArrayRef<ByteCodeField> code; 1289 ArrayRef<PatternBenefit> currentPatternBenefits; 1290 ArrayRef<PDLByteCodePattern> patterns; 1291 ArrayRef<PDLConstraintFunction> constraintFunctions; 1292 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1293 }; 1294 1295 /// This class is an instantiation of the PDLResultList that provides access to 1296 /// the returned results. This API is not on `PDLResultList` to avoid 1297 /// overexposing access to information specific solely to the ByteCode. 1298 class ByteCodeRewriteResultList : public PDLResultList { 1299 public: 1300 ByteCodeRewriteResultList(unsigned maxNumResults) 1301 : PDLResultList(maxNumResults) {} 1302 1303 /// Return the list of PDL results. 1304 MutableArrayRef<PDLValue> getResults() { return results; } 1305 1306 /// Return the type ranges allocated by this list. 1307 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1308 return allocatedTypeRanges; 1309 } 1310 1311 /// Return the value ranges allocated by this list. 1312 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1313 return allocatedValueRanges; 1314 } 1315 }; 1316 } // namespace 1317 1318 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1319 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1320 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1321 ArrayAttr constParams = read<ArrayAttr>(); 1322 SmallVector<PDLValue, 16> args; 1323 readList<PDLValue>(args); 1324 1325 LLVM_DEBUG({ 1326 llvm::dbgs() << " * Arguments: "; 1327 llvm::interleaveComma(args, llvm::dbgs()); 1328 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1329 }); 1330 1331 // Invoke the constraint and jump to the proper destination. 1332 selectJump(succeeded(constraintFn(args, constParams, rewriter))); 1333 } 1334 1335 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1336 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1337 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1338 ArrayAttr constParams = read<ArrayAttr>(); 1339 SmallVector<PDLValue, 16> args; 1340 readList<PDLValue>(args); 1341 1342 LLVM_DEBUG({ 1343 llvm::dbgs() << " * Arguments: "; 1344 llvm::interleaveComma(args, llvm::dbgs()); 1345 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1346 }); 1347 1348 // Execute the rewrite function. 1349 ByteCodeField numResults = read(); 1350 ByteCodeRewriteResultList results(numResults); 1351 rewriteFn(args, constParams, rewriter, results); 1352 1353 assert(results.getResults().size() == numResults && 1354 "native PDL rewrite function returned unexpected number of results"); 1355 1356 // Store the results in the bytecode memory. 1357 for (PDLValue &result : results.getResults()) { 1358 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1359 1360 // In debug mode we also verify the expected kind of the result. 1361 #ifndef NDEBUG 1362 assert(result.getKind() == read<PDLValue::Kind>() && 1363 "native PDL rewrite function returned an unexpected type of result"); 1364 #endif 1365 1366 // If the result is a range, we need to copy it over to the bytecodes 1367 // range memory. 1368 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1369 unsigned rangeIndex = read(); 1370 typeRangeMemory[rangeIndex] = *typeRange; 1371 memory[read()] = &typeRangeMemory[rangeIndex]; 1372 } else if (Optional<ValueRange> valueRange = 1373 result.dyn_cast<ValueRange>()) { 1374 unsigned rangeIndex = read(); 1375 valueRangeMemory[rangeIndex] = *valueRange; 1376 memory[read()] = &valueRangeMemory[rangeIndex]; 1377 } else { 1378 memory[read()] = result.getAsOpaquePointer(); 1379 } 1380 } 1381 1382 // Copy over any underlying storage allocated for result ranges. 1383 for (auto &it : results.getAllocatedTypeRanges()) 1384 allocatedTypeRangeMemory.push_back(std::move(it)); 1385 for (auto &it : results.getAllocatedValueRanges()) 1386 allocatedValueRangeMemory.push_back(std::move(it)); 1387 } 1388 1389 void ByteCodeExecutor::executeAreEqual() { 1390 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1391 const void *lhs = read<const void *>(); 1392 const void *rhs = read<const void *>(); 1393 1394 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1395 selectJump(lhs == rhs); 1396 } 1397 1398 void ByteCodeExecutor::executeAreRangesEqual() { 1399 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1400 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1401 const void *lhs = read<const void *>(); 1402 const void *rhs = read<const void *>(); 1403 1404 switch (valueKind) { 1405 case PDLValue::Kind::TypeRange: { 1406 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1407 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1408 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1409 selectJump(*lhsRange == *rhsRange); 1410 break; 1411 } 1412 case PDLValue::Kind::ValueRange: { 1413 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1414 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1415 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1416 selectJump(*lhsRange == *rhsRange); 1417 break; 1418 } 1419 default: 1420 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1421 } 1422 } 1423 1424 void ByteCodeExecutor::executeBranch() { 1425 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1426 curCodeIt = &code[read<ByteCodeAddr>()]; 1427 } 1428 1429 void ByteCodeExecutor::executeCheckOperandCount() { 1430 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1431 Operation *op = read<Operation *>(); 1432 uint32_t expectedCount = read<uint32_t>(); 1433 bool compareAtLeast = read(); 1434 1435 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1436 << " * Expected: " << expectedCount << "\n" 1437 << " * Comparator: " 1438 << (compareAtLeast ? ">=" : "==") << "\n"); 1439 if (compareAtLeast) 1440 selectJump(op->getNumOperands() >= expectedCount); 1441 else 1442 selectJump(op->getNumOperands() == expectedCount); 1443 } 1444 1445 void ByteCodeExecutor::executeCheckOperationName() { 1446 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1447 Operation *op = read<Operation *>(); 1448 OperationName expectedName = read<OperationName>(); 1449 1450 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1451 << " * Expected: \"" << expectedName << "\"\n"); 1452 selectJump(op->getName() == expectedName); 1453 } 1454 1455 void ByteCodeExecutor::executeCheckResultCount() { 1456 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1457 Operation *op = read<Operation *>(); 1458 uint32_t expectedCount = read<uint32_t>(); 1459 bool compareAtLeast = read(); 1460 1461 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1462 << " * Expected: " << expectedCount << "\n" 1463 << " * Comparator: " 1464 << (compareAtLeast ? ">=" : "==") << "\n"); 1465 if (compareAtLeast) 1466 selectJump(op->getNumResults() >= expectedCount); 1467 else 1468 selectJump(op->getNumResults() == expectedCount); 1469 } 1470 1471 void ByteCodeExecutor::executeCheckTypes() { 1472 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1473 TypeRange *lhs = read<TypeRange *>(); 1474 Attribute rhs = read<Attribute>(); 1475 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1476 1477 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1478 } 1479 1480 void ByteCodeExecutor::executeContinue() { 1481 ByteCodeField level = read(); 1482 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" 1483 << " * Level: " << level << "\n"); 1484 ++loopIndex[level]; 1485 popCodeIt(); 1486 } 1487 1488 void ByteCodeExecutor::executeCreateTypes() { 1489 LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 1490 unsigned memIndex = read(); 1491 unsigned rangeIndex = read(); 1492 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1493 1494 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1495 1496 // Allocate a buffer for this type range. 1497 llvm::OwningArrayRef<Type> storage(typesAttr.size()); 1498 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 1499 allocatedTypeRangeMemory.emplace_back(std::move(storage)); 1500 1501 // Assign this to the range slot and use the range as the value for the 1502 // memory index. 1503 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 1504 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1505 } 1506 1507 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1508 Location mainRewriteLoc) { 1509 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1510 1511 unsigned memIndex = read(); 1512 OperationState state(mainRewriteLoc, read<OperationName>()); 1513 readValueList(state.operands); 1514 for (unsigned i = 0, e = read(); i != e; ++i) { 1515 StringAttr name = read<StringAttr>(); 1516 if (Attribute attr = read<Attribute>()) 1517 state.addAttribute(name, attr); 1518 } 1519 1520 for (unsigned i = 0, e = read(); i != e; ++i) { 1521 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1522 state.types.push_back(read<Type>()); 1523 continue; 1524 } 1525 1526 // If we find a null range, this signals that the types are infered. 1527 if (TypeRange *resultTypes = read<TypeRange *>()) { 1528 state.types.append(resultTypes->begin(), resultTypes->end()); 1529 continue; 1530 } 1531 1532 // Handle the case where the operation has inferred types. 1533 InferTypeOpInterface::Concept *concept = 1534 state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>(); 1535 1536 // TODO: Handle failure. 1537 state.types.clear(); 1538 if (failed(concept->inferReturnTypes( 1539 state.getContext(), state.location, state.operands, 1540 state.attributes.getDictionary(state.getContext()), state.regions, 1541 state.types))) 1542 return; 1543 break; 1544 } 1545 1546 Operation *resultOp = rewriter.createOperation(state); 1547 memory[memIndex] = resultOp; 1548 1549 LLVM_DEBUG({ 1550 llvm::dbgs() << " * Attributes: " 1551 << state.attributes.getDictionary(state.getContext()) 1552 << "\n * Operands: "; 1553 llvm::interleaveComma(state.operands, llvm::dbgs()); 1554 llvm::dbgs() << "\n * Result Types: "; 1555 llvm::interleaveComma(state.types, llvm::dbgs()); 1556 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1557 }); 1558 } 1559 1560 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1561 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1562 Operation *op = read<Operation *>(); 1563 1564 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1565 rewriter.eraseOp(op); 1566 } 1567 1568 template <typename T, typename Range, PDLValue::Kind kind> 1569 void ByteCodeExecutor::executeExtract() { 1570 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); 1571 Range *range = read<Range *>(); 1572 unsigned index = read<uint32_t>(); 1573 unsigned memIndex = read(); 1574 1575 if (!range) { 1576 memory[memIndex] = nullptr; 1577 return; 1578 } 1579 1580 T result = index < range->size() ? (*range)[index] : T(); 1581 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" 1582 << " * Index: " << index << "\n" 1583 << " * Result: " << result << "\n"); 1584 storeToMemory(memIndex, result); 1585 } 1586 1587 void ByteCodeExecutor::executeFinalize() { 1588 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); 1589 } 1590 1591 void ByteCodeExecutor::executeForEach() { 1592 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); 1593 const ByteCodeField *prevCodeIt = getPrevCodeIt(); 1594 unsigned rangeIndex = read(); 1595 unsigned memIndex = read(); 1596 const void *value = nullptr; 1597 1598 switch (read<PDLValue::Kind>()) { 1599 case PDLValue::Kind::Operation: { 1600 unsigned &index = loopIndex[read()]; 1601 ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; 1602 assert(index <= array.size() && "iterated past the end"); 1603 if (index < array.size()) { 1604 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); 1605 value = array[index]; 1606 break; 1607 } 1608 1609 LLVM_DEBUG(llvm::dbgs() << " * Done\n"); 1610 index = 0; 1611 selectJump(size_t(0)); 1612 return; 1613 } 1614 default: 1615 llvm_unreachable("unexpected `ForEach` value kind"); 1616 } 1617 1618 // Store the iterate value and the stack address. 1619 memory[memIndex] = value; 1620 pushCodeIt(prevCodeIt); 1621 1622 // Skip over the successor (we will enter the body of the loop). 1623 read<ByteCodeAddr>(); 1624 } 1625 1626 void ByteCodeExecutor::executeGetAttribute() { 1627 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1628 unsigned memIndex = read(); 1629 Operation *op = read<Operation *>(); 1630 StringAttr attrName = read<StringAttr>(); 1631 Attribute attr = op->getAttr(attrName); 1632 1633 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1634 << " * Attribute: " << attrName << "\n" 1635 << " * Result: " << attr << "\n"); 1636 memory[memIndex] = attr.getAsOpaquePointer(); 1637 } 1638 1639 void ByteCodeExecutor::executeGetAttributeType() { 1640 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1641 unsigned memIndex = read(); 1642 Attribute attr = read<Attribute>(); 1643 Type type = attr ? attr.getType() : Type(); 1644 1645 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1646 << " * Result: " << type << "\n"); 1647 memory[memIndex] = type.getAsOpaquePointer(); 1648 } 1649 1650 void ByteCodeExecutor::executeGetDefiningOp() { 1651 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1652 unsigned memIndex = read(); 1653 Operation *op = nullptr; 1654 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1655 Value value = read<Value>(); 1656 if (value) 1657 op = value.getDefiningOp(); 1658 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1659 } else { 1660 ValueRange *values = read<ValueRange *>(); 1661 if (values && !values->empty()) { 1662 op = values->front().getDefiningOp(); 1663 } 1664 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1665 } 1666 1667 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1668 memory[memIndex] = op; 1669 } 1670 1671 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1672 Operation *op = read<Operation *>(); 1673 unsigned memIndex = read(); 1674 Value operand = 1675 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1676 1677 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1678 << " * Index: " << index << "\n" 1679 << " * Result: " << operand << "\n"); 1680 memory[memIndex] = operand.getAsOpaquePointer(); 1681 } 1682 1683 /// This function is the internal implementation of `GetResults` and 1684 /// `GetOperands` that provides support for extracting a value range from the 1685 /// given operation. 1686 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1687 static void * 1688 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1689 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1690 MutableArrayRef<ValueRange> valueRangeMemory) { 1691 // Check for the sentinel index that signals that all values should be 1692 // returned. 1693 if (index == std::numeric_limits<uint32_t>::max()) { 1694 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1695 // `values` is already the full value range. 1696 1697 // Otherwise, check to see if this operation uses AttrSizedSegments. 1698 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1699 LLVM_DEBUG(llvm::dbgs() 1700 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1701 1702 auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 1703 if (!segmentAttr || segmentAttr.getNumElements() <= index) 1704 return nullptr; 1705 1706 auto segments = segmentAttr.getValues<int32_t>(); 1707 unsigned startIndex = 1708 std::accumulate(segments.begin(), segments.begin() + index, 0); 1709 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1710 1711 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1712 << *std::next(segments.begin(), index) << "]\n"); 1713 1714 // Otherwise, assume this is the last operand group of the operation. 1715 // FIXME: We currently don't support operations with 1716 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1717 // have a way to detect it's presence. 1718 } else if (values.size() >= index) { 1719 LLVM_DEBUG(llvm::dbgs() 1720 << " * Treating values as trailing variadic range\n"); 1721 values = values.drop_front(index); 1722 1723 // If we couldn't detect a way to compute the values, bail out. 1724 } else { 1725 return nullptr; 1726 } 1727 1728 // If the range index is valid, we are returning a range. 1729 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1730 valueRangeMemory[rangeIndex] = values; 1731 return &valueRangeMemory[rangeIndex]; 1732 } 1733 1734 // If a range index wasn't provided, the range is required to be non-variadic. 1735 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1736 } 1737 1738 void ByteCodeExecutor::executeGetOperands() { 1739 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1740 unsigned index = read<uint32_t>(); 1741 Operation *op = read<Operation *>(); 1742 ByteCodeField rangeIndex = read(); 1743 1744 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1745 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1746 valueRangeMemory); 1747 if (!result) 1748 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1749 memory[read()] = result; 1750 } 1751 1752 void ByteCodeExecutor::executeGetResult(unsigned index) { 1753 Operation *op = read<Operation *>(); 1754 unsigned memIndex = read(); 1755 OpResult result = 1756 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1757 1758 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1759 << " * Index: " << index << "\n" 1760 << " * Result: " << result << "\n"); 1761 memory[memIndex] = result.getAsOpaquePointer(); 1762 } 1763 1764 void ByteCodeExecutor::executeGetResults() { 1765 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1766 unsigned index = read<uint32_t>(); 1767 Operation *op = read<Operation *>(); 1768 ByteCodeField rangeIndex = read(); 1769 1770 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1771 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1772 valueRangeMemory); 1773 if (!result) 1774 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1775 memory[read()] = result; 1776 } 1777 1778 void ByteCodeExecutor::executeGetUsers() { 1779 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); 1780 unsigned memIndex = read(); 1781 unsigned rangeIndex = read(); 1782 OwningOpRange &range = opRangeMemory[rangeIndex]; 1783 memory[memIndex] = ⦥ 1784 1785 range = OwningOpRange(); 1786 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1787 // Read the value. 1788 Value value = read<Value>(); 1789 if (!value) 1790 return; 1791 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1792 1793 // Extract the users of a single value. 1794 range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); 1795 llvm::copy(value.getUsers(), range.begin()); 1796 } else { 1797 // Read a range of values. 1798 ValueRange *values = read<ValueRange *>(); 1799 if (!values) 1800 return; 1801 LLVM_DEBUG({ 1802 llvm::dbgs() << " * Values (" << values->size() << "): "; 1803 llvm::interleaveComma(*values, llvm::dbgs()); 1804 llvm::dbgs() << "\n"; 1805 }); 1806 1807 // Extract all the users of a range of values. 1808 SmallVector<Operation *> users; 1809 for (Value value : *values) 1810 users.append(value.user_begin(), value.user_end()); 1811 range = OwningOpRange(users.size()); 1812 llvm::copy(users, range.begin()); 1813 } 1814 1815 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); 1816 } 1817 1818 void ByteCodeExecutor::executeGetValueType() { 1819 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1820 unsigned memIndex = read(); 1821 Value value = read<Value>(); 1822 Type type = value ? value.getType() : Type(); 1823 1824 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1825 << " * Result: " << type << "\n"); 1826 memory[memIndex] = type.getAsOpaquePointer(); 1827 } 1828 1829 void ByteCodeExecutor::executeGetValueRangeTypes() { 1830 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1831 unsigned memIndex = read(); 1832 unsigned rangeIndex = read(); 1833 ValueRange *values = read<ValueRange *>(); 1834 if (!values) { 1835 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1836 memory[memIndex] = nullptr; 1837 return; 1838 } 1839 1840 LLVM_DEBUG({ 1841 llvm::dbgs() << " * Values (" << values->size() << "): "; 1842 llvm::interleaveComma(*values, llvm::dbgs()); 1843 llvm::dbgs() << "\n * Result: "; 1844 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1845 llvm::dbgs() << "\n"; 1846 }); 1847 typeRangeMemory[rangeIndex] = values->getType(); 1848 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1849 } 1850 1851 void ByteCodeExecutor::executeIsNotNull() { 1852 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1853 const void *value = read<const void *>(); 1854 1855 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1856 selectJump(value != nullptr); 1857 } 1858 1859 void ByteCodeExecutor::executeRecordMatch( 1860 PatternRewriter &rewriter, 1861 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1862 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1863 unsigned patternIndex = read(); 1864 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1865 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1866 1867 // If the benefit of the pattern is impossible, skip the processing of the 1868 // rest of the pattern. 1869 if (benefit.isImpossibleToMatch()) { 1870 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1871 curCodeIt = dest; 1872 return; 1873 } 1874 1875 // Create a fused location containing the locations of each of the 1876 // operations used in the match. This will be used as the location for 1877 // created operations during the rewrite that don't already have an 1878 // explicit location set. 1879 unsigned numMatchLocs = read(); 1880 SmallVector<Location, 4> matchLocs; 1881 matchLocs.reserve(numMatchLocs); 1882 for (unsigned i = 0; i != numMatchLocs; ++i) 1883 matchLocs.push_back(read<Operation *>()->getLoc()); 1884 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1885 1886 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1887 << " * Location: " << matchLoc << "\n"); 1888 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1889 PDLByteCode::MatchResult &match = matches.back(); 1890 1891 // Record all of the inputs to the match. If any of the inputs are ranges, we 1892 // will also need to remap the range pointer to memory stored in the match 1893 // state. 1894 unsigned numInputs = read(); 1895 match.values.reserve(numInputs); 1896 match.typeRangeValues.reserve(numInputs); 1897 match.valueRangeValues.reserve(numInputs); 1898 for (unsigned i = 0; i < numInputs; ++i) { 1899 switch (read<PDLValue::Kind>()) { 1900 case PDLValue::Kind::TypeRange: 1901 match.typeRangeValues.push_back(*read<TypeRange *>()); 1902 match.values.push_back(&match.typeRangeValues.back()); 1903 break; 1904 case PDLValue::Kind::ValueRange: 1905 match.valueRangeValues.push_back(*read<ValueRange *>()); 1906 match.values.push_back(&match.valueRangeValues.back()); 1907 break; 1908 default: 1909 match.values.push_back(read<const void *>()); 1910 break; 1911 } 1912 } 1913 curCodeIt = dest; 1914 } 1915 1916 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1917 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1918 Operation *op = read<Operation *>(); 1919 SmallVector<Value, 16> args; 1920 readValueList(args); 1921 1922 LLVM_DEBUG({ 1923 llvm::dbgs() << " * Operation: " << *op << "\n" 1924 << " * Values: "; 1925 llvm::interleaveComma(args, llvm::dbgs()); 1926 llvm::dbgs() << "\n"; 1927 }); 1928 rewriter.replaceOp(op, args); 1929 } 1930 1931 void ByteCodeExecutor::executeSwitchAttribute() { 1932 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1933 Attribute value = read<Attribute>(); 1934 ArrayAttr cases = read<ArrayAttr>(); 1935 handleSwitch(value, cases); 1936 } 1937 1938 void ByteCodeExecutor::executeSwitchOperandCount() { 1939 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1940 Operation *op = read<Operation *>(); 1941 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1942 1943 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1944 handleSwitch(op->getNumOperands(), cases); 1945 } 1946 1947 void ByteCodeExecutor::executeSwitchOperationName() { 1948 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1949 OperationName value = read<Operation *>()->getName(); 1950 size_t caseCount = read(); 1951 1952 // The operation names are stored in-line, so to print them out for 1953 // debugging purposes we need to read the array before executing the 1954 // switch so that we can display all of the possible values. 1955 LLVM_DEBUG({ 1956 const ByteCodeField *prevCodeIt = curCodeIt; 1957 llvm::dbgs() << " * Value: " << value << "\n" 1958 << " * Cases: "; 1959 llvm::interleaveComma( 1960 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1961 [&](size_t) { return read<OperationName>(); }), 1962 llvm::dbgs()); 1963 llvm::dbgs() << "\n"; 1964 curCodeIt = prevCodeIt; 1965 }); 1966 1967 // Try to find the switch value within any of the cases. 1968 for (size_t i = 0; i != caseCount; ++i) { 1969 if (read<OperationName>() == value) { 1970 curCodeIt += (caseCount - i - 1); 1971 return selectJump(i + 1); 1972 } 1973 } 1974 selectJump(size_t(0)); 1975 } 1976 1977 void ByteCodeExecutor::executeSwitchResultCount() { 1978 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1979 Operation *op = read<Operation *>(); 1980 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1981 1982 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1983 handleSwitch(op->getNumResults(), cases); 1984 } 1985 1986 void ByteCodeExecutor::executeSwitchType() { 1987 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 1988 Type value = read<Type>(); 1989 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 1990 handleSwitch(value, cases); 1991 } 1992 1993 void ByteCodeExecutor::executeSwitchTypes() { 1994 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 1995 TypeRange *value = read<TypeRange *>(); 1996 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 1997 if (!value) { 1998 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 1999 return selectJump(size_t(0)); 2000 } 2001 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 2002 return value == caseValue.getAsValueRange<TypeAttr>(); 2003 }); 2004 } 2005 2006 void ByteCodeExecutor::execute( 2007 PatternRewriter &rewriter, 2008 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 2009 Optional<Location> mainRewriteLoc) { 2010 while (true) { 2011 // Print the location of the operation being executed. 2012 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); 2013 2014 OpCode opCode = static_cast<OpCode>(read()); 2015 switch (opCode) { 2016 case ApplyConstraint: 2017 executeApplyConstraint(rewriter); 2018 break; 2019 case ApplyRewrite: 2020 executeApplyRewrite(rewriter); 2021 break; 2022 case AreEqual: 2023 executeAreEqual(); 2024 break; 2025 case AreRangesEqual: 2026 executeAreRangesEqual(); 2027 break; 2028 case Branch: 2029 executeBranch(); 2030 break; 2031 case CheckOperandCount: 2032 executeCheckOperandCount(); 2033 break; 2034 case CheckOperationName: 2035 executeCheckOperationName(); 2036 break; 2037 case CheckResultCount: 2038 executeCheckResultCount(); 2039 break; 2040 case CheckTypes: 2041 executeCheckTypes(); 2042 break; 2043 case Continue: 2044 executeContinue(); 2045 break; 2046 case CreateOperation: 2047 executeCreateOperation(rewriter, *mainRewriteLoc); 2048 break; 2049 case CreateTypes: 2050 executeCreateTypes(); 2051 break; 2052 case EraseOp: 2053 executeEraseOp(rewriter); 2054 break; 2055 case ExtractOp: 2056 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); 2057 break; 2058 case ExtractType: 2059 executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); 2060 break; 2061 case ExtractValue: 2062 executeExtract<Value, ValueRange, PDLValue::Kind::Value>(); 2063 break; 2064 case Finalize: 2065 executeFinalize(); 2066 LLVM_DEBUG(llvm::dbgs() << "\n"); 2067 return; 2068 case ForEach: 2069 executeForEach(); 2070 break; 2071 case GetAttribute: 2072 executeGetAttribute(); 2073 break; 2074 case GetAttributeType: 2075 executeGetAttributeType(); 2076 break; 2077 case GetDefiningOp: 2078 executeGetDefiningOp(); 2079 break; 2080 case GetOperand0: 2081 case GetOperand1: 2082 case GetOperand2: 2083 case GetOperand3: { 2084 unsigned index = opCode - GetOperand0; 2085 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 2086 executeGetOperand(index); 2087 break; 2088 } 2089 case GetOperandN: 2090 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 2091 executeGetOperand(read<uint32_t>()); 2092 break; 2093 case GetOperands: 2094 executeGetOperands(); 2095 break; 2096 case GetResult0: 2097 case GetResult1: 2098 case GetResult2: 2099 case GetResult3: { 2100 unsigned index = opCode - GetResult0; 2101 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 2102 executeGetResult(index); 2103 break; 2104 } 2105 case GetResultN: 2106 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 2107 executeGetResult(read<uint32_t>()); 2108 break; 2109 case GetResults: 2110 executeGetResults(); 2111 break; 2112 case GetUsers: 2113 executeGetUsers(); 2114 break; 2115 case GetValueType: 2116 executeGetValueType(); 2117 break; 2118 case GetValueRangeTypes: 2119 executeGetValueRangeTypes(); 2120 break; 2121 case IsNotNull: 2122 executeIsNotNull(); 2123 break; 2124 case RecordMatch: 2125 assert(matches && 2126 "expected matches to be provided when executing the matcher"); 2127 executeRecordMatch(rewriter, *matches); 2128 break; 2129 case ReplaceOp: 2130 executeReplaceOp(rewriter); 2131 break; 2132 case SwitchAttribute: 2133 executeSwitchAttribute(); 2134 break; 2135 case SwitchOperandCount: 2136 executeSwitchOperandCount(); 2137 break; 2138 case SwitchOperationName: 2139 executeSwitchOperationName(); 2140 break; 2141 case SwitchResultCount: 2142 executeSwitchResultCount(); 2143 break; 2144 case SwitchType: 2145 executeSwitchType(); 2146 break; 2147 case SwitchTypes: 2148 executeSwitchTypes(); 2149 break; 2150 } 2151 LLVM_DEBUG(llvm::dbgs() << "\n"); 2152 } 2153 } 2154 2155 /// Run the pattern matcher on the given root operation, collecting the matched 2156 /// patterns in `matches`. 2157 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 2158 SmallVectorImpl<MatchResult> &matches, 2159 PDLByteCodeMutableState &state) const { 2160 // The first memory slot is always the root operation. 2161 state.memory[0] = op; 2162 2163 // The matcher function always starts at code address 0. 2164 ByteCodeExecutor executor( 2165 matcherByteCode.data(), state.memory, state.opRangeMemory, 2166 state.typeRangeMemory, state.allocatedTypeRangeMemory, 2167 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, 2168 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, 2169 constraintFunctions, rewriteFunctions); 2170 executor.execute(rewriter, &matches); 2171 2172 // Order the found matches by benefit. 2173 std::stable_sort(matches.begin(), matches.end(), 2174 [](const MatchResult &lhs, const MatchResult &rhs) { 2175 return lhs.benefit > rhs.benefit; 2176 }); 2177 } 2178 2179 /// Run the rewriter of the given pattern on the root operation `op`. 2180 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 2181 PDLByteCodeMutableState &state) const { 2182 // The arguments of the rewrite function are stored at the start of the 2183 // memory buffer. 2184 llvm::copy(match.values, state.memory.begin()); 2185 2186 ByteCodeExecutor executor( 2187 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 2188 state.opRangeMemory, state.typeRangeMemory, 2189 state.allocatedTypeRangeMemory, state.valueRangeMemory, 2190 state.allocatedValueRangeMemory, state.loopIndex, uniquedData, 2191 rewriterByteCode, state.currentPatternBenefits, patterns, 2192 constraintFunctions, rewriteFunctions); 2193 executor.execute(rewriter, /*matches=*/nullptr, match.location); 2194 } 2195