1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 27 #define DEBUG_TYPE "pdl-bytecode" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37 ByteCodeAddr rewriterAddr) { 38 SmallVector<StringRef, 8> generatedOps; 39 if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 40 generatedOps = 41 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42 43 PatternBenefit benefit = matchOp.benefit(); 44 MLIRContext *ctx = matchOp.getContext(); 45 46 // Check to see if this is pattern matches a specific operation type. 47 if (Optional<StringRef> rootKind = matchOp.rootKind()) 48 return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, 49 generatedOps); 50 return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, 51 generatedOps); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // PDLByteCodeMutableState 56 //===----------------------------------------------------------------------===// 57 58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59 /// to the position of the pattern within the range returned by 60 /// `PDLByteCode::getPatterns`. 61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62 PatternBenefit benefit) { 63 currentPatternBenefits[patternIndex] = benefit; 64 } 65 66 /// Cleanup any allocated state after a full match/rewrite has been completed. 67 /// This method should be called irregardless of whether the match+rewrite was a 68 /// success or not. 69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 70 allocatedTypeRangeMemory.clear(); 71 allocatedValueRangeMemory.clear(); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Bytecode OpCodes 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 enum OpCode : ByteCodeField { 80 /// Apply an externally registered constraint. 81 ApplyConstraint, 82 /// Apply an externally registered rewrite. 83 ApplyRewrite, 84 /// Check if two generic values are equal. 85 AreEqual, 86 /// Check if two ranges are equal. 87 AreRangesEqual, 88 /// Unconditional branch. 89 Branch, 90 /// Compare the operand count of an operation with a constant. 91 CheckOperandCount, 92 /// Compare the name of an operation with a constant. 93 CheckOperationName, 94 /// Compare the result count of an operation with a constant. 95 CheckResultCount, 96 /// Compare a range of types to a constant range of types. 97 CheckTypes, 98 /// Create an operation. 99 CreateOperation, 100 /// Create a range of types. 101 CreateTypes, 102 /// Erase an operation. 103 EraseOp, 104 /// Terminate a matcher or rewrite sequence. 105 Finalize, 106 /// Get a specific attribute of an operation. 107 GetAttribute, 108 /// Get the type of an attribute. 109 GetAttributeType, 110 /// Get the defining operation of a value. 111 GetDefiningOp, 112 /// Get a specific operand of an operation. 113 GetOperand0, 114 GetOperand1, 115 GetOperand2, 116 GetOperand3, 117 GetOperandN, 118 /// Get a specific operand group of an operation. 119 GetOperands, 120 /// Get a specific result of an operation. 121 GetResult0, 122 GetResult1, 123 GetResult2, 124 GetResult3, 125 GetResultN, 126 /// Get a specific result group of an operation. 127 GetResults, 128 /// Get the type of a value. 129 GetValueType, 130 /// Get the types of a value range. 131 GetValueRangeTypes, 132 /// Check if a generic value is not null. 133 IsNotNull, 134 /// Record a successful pattern match. 135 RecordMatch, 136 /// Replace an operation. 137 ReplaceOp, 138 /// Compare an attribute with a set of constants. 139 SwitchAttribute, 140 /// Compare the operand count of an operation with a set of constants. 141 SwitchOperandCount, 142 /// Compare the name of an operation with a set of constants. 143 SwitchOperationName, 144 /// Compare the result count of an operation with a set of constants. 145 SwitchResultCount, 146 /// Compare a type with a set of constants. 147 SwitchType, 148 /// Compare a range of types with a set of constants. 149 SwitchTypes, 150 }; 151 } // end anonymous namespace 152 153 //===----------------------------------------------------------------------===// 154 // ByteCode Generation 155 //===----------------------------------------------------------------------===// 156 157 //===----------------------------------------------------------------------===// 158 // Generator 159 160 namespace { 161 struct ByteCodeWriter; 162 163 /// This class represents the main generator for the pattern bytecode. 164 class Generator { 165 public: 166 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 167 SmallVectorImpl<ByteCodeField> &matcherByteCode, 168 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 169 SmallVectorImpl<PDLByteCodePattern> &patterns, 170 ByteCodeField &maxValueMemoryIndex, 171 ByteCodeField &maxTypeRangeMemoryIndex, 172 ByteCodeField &maxValueRangeMemoryIndex, 173 llvm::StringMap<PDLConstraintFunction> &constraintFns, 174 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 175 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 176 rewriterByteCode(rewriterByteCode), patterns(patterns), 177 maxValueMemoryIndex(maxValueMemoryIndex), 178 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 179 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) { 180 for (auto it : llvm::enumerate(constraintFns)) 181 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 182 for (auto it : llvm::enumerate(rewriteFns)) 183 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 184 } 185 186 /// Generate the bytecode for the given PDL interpreter module. 187 void generate(ModuleOp module); 188 189 /// Return the memory index to use for the given value. 190 ByteCodeField &getMemIndex(Value value) { 191 assert(valueToMemIndex.count(value) && 192 "expected memory index to be assigned"); 193 return valueToMemIndex[value]; 194 } 195 196 /// Return the range memory index used to store the given range value. 197 ByteCodeField &getRangeStorageIndex(Value value) { 198 assert(valueToRangeIndex.count(value) && 199 "expected range index to be assigned"); 200 return valueToRangeIndex[value]; 201 } 202 203 /// Return an index to use when referring to the given data that is uniqued in 204 /// the MLIR context. 205 template <typename T> 206 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 207 getMemIndex(T val) { 208 const void *opaqueVal = val.getAsOpaquePointer(); 209 210 // Get or insert a reference to this value. 211 auto it = uniquedDataToMemIndex.try_emplace( 212 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 213 if (it.second) 214 uniquedData.push_back(opaqueVal); 215 return it.first->second; 216 } 217 218 private: 219 /// Allocate memory indices for the results of operations within the matcher 220 /// and rewriters. 221 void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 222 223 /// Generate the bytecode for the given operation. 224 void generate(Operation *op, ByteCodeWriter &writer); 225 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 226 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 227 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 228 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 229 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 230 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 231 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 232 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 233 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 234 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 235 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 236 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 237 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 238 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 239 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 240 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 241 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 242 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 243 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 244 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 245 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 246 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 247 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 248 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 249 void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); 250 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 251 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 252 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 253 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 254 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 255 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 256 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 257 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 258 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 259 260 /// Mapping from value to its corresponding memory index. 261 DenseMap<Value, ByteCodeField> valueToMemIndex; 262 263 /// Mapping from a range value to its corresponding range storage index. 264 DenseMap<Value, ByteCodeField> valueToRangeIndex; 265 266 /// Mapping from the name of an externally registered rewrite to its index in 267 /// the bytecode registry. 268 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 269 270 /// Mapping from the name of an externally registered constraint to its index 271 /// in the bytecode registry. 272 llvm::StringMap<ByteCodeField> constraintToMemIndex; 273 274 /// Mapping from rewriter function name to the bytecode address of the 275 /// rewriter function in byte. 276 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 277 278 /// Mapping from a uniqued storage object to its memory index within 279 /// `uniquedData`. 280 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 281 282 /// The current MLIR context. 283 MLIRContext *ctx; 284 285 /// Data of the ByteCode class to be populated. 286 std::vector<const void *> &uniquedData; 287 SmallVectorImpl<ByteCodeField> &matcherByteCode; 288 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 289 SmallVectorImpl<PDLByteCodePattern> &patterns; 290 ByteCodeField &maxValueMemoryIndex; 291 ByteCodeField &maxTypeRangeMemoryIndex; 292 ByteCodeField &maxValueRangeMemoryIndex; 293 }; 294 295 /// This class provides utilities for writing a bytecode stream. 296 struct ByteCodeWriter { 297 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 298 : bytecode(bytecode), generator(generator) {} 299 300 /// Append a field to the bytecode. 301 void append(ByteCodeField field) { bytecode.push_back(field); } 302 void append(OpCode opCode) { bytecode.push_back(opCode); } 303 304 /// Append an address to the bytecode. 305 void append(ByteCodeAddr field) { 306 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 307 "unexpected ByteCode address size"); 308 309 ByteCodeField fieldParts[2]; 310 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 311 bytecode.append({fieldParts[0], fieldParts[1]}); 312 } 313 314 /// Append a successor range to the bytecode, the exact address will need to 315 /// be resolved later. 316 void append(SuccessorRange successors) { 317 // Add back references to the any successors so that the address can be 318 // resolved later. 319 for (Block *successor : successors) { 320 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 321 append(ByteCodeAddr(0)); 322 } 323 } 324 325 /// Append a range of values that will be read as generic PDLValues. 326 void appendPDLValueList(OperandRange values) { 327 bytecode.push_back(values.size()); 328 for (Value value : values) 329 appendPDLValue(value); 330 } 331 332 /// Append a value as a PDLValue. 333 void appendPDLValue(Value value) { 334 appendPDLValueKind(value); 335 append(value); 336 } 337 338 /// Append the PDLValue::Kind of the given value. 339 void appendPDLValueKind(Value value) { 340 // Append the type of the value in addition to the value itself. 341 PDLValue::Kind kind = 342 TypeSwitch<Type, PDLValue::Kind>(value.getType()) 343 .Case<pdl::AttributeType>( 344 [](Type) { return PDLValue::Kind::Attribute; }) 345 .Case<pdl::OperationType>( 346 [](Type) { return PDLValue::Kind::Operation; }) 347 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 348 if (rangeTy.getElementType().isa<pdl::TypeType>()) 349 return PDLValue::Kind::TypeRange; 350 return PDLValue::Kind::ValueRange; 351 }) 352 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 353 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 354 bytecode.push_back(static_cast<ByteCodeField>(kind)); 355 } 356 357 /// Check if the given class `T` has an iterator type. 358 template <typename T, typename... Args> 359 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 360 361 /// Append a value that will be stored in a memory slot and not inline within 362 /// the bytecode. 363 template <typename T> 364 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 365 std::is_pointer<T>::value> 366 append(T value) { 367 bytecode.push_back(generator.getMemIndex(value)); 368 } 369 370 /// Append a range of values. 371 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 372 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 373 append(T range) { 374 bytecode.push_back(llvm::size(range)); 375 for (auto it : range) 376 append(it); 377 } 378 379 /// Append a variadic number of fields to the bytecode. 380 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 381 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 382 append(field); 383 append(field2, fields...); 384 } 385 386 /// Successor references in the bytecode that have yet to be resolved. 387 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 388 389 /// The underlying bytecode buffer. 390 SmallVectorImpl<ByteCodeField> &bytecode; 391 392 /// The main generator producing PDL. 393 Generator &generator; 394 }; 395 396 /// This class represents a live range of PDL Interpreter values, containing 397 /// information about when values are live within a match/rewrite. 398 struct ByteCodeLiveRange { 399 using Set = llvm::IntervalMap<ByteCodeField, char, 16>; 400 using Allocator = Set::Allocator; 401 402 ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {} 403 404 /// Union this live range with the one provided. 405 void unionWith(const ByteCodeLiveRange &rhs) { 406 for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it) 407 liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0); 408 } 409 410 /// Returns true if this range overlaps with the one provided. 411 bool overlaps(const ByteCodeLiveRange &rhs) const { 412 return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid(); 413 } 414 415 /// A map representing the ranges of the match/rewrite that a value is live in 416 /// the interpreter. 417 llvm::IntervalMap<ByteCodeField, char, 16> liveness; 418 419 /// The type range storage index for this range. 420 Optional<unsigned> typeRangeIndex; 421 422 /// The value range storage index for this range. 423 Optional<unsigned> valueRangeIndex; 424 }; 425 } // end anonymous namespace 426 427 void Generator::generate(ModuleOp module) { 428 FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 429 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 430 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 431 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 432 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 433 434 // Allocate memory indices for the results of operations within the matcher 435 // and rewriters. 436 allocateMemoryIndices(matcherFunc, rewriterModule); 437 438 // Generate code for the rewriter functions. 439 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 440 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 441 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 442 for (Operation &op : rewriterFunc.getOps()) 443 generate(&op, rewriterByteCodeWriter); 444 } 445 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 446 "unexpected branches in rewriter function"); 447 448 // Generate code for the matcher function. 449 DenseMap<Block *, ByteCodeAddr> blockToAddr; 450 llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody()); 451 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 452 for (Block *block : rpot) { 453 // Keep track of where this block begins within the matcher function. 454 blockToAddr.try_emplace(block, matcherByteCode.size()); 455 for (Operation &op : *block) 456 generate(&op, matcherByteCodeWriter); 457 } 458 459 // Resolve successor references in the matcher. 460 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 461 ByteCodeAddr addr = blockToAddr[it.first]; 462 for (unsigned offsetToFix : it.second) 463 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 464 } 465 } 466 467 void Generator::allocateMemoryIndices(FuncOp matcherFunc, 468 ModuleOp rewriterModule) { 469 // Rewriters use simplistic allocation scheme that simply assigns an index to 470 // each result. 471 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 472 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 473 auto processRewriterValue = [&](Value val) { 474 valueToMemIndex.try_emplace(val, index++); 475 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 476 Type elementTy = rangeType.getElementType(); 477 if (elementTy.isa<pdl::TypeType>()) 478 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 479 else if (elementTy.isa<pdl::ValueType>()) 480 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 481 } 482 }; 483 484 for (BlockArgument arg : rewriterFunc.getArguments()) 485 processRewriterValue(arg); 486 rewriterFunc.getBody().walk([&](Operation *op) { 487 for (Value result : op->getResults()) 488 processRewriterValue(result); 489 }); 490 if (index > maxValueMemoryIndex) 491 maxValueMemoryIndex = index; 492 if (typeRangeIndex > maxTypeRangeMemoryIndex) 493 maxTypeRangeMemoryIndex = typeRangeIndex; 494 if (valueRangeIndex > maxValueRangeMemoryIndex) 495 maxValueRangeMemoryIndex = valueRangeIndex; 496 } 497 498 // The matcher function uses a more sophisticated numbering that tries to 499 // minimize the number of memory indices assigned. This is done by determining 500 // a live range of the values within the matcher, then the allocation is just 501 // finding the minimal number of overlapping live ranges. This is essentially 502 // a simplified form of register allocation where we don't necessarily have a 503 // limited number of registers, but we still want to minimize the number used. 504 DenseMap<Operation *, ByteCodeField> opToIndex; 505 matcherFunc.getBody().walk([&](Operation *op) { 506 opToIndex.insert(std::make_pair(op, opToIndex.size())); 507 }); 508 509 // Liveness info for each of the defs within the matcher. 510 ByteCodeLiveRange::Allocator allocator; 511 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 512 513 // Assign the root operation being matched to slot 0. 514 BlockArgument rootOpArg = matcherFunc.getArgument(0); 515 valueToMemIndex[rootOpArg] = 0; 516 517 // Walk each of the blocks, computing the def interval that the value is used. 518 Liveness matcherLiveness(matcherFunc); 519 for (Block &block : matcherFunc.getBody()) { 520 const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); 521 assert(info && "expected liveness info for block"); 522 auto processValue = [&](Value value, Operation *firstUseOrDef) { 523 // We don't need to process the root op argument, this value is always 524 // assigned to the first memory slot. 525 if (value == rootOpArg) 526 return; 527 528 // Set indices for the range of this block that the value is used. 529 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 530 defRangeIt->second.liveness.insert( 531 opToIndex[firstUseOrDef], 532 opToIndex[info->getEndOperation(value, firstUseOrDef)], 533 /*dummyValue*/ 0); 534 535 // Check to see if this value is a range type. 536 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 537 Type eleType = rangeTy.getElementType(); 538 if (eleType.isa<pdl::TypeType>()) 539 defRangeIt->second.typeRangeIndex = 0; 540 else if (eleType.isa<pdl::ValueType>()) 541 defRangeIt->second.valueRangeIndex = 0; 542 } 543 }; 544 545 // Process the live-ins of this block. 546 for (Value liveIn : info->in()) 547 processValue(liveIn, &block.front()); 548 549 // Process any new defs within this block. 550 for (Operation &op : block) 551 for (Value result : op.getResults()) 552 processValue(result, &op); 553 } 554 555 // Greedily allocate memory slots using the computed def live ranges. 556 std::vector<ByteCodeLiveRange> allocatedIndices; 557 ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0; 558 for (auto &defIt : valueDefRanges) { 559 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 560 ByteCodeLiveRange &defRange = defIt.second; 561 562 // Try to allocate to an existing index. 563 for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { 564 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 565 if (!defRange.overlaps(existingRange)) { 566 existingRange.unionWith(defRange); 567 memIndex = existingIndexIt.index() + 1; 568 569 if (defRange.typeRangeIndex) { 570 if (!existingRange.typeRangeIndex) 571 existingRange.typeRangeIndex = numTypeRanges++; 572 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 573 } else if (defRange.valueRangeIndex) { 574 if (!existingRange.valueRangeIndex) 575 existingRange.valueRangeIndex = numValueRanges++; 576 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 577 } 578 break; 579 } 580 } 581 582 // If no existing index could be used, add a new one. 583 if (memIndex == 0) { 584 allocatedIndices.emplace_back(allocator); 585 ByteCodeLiveRange &newRange = allocatedIndices.back(); 586 newRange.unionWith(defRange); 587 588 // Allocate an index for type/value ranges. 589 if (defRange.typeRangeIndex) { 590 newRange.typeRangeIndex = numTypeRanges; 591 valueToRangeIndex[defIt.first] = numTypeRanges++; 592 } else if (defRange.valueRangeIndex) { 593 newRange.valueRangeIndex = numValueRanges; 594 valueToRangeIndex[defIt.first] = numValueRanges++; 595 } 596 597 memIndex = allocatedIndices.size(); 598 ++numIndices; 599 } 600 } 601 602 // Update the max number of indices. 603 if (numIndices > maxValueMemoryIndex) 604 maxValueMemoryIndex = numIndices; 605 if (numTypeRanges > maxTypeRangeMemoryIndex) 606 maxTypeRangeMemoryIndex = numTypeRanges; 607 if (numValueRanges > maxValueRangeMemoryIndex) 608 maxValueRangeMemoryIndex = numValueRanges; 609 } 610 611 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 612 TypeSwitch<Operation *>(op) 613 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 614 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 615 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 616 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 617 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 618 pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp, 619 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, 620 pdl_interp::EraseOp, pdl_interp::FinalizeOp, 621 pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, 622 pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, 623 pdl_interp::GetOperandsOp, pdl_interp::GetResultOp, 624 pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp, 625 pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, 626 pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, 627 pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, 628 pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, 629 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 630 [&](auto interpOp) { this->generate(interpOp, writer); }) 631 .Default([](Operation *) { 632 llvm_unreachable("unknown `pdl_interp` operation"); 633 }); 634 } 635 636 void Generator::generate(pdl_interp::ApplyConstraintOp op, 637 ByteCodeWriter &writer) { 638 assert(constraintToMemIndex.count(op.name()) && 639 "expected index for constraint function"); 640 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 641 op.constParamsAttr()); 642 writer.appendPDLValueList(op.args()); 643 writer.append(op.getSuccessors()); 644 } 645 void Generator::generate(pdl_interp::ApplyRewriteOp op, 646 ByteCodeWriter &writer) { 647 assert(externalRewriterToMemIndex.count(op.name()) && 648 "expected index for rewrite function"); 649 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 650 op.constParamsAttr()); 651 writer.appendPDLValueList(op.args()); 652 653 ResultRange results = op.results(); 654 writer.append(ByteCodeField(results.size())); 655 for (Value result : results) { 656 // In debug mode we also record the expected kind of the result, so that we 657 // can provide extra verification of the native rewrite function. 658 #ifndef NDEBUG 659 writer.appendPDLValueKind(result); 660 #endif 661 662 // Range results also need to append the range storage index. 663 if (result.getType().isa<pdl::RangeType>()) 664 writer.append(getRangeStorageIndex(result)); 665 writer.append(result); 666 } 667 } 668 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 669 Value lhs = op.lhs(); 670 if (lhs.getType().isa<pdl::RangeType>()) { 671 writer.append(OpCode::AreRangesEqual); 672 writer.appendPDLValueKind(lhs); 673 writer.append(op.lhs(), op.rhs(), op.getSuccessors()); 674 return; 675 } 676 677 writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors()); 678 } 679 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 680 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 681 } 682 void Generator::generate(pdl_interp::CheckAttributeOp op, 683 ByteCodeWriter &writer) { 684 writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 685 op.getSuccessors()); 686 } 687 void Generator::generate(pdl_interp::CheckOperandCountOp op, 688 ByteCodeWriter &writer) { 689 writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 690 static_cast<ByteCodeField>(op.compareAtLeast()), 691 op.getSuccessors()); 692 } 693 void Generator::generate(pdl_interp::CheckOperationNameOp op, 694 ByteCodeWriter &writer) { 695 writer.append(OpCode::CheckOperationName, op.operation(), 696 OperationName(op.name(), ctx), op.getSuccessors()); 697 } 698 void Generator::generate(pdl_interp::CheckResultCountOp op, 699 ByteCodeWriter &writer) { 700 writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 701 static_cast<ByteCodeField>(op.compareAtLeast()), 702 op.getSuccessors()); 703 } 704 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 705 writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 706 } 707 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 708 writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); 709 } 710 void Generator::generate(pdl_interp::CreateAttributeOp op, 711 ByteCodeWriter &writer) { 712 // Simply repoint the memory index of the result to the constant. 713 getMemIndex(op.attribute()) = getMemIndex(op.value()); 714 } 715 void Generator::generate(pdl_interp::CreateOperationOp op, 716 ByteCodeWriter &writer) { 717 writer.append(OpCode::CreateOperation, op.operation(), 718 OperationName(op.name(), ctx)); 719 writer.appendPDLValueList(op.operands()); 720 721 // Add the attributes. 722 OperandRange attributes = op.attributes(); 723 writer.append(static_cast<ByteCodeField>(attributes.size())); 724 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) 725 writer.append(std::get<0>(it), std::get<1>(it)); 726 writer.appendPDLValueList(op.types()); 727 } 728 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 729 // Simply repoint the memory index of the result to the constant. 730 getMemIndex(op.result()) = getMemIndex(op.value()); 731 } 732 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 733 writer.append(OpCode::CreateTypes, op.result(), 734 getRangeStorageIndex(op.result()), op.value()); 735 } 736 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 737 writer.append(OpCode::EraseOp, op.operation()); 738 } 739 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 740 writer.append(OpCode::Finalize); 741 } 742 void Generator::generate(pdl_interp::GetAttributeOp op, 743 ByteCodeWriter &writer) { 744 writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 745 op.nameAttr()); 746 } 747 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 748 ByteCodeWriter &writer) { 749 writer.append(OpCode::GetAttributeType, op.result(), op.value()); 750 } 751 void Generator::generate(pdl_interp::GetDefiningOpOp op, 752 ByteCodeWriter &writer) { 753 writer.append(OpCode::GetDefiningOp, op.operation()); 754 writer.appendPDLValue(op.value()); 755 } 756 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 757 uint32_t index = op.index(); 758 if (index < 4) 759 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 760 else 761 writer.append(OpCode::GetOperandN, index); 762 writer.append(op.operation(), op.value()); 763 } 764 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 765 Value result = op.value(); 766 Optional<uint32_t> index = op.index(); 767 writer.append(OpCode::GetOperands, 768 index.getValueOr(std::numeric_limits<uint32_t>::max()), 769 op.operation()); 770 if (result.getType().isa<pdl::RangeType>()) 771 writer.append(getRangeStorageIndex(result)); 772 else 773 writer.append(std::numeric_limits<ByteCodeField>::max()); 774 writer.append(result); 775 } 776 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 777 uint32_t index = op.index(); 778 if (index < 4) 779 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 780 else 781 writer.append(OpCode::GetResultN, index); 782 writer.append(op.operation(), op.value()); 783 } 784 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 785 Value result = op.value(); 786 Optional<uint32_t> index = op.index(); 787 writer.append(OpCode::GetResults, 788 index.getValueOr(std::numeric_limits<uint32_t>::max()), 789 op.operation()); 790 if (result.getType().isa<pdl::RangeType>()) 791 writer.append(getRangeStorageIndex(result)); 792 else 793 writer.append(std::numeric_limits<ByteCodeField>::max()); 794 writer.append(result); 795 } 796 void Generator::generate(pdl_interp::GetValueTypeOp op, 797 ByteCodeWriter &writer) { 798 if (op.getType().isa<pdl::RangeType>()) { 799 Value result = op.result(); 800 writer.append(OpCode::GetValueRangeTypes, result, 801 getRangeStorageIndex(result), op.value()); 802 } else { 803 writer.append(OpCode::GetValueType, op.result(), op.value()); 804 } 805 } 806 807 void Generator::generate(pdl_interp::InferredTypesOp op, 808 ByteCodeWriter &writer) { 809 // InferType maps to a null type as a marker for inferring result types. 810 getMemIndex(op.type()) = getMemIndex(Type()); 811 } 812 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 813 writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 814 } 815 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 816 ByteCodeField patternIndex = patterns.size(); 817 patterns.emplace_back(PDLByteCodePattern::create( 818 op, rewriterToAddr[op.rewriter().getLeafReference().getValue()])); 819 writer.append(OpCode::RecordMatch, patternIndex, 820 SuccessorRange(op.getOperation()), op.matchedOps()); 821 writer.appendPDLValueList(op.inputs()); 822 } 823 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 824 writer.append(OpCode::ReplaceOp, op.operation()); 825 writer.appendPDLValueList(op.replValues()); 826 } 827 void Generator::generate(pdl_interp::SwitchAttributeOp op, 828 ByteCodeWriter &writer) { 829 writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 830 op.getSuccessors()); 831 } 832 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 833 ByteCodeWriter &writer) { 834 writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 835 op.getSuccessors()); 836 } 837 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 838 ByteCodeWriter &writer) { 839 auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 840 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 841 }); 842 writer.append(OpCode::SwitchOperationName, op.operation(), cases, 843 op.getSuccessors()); 844 } 845 void Generator::generate(pdl_interp::SwitchResultCountOp op, 846 ByteCodeWriter &writer) { 847 writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 848 op.getSuccessors()); 849 } 850 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 851 writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 852 op.getSuccessors()); 853 } 854 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 855 writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(), 856 op.getSuccessors()); 857 } 858 859 //===----------------------------------------------------------------------===// 860 // PDLByteCode 861 //===----------------------------------------------------------------------===// 862 863 PDLByteCode::PDLByteCode(ModuleOp module, 864 llvm::StringMap<PDLConstraintFunction> constraintFns, 865 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 866 Generator generator(module.getContext(), uniquedData, matcherByteCode, 867 rewriterByteCode, patterns, maxValueMemoryIndex, 868 maxTypeRangeCount, maxValueRangeCount, constraintFns, 869 rewriteFns); 870 generator.generate(module); 871 872 // Initialize the external functions. 873 for (auto &it : constraintFns) 874 constraintFunctions.push_back(std::move(it.second)); 875 for (auto &it : rewriteFns) 876 rewriteFunctions.push_back(std::move(it.second)); 877 } 878 879 /// Initialize the given state such that it can be used to execute the current 880 /// bytecode. 881 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 882 state.memory.resize(maxValueMemoryIndex, nullptr); 883 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 884 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 885 state.currentPatternBenefits.reserve(patterns.size()); 886 for (const PDLByteCodePattern &pattern : patterns) 887 state.currentPatternBenefits.push_back(pattern.getBenefit()); 888 } 889 890 //===----------------------------------------------------------------------===// 891 // ByteCode Execution 892 893 namespace { 894 /// This class provides support for executing a bytecode stream. 895 class ByteCodeExecutor { 896 public: 897 ByteCodeExecutor( 898 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 899 MutableArrayRef<TypeRange> typeRangeMemory, 900 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 901 MutableArrayRef<ValueRange> valueRangeMemory, 902 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 903 ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code, 904 ArrayRef<PatternBenefit> currentPatternBenefits, 905 ArrayRef<PDLByteCodePattern> patterns, 906 ArrayRef<PDLConstraintFunction> constraintFunctions, 907 ArrayRef<PDLRewriteFunction> rewriteFunctions) 908 : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory), 909 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 910 valueRangeMemory(valueRangeMemory), 911 allocatedValueRangeMemory(allocatedValueRangeMemory), 912 uniquedMemory(uniquedMemory), code(code), 913 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 914 constraintFunctions(constraintFunctions), 915 rewriteFunctions(rewriteFunctions) {} 916 917 /// Start executing the code at the current bytecode index. `matches` is an 918 /// optional field provided when this function is executed in a matching 919 /// context. 920 void execute(PatternRewriter &rewriter, 921 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 922 Optional<Location> mainRewriteLoc = {}); 923 924 private: 925 /// Internal implementation of executing each of the bytecode commands. 926 void executeApplyConstraint(PatternRewriter &rewriter); 927 void executeApplyRewrite(PatternRewriter &rewriter); 928 void executeAreEqual(); 929 void executeAreRangesEqual(); 930 void executeBranch(); 931 void executeCheckOperandCount(); 932 void executeCheckOperationName(); 933 void executeCheckResultCount(); 934 void executeCheckTypes(); 935 void executeCreateOperation(PatternRewriter &rewriter, 936 Location mainRewriteLoc); 937 void executeCreateTypes(); 938 void executeEraseOp(PatternRewriter &rewriter); 939 void executeGetAttribute(); 940 void executeGetAttributeType(); 941 void executeGetDefiningOp(); 942 void executeGetOperand(unsigned index); 943 void executeGetOperands(); 944 void executeGetResult(unsigned index); 945 void executeGetResults(); 946 void executeGetValueType(); 947 void executeGetValueRangeTypes(); 948 void executeIsNotNull(); 949 void executeRecordMatch(PatternRewriter &rewriter, 950 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 951 void executeReplaceOp(PatternRewriter &rewriter); 952 void executeSwitchAttribute(); 953 void executeSwitchOperandCount(); 954 void executeSwitchOperationName(); 955 void executeSwitchResultCount(); 956 void executeSwitchType(); 957 void executeSwitchTypes(); 958 959 /// Read a value from the bytecode buffer, optionally skipping a certain 960 /// number of prefix values. These methods always update the buffer to point 961 /// to the next field after the read data. 962 template <typename T = ByteCodeField> 963 T read(size_t skipN = 0) { 964 curCodeIt += skipN; 965 return readImpl<T>(); 966 } 967 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 968 969 /// Read a list of values from the bytecode buffer. 970 template <typename ValueT, typename T> 971 void readList(SmallVectorImpl<T> &list) { 972 list.clear(); 973 for (unsigned i = 0, e = read(); i != e; ++i) 974 list.push_back(read<ValueT>()); 975 } 976 977 /// Read a list of values from the bytecode buffer. The values may be encoded 978 /// as either Value or ValueRange elements. 979 void readValueList(SmallVectorImpl<Value> &list) { 980 for (unsigned i = 0, e = read(); i != e; ++i) { 981 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 982 list.push_back(read<Value>()); 983 } else { 984 ValueRange *values = read<ValueRange *>(); 985 list.append(values->begin(), values->end()); 986 } 987 } 988 } 989 990 /// Jump to a specific successor based on a predicate value. 991 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 992 /// Jump to a specific successor based on a destination index. 993 void selectJump(size_t destIndex) { 994 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 995 } 996 997 /// Handle a switch operation with the provided value and cases. 998 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 999 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1000 LLVM_DEBUG({ 1001 llvm::dbgs() << " * Value: " << value << "\n" 1002 << " * Cases: "; 1003 llvm::interleaveComma(cases, llvm::dbgs()); 1004 llvm::dbgs() << "\n"; 1005 }); 1006 1007 // Check to see if the attribute value is within the case list. Jump to 1008 // the correct successor index based on the result. 1009 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1010 if (cmp(*it, value)) 1011 return selectJump(size_t((it - cases.begin()) + 1)); 1012 selectJump(size_t(0)); 1013 } 1014 1015 /// Internal implementation of reading various data types from the bytecode 1016 /// stream. 1017 template <typename T> 1018 const void *readFromMemory() { 1019 size_t index = *curCodeIt++; 1020 1021 // If this type is an SSA value, it can only be stored in non-const memory. 1022 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1023 Value>::value || 1024 index < memory.size()) 1025 return memory[index]; 1026 1027 // Otherwise, if this index is not inbounds it is uniqued. 1028 return uniquedMemory[index - memory.size()]; 1029 } 1030 template <typename T> 1031 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1032 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1033 } 1034 template <typename T> 1035 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1036 T> 1037 readImpl() { 1038 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1039 } 1040 template <typename T> 1041 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1042 switch (read<PDLValue::Kind>()) { 1043 case PDLValue::Kind::Attribute: 1044 return read<Attribute>(); 1045 case PDLValue::Kind::Operation: 1046 return read<Operation *>(); 1047 case PDLValue::Kind::Type: 1048 return read<Type>(); 1049 case PDLValue::Kind::Value: 1050 return read<Value>(); 1051 case PDLValue::Kind::TypeRange: 1052 return read<TypeRange *>(); 1053 case PDLValue::Kind::ValueRange: 1054 return read<ValueRange *>(); 1055 } 1056 llvm_unreachable("unhandled PDLValue::Kind"); 1057 } 1058 template <typename T> 1059 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1060 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1061 "unexpected ByteCode address size"); 1062 ByteCodeAddr result; 1063 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1064 curCodeIt += 2; 1065 return result; 1066 } 1067 template <typename T> 1068 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1069 return *curCodeIt++; 1070 } 1071 template <typename T> 1072 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1073 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1074 } 1075 1076 /// The underlying bytecode buffer. 1077 const ByteCodeField *curCodeIt; 1078 1079 /// The current execution memory. 1080 MutableArrayRef<const void *> memory; 1081 MutableArrayRef<TypeRange> typeRangeMemory; 1082 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1083 MutableArrayRef<ValueRange> valueRangeMemory; 1084 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1085 1086 /// References to ByteCode data necessary for execution. 1087 ArrayRef<const void *> uniquedMemory; 1088 ArrayRef<ByteCodeField> code; 1089 ArrayRef<PatternBenefit> currentPatternBenefits; 1090 ArrayRef<PDLByteCodePattern> patterns; 1091 ArrayRef<PDLConstraintFunction> constraintFunctions; 1092 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1093 }; 1094 1095 /// This class is an instantiation of the PDLResultList that provides access to 1096 /// the returned results. This API is not on `PDLResultList` to avoid 1097 /// overexposing access to information specific solely to the ByteCode. 1098 class ByteCodeRewriteResultList : public PDLResultList { 1099 public: 1100 ByteCodeRewriteResultList(unsigned maxNumResults) 1101 : PDLResultList(maxNumResults) {} 1102 1103 /// Return the list of PDL results. 1104 MutableArrayRef<PDLValue> getResults() { return results; } 1105 1106 /// Return the type ranges allocated by this list. 1107 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1108 return allocatedTypeRanges; 1109 } 1110 1111 /// Return the value ranges allocated by this list. 1112 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1113 return allocatedValueRanges; 1114 } 1115 }; 1116 } // end anonymous namespace 1117 1118 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1119 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1120 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1121 ArrayAttr constParams = read<ArrayAttr>(); 1122 SmallVector<PDLValue, 16> args; 1123 readList<PDLValue>(args); 1124 1125 LLVM_DEBUG({ 1126 llvm::dbgs() << " * Arguments: "; 1127 llvm::interleaveComma(args, llvm::dbgs()); 1128 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1129 }); 1130 1131 // Invoke the constraint and jump to the proper destination. 1132 selectJump(succeeded(constraintFn(args, constParams, rewriter))); 1133 } 1134 1135 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1136 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1137 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1138 ArrayAttr constParams = read<ArrayAttr>(); 1139 SmallVector<PDLValue, 16> args; 1140 readList<PDLValue>(args); 1141 1142 LLVM_DEBUG({ 1143 llvm::dbgs() << " * Arguments: "; 1144 llvm::interleaveComma(args, llvm::dbgs()); 1145 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1146 }); 1147 1148 // Execute the rewrite function. 1149 ByteCodeField numResults = read(); 1150 ByteCodeRewriteResultList results(numResults); 1151 rewriteFn(args, constParams, rewriter, results); 1152 1153 assert(results.getResults().size() == numResults && 1154 "native PDL rewrite function returned unexpected number of results"); 1155 1156 // Store the results in the bytecode memory. 1157 for (PDLValue &result : results.getResults()) { 1158 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1159 1160 // In debug mode we also verify the expected kind of the result. 1161 #ifndef NDEBUG 1162 assert(result.getKind() == read<PDLValue::Kind>() && 1163 "native PDL rewrite function returned an unexpected type of result"); 1164 #endif 1165 1166 // If the result is a range, we need to copy it over to the bytecodes 1167 // range memory. 1168 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1169 unsigned rangeIndex = read(); 1170 typeRangeMemory[rangeIndex] = *typeRange; 1171 memory[read()] = &typeRangeMemory[rangeIndex]; 1172 } else if (Optional<ValueRange> valueRange = 1173 result.dyn_cast<ValueRange>()) { 1174 unsigned rangeIndex = read(); 1175 valueRangeMemory[rangeIndex] = *valueRange; 1176 memory[read()] = &valueRangeMemory[rangeIndex]; 1177 } else { 1178 memory[read()] = result.getAsOpaquePointer(); 1179 } 1180 } 1181 1182 // Copy over any underlying storage allocated for result ranges. 1183 for (auto &it : results.getAllocatedTypeRanges()) 1184 allocatedTypeRangeMemory.push_back(std::move(it)); 1185 for (auto &it : results.getAllocatedValueRanges()) 1186 allocatedValueRangeMemory.push_back(std::move(it)); 1187 } 1188 1189 void ByteCodeExecutor::executeAreEqual() { 1190 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1191 const void *lhs = read<const void *>(); 1192 const void *rhs = read<const void *>(); 1193 1194 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1195 selectJump(lhs == rhs); 1196 } 1197 1198 void ByteCodeExecutor::executeAreRangesEqual() { 1199 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1200 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1201 const void *lhs = read<const void *>(); 1202 const void *rhs = read<const void *>(); 1203 1204 switch (valueKind) { 1205 case PDLValue::Kind::TypeRange: { 1206 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1207 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1208 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1209 selectJump(*lhsRange == *rhsRange); 1210 break; 1211 } 1212 case PDLValue::Kind::ValueRange: { 1213 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1214 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1215 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1216 selectJump(*lhsRange == *rhsRange); 1217 break; 1218 } 1219 default: 1220 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1221 } 1222 } 1223 1224 void ByteCodeExecutor::executeBranch() { 1225 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1226 curCodeIt = &code[read<ByteCodeAddr>()]; 1227 } 1228 1229 void ByteCodeExecutor::executeCheckOperandCount() { 1230 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1231 Operation *op = read<Operation *>(); 1232 uint32_t expectedCount = read<uint32_t>(); 1233 bool compareAtLeast = read(); 1234 1235 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1236 << " * Expected: " << expectedCount << "\n" 1237 << " * Comparator: " 1238 << (compareAtLeast ? ">=" : "==") << "\n"); 1239 if (compareAtLeast) 1240 selectJump(op->getNumOperands() >= expectedCount); 1241 else 1242 selectJump(op->getNumOperands() == expectedCount); 1243 } 1244 1245 void ByteCodeExecutor::executeCheckOperationName() { 1246 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1247 Operation *op = read<Operation *>(); 1248 OperationName expectedName = read<OperationName>(); 1249 1250 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1251 << " * Expected: \"" << expectedName << "\"\n"); 1252 selectJump(op->getName() == expectedName); 1253 } 1254 1255 void ByteCodeExecutor::executeCheckResultCount() { 1256 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1257 Operation *op = read<Operation *>(); 1258 uint32_t expectedCount = read<uint32_t>(); 1259 bool compareAtLeast = read(); 1260 1261 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1262 << " * Expected: " << expectedCount << "\n" 1263 << " * Comparator: " 1264 << (compareAtLeast ? ">=" : "==") << "\n"); 1265 if (compareAtLeast) 1266 selectJump(op->getNumResults() >= expectedCount); 1267 else 1268 selectJump(op->getNumResults() == expectedCount); 1269 } 1270 1271 void ByteCodeExecutor::executeCheckTypes() { 1272 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1273 TypeRange *lhs = read<TypeRange *>(); 1274 Attribute rhs = read<Attribute>(); 1275 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1276 1277 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1278 } 1279 1280 void ByteCodeExecutor::executeCreateTypes() { 1281 LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 1282 unsigned memIndex = read(); 1283 unsigned rangeIndex = read(); 1284 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1285 1286 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1287 1288 // Allocate a buffer for this type range. 1289 llvm::OwningArrayRef<Type> storage(typesAttr.size()); 1290 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 1291 allocatedTypeRangeMemory.emplace_back(std::move(storage)); 1292 1293 // Assign this to the range slot and use the range as the value for the 1294 // memory index. 1295 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 1296 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1297 } 1298 1299 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1300 Location mainRewriteLoc) { 1301 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1302 1303 unsigned memIndex = read(); 1304 OperationState state(mainRewriteLoc, read<OperationName>()); 1305 readValueList(state.operands); 1306 for (unsigned i = 0, e = read(); i != e; ++i) { 1307 StringAttr name = read<StringAttr>(); 1308 if (Attribute attr = read<Attribute>()) 1309 state.addAttribute(name, attr); 1310 } 1311 1312 for (unsigned i = 0, e = read(); i != e; ++i) { 1313 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1314 state.types.push_back(read<Type>()); 1315 continue; 1316 } 1317 1318 // If we find a null range, this signals that the types are infered. 1319 if (TypeRange *resultTypes = read<TypeRange *>()) { 1320 state.types.append(resultTypes->begin(), resultTypes->end()); 1321 continue; 1322 } 1323 1324 // Handle the case where the operation has inferred types. 1325 InferTypeOpInterface::Concept *concept = 1326 state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>(); 1327 1328 // TODO: Handle failure. 1329 state.types.clear(); 1330 if (failed(concept->inferReturnTypes( 1331 state.getContext(), state.location, state.operands, 1332 state.attributes.getDictionary(state.getContext()), state.regions, 1333 state.types))) 1334 return; 1335 break; 1336 } 1337 1338 Operation *resultOp = rewriter.createOperation(state); 1339 memory[memIndex] = resultOp; 1340 1341 LLVM_DEBUG({ 1342 llvm::dbgs() << " * Attributes: " 1343 << state.attributes.getDictionary(state.getContext()) 1344 << "\n * Operands: "; 1345 llvm::interleaveComma(state.operands, llvm::dbgs()); 1346 llvm::dbgs() << "\n * Result Types: "; 1347 llvm::interleaveComma(state.types, llvm::dbgs()); 1348 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1349 }); 1350 } 1351 1352 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1353 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1354 Operation *op = read<Operation *>(); 1355 1356 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1357 rewriter.eraseOp(op); 1358 } 1359 1360 void ByteCodeExecutor::executeGetAttribute() { 1361 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1362 unsigned memIndex = read(); 1363 Operation *op = read<Operation *>(); 1364 StringAttr attrName = read<StringAttr>(); 1365 Attribute attr = op->getAttr(attrName); 1366 1367 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1368 << " * Attribute: " << attrName << "\n" 1369 << " * Result: " << attr << "\n"); 1370 memory[memIndex] = attr.getAsOpaquePointer(); 1371 } 1372 1373 void ByteCodeExecutor::executeGetAttributeType() { 1374 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1375 unsigned memIndex = read(); 1376 Attribute attr = read<Attribute>(); 1377 Type type = attr ? attr.getType() : Type(); 1378 1379 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1380 << " * Result: " << type << "\n"); 1381 memory[memIndex] = type.getAsOpaquePointer(); 1382 } 1383 1384 void ByteCodeExecutor::executeGetDefiningOp() { 1385 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1386 unsigned memIndex = read(); 1387 Operation *op = nullptr; 1388 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1389 Value value = read<Value>(); 1390 if (value) 1391 op = value.getDefiningOp(); 1392 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1393 } else { 1394 ValueRange *values = read<ValueRange *>(); 1395 if (values && !values->empty()) { 1396 op = values->front().getDefiningOp(); 1397 } 1398 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1399 } 1400 1401 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1402 memory[memIndex] = op; 1403 } 1404 1405 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1406 Operation *op = read<Operation *>(); 1407 unsigned memIndex = read(); 1408 Value operand = 1409 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1410 1411 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1412 << " * Index: " << index << "\n" 1413 << " * Result: " << operand << "\n"); 1414 memory[memIndex] = operand.getAsOpaquePointer(); 1415 } 1416 1417 /// This function is the internal implementation of `GetResults` and 1418 /// `GetOperands` that provides support for extracting a value range from the 1419 /// given operation. 1420 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1421 static void * 1422 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1423 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1424 MutableArrayRef<ValueRange> &valueRangeMemory) { 1425 // Check for the sentinel index that signals that all values should be 1426 // returned. 1427 if (index == std::numeric_limits<uint32_t>::max()) { 1428 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1429 // `values` is already the full value range. 1430 1431 // Otherwise, check to see if this operation uses AttrSizedSegments. 1432 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1433 LLVM_DEBUG(llvm::dbgs() 1434 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1435 1436 auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 1437 if (!segmentAttr || segmentAttr.getNumElements() <= index) 1438 return nullptr; 1439 1440 auto segments = segmentAttr.getValues<int32_t>(); 1441 unsigned startIndex = 1442 std::accumulate(segments.begin(), segments.begin() + index, 0); 1443 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1444 1445 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1446 << *std::next(segments.begin(), index) << "]\n"); 1447 1448 // Otherwise, assume this is the last operand group of the operation. 1449 // FIXME: We currently don't support operations with 1450 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1451 // have a way to detect it's presence. 1452 } else if (values.size() >= index) { 1453 LLVM_DEBUG(llvm::dbgs() 1454 << " * Treating values as trailing variadic range\n"); 1455 values = values.drop_front(index); 1456 1457 // If we couldn't detect a way to compute the values, bail out. 1458 } else { 1459 return nullptr; 1460 } 1461 1462 // If the range index is valid, we are returning a range. 1463 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1464 valueRangeMemory[rangeIndex] = values; 1465 return &valueRangeMemory[rangeIndex]; 1466 } 1467 1468 // If a range index wasn't provided, the range is required to be non-variadic. 1469 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1470 } 1471 1472 void ByteCodeExecutor::executeGetOperands() { 1473 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1474 unsigned index = read<uint32_t>(); 1475 Operation *op = read<Operation *>(); 1476 ByteCodeField rangeIndex = read(); 1477 1478 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1479 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1480 valueRangeMemory); 1481 if (!result) 1482 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1483 memory[read()] = result; 1484 } 1485 1486 void ByteCodeExecutor::executeGetResult(unsigned index) { 1487 Operation *op = read<Operation *>(); 1488 unsigned memIndex = read(); 1489 OpResult result = 1490 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1491 1492 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1493 << " * Index: " << index << "\n" 1494 << " * Result: " << result << "\n"); 1495 memory[memIndex] = result.getAsOpaquePointer(); 1496 } 1497 1498 void ByteCodeExecutor::executeGetResults() { 1499 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1500 unsigned index = read<uint32_t>(); 1501 Operation *op = read<Operation *>(); 1502 ByteCodeField rangeIndex = read(); 1503 1504 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1505 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1506 valueRangeMemory); 1507 if (!result) 1508 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1509 memory[read()] = result; 1510 } 1511 1512 void ByteCodeExecutor::executeGetValueType() { 1513 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1514 unsigned memIndex = read(); 1515 Value value = read<Value>(); 1516 Type type = value ? value.getType() : Type(); 1517 1518 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1519 << " * Result: " << type << "\n"); 1520 memory[memIndex] = type.getAsOpaquePointer(); 1521 } 1522 1523 void ByteCodeExecutor::executeGetValueRangeTypes() { 1524 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1525 unsigned memIndex = read(); 1526 unsigned rangeIndex = read(); 1527 ValueRange *values = read<ValueRange *>(); 1528 if (!values) { 1529 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1530 memory[memIndex] = nullptr; 1531 return; 1532 } 1533 1534 LLVM_DEBUG({ 1535 llvm::dbgs() << " * Values (" << values->size() << "): "; 1536 llvm::interleaveComma(*values, llvm::dbgs()); 1537 llvm::dbgs() << "\n * Result: "; 1538 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1539 llvm::dbgs() << "\n"; 1540 }); 1541 typeRangeMemory[rangeIndex] = values->getType(); 1542 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1543 } 1544 1545 void ByteCodeExecutor::executeIsNotNull() { 1546 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1547 const void *value = read<const void *>(); 1548 1549 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1550 selectJump(value != nullptr); 1551 } 1552 1553 void ByteCodeExecutor::executeRecordMatch( 1554 PatternRewriter &rewriter, 1555 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1556 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1557 unsigned patternIndex = read(); 1558 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1559 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1560 1561 // If the benefit of the pattern is impossible, skip the processing of the 1562 // rest of the pattern. 1563 if (benefit.isImpossibleToMatch()) { 1564 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1565 curCodeIt = dest; 1566 return; 1567 } 1568 1569 // Create a fused location containing the locations of each of the 1570 // operations used in the match. This will be used as the location for 1571 // created operations during the rewrite that don't already have an 1572 // explicit location set. 1573 unsigned numMatchLocs = read(); 1574 SmallVector<Location, 4> matchLocs; 1575 matchLocs.reserve(numMatchLocs); 1576 for (unsigned i = 0; i != numMatchLocs; ++i) 1577 matchLocs.push_back(read<Operation *>()->getLoc()); 1578 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1579 1580 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1581 << " * Location: " << matchLoc << "\n"); 1582 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1583 PDLByteCode::MatchResult &match = matches.back(); 1584 1585 // Record all of the inputs to the match. If any of the inputs are ranges, we 1586 // will also need to remap the range pointer to memory stored in the match 1587 // state. 1588 unsigned numInputs = read(); 1589 match.values.reserve(numInputs); 1590 match.typeRangeValues.reserve(numInputs); 1591 match.valueRangeValues.reserve(numInputs); 1592 for (unsigned i = 0; i < numInputs; ++i) { 1593 switch (read<PDLValue::Kind>()) { 1594 case PDLValue::Kind::TypeRange: 1595 match.typeRangeValues.push_back(*read<TypeRange *>()); 1596 match.values.push_back(&match.typeRangeValues.back()); 1597 break; 1598 case PDLValue::Kind::ValueRange: 1599 match.valueRangeValues.push_back(*read<ValueRange *>()); 1600 match.values.push_back(&match.valueRangeValues.back()); 1601 break; 1602 default: 1603 match.values.push_back(read<const void *>()); 1604 break; 1605 } 1606 } 1607 curCodeIt = dest; 1608 } 1609 1610 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1611 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1612 Operation *op = read<Operation *>(); 1613 SmallVector<Value, 16> args; 1614 readValueList(args); 1615 1616 LLVM_DEBUG({ 1617 llvm::dbgs() << " * Operation: " << *op << "\n" 1618 << " * Values: "; 1619 llvm::interleaveComma(args, llvm::dbgs()); 1620 llvm::dbgs() << "\n"; 1621 }); 1622 rewriter.replaceOp(op, args); 1623 } 1624 1625 void ByteCodeExecutor::executeSwitchAttribute() { 1626 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1627 Attribute value = read<Attribute>(); 1628 ArrayAttr cases = read<ArrayAttr>(); 1629 handleSwitch(value, cases); 1630 } 1631 1632 void ByteCodeExecutor::executeSwitchOperandCount() { 1633 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1634 Operation *op = read<Operation *>(); 1635 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1636 1637 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1638 handleSwitch(op->getNumOperands(), cases); 1639 } 1640 1641 void ByteCodeExecutor::executeSwitchOperationName() { 1642 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1643 OperationName value = read<Operation *>()->getName(); 1644 size_t caseCount = read(); 1645 1646 // The operation names are stored in-line, so to print them out for 1647 // debugging purposes we need to read the array before executing the 1648 // switch so that we can display all of the possible values. 1649 LLVM_DEBUG({ 1650 const ByteCodeField *prevCodeIt = curCodeIt; 1651 llvm::dbgs() << " * Value: " << value << "\n" 1652 << " * Cases: "; 1653 llvm::interleaveComma( 1654 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1655 [&](size_t) { return read<OperationName>(); }), 1656 llvm::dbgs()); 1657 llvm::dbgs() << "\n"; 1658 curCodeIt = prevCodeIt; 1659 }); 1660 1661 // Try to find the switch value within any of the cases. 1662 for (size_t i = 0; i != caseCount; ++i) { 1663 if (read<OperationName>() == value) { 1664 curCodeIt += (caseCount - i - 1); 1665 return selectJump(i + 1); 1666 } 1667 } 1668 selectJump(size_t(0)); 1669 } 1670 1671 void ByteCodeExecutor::executeSwitchResultCount() { 1672 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1673 Operation *op = read<Operation *>(); 1674 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1675 1676 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1677 handleSwitch(op->getNumResults(), cases); 1678 } 1679 1680 void ByteCodeExecutor::executeSwitchType() { 1681 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 1682 Type value = read<Type>(); 1683 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 1684 handleSwitch(value, cases); 1685 } 1686 1687 void ByteCodeExecutor::executeSwitchTypes() { 1688 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 1689 TypeRange *value = read<TypeRange *>(); 1690 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 1691 if (!value) { 1692 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 1693 return selectJump(size_t(0)); 1694 } 1695 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 1696 return value == caseValue.getAsValueRange<TypeAttr>(); 1697 }); 1698 } 1699 1700 void ByteCodeExecutor::execute( 1701 PatternRewriter &rewriter, 1702 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 1703 Optional<Location> mainRewriteLoc) { 1704 while (true) { 1705 OpCode opCode = static_cast<OpCode>(read()); 1706 switch (opCode) { 1707 case ApplyConstraint: 1708 executeApplyConstraint(rewriter); 1709 break; 1710 case ApplyRewrite: 1711 executeApplyRewrite(rewriter); 1712 break; 1713 case AreEqual: 1714 executeAreEqual(); 1715 break; 1716 case AreRangesEqual: 1717 executeAreRangesEqual(); 1718 break; 1719 case Branch: 1720 executeBranch(); 1721 break; 1722 case CheckOperandCount: 1723 executeCheckOperandCount(); 1724 break; 1725 case CheckOperationName: 1726 executeCheckOperationName(); 1727 break; 1728 case CheckResultCount: 1729 executeCheckResultCount(); 1730 break; 1731 case CheckTypes: 1732 executeCheckTypes(); 1733 break; 1734 case CreateOperation: 1735 executeCreateOperation(rewriter, *mainRewriteLoc); 1736 break; 1737 case CreateTypes: 1738 executeCreateTypes(); 1739 break; 1740 case EraseOp: 1741 executeEraseOp(rewriter); 1742 break; 1743 case Finalize: 1744 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); 1745 return; 1746 case GetAttribute: 1747 executeGetAttribute(); 1748 break; 1749 case GetAttributeType: 1750 executeGetAttributeType(); 1751 break; 1752 case GetDefiningOp: 1753 executeGetDefiningOp(); 1754 break; 1755 case GetOperand0: 1756 case GetOperand1: 1757 case GetOperand2: 1758 case GetOperand3: { 1759 unsigned index = opCode - GetOperand0; 1760 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 1761 executeGetOperand(index); 1762 break; 1763 } 1764 case GetOperandN: 1765 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 1766 executeGetOperand(read<uint32_t>()); 1767 break; 1768 case GetOperands: 1769 executeGetOperands(); 1770 break; 1771 case GetResult0: 1772 case GetResult1: 1773 case GetResult2: 1774 case GetResult3: { 1775 unsigned index = opCode - GetResult0; 1776 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 1777 executeGetResult(index); 1778 break; 1779 } 1780 case GetResultN: 1781 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 1782 executeGetResult(read<uint32_t>()); 1783 break; 1784 case GetResults: 1785 executeGetResults(); 1786 break; 1787 case GetValueType: 1788 executeGetValueType(); 1789 break; 1790 case GetValueRangeTypes: 1791 executeGetValueRangeTypes(); 1792 break; 1793 case IsNotNull: 1794 executeIsNotNull(); 1795 break; 1796 case RecordMatch: 1797 assert(matches && 1798 "expected matches to be provided when executing the matcher"); 1799 executeRecordMatch(rewriter, *matches); 1800 break; 1801 case ReplaceOp: 1802 executeReplaceOp(rewriter); 1803 break; 1804 case SwitchAttribute: 1805 executeSwitchAttribute(); 1806 break; 1807 case SwitchOperandCount: 1808 executeSwitchOperandCount(); 1809 break; 1810 case SwitchOperationName: 1811 executeSwitchOperationName(); 1812 break; 1813 case SwitchResultCount: 1814 executeSwitchResultCount(); 1815 break; 1816 case SwitchType: 1817 executeSwitchType(); 1818 break; 1819 case SwitchTypes: 1820 executeSwitchTypes(); 1821 break; 1822 } 1823 LLVM_DEBUG(llvm::dbgs() << "\n"); 1824 } 1825 } 1826 1827 /// Run the pattern matcher on the given root operation, collecting the matched 1828 /// patterns in `matches`. 1829 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 1830 SmallVectorImpl<MatchResult> &matches, 1831 PDLByteCodeMutableState &state) const { 1832 // The first memory slot is always the root operation. 1833 state.memory[0] = op; 1834 1835 // The matcher function always starts at code address 0. 1836 ByteCodeExecutor executor( 1837 matcherByteCode.data(), state.memory, state.typeRangeMemory, 1838 state.allocatedTypeRangeMemory, state.valueRangeMemory, 1839 state.allocatedValueRangeMemory, uniquedData, matcherByteCode, 1840 state.currentPatternBenefits, patterns, constraintFunctions, 1841 rewriteFunctions); 1842 executor.execute(rewriter, &matches); 1843 1844 // Order the found matches by benefit. 1845 std::stable_sort(matches.begin(), matches.end(), 1846 [](const MatchResult &lhs, const MatchResult &rhs) { 1847 return lhs.benefit > rhs.benefit; 1848 }); 1849 } 1850 1851 /// Run the rewriter of the given pattern on the root operation `op`. 1852 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 1853 PDLByteCodeMutableState &state) const { 1854 // The arguments of the rewrite function are stored at the start of the 1855 // memory buffer. 1856 llvm::copy(match.values, state.memory.begin()); 1857 1858 ByteCodeExecutor executor( 1859 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 1860 state.typeRangeMemory, state.allocatedTypeRangeMemory, 1861 state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData, 1862 rewriterByteCode, state.currentPatternBenefits, patterns, 1863 constraintFunctions, rewriteFunctions); 1864 executor.execute(rewriter, /*matches=*/nullptr, match.location); 1865 } 1866