1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 27 #define DEBUG_TYPE "pdl-bytecode" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37 ByteCodeAddr rewriterAddr) { 38 SmallVector<StringRef, 8> generatedOps; 39 if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 40 generatedOps = 41 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42 43 PatternBenefit benefit = matchOp.benefit(); 44 MLIRContext *ctx = matchOp.getContext(); 45 46 // Check to see if this is pattern matches a specific operation type. 47 if (Optional<StringRef> rootKind = matchOp.rootKind()) 48 return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, 49 generatedOps); 50 return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, 51 generatedOps); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // PDLByteCodeMutableState 56 //===----------------------------------------------------------------------===// 57 58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59 /// to the position of the pattern within the range returned by 60 /// `PDLByteCode::getPatterns`. 61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62 PatternBenefit benefit) { 63 currentPatternBenefits[patternIndex] = benefit; 64 } 65 66 /// Cleanup any allocated state after a full match/rewrite has been completed. 67 /// This method should be called irregardless of whether the match+rewrite was a 68 /// success or not. 69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 70 allocatedTypeRangeMemory.clear(); 71 allocatedValueRangeMemory.clear(); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Bytecode OpCodes 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 enum OpCode : ByteCodeField { 80 /// Apply an externally registered constraint. 81 ApplyConstraint, 82 /// Apply an externally registered rewrite. 83 ApplyRewrite, 84 /// Check if two generic values are equal. 85 AreEqual, 86 /// Check if two ranges are equal. 87 AreRangesEqual, 88 /// Unconditional branch. 89 Branch, 90 /// Compare the operand count of an operation with a constant. 91 CheckOperandCount, 92 /// Compare the name of an operation with a constant. 93 CheckOperationName, 94 /// Compare the result count of an operation with a constant. 95 CheckResultCount, 96 /// Compare a range of types to a constant range of types. 97 CheckTypes, 98 /// Create an operation. 99 CreateOperation, 100 /// Create a range of types. 101 CreateTypes, 102 /// Erase an operation. 103 EraseOp, 104 /// Terminate a matcher or rewrite sequence. 105 Finalize, 106 /// Get a specific attribute of an operation. 107 GetAttribute, 108 /// Get the type of an attribute. 109 GetAttributeType, 110 /// Get the defining operation of a value. 111 GetDefiningOp, 112 /// Get a specific operand of an operation. 113 GetOperand0, 114 GetOperand1, 115 GetOperand2, 116 GetOperand3, 117 GetOperandN, 118 /// Get a specific operand group of an operation. 119 GetOperands, 120 /// Get a specific result of an operation. 121 GetResult0, 122 GetResult1, 123 GetResult2, 124 GetResult3, 125 GetResultN, 126 /// Get a specific result group of an operation. 127 GetResults, 128 /// Get the type of a value. 129 GetValueType, 130 /// Get the types of a value range. 131 GetValueRangeTypes, 132 /// Check if a generic value is not null. 133 IsNotNull, 134 /// Record a successful pattern match. 135 RecordMatch, 136 /// Replace an operation. 137 ReplaceOp, 138 /// Compare an attribute with a set of constants. 139 SwitchAttribute, 140 /// Compare the operand count of an operation with a set of constants. 141 SwitchOperandCount, 142 /// Compare the name of an operation with a set of constants. 143 SwitchOperationName, 144 /// Compare the result count of an operation with a set of constants. 145 SwitchResultCount, 146 /// Compare a type with a set of constants. 147 SwitchType, 148 /// Compare a range of types with a set of constants. 149 SwitchTypes, 150 }; 151 } // end anonymous namespace 152 153 //===----------------------------------------------------------------------===// 154 // ByteCode Generation 155 //===----------------------------------------------------------------------===// 156 157 //===----------------------------------------------------------------------===// 158 // Generator 159 160 namespace { 161 struct ByteCodeWriter; 162 163 /// This class represents the main generator for the pattern bytecode. 164 class Generator { 165 public: 166 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 167 SmallVectorImpl<ByteCodeField> &matcherByteCode, 168 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 169 SmallVectorImpl<PDLByteCodePattern> &patterns, 170 ByteCodeField &maxValueMemoryIndex, 171 ByteCodeField &maxTypeRangeMemoryIndex, 172 ByteCodeField &maxValueRangeMemoryIndex, 173 llvm::StringMap<PDLConstraintFunction> &constraintFns, 174 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 175 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 176 rewriterByteCode(rewriterByteCode), patterns(patterns), 177 maxValueMemoryIndex(maxValueMemoryIndex), 178 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 179 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) { 180 for (auto it : llvm::enumerate(constraintFns)) 181 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 182 for (auto it : llvm::enumerate(rewriteFns)) 183 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 184 } 185 186 /// Generate the bytecode for the given PDL interpreter module. 187 void generate(ModuleOp module); 188 189 /// Return the memory index to use for the given value. 190 ByteCodeField &getMemIndex(Value value) { 191 assert(valueToMemIndex.count(value) && 192 "expected memory index to be assigned"); 193 return valueToMemIndex[value]; 194 } 195 196 /// Return the range memory index used to store the given range value. 197 ByteCodeField &getRangeStorageIndex(Value value) { 198 assert(valueToRangeIndex.count(value) && 199 "expected range index to be assigned"); 200 return valueToRangeIndex[value]; 201 } 202 203 /// Return an index to use when referring to the given data that is uniqued in 204 /// the MLIR context. 205 template <typename T> 206 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 207 getMemIndex(T val) { 208 const void *opaqueVal = val.getAsOpaquePointer(); 209 210 // Get or insert a reference to this value. 211 auto it = uniquedDataToMemIndex.try_emplace( 212 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 213 if (it.second) 214 uniquedData.push_back(opaqueVal); 215 return it.first->second; 216 } 217 218 private: 219 /// Allocate memory indices for the results of operations within the matcher 220 /// and rewriters. 221 void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 222 223 /// Generate the bytecode for the given operation. 224 void generate(Operation *op, ByteCodeWriter &writer); 225 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 226 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 227 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 228 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 229 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 230 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 231 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 232 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 233 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 234 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 235 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 236 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 237 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 238 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 239 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 240 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 241 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 242 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 243 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 244 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 245 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 246 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 247 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 248 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 249 void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); 250 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 251 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 252 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 253 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 254 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 255 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 256 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 257 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 258 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 259 260 /// Mapping from value to its corresponding memory index. 261 DenseMap<Value, ByteCodeField> valueToMemIndex; 262 263 /// Mapping from a range value to its corresponding range storage index. 264 DenseMap<Value, ByteCodeField> valueToRangeIndex; 265 266 /// Mapping from the name of an externally registered rewrite to its index in 267 /// the bytecode registry. 268 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 269 270 /// Mapping from the name of an externally registered constraint to its index 271 /// in the bytecode registry. 272 llvm::StringMap<ByteCodeField> constraintToMemIndex; 273 274 /// Mapping from rewriter function name to the bytecode address of the 275 /// rewriter function in byte. 276 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 277 278 /// Mapping from a uniqued storage object to its memory index within 279 /// `uniquedData`. 280 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 281 282 /// The current MLIR context. 283 MLIRContext *ctx; 284 285 /// Data of the ByteCode class to be populated. 286 std::vector<const void *> &uniquedData; 287 SmallVectorImpl<ByteCodeField> &matcherByteCode; 288 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 289 SmallVectorImpl<PDLByteCodePattern> &patterns; 290 ByteCodeField &maxValueMemoryIndex; 291 ByteCodeField &maxTypeRangeMemoryIndex; 292 ByteCodeField &maxValueRangeMemoryIndex; 293 }; 294 295 /// This class provides utilities for writing a bytecode stream. 296 struct ByteCodeWriter { 297 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 298 : bytecode(bytecode), generator(generator) {} 299 300 /// Append a field to the bytecode. 301 void append(ByteCodeField field) { bytecode.push_back(field); } 302 void append(OpCode opCode) { bytecode.push_back(opCode); } 303 304 /// Append an address to the bytecode. 305 void append(ByteCodeAddr field) { 306 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 307 "unexpected ByteCode address size"); 308 309 ByteCodeField fieldParts[2]; 310 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 311 bytecode.append({fieldParts[0], fieldParts[1]}); 312 } 313 314 /// Append a successor range to the bytecode, the exact address will need to 315 /// be resolved later. 316 void append(SuccessorRange successors) { 317 // Add back references to the any successors so that the address can be 318 // resolved later. 319 for (Block *successor : successors) { 320 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 321 append(ByteCodeAddr(0)); 322 } 323 } 324 325 /// Append a range of values that will be read as generic PDLValues. 326 void appendPDLValueList(OperandRange values) { 327 bytecode.push_back(values.size()); 328 for (Value value : values) 329 appendPDLValue(value); 330 } 331 332 /// Append a value as a PDLValue. 333 void appendPDLValue(Value value) { 334 appendPDLValueKind(value); 335 append(value); 336 } 337 338 /// Append the PDLValue::Kind of the given value. 339 void appendPDLValueKind(Value value) { 340 // Append the type of the value in addition to the value itself. 341 PDLValue::Kind kind = 342 TypeSwitch<Type, PDLValue::Kind>(value.getType()) 343 .Case<pdl::AttributeType>( 344 [](Type) { return PDLValue::Kind::Attribute; }) 345 .Case<pdl::OperationType>( 346 [](Type) { return PDLValue::Kind::Operation; }) 347 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 348 if (rangeTy.getElementType().isa<pdl::TypeType>()) 349 return PDLValue::Kind::TypeRange; 350 return PDLValue::Kind::ValueRange; 351 }) 352 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 353 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 354 bytecode.push_back(static_cast<ByteCodeField>(kind)); 355 } 356 357 /// Check if the given class `T` has an iterator type. 358 template <typename T, typename... Args> 359 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 360 361 /// Append a value that will be stored in a memory slot and not inline within 362 /// the bytecode. 363 template <typename T> 364 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 365 std::is_pointer<T>::value> 366 append(T value) { 367 bytecode.push_back(generator.getMemIndex(value)); 368 } 369 370 /// Append a range of values. 371 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 372 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 373 append(T range) { 374 bytecode.push_back(llvm::size(range)); 375 for (auto it : range) 376 append(it); 377 } 378 379 /// Append a variadic number of fields to the bytecode. 380 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 381 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 382 append(field); 383 append(field2, fields...); 384 } 385 386 /// Successor references in the bytecode that have yet to be resolved. 387 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 388 389 /// The underlying bytecode buffer. 390 SmallVectorImpl<ByteCodeField> &bytecode; 391 392 /// The main generator producing PDL. 393 Generator &generator; 394 }; 395 396 /// This class represents a live range of PDL Interpreter values, containing 397 /// information about when values are live within a match/rewrite. 398 struct ByteCodeLiveRange { 399 using Set = llvm::IntervalMap<ByteCodeField, char, 16>; 400 using Allocator = Set::Allocator; 401 402 ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {} 403 404 /// Union this live range with the one provided. 405 void unionWith(const ByteCodeLiveRange &rhs) { 406 for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it) 407 liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0); 408 } 409 410 /// Returns true if this range overlaps with the one provided. 411 bool overlaps(const ByteCodeLiveRange &rhs) const { 412 return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid(); 413 } 414 415 /// A map representing the ranges of the match/rewrite that a value is live in 416 /// the interpreter. 417 llvm::IntervalMap<ByteCodeField, char, 16> liveness; 418 419 /// The type range storage index for this range. 420 Optional<unsigned> typeRangeIndex; 421 422 /// The value range storage index for this range. 423 Optional<unsigned> valueRangeIndex; 424 }; 425 } // end anonymous namespace 426 427 void Generator::generate(ModuleOp module) { 428 FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 429 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 430 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 431 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 432 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 433 434 // Allocate memory indices for the results of operations within the matcher 435 // and rewriters. 436 allocateMemoryIndices(matcherFunc, rewriterModule); 437 438 // Generate code for the rewriter functions. 439 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 440 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 441 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 442 for (Operation &op : rewriterFunc.getOps()) 443 generate(&op, rewriterByteCodeWriter); 444 } 445 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 446 "unexpected branches in rewriter function"); 447 448 // Generate code for the matcher function. 449 DenseMap<Block *, ByteCodeAddr> blockToAddr; 450 llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody()); 451 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 452 for (Block *block : rpot) { 453 // Keep track of where this block begins within the matcher function. 454 blockToAddr.try_emplace(block, matcherByteCode.size()); 455 for (Operation &op : *block) 456 generate(&op, matcherByteCodeWriter); 457 } 458 459 // Resolve successor references in the matcher. 460 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 461 ByteCodeAddr addr = blockToAddr[it.first]; 462 for (unsigned offsetToFix : it.second) 463 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 464 } 465 } 466 467 void Generator::allocateMemoryIndices(FuncOp matcherFunc, 468 ModuleOp rewriterModule) { 469 // Rewriters use simplistic allocation scheme that simply assigns an index to 470 // each result. 471 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 472 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 473 auto processRewriterValue = [&](Value val) { 474 valueToMemIndex.try_emplace(val, index++); 475 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 476 Type elementTy = rangeType.getElementType(); 477 if (elementTy.isa<pdl::TypeType>()) 478 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 479 else if (elementTy.isa<pdl::ValueType>()) 480 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 481 } 482 }; 483 484 for (BlockArgument arg : rewriterFunc.getArguments()) 485 processRewriterValue(arg); 486 rewriterFunc.getBody().walk([&](Operation *op) { 487 for (Value result : op->getResults()) 488 processRewriterValue(result); 489 }); 490 if (index > maxValueMemoryIndex) 491 maxValueMemoryIndex = index; 492 if (typeRangeIndex > maxTypeRangeMemoryIndex) 493 maxTypeRangeMemoryIndex = typeRangeIndex; 494 if (valueRangeIndex > maxValueRangeMemoryIndex) 495 maxValueRangeMemoryIndex = valueRangeIndex; 496 } 497 498 // The matcher function uses a more sophisticated numbering that tries to 499 // minimize the number of memory indices assigned. This is done by determining 500 // a live range of the values within the matcher, then the allocation is just 501 // finding the minimal number of overlapping live ranges. This is essentially 502 // a simplified form of register allocation where we don't necessarily have a 503 // limited number of registers, but we still want to minimize the number used. 504 DenseMap<Operation *, ByteCodeField> opToIndex; 505 matcherFunc.getBody().walk([&](Operation *op) { 506 opToIndex.insert(std::make_pair(op, opToIndex.size())); 507 }); 508 509 // Liveness info for each of the defs within the matcher. 510 ByteCodeLiveRange::Allocator allocator; 511 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 512 513 // Assign the root operation being matched to slot 0. 514 BlockArgument rootOpArg = matcherFunc.getArgument(0); 515 valueToMemIndex[rootOpArg] = 0; 516 517 // Walk each of the blocks, computing the def interval that the value is used. 518 Liveness matcherLiveness(matcherFunc); 519 for (Block &block : matcherFunc.getBody()) { 520 const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); 521 assert(info && "expected liveness info for block"); 522 auto processValue = [&](Value value, Operation *firstUseOrDef) { 523 // We don't need to process the root op argument, this value is always 524 // assigned to the first memory slot. 525 if (value == rootOpArg) 526 return; 527 528 // Set indices for the range of this block that the value is used. 529 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 530 defRangeIt->second.liveness.insert( 531 opToIndex[firstUseOrDef], 532 opToIndex[info->getEndOperation(value, firstUseOrDef)], 533 /*dummyValue*/ 0); 534 535 // Check to see if this value is a range type. 536 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 537 Type eleType = rangeTy.getElementType(); 538 if (eleType.isa<pdl::TypeType>()) 539 defRangeIt->second.typeRangeIndex = 0; 540 else if (eleType.isa<pdl::ValueType>()) 541 defRangeIt->second.valueRangeIndex = 0; 542 } 543 }; 544 545 // Process the live-ins of this block. 546 for (Value liveIn : info->in()) 547 processValue(liveIn, &block.front()); 548 549 // Process any new defs within this block. 550 for (Operation &op : block) 551 for (Value result : op.getResults()) 552 processValue(result, &op); 553 } 554 555 // Greedily allocate memory slots using the computed def live ranges. 556 std::vector<ByteCodeLiveRange> allocatedIndices; 557 ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0; 558 for (auto &defIt : valueDefRanges) { 559 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 560 ByteCodeLiveRange &defRange = defIt.second; 561 562 // Try to allocate to an existing index. 563 for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { 564 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 565 if (!defRange.overlaps(existingRange)) { 566 existingRange.unionWith(defRange); 567 memIndex = existingIndexIt.index() + 1; 568 569 if (defRange.typeRangeIndex) { 570 if (!existingRange.typeRangeIndex) 571 existingRange.typeRangeIndex = numTypeRanges++; 572 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 573 } else if (defRange.valueRangeIndex) { 574 if (!existingRange.valueRangeIndex) 575 existingRange.valueRangeIndex = numValueRanges++; 576 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 577 } 578 break; 579 } 580 } 581 582 // If no existing index could be used, add a new one. 583 if (memIndex == 0) { 584 allocatedIndices.emplace_back(allocator); 585 ByteCodeLiveRange &newRange = allocatedIndices.back(); 586 newRange.unionWith(defRange); 587 588 // Allocate an index for type/value ranges. 589 if (defRange.typeRangeIndex) { 590 newRange.typeRangeIndex = numTypeRanges; 591 valueToRangeIndex[defIt.first] = numTypeRanges++; 592 } else if (defRange.valueRangeIndex) { 593 newRange.valueRangeIndex = numValueRanges; 594 valueToRangeIndex[defIt.first] = numValueRanges++; 595 } 596 597 memIndex = allocatedIndices.size(); 598 ++numIndices; 599 } 600 } 601 602 // Update the max number of indices. 603 if (numIndices > maxValueMemoryIndex) 604 maxValueMemoryIndex = numIndices; 605 if (numTypeRanges > maxTypeRangeMemoryIndex) 606 maxTypeRangeMemoryIndex = numTypeRanges; 607 if (numValueRanges > maxValueRangeMemoryIndex) 608 maxValueRangeMemoryIndex = numValueRanges; 609 } 610 611 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 612 TypeSwitch<Operation *>(op) 613 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 614 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 615 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 616 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 617 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 618 pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp, 619 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, 620 pdl_interp::EraseOp, pdl_interp::FinalizeOp, 621 pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, 622 pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, 623 pdl_interp::GetOperandsOp, pdl_interp::GetResultOp, 624 pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp, 625 pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, 626 pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, 627 pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, 628 pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, 629 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 630 [&](auto interpOp) { this->generate(interpOp, writer); }) 631 .Default([](Operation *) { 632 llvm_unreachable("unknown `pdl_interp` operation"); 633 }); 634 } 635 636 void Generator::generate(pdl_interp::ApplyConstraintOp op, 637 ByteCodeWriter &writer) { 638 assert(constraintToMemIndex.count(op.name()) && 639 "expected index for constraint function"); 640 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 641 op.constParamsAttr()); 642 writer.appendPDLValueList(op.args()); 643 writer.append(op.getSuccessors()); 644 } 645 void Generator::generate(pdl_interp::ApplyRewriteOp op, 646 ByteCodeWriter &writer) { 647 assert(externalRewriterToMemIndex.count(op.name()) && 648 "expected index for rewrite function"); 649 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 650 op.constParamsAttr()); 651 writer.appendPDLValueList(op.args()); 652 653 ResultRange results = op.results(); 654 writer.append(ByteCodeField(results.size())); 655 for (Value result : results) { 656 // In debug mode we also record the expected kind of the result, so that we 657 // can provide extra verification of the native rewrite function. 658 #ifndef NDEBUG 659 writer.appendPDLValueKind(result); 660 #endif 661 662 // Range results also need to append the range storage index. 663 if (result.getType().isa<pdl::RangeType>()) 664 writer.append(getRangeStorageIndex(result)); 665 writer.append(result); 666 } 667 } 668 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 669 Value lhs = op.lhs(); 670 if (lhs.getType().isa<pdl::RangeType>()) { 671 writer.append(OpCode::AreRangesEqual); 672 writer.appendPDLValueKind(lhs); 673 writer.append(op.lhs(), op.rhs(), op.getSuccessors()); 674 return; 675 } 676 677 writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors()); 678 } 679 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 680 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 681 } 682 void Generator::generate(pdl_interp::CheckAttributeOp op, 683 ByteCodeWriter &writer) { 684 writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 685 op.getSuccessors()); 686 } 687 void Generator::generate(pdl_interp::CheckOperandCountOp op, 688 ByteCodeWriter &writer) { 689 writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 690 static_cast<ByteCodeField>(op.compareAtLeast()), 691 op.getSuccessors()); 692 } 693 void Generator::generate(pdl_interp::CheckOperationNameOp op, 694 ByteCodeWriter &writer) { 695 writer.append(OpCode::CheckOperationName, op.operation(), 696 OperationName(op.name(), ctx), op.getSuccessors()); 697 } 698 void Generator::generate(pdl_interp::CheckResultCountOp op, 699 ByteCodeWriter &writer) { 700 writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 701 static_cast<ByteCodeField>(op.compareAtLeast()), 702 op.getSuccessors()); 703 } 704 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 705 writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 706 } 707 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 708 writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); 709 } 710 void Generator::generate(pdl_interp::CreateAttributeOp op, 711 ByteCodeWriter &writer) { 712 // Simply repoint the memory index of the result to the constant. 713 getMemIndex(op.attribute()) = getMemIndex(op.value()); 714 } 715 void Generator::generate(pdl_interp::CreateOperationOp op, 716 ByteCodeWriter &writer) { 717 writer.append(OpCode::CreateOperation, op.operation(), 718 OperationName(op.name(), ctx)); 719 writer.appendPDLValueList(op.operands()); 720 721 // Add the attributes. 722 OperandRange attributes = op.attributes(); 723 writer.append(static_cast<ByteCodeField>(attributes.size())); 724 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 725 writer.append( 726 Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), 727 std::get<1>(it)); 728 } 729 writer.appendPDLValueList(op.types()); 730 } 731 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 732 // Simply repoint the memory index of the result to the constant. 733 getMemIndex(op.result()) = getMemIndex(op.value()); 734 } 735 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 736 writer.append(OpCode::CreateTypes, op.result(), 737 getRangeStorageIndex(op.result()), op.value()); 738 } 739 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 740 writer.append(OpCode::EraseOp, op.operation()); 741 } 742 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 743 writer.append(OpCode::Finalize); 744 } 745 void Generator::generate(pdl_interp::GetAttributeOp op, 746 ByteCodeWriter &writer) { 747 writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 748 Identifier::get(op.name(), ctx)); 749 } 750 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 751 ByteCodeWriter &writer) { 752 writer.append(OpCode::GetAttributeType, op.result(), op.value()); 753 } 754 void Generator::generate(pdl_interp::GetDefiningOpOp op, 755 ByteCodeWriter &writer) { 756 writer.append(OpCode::GetDefiningOp, op.operation()); 757 writer.appendPDLValue(op.value()); 758 } 759 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 760 uint32_t index = op.index(); 761 if (index < 4) 762 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 763 else 764 writer.append(OpCode::GetOperandN, index); 765 writer.append(op.operation(), op.value()); 766 } 767 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 768 Value result = op.value(); 769 Optional<uint32_t> index = op.index(); 770 writer.append(OpCode::GetOperands, 771 index.getValueOr(std::numeric_limits<uint32_t>::max()), 772 op.operation()); 773 if (result.getType().isa<pdl::RangeType>()) 774 writer.append(getRangeStorageIndex(result)); 775 else 776 writer.append(std::numeric_limits<ByteCodeField>::max()); 777 writer.append(result); 778 } 779 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 780 uint32_t index = op.index(); 781 if (index < 4) 782 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 783 else 784 writer.append(OpCode::GetResultN, index); 785 writer.append(op.operation(), op.value()); 786 } 787 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 788 Value result = op.value(); 789 Optional<uint32_t> index = op.index(); 790 writer.append(OpCode::GetResults, 791 index.getValueOr(std::numeric_limits<uint32_t>::max()), 792 op.operation()); 793 if (result.getType().isa<pdl::RangeType>()) 794 writer.append(getRangeStorageIndex(result)); 795 else 796 writer.append(std::numeric_limits<ByteCodeField>::max()); 797 writer.append(result); 798 } 799 void Generator::generate(pdl_interp::GetValueTypeOp op, 800 ByteCodeWriter &writer) { 801 if (op.getType().isa<pdl::RangeType>()) { 802 Value result = op.result(); 803 writer.append(OpCode::GetValueRangeTypes, result, 804 getRangeStorageIndex(result), op.value()); 805 } else { 806 writer.append(OpCode::GetValueType, op.result(), op.value()); 807 } 808 } 809 810 void Generator::generate(pdl_interp::InferredTypesOp op, 811 ByteCodeWriter &writer) { 812 // InferType maps to a null type as a marker for inferring result types. 813 getMemIndex(op.type()) = getMemIndex(Type()); 814 } 815 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 816 writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 817 } 818 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 819 ByteCodeField patternIndex = patterns.size(); 820 patterns.emplace_back(PDLByteCodePattern::create( 821 op, rewriterToAddr[op.rewriter().getLeafReference().getValue()])); 822 writer.append(OpCode::RecordMatch, patternIndex, 823 SuccessorRange(op.getOperation()), op.matchedOps()); 824 writer.appendPDLValueList(op.inputs()); 825 } 826 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 827 writer.append(OpCode::ReplaceOp, op.operation()); 828 writer.appendPDLValueList(op.replValues()); 829 } 830 void Generator::generate(pdl_interp::SwitchAttributeOp op, 831 ByteCodeWriter &writer) { 832 writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 833 op.getSuccessors()); 834 } 835 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 836 ByteCodeWriter &writer) { 837 writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 838 op.getSuccessors()); 839 } 840 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 841 ByteCodeWriter &writer) { 842 auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 843 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 844 }); 845 writer.append(OpCode::SwitchOperationName, op.operation(), cases, 846 op.getSuccessors()); 847 } 848 void Generator::generate(pdl_interp::SwitchResultCountOp op, 849 ByteCodeWriter &writer) { 850 writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 851 op.getSuccessors()); 852 } 853 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 854 writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 855 op.getSuccessors()); 856 } 857 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 858 writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(), 859 op.getSuccessors()); 860 } 861 862 //===----------------------------------------------------------------------===// 863 // PDLByteCode 864 //===----------------------------------------------------------------------===// 865 866 PDLByteCode::PDLByteCode(ModuleOp module, 867 llvm::StringMap<PDLConstraintFunction> constraintFns, 868 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 869 Generator generator(module.getContext(), uniquedData, matcherByteCode, 870 rewriterByteCode, patterns, maxValueMemoryIndex, 871 maxTypeRangeCount, maxValueRangeCount, constraintFns, 872 rewriteFns); 873 generator.generate(module); 874 875 // Initialize the external functions. 876 for (auto &it : constraintFns) 877 constraintFunctions.push_back(std::move(it.second)); 878 for (auto &it : rewriteFns) 879 rewriteFunctions.push_back(std::move(it.second)); 880 } 881 882 /// Initialize the given state such that it can be used to execute the current 883 /// bytecode. 884 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 885 state.memory.resize(maxValueMemoryIndex, nullptr); 886 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 887 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 888 state.currentPatternBenefits.reserve(patterns.size()); 889 for (const PDLByteCodePattern &pattern : patterns) 890 state.currentPatternBenefits.push_back(pattern.getBenefit()); 891 } 892 893 //===----------------------------------------------------------------------===// 894 // ByteCode Execution 895 896 namespace { 897 /// This class provides support for executing a bytecode stream. 898 class ByteCodeExecutor { 899 public: 900 ByteCodeExecutor( 901 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 902 MutableArrayRef<TypeRange> typeRangeMemory, 903 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 904 MutableArrayRef<ValueRange> valueRangeMemory, 905 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 906 ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code, 907 ArrayRef<PatternBenefit> currentPatternBenefits, 908 ArrayRef<PDLByteCodePattern> patterns, 909 ArrayRef<PDLConstraintFunction> constraintFunctions, 910 ArrayRef<PDLRewriteFunction> rewriteFunctions) 911 : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory), 912 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 913 valueRangeMemory(valueRangeMemory), 914 allocatedValueRangeMemory(allocatedValueRangeMemory), 915 uniquedMemory(uniquedMemory), code(code), 916 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 917 constraintFunctions(constraintFunctions), 918 rewriteFunctions(rewriteFunctions) {} 919 920 /// Start executing the code at the current bytecode index. `matches` is an 921 /// optional field provided when this function is executed in a matching 922 /// context. 923 void execute(PatternRewriter &rewriter, 924 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 925 Optional<Location> mainRewriteLoc = {}); 926 927 private: 928 /// Internal implementation of executing each of the bytecode commands. 929 void executeApplyConstraint(PatternRewriter &rewriter); 930 void executeApplyRewrite(PatternRewriter &rewriter); 931 void executeAreEqual(); 932 void executeAreRangesEqual(); 933 void executeBranch(); 934 void executeCheckOperandCount(); 935 void executeCheckOperationName(); 936 void executeCheckResultCount(); 937 void executeCheckTypes(); 938 void executeCreateOperation(PatternRewriter &rewriter, 939 Location mainRewriteLoc); 940 void executeCreateTypes(); 941 void executeEraseOp(PatternRewriter &rewriter); 942 void executeGetAttribute(); 943 void executeGetAttributeType(); 944 void executeGetDefiningOp(); 945 void executeGetOperand(unsigned index); 946 void executeGetOperands(); 947 void executeGetResult(unsigned index); 948 void executeGetResults(); 949 void executeGetValueType(); 950 void executeGetValueRangeTypes(); 951 void executeIsNotNull(); 952 void executeRecordMatch(PatternRewriter &rewriter, 953 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 954 void executeReplaceOp(PatternRewriter &rewriter); 955 void executeSwitchAttribute(); 956 void executeSwitchOperandCount(); 957 void executeSwitchOperationName(); 958 void executeSwitchResultCount(); 959 void executeSwitchType(); 960 void executeSwitchTypes(); 961 962 /// Read a value from the bytecode buffer, optionally skipping a certain 963 /// number of prefix values. These methods always update the buffer to point 964 /// to the next field after the read data. 965 template <typename T = ByteCodeField> 966 T read(size_t skipN = 0) { 967 curCodeIt += skipN; 968 return readImpl<T>(); 969 } 970 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 971 972 /// Read a list of values from the bytecode buffer. 973 template <typename ValueT, typename T> 974 void readList(SmallVectorImpl<T> &list) { 975 list.clear(); 976 for (unsigned i = 0, e = read(); i != e; ++i) 977 list.push_back(read<ValueT>()); 978 } 979 980 /// Read a list of values from the bytecode buffer. The values may be encoded 981 /// as either Value or ValueRange elements. 982 void readValueList(SmallVectorImpl<Value> &list) { 983 for (unsigned i = 0, e = read(); i != e; ++i) { 984 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 985 list.push_back(read<Value>()); 986 } else { 987 ValueRange *values = read<ValueRange *>(); 988 list.append(values->begin(), values->end()); 989 } 990 } 991 } 992 993 /// Jump to a specific successor based on a predicate value. 994 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 995 /// Jump to a specific successor based on a destination index. 996 void selectJump(size_t destIndex) { 997 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 998 } 999 1000 /// Handle a switch operation with the provided value and cases. 1001 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1002 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1003 LLVM_DEBUG({ 1004 llvm::dbgs() << " * Value: " << value << "\n" 1005 << " * Cases: "; 1006 llvm::interleaveComma(cases, llvm::dbgs()); 1007 llvm::dbgs() << "\n"; 1008 }); 1009 1010 // Check to see if the attribute value is within the case list. Jump to 1011 // the correct successor index based on the result. 1012 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1013 if (cmp(*it, value)) 1014 return selectJump(size_t((it - cases.begin()) + 1)); 1015 selectJump(size_t(0)); 1016 } 1017 1018 /// Internal implementation of reading various data types from the bytecode 1019 /// stream. 1020 template <typename T> 1021 const void *readFromMemory() { 1022 size_t index = *curCodeIt++; 1023 1024 // If this type is an SSA value, it can only be stored in non-const memory. 1025 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1026 Value>::value || 1027 index < memory.size()) 1028 return memory[index]; 1029 1030 // Otherwise, if this index is not inbounds it is uniqued. 1031 return uniquedMemory[index - memory.size()]; 1032 } 1033 template <typename T> 1034 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1035 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1036 } 1037 template <typename T> 1038 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1039 T> 1040 readImpl() { 1041 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1042 } 1043 template <typename T> 1044 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1045 switch (read<PDLValue::Kind>()) { 1046 case PDLValue::Kind::Attribute: 1047 return read<Attribute>(); 1048 case PDLValue::Kind::Operation: 1049 return read<Operation *>(); 1050 case PDLValue::Kind::Type: 1051 return read<Type>(); 1052 case PDLValue::Kind::Value: 1053 return read<Value>(); 1054 case PDLValue::Kind::TypeRange: 1055 return read<TypeRange *>(); 1056 case PDLValue::Kind::ValueRange: 1057 return read<ValueRange *>(); 1058 } 1059 llvm_unreachable("unhandled PDLValue::Kind"); 1060 } 1061 template <typename T> 1062 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1063 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1064 "unexpected ByteCode address size"); 1065 ByteCodeAddr result; 1066 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1067 curCodeIt += 2; 1068 return result; 1069 } 1070 template <typename T> 1071 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1072 return *curCodeIt++; 1073 } 1074 template <typename T> 1075 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1076 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1077 } 1078 1079 /// The underlying bytecode buffer. 1080 const ByteCodeField *curCodeIt; 1081 1082 /// The current execution memory. 1083 MutableArrayRef<const void *> memory; 1084 MutableArrayRef<TypeRange> typeRangeMemory; 1085 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1086 MutableArrayRef<ValueRange> valueRangeMemory; 1087 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1088 1089 /// References to ByteCode data necessary for execution. 1090 ArrayRef<const void *> uniquedMemory; 1091 ArrayRef<ByteCodeField> code; 1092 ArrayRef<PatternBenefit> currentPatternBenefits; 1093 ArrayRef<PDLByteCodePattern> patterns; 1094 ArrayRef<PDLConstraintFunction> constraintFunctions; 1095 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1096 }; 1097 1098 /// This class is an instantiation of the PDLResultList that provides access to 1099 /// the returned results. This API is not on `PDLResultList` to avoid 1100 /// overexposing access to information specific solely to the ByteCode. 1101 class ByteCodeRewriteResultList : public PDLResultList { 1102 public: 1103 ByteCodeRewriteResultList(unsigned maxNumResults) 1104 : PDLResultList(maxNumResults) {} 1105 1106 /// Return the list of PDL results. 1107 MutableArrayRef<PDLValue> getResults() { return results; } 1108 1109 /// Return the type ranges allocated by this list. 1110 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1111 return allocatedTypeRanges; 1112 } 1113 1114 /// Return the value ranges allocated by this list. 1115 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1116 return allocatedValueRanges; 1117 } 1118 }; 1119 } // end anonymous namespace 1120 1121 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1122 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1123 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1124 ArrayAttr constParams = read<ArrayAttr>(); 1125 SmallVector<PDLValue, 16> args; 1126 readList<PDLValue>(args); 1127 1128 LLVM_DEBUG({ 1129 llvm::dbgs() << " * Arguments: "; 1130 llvm::interleaveComma(args, llvm::dbgs()); 1131 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1132 }); 1133 1134 // Invoke the constraint and jump to the proper destination. 1135 selectJump(succeeded(constraintFn(args, constParams, rewriter))); 1136 } 1137 1138 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1139 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1140 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1141 ArrayAttr constParams = read<ArrayAttr>(); 1142 SmallVector<PDLValue, 16> args; 1143 readList<PDLValue>(args); 1144 1145 LLVM_DEBUG({ 1146 llvm::dbgs() << " * Arguments: "; 1147 llvm::interleaveComma(args, llvm::dbgs()); 1148 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1149 }); 1150 1151 // Execute the rewrite function. 1152 ByteCodeField numResults = read(); 1153 ByteCodeRewriteResultList results(numResults); 1154 rewriteFn(args, constParams, rewriter, results); 1155 1156 assert(results.getResults().size() == numResults && 1157 "native PDL rewrite function returned unexpected number of results"); 1158 1159 // Store the results in the bytecode memory. 1160 for (PDLValue &result : results.getResults()) { 1161 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1162 1163 // In debug mode we also verify the expected kind of the result. 1164 #ifndef NDEBUG 1165 assert(result.getKind() == read<PDLValue::Kind>() && 1166 "native PDL rewrite function returned an unexpected type of result"); 1167 #endif 1168 1169 // If the result is a range, we need to copy it over to the bytecodes 1170 // range memory. 1171 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1172 unsigned rangeIndex = read(); 1173 typeRangeMemory[rangeIndex] = *typeRange; 1174 memory[read()] = &typeRangeMemory[rangeIndex]; 1175 } else if (Optional<ValueRange> valueRange = 1176 result.dyn_cast<ValueRange>()) { 1177 unsigned rangeIndex = read(); 1178 valueRangeMemory[rangeIndex] = *valueRange; 1179 memory[read()] = &valueRangeMemory[rangeIndex]; 1180 } else { 1181 memory[read()] = result.getAsOpaquePointer(); 1182 } 1183 } 1184 1185 // Copy over any underlying storage allocated for result ranges. 1186 for (auto &it : results.getAllocatedTypeRanges()) 1187 allocatedTypeRangeMemory.push_back(std::move(it)); 1188 for (auto &it : results.getAllocatedValueRanges()) 1189 allocatedValueRangeMemory.push_back(std::move(it)); 1190 } 1191 1192 void ByteCodeExecutor::executeAreEqual() { 1193 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1194 const void *lhs = read<const void *>(); 1195 const void *rhs = read<const void *>(); 1196 1197 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1198 selectJump(lhs == rhs); 1199 } 1200 1201 void ByteCodeExecutor::executeAreRangesEqual() { 1202 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1203 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1204 const void *lhs = read<const void *>(); 1205 const void *rhs = read<const void *>(); 1206 1207 switch (valueKind) { 1208 case PDLValue::Kind::TypeRange: { 1209 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1210 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1211 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1212 selectJump(*lhsRange == *rhsRange); 1213 break; 1214 } 1215 case PDLValue::Kind::ValueRange: { 1216 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1217 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1218 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1219 selectJump(*lhsRange == *rhsRange); 1220 break; 1221 } 1222 default: 1223 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1224 } 1225 } 1226 1227 void ByteCodeExecutor::executeBranch() { 1228 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1229 curCodeIt = &code[read<ByteCodeAddr>()]; 1230 } 1231 1232 void ByteCodeExecutor::executeCheckOperandCount() { 1233 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1234 Operation *op = read<Operation *>(); 1235 uint32_t expectedCount = read<uint32_t>(); 1236 bool compareAtLeast = read(); 1237 1238 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1239 << " * Expected: " << expectedCount << "\n" 1240 << " * Comparator: " 1241 << (compareAtLeast ? ">=" : "==") << "\n"); 1242 if (compareAtLeast) 1243 selectJump(op->getNumOperands() >= expectedCount); 1244 else 1245 selectJump(op->getNumOperands() == expectedCount); 1246 } 1247 1248 void ByteCodeExecutor::executeCheckOperationName() { 1249 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1250 Operation *op = read<Operation *>(); 1251 OperationName expectedName = read<OperationName>(); 1252 1253 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1254 << " * Expected: \"" << expectedName << "\"\n"); 1255 selectJump(op->getName() == expectedName); 1256 } 1257 1258 void ByteCodeExecutor::executeCheckResultCount() { 1259 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1260 Operation *op = read<Operation *>(); 1261 uint32_t expectedCount = read<uint32_t>(); 1262 bool compareAtLeast = read(); 1263 1264 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1265 << " * Expected: " << expectedCount << "\n" 1266 << " * Comparator: " 1267 << (compareAtLeast ? ">=" : "==") << "\n"); 1268 if (compareAtLeast) 1269 selectJump(op->getNumResults() >= expectedCount); 1270 else 1271 selectJump(op->getNumResults() == expectedCount); 1272 } 1273 1274 void ByteCodeExecutor::executeCheckTypes() { 1275 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1276 TypeRange *lhs = read<TypeRange *>(); 1277 Attribute rhs = read<Attribute>(); 1278 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1279 1280 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1281 } 1282 1283 void ByteCodeExecutor::executeCreateTypes() { 1284 LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 1285 unsigned memIndex = read(); 1286 unsigned rangeIndex = read(); 1287 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1288 1289 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1290 1291 // Allocate a buffer for this type range. 1292 llvm::OwningArrayRef<Type> storage(typesAttr.size()); 1293 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 1294 allocatedTypeRangeMemory.emplace_back(std::move(storage)); 1295 1296 // Assign this to the range slot and use the range as the value for the 1297 // memory index. 1298 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 1299 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1300 } 1301 1302 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1303 Location mainRewriteLoc) { 1304 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1305 1306 unsigned memIndex = read(); 1307 OperationState state(mainRewriteLoc, read<OperationName>()); 1308 readValueList(state.operands); 1309 for (unsigned i = 0, e = read(); i != e; ++i) { 1310 Identifier name = read<Identifier>(); 1311 if (Attribute attr = read<Attribute>()) 1312 state.addAttribute(name, attr); 1313 } 1314 1315 for (unsigned i = 0, e = read(); i != e; ++i) { 1316 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1317 state.types.push_back(read<Type>()); 1318 continue; 1319 } 1320 1321 // If we find a null range, this signals that the types are infered. 1322 if (TypeRange *resultTypes = read<TypeRange *>()) { 1323 state.types.append(resultTypes->begin(), resultTypes->end()); 1324 continue; 1325 } 1326 1327 // Handle the case where the operation has inferred types. 1328 InferTypeOpInterface::Concept *concept = 1329 state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>(); 1330 1331 // TODO: Handle failure. 1332 state.types.clear(); 1333 if (failed(concept->inferReturnTypes( 1334 state.getContext(), state.location, state.operands, 1335 state.attributes.getDictionary(state.getContext()), state.regions, 1336 state.types))) 1337 return; 1338 break; 1339 } 1340 1341 Operation *resultOp = rewriter.createOperation(state); 1342 memory[memIndex] = resultOp; 1343 1344 LLVM_DEBUG({ 1345 llvm::dbgs() << " * Attributes: " 1346 << state.attributes.getDictionary(state.getContext()) 1347 << "\n * Operands: "; 1348 llvm::interleaveComma(state.operands, llvm::dbgs()); 1349 llvm::dbgs() << "\n * Result Types: "; 1350 llvm::interleaveComma(state.types, llvm::dbgs()); 1351 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1352 }); 1353 } 1354 1355 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1356 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1357 Operation *op = read<Operation *>(); 1358 1359 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1360 rewriter.eraseOp(op); 1361 } 1362 1363 void ByteCodeExecutor::executeGetAttribute() { 1364 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1365 unsigned memIndex = read(); 1366 Operation *op = read<Operation *>(); 1367 Identifier attrName = read<Identifier>(); 1368 Attribute attr = op->getAttr(attrName); 1369 1370 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1371 << " * Attribute: " << attrName << "\n" 1372 << " * Result: " << attr << "\n"); 1373 memory[memIndex] = attr.getAsOpaquePointer(); 1374 } 1375 1376 void ByteCodeExecutor::executeGetAttributeType() { 1377 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1378 unsigned memIndex = read(); 1379 Attribute attr = read<Attribute>(); 1380 Type type = attr ? attr.getType() : Type(); 1381 1382 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1383 << " * Result: " << type << "\n"); 1384 memory[memIndex] = type.getAsOpaquePointer(); 1385 } 1386 1387 void ByteCodeExecutor::executeGetDefiningOp() { 1388 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1389 unsigned memIndex = read(); 1390 Operation *op = nullptr; 1391 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1392 Value value = read<Value>(); 1393 if (value) 1394 op = value.getDefiningOp(); 1395 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1396 } else { 1397 ValueRange *values = read<ValueRange *>(); 1398 if (values && !values->empty()) { 1399 op = values->front().getDefiningOp(); 1400 } 1401 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1402 } 1403 1404 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1405 memory[memIndex] = op; 1406 } 1407 1408 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1409 Operation *op = read<Operation *>(); 1410 unsigned memIndex = read(); 1411 Value operand = 1412 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1413 1414 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1415 << " * Index: " << index << "\n" 1416 << " * Result: " << operand << "\n"); 1417 memory[memIndex] = operand.getAsOpaquePointer(); 1418 } 1419 1420 /// This function is the internal implementation of `GetResults` and 1421 /// `GetOperands` that provides support for extracting a value range from the 1422 /// given operation. 1423 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1424 static void * 1425 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1426 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1427 MutableArrayRef<ValueRange> &valueRangeMemory) { 1428 // Check for the sentinel index that signals that all values should be 1429 // returned. 1430 if (index == std::numeric_limits<uint32_t>::max()) { 1431 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1432 // `values` is already the full value range. 1433 1434 // Otherwise, check to see if this operation uses AttrSizedSegments. 1435 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1436 LLVM_DEBUG(llvm::dbgs() 1437 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1438 1439 auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 1440 if (!segmentAttr || segmentAttr.getNumElements() <= index) 1441 return nullptr; 1442 1443 auto segments = segmentAttr.getValues<int32_t>(); 1444 unsigned startIndex = 1445 std::accumulate(segments.begin(), segments.begin() + index, 0); 1446 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1447 1448 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1449 << *std::next(segments.begin(), index) << "]\n"); 1450 1451 // Otherwise, assume this is the last operand group of the operation. 1452 // FIXME: We currently don't support operations with 1453 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1454 // have a way to detect it's presence. 1455 } else if (values.size() >= index) { 1456 LLVM_DEBUG(llvm::dbgs() 1457 << " * Treating values as trailing variadic range\n"); 1458 values = values.drop_front(index); 1459 1460 // If we couldn't detect a way to compute the values, bail out. 1461 } else { 1462 return nullptr; 1463 } 1464 1465 // If the range index is valid, we are returning a range. 1466 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1467 valueRangeMemory[rangeIndex] = values; 1468 return &valueRangeMemory[rangeIndex]; 1469 } 1470 1471 // If a range index wasn't provided, the range is required to be non-variadic. 1472 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1473 } 1474 1475 void ByteCodeExecutor::executeGetOperands() { 1476 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1477 unsigned index = read<uint32_t>(); 1478 Operation *op = read<Operation *>(); 1479 ByteCodeField rangeIndex = read(); 1480 1481 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1482 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1483 valueRangeMemory); 1484 if (!result) 1485 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1486 memory[read()] = result; 1487 } 1488 1489 void ByteCodeExecutor::executeGetResult(unsigned index) { 1490 Operation *op = read<Operation *>(); 1491 unsigned memIndex = read(); 1492 OpResult result = 1493 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1494 1495 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1496 << " * Index: " << index << "\n" 1497 << " * Result: " << result << "\n"); 1498 memory[memIndex] = result.getAsOpaquePointer(); 1499 } 1500 1501 void ByteCodeExecutor::executeGetResults() { 1502 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1503 unsigned index = read<uint32_t>(); 1504 Operation *op = read<Operation *>(); 1505 ByteCodeField rangeIndex = read(); 1506 1507 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1508 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1509 valueRangeMemory); 1510 if (!result) 1511 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1512 memory[read()] = result; 1513 } 1514 1515 void ByteCodeExecutor::executeGetValueType() { 1516 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1517 unsigned memIndex = read(); 1518 Value value = read<Value>(); 1519 Type type = value ? value.getType() : Type(); 1520 1521 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1522 << " * Result: " << type << "\n"); 1523 memory[memIndex] = type.getAsOpaquePointer(); 1524 } 1525 1526 void ByteCodeExecutor::executeGetValueRangeTypes() { 1527 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1528 unsigned memIndex = read(); 1529 unsigned rangeIndex = read(); 1530 ValueRange *values = read<ValueRange *>(); 1531 if (!values) { 1532 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1533 memory[memIndex] = nullptr; 1534 return; 1535 } 1536 1537 LLVM_DEBUG({ 1538 llvm::dbgs() << " * Values (" << values->size() << "): "; 1539 llvm::interleaveComma(*values, llvm::dbgs()); 1540 llvm::dbgs() << "\n * Result: "; 1541 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1542 llvm::dbgs() << "\n"; 1543 }); 1544 typeRangeMemory[rangeIndex] = values->getType(); 1545 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1546 } 1547 1548 void ByteCodeExecutor::executeIsNotNull() { 1549 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1550 const void *value = read<const void *>(); 1551 1552 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1553 selectJump(value != nullptr); 1554 } 1555 1556 void ByteCodeExecutor::executeRecordMatch( 1557 PatternRewriter &rewriter, 1558 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1559 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1560 unsigned patternIndex = read(); 1561 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1562 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1563 1564 // If the benefit of the pattern is impossible, skip the processing of the 1565 // rest of the pattern. 1566 if (benefit.isImpossibleToMatch()) { 1567 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1568 curCodeIt = dest; 1569 return; 1570 } 1571 1572 // Create a fused location containing the locations of each of the 1573 // operations used in the match. This will be used as the location for 1574 // created operations during the rewrite that don't already have an 1575 // explicit location set. 1576 unsigned numMatchLocs = read(); 1577 SmallVector<Location, 4> matchLocs; 1578 matchLocs.reserve(numMatchLocs); 1579 for (unsigned i = 0; i != numMatchLocs; ++i) 1580 matchLocs.push_back(read<Operation *>()->getLoc()); 1581 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1582 1583 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1584 << " * Location: " << matchLoc << "\n"); 1585 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1586 PDLByteCode::MatchResult &match = matches.back(); 1587 1588 // Record all of the inputs to the match. If any of the inputs are ranges, we 1589 // will also need to remap the range pointer to memory stored in the match 1590 // state. 1591 unsigned numInputs = read(); 1592 match.values.reserve(numInputs); 1593 match.typeRangeValues.reserve(numInputs); 1594 match.valueRangeValues.reserve(numInputs); 1595 for (unsigned i = 0; i < numInputs; ++i) { 1596 switch (read<PDLValue::Kind>()) { 1597 case PDLValue::Kind::TypeRange: 1598 match.typeRangeValues.push_back(*read<TypeRange *>()); 1599 match.values.push_back(&match.typeRangeValues.back()); 1600 break; 1601 case PDLValue::Kind::ValueRange: 1602 match.valueRangeValues.push_back(*read<ValueRange *>()); 1603 match.values.push_back(&match.valueRangeValues.back()); 1604 break; 1605 default: 1606 match.values.push_back(read<const void *>()); 1607 break; 1608 } 1609 } 1610 curCodeIt = dest; 1611 } 1612 1613 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1614 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1615 Operation *op = read<Operation *>(); 1616 SmallVector<Value, 16> args; 1617 readValueList(args); 1618 1619 LLVM_DEBUG({ 1620 llvm::dbgs() << " * Operation: " << *op << "\n" 1621 << " * Values: "; 1622 llvm::interleaveComma(args, llvm::dbgs()); 1623 llvm::dbgs() << "\n"; 1624 }); 1625 rewriter.replaceOp(op, args); 1626 } 1627 1628 void ByteCodeExecutor::executeSwitchAttribute() { 1629 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1630 Attribute value = read<Attribute>(); 1631 ArrayAttr cases = read<ArrayAttr>(); 1632 handleSwitch(value, cases); 1633 } 1634 1635 void ByteCodeExecutor::executeSwitchOperandCount() { 1636 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1637 Operation *op = read<Operation *>(); 1638 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1639 1640 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1641 handleSwitch(op->getNumOperands(), cases); 1642 } 1643 1644 void ByteCodeExecutor::executeSwitchOperationName() { 1645 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1646 OperationName value = read<Operation *>()->getName(); 1647 size_t caseCount = read(); 1648 1649 // The operation names are stored in-line, so to print them out for 1650 // debugging purposes we need to read the array before executing the 1651 // switch so that we can display all of the possible values. 1652 LLVM_DEBUG({ 1653 const ByteCodeField *prevCodeIt = curCodeIt; 1654 llvm::dbgs() << " * Value: " << value << "\n" 1655 << " * Cases: "; 1656 llvm::interleaveComma( 1657 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1658 [&](size_t) { return read<OperationName>(); }), 1659 llvm::dbgs()); 1660 llvm::dbgs() << "\n"; 1661 curCodeIt = prevCodeIt; 1662 }); 1663 1664 // Try to find the switch value within any of the cases. 1665 for (size_t i = 0; i != caseCount; ++i) { 1666 if (read<OperationName>() == value) { 1667 curCodeIt += (caseCount - i - 1); 1668 return selectJump(i + 1); 1669 } 1670 } 1671 selectJump(size_t(0)); 1672 } 1673 1674 void ByteCodeExecutor::executeSwitchResultCount() { 1675 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1676 Operation *op = read<Operation *>(); 1677 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1678 1679 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1680 handleSwitch(op->getNumResults(), cases); 1681 } 1682 1683 void ByteCodeExecutor::executeSwitchType() { 1684 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 1685 Type value = read<Type>(); 1686 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 1687 handleSwitch(value, cases); 1688 } 1689 1690 void ByteCodeExecutor::executeSwitchTypes() { 1691 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 1692 TypeRange *value = read<TypeRange *>(); 1693 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 1694 if (!value) { 1695 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 1696 return selectJump(size_t(0)); 1697 } 1698 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 1699 return value == caseValue.getAsValueRange<TypeAttr>(); 1700 }); 1701 } 1702 1703 void ByteCodeExecutor::execute( 1704 PatternRewriter &rewriter, 1705 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 1706 Optional<Location> mainRewriteLoc) { 1707 while (true) { 1708 OpCode opCode = static_cast<OpCode>(read()); 1709 switch (opCode) { 1710 case ApplyConstraint: 1711 executeApplyConstraint(rewriter); 1712 break; 1713 case ApplyRewrite: 1714 executeApplyRewrite(rewriter); 1715 break; 1716 case AreEqual: 1717 executeAreEqual(); 1718 break; 1719 case AreRangesEqual: 1720 executeAreRangesEqual(); 1721 break; 1722 case Branch: 1723 executeBranch(); 1724 break; 1725 case CheckOperandCount: 1726 executeCheckOperandCount(); 1727 break; 1728 case CheckOperationName: 1729 executeCheckOperationName(); 1730 break; 1731 case CheckResultCount: 1732 executeCheckResultCount(); 1733 break; 1734 case CheckTypes: 1735 executeCheckTypes(); 1736 break; 1737 case CreateOperation: 1738 executeCreateOperation(rewriter, *mainRewriteLoc); 1739 break; 1740 case CreateTypes: 1741 executeCreateTypes(); 1742 break; 1743 case EraseOp: 1744 executeEraseOp(rewriter); 1745 break; 1746 case Finalize: 1747 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); 1748 return; 1749 case GetAttribute: 1750 executeGetAttribute(); 1751 break; 1752 case GetAttributeType: 1753 executeGetAttributeType(); 1754 break; 1755 case GetDefiningOp: 1756 executeGetDefiningOp(); 1757 break; 1758 case GetOperand0: 1759 case GetOperand1: 1760 case GetOperand2: 1761 case GetOperand3: { 1762 unsigned index = opCode - GetOperand0; 1763 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 1764 executeGetOperand(index); 1765 break; 1766 } 1767 case GetOperandN: 1768 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 1769 executeGetOperand(read<uint32_t>()); 1770 break; 1771 case GetOperands: 1772 executeGetOperands(); 1773 break; 1774 case GetResult0: 1775 case GetResult1: 1776 case GetResult2: 1777 case GetResult3: { 1778 unsigned index = opCode - GetResult0; 1779 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 1780 executeGetResult(index); 1781 break; 1782 } 1783 case GetResultN: 1784 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 1785 executeGetResult(read<uint32_t>()); 1786 break; 1787 case GetResults: 1788 executeGetResults(); 1789 break; 1790 case GetValueType: 1791 executeGetValueType(); 1792 break; 1793 case GetValueRangeTypes: 1794 executeGetValueRangeTypes(); 1795 break; 1796 case IsNotNull: 1797 executeIsNotNull(); 1798 break; 1799 case RecordMatch: 1800 assert(matches && 1801 "expected matches to be provided when executing the matcher"); 1802 executeRecordMatch(rewriter, *matches); 1803 break; 1804 case ReplaceOp: 1805 executeReplaceOp(rewriter); 1806 break; 1807 case SwitchAttribute: 1808 executeSwitchAttribute(); 1809 break; 1810 case SwitchOperandCount: 1811 executeSwitchOperandCount(); 1812 break; 1813 case SwitchOperationName: 1814 executeSwitchOperationName(); 1815 break; 1816 case SwitchResultCount: 1817 executeSwitchResultCount(); 1818 break; 1819 case SwitchType: 1820 executeSwitchType(); 1821 break; 1822 case SwitchTypes: 1823 executeSwitchTypes(); 1824 break; 1825 } 1826 LLVM_DEBUG(llvm::dbgs() << "\n"); 1827 } 1828 } 1829 1830 /// Run the pattern matcher on the given root operation, collecting the matched 1831 /// patterns in `matches`. 1832 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 1833 SmallVectorImpl<MatchResult> &matches, 1834 PDLByteCodeMutableState &state) const { 1835 // The first memory slot is always the root operation. 1836 state.memory[0] = op; 1837 1838 // The matcher function always starts at code address 0. 1839 ByteCodeExecutor executor( 1840 matcherByteCode.data(), state.memory, state.typeRangeMemory, 1841 state.allocatedTypeRangeMemory, state.valueRangeMemory, 1842 state.allocatedValueRangeMemory, uniquedData, matcherByteCode, 1843 state.currentPatternBenefits, patterns, constraintFunctions, 1844 rewriteFunctions); 1845 executor.execute(rewriter, &matches); 1846 1847 // Order the found matches by benefit. 1848 std::stable_sort(matches.begin(), matches.end(), 1849 [](const MatchResult &lhs, const MatchResult &rhs) { 1850 return lhs.benefit > rhs.benefit; 1851 }); 1852 } 1853 1854 /// Run the rewriter of the given pattern on the root operation `op`. 1855 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 1856 PDLByteCodeMutableState &state) const { 1857 // The arguments of the rewrite function are stored at the start of the 1858 // memory buffer. 1859 llvm::copy(match.values, state.memory.begin()); 1860 1861 ByteCodeExecutor executor( 1862 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 1863 state.typeRangeMemory, state.allocatedTypeRangeMemory, 1864 state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData, 1865 rewriterByteCode, state.currentPatternBenefits, patterns, 1866 constraintFunctions, rewriteFunctions); 1867 executor.execute(rewriter, /*matches=*/nullptr, match.location); 1868 } 1869