1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/Format.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <numeric> 26 27 #define DEBUG_TYPE "pdl-bytecode" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37 ByteCodeAddr rewriterAddr) { 38 SmallVector<StringRef, 8> generatedOps; 39 if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 40 generatedOps = 41 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42 43 PatternBenefit benefit = matchOp.benefit(); 44 MLIRContext *ctx = matchOp.getContext(); 45 46 // Check to see if this is pattern matches a specific operation type. 47 if (Optional<StringRef> rootKind = matchOp.rootKind()) 48 return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, 49 generatedOps); 50 return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, 51 generatedOps); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // PDLByteCodeMutableState 56 //===----------------------------------------------------------------------===// 57 58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59 /// to the position of the pattern within the range returned by 60 /// `PDLByteCode::getPatterns`. 61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62 PatternBenefit benefit) { 63 currentPatternBenefits[patternIndex] = benefit; 64 } 65 66 /// Cleanup any allocated state after a full match/rewrite has been completed. 67 /// This method should be called irregardless of whether the match+rewrite was a 68 /// success or not. 69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 70 allocatedTypeRangeMemory.clear(); 71 allocatedValueRangeMemory.clear(); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Bytecode OpCodes 76 //===----------------------------------------------------------------------===// 77 78 namespace { 79 enum OpCode : ByteCodeField { 80 /// Apply an externally registered constraint. 81 ApplyConstraint, 82 /// Apply an externally registered rewrite. 83 ApplyRewrite, 84 /// Check if two generic values are equal. 85 AreEqual, 86 /// Check if two ranges are equal. 87 AreRangesEqual, 88 /// Unconditional branch. 89 Branch, 90 /// Compare the operand count of an operation with a constant. 91 CheckOperandCount, 92 /// Compare the name of an operation with a constant. 93 CheckOperationName, 94 /// Compare the result count of an operation with a constant. 95 CheckResultCount, 96 /// Compare a range of types to a constant range of types. 97 CheckTypes, 98 /// Continue to the next iteration of a loop. 99 Continue, 100 /// Create an operation. 101 CreateOperation, 102 /// Create a range of types. 103 CreateTypes, 104 /// Erase an operation. 105 EraseOp, 106 /// Extract the op from a range at the specified index. 107 ExtractOp, 108 /// Extract the type from a range at the specified index. 109 ExtractType, 110 /// Extract the value from a range at the specified index. 111 ExtractValue, 112 /// Terminate a matcher or rewrite sequence. 113 Finalize, 114 /// Iterate over a range of values. 115 ForEach, 116 /// Get a specific attribute of an operation. 117 GetAttribute, 118 /// Get the type of an attribute. 119 GetAttributeType, 120 /// Get the defining operation of a value. 121 GetDefiningOp, 122 /// Get a specific operand of an operation. 123 GetOperand0, 124 GetOperand1, 125 GetOperand2, 126 GetOperand3, 127 GetOperandN, 128 /// Get a specific operand group of an operation. 129 GetOperands, 130 /// Get a specific result of an operation. 131 GetResult0, 132 GetResult1, 133 GetResult2, 134 GetResult3, 135 GetResultN, 136 /// Get a specific result group of an operation. 137 GetResults, 138 /// Get the users of a value or a range of values. 139 GetUsers, 140 /// Get the type of a value. 141 GetValueType, 142 /// Get the types of a value range. 143 GetValueRangeTypes, 144 /// Check if a generic value is not null. 145 IsNotNull, 146 /// Record a successful pattern match. 147 RecordMatch, 148 /// Replace an operation. 149 ReplaceOp, 150 /// Compare an attribute with a set of constants. 151 SwitchAttribute, 152 /// Compare the operand count of an operation with a set of constants. 153 SwitchOperandCount, 154 /// Compare the name of an operation with a set of constants. 155 SwitchOperationName, 156 /// Compare the result count of an operation with a set of constants. 157 SwitchResultCount, 158 /// Compare a type with a set of constants. 159 SwitchType, 160 /// Compare a range of types with a set of constants. 161 SwitchTypes, 162 }; 163 } // end anonymous namespace 164 165 //===----------------------------------------------------------------------===// 166 // ByteCode Generation 167 //===----------------------------------------------------------------------===// 168 169 //===----------------------------------------------------------------------===// 170 // Generator 171 172 namespace { 173 struct ByteCodeLiveRange; 174 struct ByteCodeWriter; 175 176 /// Check if the given class `T` can be converted to an opaque pointer. 177 template <typename T, typename... Args> 178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 179 180 /// This class represents the main generator for the pattern bytecode. 181 class Generator { 182 public: 183 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 184 SmallVectorImpl<ByteCodeField> &matcherByteCode, 185 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 186 SmallVectorImpl<PDLByteCodePattern> &patterns, 187 ByteCodeField &maxValueMemoryIndex, 188 ByteCodeField &maxOpRangeMemoryIndex, 189 ByteCodeField &maxTypeRangeMemoryIndex, 190 ByteCodeField &maxValueRangeMemoryIndex, 191 ByteCodeField &maxLoopLevel, 192 llvm::StringMap<PDLConstraintFunction> &constraintFns, 193 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 194 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 195 rewriterByteCode(rewriterByteCode), patterns(patterns), 196 maxValueMemoryIndex(maxValueMemoryIndex), 197 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), 198 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 199 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), 200 maxLoopLevel(maxLoopLevel) { 201 for (auto it : llvm::enumerate(constraintFns)) 202 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 203 for (auto it : llvm::enumerate(rewriteFns)) 204 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 205 } 206 207 /// Generate the bytecode for the given PDL interpreter module. 208 void generate(ModuleOp module); 209 210 /// Return the memory index to use for the given value. 211 ByteCodeField &getMemIndex(Value value) { 212 assert(valueToMemIndex.count(value) && 213 "expected memory index to be assigned"); 214 return valueToMemIndex[value]; 215 } 216 217 /// Return the range memory index used to store the given range value. 218 ByteCodeField &getRangeStorageIndex(Value value) { 219 assert(valueToRangeIndex.count(value) && 220 "expected range index to be assigned"); 221 return valueToRangeIndex[value]; 222 } 223 224 /// Return an index to use when referring to the given data that is uniqued in 225 /// the MLIR context. 226 template <typename T> 227 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 228 getMemIndex(T val) { 229 const void *opaqueVal = val.getAsOpaquePointer(); 230 231 // Get or insert a reference to this value. 232 auto it = uniquedDataToMemIndex.try_emplace( 233 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 234 if (it.second) 235 uniquedData.push_back(opaqueVal); 236 return it.first->second; 237 } 238 239 private: 240 /// Allocate memory indices for the results of operations within the matcher 241 /// and rewriters. 242 void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 243 244 /// Generate the bytecode for the given operation. 245 void generate(Region *region, ByteCodeWriter &writer); 246 void generate(Operation *op, ByteCodeWriter &writer); 247 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 248 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 249 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 250 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 251 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 252 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 253 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 254 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 255 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 256 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 257 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); 258 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 259 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 260 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 261 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 262 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 263 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); 264 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 265 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); 266 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 267 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 268 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 269 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 270 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 271 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 272 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 273 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); 274 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 275 void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); 276 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 277 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 278 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 279 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 280 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 281 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 282 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 283 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 284 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 285 286 /// Mapping from value to its corresponding memory index. 287 DenseMap<Value, ByteCodeField> valueToMemIndex; 288 289 /// Mapping from a range value to its corresponding range storage index. 290 DenseMap<Value, ByteCodeField> valueToRangeIndex; 291 292 /// Mapping from the name of an externally registered rewrite to its index in 293 /// the bytecode registry. 294 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 295 296 /// Mapping from the name of an externally registered constraint to its index 297 /// in the bytecode registry. 298 llvm::StringMap<ByteCodeField> constraintToMemIndex; 299 300 /// Mapping from rewriter function name to the bytecode address of the 301 /// rewriter function in byte. 302 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 303 304 /// Mapping from a uniqued storage object to its memory index within 305 /// `uniquedData`. 306 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 307 308 /// The current level of the foreach loop. 309 ByteCodeField curLoopLevel = 0; 310 311 /// The current MLIR context. 312 MLIRContext *ctx; 313 314 /// Mapping from block to its address. 315 DenseMap<Block *, ByteCodeAddr> blockToAddr; 316 317 /// Data of the ByteCode class to be populated. 318 std::vector<const void *> &uniquedData; 319 SmallVectorImpl<ByteCodeField> &matcherByteCode; 320 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 321 SmallVectorImpl<PDLByteCodePattern> &patterns; 322 ByteCodeField &maxValueMemoryIndex; 323 ByteCodeField &maxOpRangeMemoryIndex; 324 ByteCodeField &maxTypeRangeMemoryIndex; 325 ByteCodeField &maxValueRangeMemoryIndex; 326 ByteCodeField &maxLoopLevel; 327 }; 328 329 /// This class provides utilities for writing a bytecode stream. 330 struct ByteCodeWriter { 331 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 332 : bytecode(bytecode), generator(generator) {} 333 334 /// Append a field to the bytecode. 335 void append(ByteCodeField field) { bytecode.push_back(field); } 336 void append(OpCode opCode) { bytecode.push_back(opCode); } 337 338 /// Append an address to the bytecode. 339 void append(ByteCodeAddr field) { 340 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 341 "unexpected ByteCode address size"); 342 343 ByteCodeField fieldParts[2]; 344 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 345 bytecode.append({fieldParts[0], fieldParts[1]}); 346 } 347 348 /// Append a single successor to the bytecode, the exact address will need to 349 /// be resolved later. 350 void append(Block *successor) { 351 // Add back a reference to the successor so that the address can be resolved 352 // later. 353 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 354 append(ByteCodeAddr(0)); 355 } 356 357 /// Append a successor range to the bytecode, the exact address will need to 358 /// be resolved later. 359 void append(SuccessorRange successors) { 360 for (Block *successor : successors) 361 append(successor); 362 } 363 364 /// Append a range of values that will be read as generic PDLValues. 365 void appendPDLValueList(OperandRange values) { 366 bytecode.push_back(values.size()); 367 for (Value value : values) 368 appendPDLValue(value); 369 } 370 371 /// Append a value as a PDLValue. 372 void appendPDLValue(Value value) { 373 appendPDLValueKind(value); 374 append(value); 375 } 376 377 /// Append the PDLValue::Kind of the given value. 378 void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } 379 380 /// Append the PDLValue::Kind of the given type. 381 void appendPDLValueKind(Type type) { 382 PDLValue::Kind kind = 383 TypeSwitch<Type, PDLValue::Kind>(type) 384 .Case<pdl::AttributeType>( 385 [](Type) { return PDLValue::Kind::Attribute; }) 386 .Case<pdl::OperationType>( 387 [](Type) { return PDLValue::Kind::Operation; }) 388 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 389 if (rangeTy.getElementType().isa<pdl::TypeType>()) 390 return PDLValue::Kind::TypeRange; 391 return PDLValue::Kind::ValueRange; 392 }) 393 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 394 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 395 bytecode.push_back(static_cast<ByteCodeField>(kind)); 396 } 397 398 /// Append a value that will be stored in a memory slot and not inline within 399 /// the bytecode. 400 template <typename T> 401 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 402 std::is_pointer<T>::value> 403 append(T value) { 404 bytecode.push_back(generator.getMemIndex(value)); 405 } 406 407 /// Append a range of values. 408 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 409 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 410 append(T range) { 411 bytecode.push_back(llvm::size(range)); 412 for (auto it : range) 413 append(it); 414 } 415 416 /// Append a variadic number of fields to the bytecode. 417 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 418 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 419 append(field); 420 append(field2, fields...); 421 } 422 423 /// Successor references in the bytecode that have yet to be resolved. 424 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 425 426 /// The underlying bytecode buffer. 427 SmallVectorImpl<ByteCodeField> &bytecode; 428 429 /// The main generator producing PDL. 430 Generator &generator; 431 }; 432 433 /// This class represents a live range of PDL Interpreter values, containing 434 /// information about when values are live within a match/rewrite. 435 struct ByteCodeLiveRange { 436 using Set = llvm::IntervalMap<uint64_t, char, 16>; 437 using Allocator = Set::Allocator; 438 439 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} 440 441 /// Union this live range with the one provided. 442 void unionWith(const ByteCodeLiveRange &rhs) { 443 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; 444 ++it) 445 liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); 446 } 447 448 /// Returns true if this range overlaps with the one provided. 449 bool overlaps(const ByteCodeLiveRange &rhs) const { 450 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness) 451 .valid(); 452 } 453 454 /// A map representing the ranges of the match/rewrite that a value is live in 455 /// the interpreter. 456 /// 457 /// We use std::unique_ptr here, because IntervalMap does not provide a 458 /// correct copy or move constructor. We can eliminate the pointer once 459 /// https://reviews.llvm.org/D113240 lands. 460 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness; 461 462 /// The operation range storage index for this range. 463 Optional<unsigned> opRangeIndex; 464 465 /// The type range storage index for this range. 466 Optional<unsigned> typeRangeIndex; 467 468 /// The value range storage index for this range. 469 Optional<unsigned> valueRangeIndex; 470 }; 471 } // end anonymous namespace 472 473 void Generator::generate(ModuleOp module) { 474 FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 475 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 476 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 477 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 478 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 479 480 // Allocate memory indices for the results of operations within the matcher 481 // and rewriters. 482 allocateMemoryIndices(matcherFunc, rewriterModule); 483 484 // Generate code for the rewriter functions. 485 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 486 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 487 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 488 for (Operation &op : rewriterFunc.getOps()) 489 generate(&op, rewriterByteCodeWriter); 490 } 491 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 492 "unexpected branches in rewriter function"); 493 494 // Generate code for the matcher function. 495 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 496 generate(&matcherFunc.getBody(), matcherByteCodeWriter); 497 498 // Resolve successor references in the matcher. 499 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 500 ByteCodeAddr addr = blockToAddr[it.first]; 501 for (unsigned offsetToFix : it.second) 502 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 503 } 504 } 505 506 void Generator::allocateMemoryIndices(FuncOp matcherFunc, 507 ModuleOp rewriterModule) { 508 // Rewriters use simplistic allocation scheme that simply assigns an index to 509 // each result. 510 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 511 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 512 auto processRewriterValue = [&](Value val) { 513 valueToMemIndex.try_emplace(val, index++); 514 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 515 Type elementTy = rangeType.getElementType(); 516 if (elementTy.isa<pdl::TypeType>()) 517 valueToRangeIndex.try_emplace(val, typeRangeIndex++); 518 else if (elementTy.isa<pdl::ValueType>()) 519 valueToRangeIndex.try_emplace(val, valueRangeIndex++); 520 } 521 }; 522 523 for (BlockArgument arg : rewriterFunc.getArguments()) 524 processRewriterValue(arg); 525 rewriterFunc.getBody().walk([&](Operation *op) { 526 for (Value result : op->getResults()) 527 processRewriterValue(result); 528 }); 529 if (index > maxValueMemoryIndex) 530 maxValueMemoryIndex = index; 531 if (typeRangeIndex > maxTypeRangeMemoryIndex) 532 maxTypeRangeMemoryIndex = typeRangeIndex; 533 if (valueRangeIndex > maxValueRangeMemoryIndex) 534 maxValueRangeMemoryIndex = valueRangeIndex; 535 } 536 537 // The matcher function uses a more sophisticated numbering that tries to 538 // minimize the number of memory indices assigned. This is done by determining 539 // a live range of the values within the matcher, then the allocation is just 540 // finding the minimal number of overlapping live ranges. This is essentially 541 // a simplified form of register allocation where we don't necessarily have a 542 // limited number of registers, but we still want to minimize the number used. 543 DenseMap<Operation *, unsigned> opToIndex; 544 matcherFunc.getBody().walk([&](Operation *op) { 545 opToIndex.insert(std::make_pair(op, opToIndex.size())); 546 }); 547 548 // Liveness info for each of the defs within the matcher. 549 ByteCodeLiveRange::Allocator allocator; 550 DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 551 552 // Assign the root operation being matched to slot 0. 553 BlockArgument rootOpArg = matcherFunc.getArgument(0); 554 valueToMemIndex[rootOpArg] = 0; 555 556 // Walk each of the blocks, computing the def interval that the value is used. 557 Liveness matcherLiveness(matcherFunc); 558 matcherFunc->walk([&](Block *block) { 559 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); 560 assert(info && "expected liveness info for block"); 561 auto processValue = [&](Value value, Operation *firstUseOrDef) { 562 // We don't need to process the root op argument, this value is always 563 // assigned to the first memory slot. 564 if (value == rootOpArg) 565 return; 566 567 // Set indices for the range of this block that the value is used. 568 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 569 defRangeIt->second.liveness->insert( 570 opToIndex[firstUseOrDef], 571 opToIndex[info->getEndOperation(value, firstUseOrDef)], 572 /*dummyValue*/ 0); 573 574 // Check to see if this value is a range type. 575 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 576 Type eleType = rangeTy.getElementType(); 577 if (eleType.isa<pdl::OperationType>()) 578 defRangeIt->second.opRangeIndex = 0; 579 else if (eleType.isa<pdl::TypeType>()) 580 defRangeIt->second.typeRangeIndex = 0; 581 else if (eleType.isa<pdl::ValueType>()) 582 defRangeIt->second.valueRangeIndex = 0; 583 } 584 }; 585 586 // Process the live-ins of this block. 587 for (Value liveIn : info->in()) { 588 // Only process the value if it has been defined in the current region. 589 // Other values that span across pdl_interp.foreach will be added higher 590 // up. This ensures that the we keep them alive for the entire duration 591 // of the loop. 592 if (liveIn.getParentRegion() == block->getParent()) 593 processValue(liveIn, &block->front()); 594 } 595 596 // Process the block arguments for the entry block (those are not live-in). 597 if (block->isEntryBlock()) { 598 for (Value argument : block->getArguments()) 599 processValue(argument, &block->front()); 600 } 601 602 // Process any new defs within this block. 603 for (Operation &op : *block) 604 for (Value result : op.getResults()) 605 processValue(result, &op); 606 }); 607 608 // Greedily allocate memory slots using the computed def live ranges. 609 std::vector<ByteCodeLiveRange> allocatedIndices; 610 611 // The number of memory indices currently allocated (and its next value). 612 // Recall that the root gets allocated memory index 0. 613 ByteCodeField numIndices = 1; 614 615 // The number of memory ranges of various types (and their next values). 616 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; 617 618 for (auto &defIt : valueDefRanges) { 619 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 620 ByteCodeLiveRange &defRange = defIt.second; 621 622 // Try to allocate to an existing index. 623 for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { 624 ByteCodeLiveRange &existingRange = existingIndexIt.value(); 625 if (!defRange.overlaps(existingRange)) { 626 existingRange.unionWith(defRange); 627 memIndex = existingIndexIt.index() + 1; 628 629 if (defRange.opRangeIndex) { 630 if (!existingRange.opRangeIndex) 631 existingRange.opRangeIndex = numOpRanges++; 632 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; 633 } else if (defRange.typeRangeIndex) { 634 if (!existingRange.typeRangeIndex) 635 existingRange.typeRangeIndex = numTypeRanges++; 636 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 637 } else if (defRange.valueRangeIndex) { 638 if (!existingRange.valueRangeIndex) 639 existingRange.valueRangeIndex = numValueRanges++; 640 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 641 } 642 break; 643 } 644 } 645 646 // If no existing index could be used, add a new one. 647 if (memIndex == 0) { 648 allocatedIndices.emplace_back(allocator); 649 ByteCodeLiveRange &newRange = allocatedIndices.back(); 650 newRange.unionWith(defRange); 651 652 // Allocate an index for op/type/value ranges. 653 if (defRange.opRangeIndex) { 654 newRange.opRangeIndex = numOpRanges; 655 valueToRangeIndex[defIt.first] = numOpRanges++; 656 } else if (defRange.typeRangeIndex) { 657 newRange.typeRangeIndex = numTypeRanges; 658 valueToRangeIndex[defIt.first] = numTypeRanges++; 659 } else if (defRange.valueRangeIndex) { 660 newRange.valueRangeIndex = numValueRanges; 661 valueToRangeIndex[defIt.first] = numValueRanges++; 662 } 663 664 memIndex = allocatedIndices.size(); 665 ++numIndices; 666 } 667 } 668 669 // Print the index usage and ensure that we did not run out of index space. 670 LLVM_DEBUG({ 671 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " 672 << "(down from initial " << valueDefRanges.size() << ").\n"; 673 }); 674 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && 675 "Ran out of memory for allocated indices"); 676 677 // Update the max number of indices. 678 if (numIndices > maxValueMemoryIndex) 679 maxValueMemoryIndex = numIndices; 680 if (numOpRanges > maxOpRangeMemoryIndex) 681 maxOpRangeMemoryIndex = numOpRanges; 682 if (numTypeRanges > maxTypeRangeMemoryIndex) 683 maxTypeRangeMemoryIndex = numTypeRanges; 684 if (numValueRanges > maxValueRangeMemoryIndex) 685 maxValueRangeMemoryIndex = numValueRanges; 686 } 687 688 void Generator::generate(Region *region, ByteCodeWriter &writer) { 689 llvm::ReversePostOrderTraversal<Region *> rpot(region); 690 for (Block *block : rpot) { 691 // Keep track of where this block begins within the matcher function. 692 blockToAddr.try_emplace(block, matcherByteCode.size()); 693 for (Operation &op : *block) 694 generate(&op, writer); 695 } 696 } 697 698 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 699 TypeSwitch<Operation *>(op) 700 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 701 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 702 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 703 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 704 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 705 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, 706 pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, 707 pdl_interp::CreateTypesOp, pdl_interp::EraseOp, 708 pdl_interp::ExtractOp, pdl_interp::FinalizeOp, 709 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, 710 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 711 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, 712 pdl_interp::GetResultOp, pdl_interp::GetResultsOp, 713 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, 714 pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, 715 pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, 716 pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, 717 pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, 718 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 719 [&](auto interpOp) { this->generate(interpOp, writer); }) 720 .Default([](Operation *) { 721 llvm_unreachable("unknown `pdl_interp` operation"); 722 }); 723 } 724 725 void Generator::generate(pdl_interp::ApplyConstraintOp op, 726 ByteCodeWriter &writer) { 727 assert(constraintToMemIndex.count(op.name()) && 728 "expected index for constraint function"); 729 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 730 op.constParamsAttr()); 731 writer.appendPDLValueList(op.args()); 732 writer.append(op.getSuccessors()); 733 } 734 void Generator::generate(pdl_interp::ApplyRewriteOp op, 735 ByteCodeWriter &writer) { 736 assert(externalRewriterToMemIndex.count(op.name()) && 737 "expected index for rewrite function"); 738 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 739 op.constParamsAttr()); 740 writer.appendPDLValueList(op.args()); 741 742 ResultRange results = op.results(); 743 writer.append(ByteCodeField(results.size())); 744 for (Value result : results) { 745 // In debug mode we also record the expected kind of the result, so that we 746 // can provide extra verification of the native rewrite function. 747 #ifndef NDEBUG 748 writer.appendPDLValueKind(result); 749 #endif 750 751 // Range results also need to append the range storage index. 752 if (result.getType().isa<pdl::RangeType>()) 753 writer.append(getRangeStorageIndex(result)); 754 writer.append(result); 755 } 756 } 757 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 758 Value lhs = op.lhs(); 759 if (lhs.getType().isa<pdl::RangeType>()) { 760 writer.append(OpCode::AreRangesEqual); 761 writer.appendPDLValueKind(lhs); 762 writer.append(op.lhs(), op.rhs(), op.getSuccessors()); 763 return; 764 } 765 766 writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors()); 767 } 768 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 769 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 770 } 771 void Generator::generate(pdl_interp::CheckAttributeOp op, 772 ByteCodeWriter &writer) { 773 writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 774 op.getSuccessors()); 775 } 776 void Generator::generate(pdl_interp::CheckOperandCountOp op, 777 ByteCodeWriter &writer) { 778 writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 779 static_cast<ByteCodeField>(op.compareAtLeast()), 780 op.getSuccessors()); 781 } 782 void Generator::generate(pdl_interp::CheckOperationNameOp op, 783 ByteCodeWriter &writer) { 784 writer.append(OpCode::CheckOperationName, op.operation(), 785 OperationName(op.name(), ctx), op.getSuccessors()); 786 } 787 void Generator::generate(pdl_interp::CheckResultCountOp op, 788 ByteCodeWriter &writer) { 789 writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 790 static_cast<ByteCodeField>(op.compareAtLeast()), 791 op.getSuccessors()); 792 } 793 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 794 writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 795 } 796 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 797 writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); 798 } 799 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { 800 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); 801 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); 802 } 803 void Generator::generate(pdl_interp::CreateAttributeOp op, 804 ByteCodeWriter &writer) { 805 // Simply repoint the memory index of the result to the constant. 806 getMemIndex(op.attribute()) = getMemIndex(op.value()); 807 } 808 void Generator::generate(pdl_interp::CreateOperationOp op, 809 ByteCodeWriter &writer) { 810 writer.append(OpCode::CreateOperation, op.operation(), 811 OperationName(op.name(), ctx)); 812 writer.appendPDLValueList(op.operands()); 813 814 // Add the attributes. 815 OperandRange attributes = op.attributes(); 816 writer.append(static_cast<ByteCodeField>(attributes.size())); 817 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) 818 writer.append(std::get<0>(it), std::get<1>(it)); 819 writer.appendPDLValueList(op.types()); 820 } 821 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 822 // Simply repoint the memory index of the result to the constant. 823 getMemIndex(op.result()) = getMemIndex(op.value()); 824 } 825 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 826 writer.append(OpCode::CreateTypes, op.result(), 827 getRangeStorageIndex(op.result()), op.value()); 828 } 829 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 830 writer.append(OpCode::EraseOp, op.operation()); 831 } 832 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { 833 OpCode opCode = 834 TypeSwitch<Type, OpCode>(op.result().getType()) 835 .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) 836 .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) 837 .Case([](pdl::TypeType) { return OpCode::ExtractType; }) 838 .Default([](Type) -> OpCode { 839 llvm_unreachable("unsupported element type"); 840 }); 841 writer.append(opCode, op.range(), op.index(), op.result()); 842 } 843 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 844 writer.append(OpCode::Finalize); 845 } 846 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { 847 BlockArgument arg = op.getLoopVariable(); 848 writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg); 849 writer.appendPDLValueKind(arg.getType()); 850 writer.append(curLoopLevel, op.successor()); 851 ++curLoopLevel; 852 if (curLoopLevel > maxLoopLevel) 853 maxLoopLevel = curLoopLevel; 854 generate(&op.region(), writer); 855 --curLoopLevel; 856 } 857 void Generator::generate(pdl_interp::GetAttributeOp op, 858 ByteCodeWriter &writer) { 859 writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 860 op.nameAttr()); 861 } 862 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 863 ByteCodeWriter &writer) { 864 writer.append(OpCode::GetAttributeType, op.result(), op.value()); 865 } 866 void Generator::generate(pdl_interp::GetDefiningOpOp op, 867 ByteCodeWriter &writer) { 868 writer.append(OpCode::GetDefiningOp, op.operation()); 869 writer.appendPDLValue(op.value()); 870 } 871 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 872 uint32_t index = op.index(); 873 if (index < 4) 874 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 875 else 876 writer.append(OpCode::GetOperandN, index); 877 writer.append(op.operation(), op.value()); 878 } 879 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 880 Value result = op.value(); 881 Optional<uint32_t> index = op.index(); 882 writer.append(OpCode::GetOperands, 883 index.getValueOr(std::numeric_limits<uint32_t>::max()), 884 op.operation()); 885 if (result.getType().isa<pdl::RangeType>()) 886 writer.append(getRangeStorageIndex(result)); 887 else 888 writer.append(std::numeric_limits<ByteCodeField>::max()); 889 writer.append(result); 890 } 891 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 892 uint32_t index = op.index(); 893 if (index < 4) 894 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 895 else 896 writer.append(OpCode::GetResultN, index); 897 writer.append(op.operation(), op.value()); 898 } 899 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 900 Value result = op.value(); 901 Optional<uint32_t> index = op.index(); 902 writer.append(OpCode::GetResults, 903 index.getValueOr(std::numeric_limits<uint32_t>::max()), 904 op.operation()); 905 if (result.getType().isa<pdl::RangeType>()) 906 writer.append(getRangeStorageIndex(result)); 907 else 908 writer.append(std::numeric_limits<ByteCodeField>::max()); 909 writer.append(result); 910 } 911 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { 912 Value operations = op.operations(); 913 ByteCodeField rangeIndex = getRangeStorageIndex(operations); 914 writer.append(OpCode::GetUsers, operations, rangeIndex); 915 writer.appendPDLValue(op.value()); 916 } 917 void Generator::generate(pdl_interp::GetValueTypeOp op, 918 ByteCodeWriter &writer) { 919 if (op.getType().isa<pdl::RangeType>()) { 920 Value result = op.result(); 921 writer.append(OpCode::GetValueRangeTypes, result, 922 getRangeStorageIndex(result), op.value()); 923 } else { 924 writer.append(OpCode::GetValueType, op.result(), op.value()); 925 } 926 } 927 928 void Generator::generate(pdl_interp::InferredTypesOp op, 929 ByteCodeWriter &writer) { 930 // InferType maps to a null type as a marker for inferring result types. 931 getMemIndex(op.type()) = getMemIndex(Type()); 932 } 933 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 934 writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 935 } 936 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 937 ByteCodeField patternIndex = patterns.size(); 938 patterns.emplace_back(PDLByteCodePattern::create( 939 op, rewriterToAddr[op.rewriter().getLeafReference().getValue()])); 940 writer.append(OpCode::RecordMatch, patternIndex, 941 SuccessorRange(op.getOperation()), op.matchedOps()); 942 writer.appendPDLValueList(op.inputs()); 943 } 944 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 945 writer.append(OpCode::ReplaceOp, op.operation()); 946 writer.appendPDLValueList(op.replValues()); 947 } 948 void Generator::generate(pdl_interp::SwitchAttributeOp op, 949 ByteCodeWriter &writer) { 950 writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 951 op.getSuccessors()); 952 } 953 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 954 ByteCodeWriter &writer) { 955 writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 956 op.getSuccessors()); 957 } 958 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 959 ByteCodeWriter &writer) { 960 auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 961 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 962 }); 963 writer.append(OpCode::SwitchOperationName, op.operation(), cases, 964 op.getSuccessors()); 965 } 966 void Generator::generate(pdl_interp::SwitchResultCountOp op, 967 ByteCodeWriter &writer) { 968 writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 969 op.getSuccessors()); 970 } 971 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 972 writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 973 op.getSuccessors()); 974 } 975 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 976 writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(), 977 op.getSuccessors()); 978 } 979 980 //===----------------------------------------------------------------------===// 981 // PDLByteCode 982 //===----------------------------------------------------------------------===// 983 984 PDLByteCode::PDLByteCode(ModuleOp module, 985 llvm::StringMap<PDLConstraintFunction> constraintFns, 986 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 987 Generator generator(module.getContext(), uniquedData, matcherByteCode, 988 rewriterByteCode, patterns, maxValueMemoryIndex, 989 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, 990 maxLoopLevel, constraintFns, rewriteFns); 991 generator.generate(module); 992 993 // Initialize the external functions. 994 for (auto &it : constraintFns) 995 constraintFunctions.push_back(std::move(it.second)); 996 for (auto &it : rewriteFns) 997 rewriteFunctions.push_back(std::move(it.second)); 998 } 999 1000 /// Initialize the given state such that it can be used to execute the current 1001 /// bytecode. 1002 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 1003 state.memory.resize(maxValueMemoryIndex, nullptr); 1004 state.opRangeMemory.resize(maxOpRangeCount); 1005 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 1006 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 1007 state.loopIndex.resize(maxLoopLevel, 0); 1008 state.currentPatternBenefits.reserve(patterns.size()); 1009 for (const PDLByteCodePattern &pattern : patterns) 1010 state.currentPatternBenefits.push_back(pattern.getBenefit()); 1011 } 1012 1013 //===----------------------------------------------------------------------===// 1014 // ByteCode Execution 1015 1016 namespace { 1017 /// This class provides support for executing a bytecode stream. 1018 class ByteCodeExecutor { 1019 public: 1020 ByteCodeExecutor( 1021 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 1022 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, 1023 MutableArrayRef<TypeRange> typeRangeMemory, 1024 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 1025 MutableArrayRef<ValueRange> valueRangeMemory, 1026 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 1027 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, 1028 ArrayRef<ByteCodeField> code, 1029 ArrayRef<PatternBenefit> currentPatternBenefits, 1030 ArrayRef<PDLByteCodePattern> patterns, 1031 ArrayRef<PDLConstraintFunction> constraintFunctions, 1032 ArrayRef<PDLRewriteFunction> rewriteFunctions) 1033 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), 1034 typeRangeMemory(typeRangeMemory), 1035 allocatedTypeRangeMemory(allocatedTypeRangeMemory), 1036 valueRangeMemory(valueRangeMemory), 1037 allocatedValueRangeMemory(allocatedValueRangeMemory), 1038 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), 1039 currentPatternBenefits(currentPatternBenefits), patterns(patterns), 1040 constraintFunctions(constraintFunctions), 1041 rewriteFunctions(rewriteFunctions) {} 1042 1043 /// Start executing the code at the current bytecode index. `matches` is an 1044 /// optional field provided when this function is executed in a matching 1045 /// context. 1046 void execute(PatternRewriter &rewriter, 1047 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 1048 Optional<Location> mainRewriteLoc = {}); 1049 1050 private: 1051 /// Internal implementation of executing each of the bytecode commands. 1052 void executeApplyConstraint(PatternRewriter &rewriter); 1053 void executeApplyRewrite(PatternRewriter &rewriter); 1054 void executeAreEqual(); 1055 void executeAreRangesEqual(); 1056 void executeBranch(); 1057 void executeCheckOperandCount(); 1058 void executeCheckOperationName(); 1059 void executeCheckResultCount(); 1060 void executeCheckTypes(); 1061 void executeContinue(); 1062 void executeCreateOperation(PatternRewriter &rewriter, 1063 Location mainRewriteLoc); 1064 void executeCreateTypes(); 1065 void executeEraseOp(PatternRewriter &rewriter); 1066 template <typename T, typename Range, PDLValue::Kind kind> 1067 void executeExtract(); 1068 void executeFinalize(); 1069 void executeForEach(); 1070 void executeGetAttribute(); 1071 void executeGetAttributeType(); 1072 void executeGetDefiningOp(); 1073 void executeGetOperand(unsigned index); 1074 void executeGetOperands(); 1075 void executeGetResult(unsigned index); 1076 void executeGetResults(); 1077 void executeGetUsers(); 1078 void executeGetValueType(); 1079 void executeGetValueRangeTypes(); 1080 void executeIsNotNull(); 1081 void executeRecordMatch(PatternRewriter &rewriter, 1082 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 1083 void executeReplaceOp(PatternRewriter &rewriter); 1084 void executeSwitchAttribute(); 1085 void executeSwitchOperandCount(); 1086 void executeSwitchOperationName(); 1087 void executeSwitchResultCount(); 1088 void executeSwitchType(); 1089 void executeSwitchTypes(); 1090 1091 /// Pushes a code iterator to the stack. 1092 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } 1093 1094 /// Pops a code iterator from the stack, returning true on success. 1095 void popCodeIt() { 1096 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); 1097 curCodeIt = resumeCodeIt.back(); 1098 resumeCodeIt.pop_back(); 1099 } 1100 1101 /// Read a value from the bytecode buffer, optionally skipping a certain 1102 /// number of prefix values. These methods always update the buffer to point 1103 /// to the next field after the read data. 1104 template <typename T = ByteCodeField> 1105 T read(size_t skipN = 0) { 1106 curCodeIt += skipN; 1107 return readImpl<T>(); 1108 } 1109 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 1110 1111 /// Read a list of values from the bytecode buffer. 1112 template <typename ValueT, typename T> 1113 void readList(SmallVectorImpl<T> &list) { 1114 list.clear(); 1115 for (unsigned i = 0, e = read(); i != e; ++i) 1116 list.push_back(read<ValueT>()); 1117 } 1118 1119 /// Read a list of values from the bytecode buffer. The values may be encoded 1120 /// as either Value or ValueRange elements. 1121 void readValueList(SmallVectorImpl<Value> &list) { 1122 for (unsigned i = 0, e = read(); i != e; ++i) { 1123 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1124 list.push_back(read<Value>()); 1125 } else { 1126 ValueRange *values = read<ValueRange *>(); 1127 list.append(values->begin(), values->end()); 1128 } 1129 } 1130 } 1131 1132 /// Jump to a specific successor based on a predicate value. 1133 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 1134 /// Jump to a specific successor based on a destination index. 1135 void selectJump(size_t destIndex) { 1136 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 1137 } 1138 1139 /// Handle a switch operation with the provided value and cases. 1140 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1141 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1142 LLVM_DEBUG({ 1143 llvm::dbgs() << " * Value: " << value << "\n" 1144 << " * Cases: "; 1145 llvm::interleaveComma(cases, llvm::dbgs()); 1146 llvm::dbgs() << "\n"; 1147 }); 1148 1149 // Check to see if the attribute value is within the case list. Jump to 1150 // the correct successor index based on the result. 1151 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1152 if (cmp(*it, value)) 1153 return selectJump(size_t((it - cases.begin()) + 1)); 1154 selectJump(size_t(0)); 1155 } 1156 1157 /// Store a pointer to memory. 1158 void storeToMemory(unsigned index, const void *value) { 1159 memory[index] = value; 1160 } 1161 1162 /// Store a value to memory as an opaque pointer. 1163 template <typename T> 1164 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 1165 storeToMemory(unsigned index, T value) { 1166 memory[index] = value.getAsOpaquePointer(); 1167 } 1168 1169 /// Internal implementation of reading various data types from the bytecode 1170 /// stream. 1171 template <typename T> 1172 const void *readFromMemory() { 1173 size_t index = *curCodeIt++; 1174 1175 // If this type is an SSA value, it can only be stored in non-const memory. 1176 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1177 Value>::value || 1178 index < memory.size()) 1179 return memory[index]; 1180 1181 // Otherwise, if this index is not inbounds it is uniqued. 1182 return uniquedMemory[index - memory.size()]; 1183 } 1184 template <typename T> 1185 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1186 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1187 } 1188 template <typename T> 1189 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1190 T> 1191 readImpl() { 1192 return T(T::getFromOpaquePointer(readFromMemory<T>())); 1193 } 1194 template <typename T> 1195 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1196 switch (read<PDLValue::Kind>()) { 1197 case PDLValue::Kind::Attribute: 1198 return read<Attribute>(); 1199 case PDLValue::Kind::Operation: 1200 return read<Operation *>(); 1201 case PDLValue::Kind::Type: 1202 return read<Type>(); 1203 case PDLValue::Kind::Value: 1204 return read<Value>(); 1205 case PDLValue::Kind::TypeRange: 1206 return read<TypeRange *>(); 1207 case PDLValue::Kind::ValueRange: 1208 return read<ValueRange *>(); 1209 } 1210 llvm_unreachable("unhandled PDLValue::Kind"); 1211 } 1212 template <typename T> 1213 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1214 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1215 "unexpected ByteCode address size"); 1216 ByteCodeAddr result; 1217 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1218 curCodeIt += 2; 1219 return result; 1220 } 1221 template <typename T> 1222 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1223 return *curCodeIt++; 1224 } 1225 template <typename T> 1226 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1227 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1228 } 1229 1230 /// The underlying bytecode buffer. 1231 const ByteCodeField *curCodeIt; 1232 1233 /// The stack of bytecode positions at which to resume operation. 1234 SmallVector<const ByteCodeField *> resumeCodeIt; 1235 1236 /// The current execution memory. 1237 MutableArrayRef<const void *> memory; 1238 MutableArrayRef<OwningOpRange> opRangeMemory; 1239 MutableArrayRef<TypeRange> typeRangeMemory; 1240 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1241 MutableArrayRef<ValueRange> valueRangeMemory; 1242 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1243 1244 /// The current loop indices. 1245 MutableArrayRef<unsigned> loopIndex; 1246 1247 /// References to ByteCode data necessary for execution. 1248 ArrayRef<const void *> uniquedMemory; 1249 ArrayRef<ByteCodeField> code; 1250 ArrayRef<PatternBenefit> currentPatternBenefits; 1251 ArrayRef<PDLByteCodePattern> patterns; 1252 ArrayRef<PDLConstraintFunction> constraintFunctions; 1253 ArrayRef<PDLRewriteFunction> rewriteFunctions; 1254 }; 1255 1256 /// This class is an instantiation of the PDLResultList that provides access to 1257 /// the returned results. This API is not on `PDLResultList` to avoid 1258 /// overexposing access to information specific solely to the ByteCode. 1259 class ByteCodeRewriteResultList : public PDLResultList { 1260 public: 1261 ByteCodeRewriteResultList(unsigned maxNumResults) 1262 : PDLResultList(maxNumResults) {} 1263 1264 /// Return the list of PDL results. 1265 MutableArrayRef<PDLValue> getResults() { return results; } 1266 1267 /// Return the type ranges allocated by this list. 1268 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1269 return allocatedTypeRanges; 1270 } 1271 1272 /// Return the value ranges allocated by this list. 1273 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1274 return allocatedValueRanges; 1275 } 1276 }; 1277 } // end anonymous namespace 1278 1279 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1280 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1281 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1282 ArrayAttr constParams = read<ArrayAttr>(); 1283 SmallVector<PDLValue, 16> args; 1284 readList<PDLValue>(args); 1285 1286 LLVM_DEBUG({ 1287 llvm::dbgs() << " * Arguments: "; 1288 llvm::interleaveComma(args, llvm::dbgs()); 1289 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1290 }); 1291 1292 // Invoke the constraint and jump to the proper destination. 1293 selectJump(succeeded(constraintFn(args, constParams, rewriter))); 1294 } 1295 1296 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1297 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1298 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1299 ArrayAttr constParams = read<ArrayAttr>(); 1300 SmallVector<PDLValue, 16> args; 1301 readList<PDLValue>(args); 1302 1303 LLVM_DEBUG({ 1304 llvm::dbgs() << " * Arguments: "; 1305 llvm::interleaveComma(args, llvm::dbgs()); 1306 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1307 }); 1308 1309 // Execute the rewrite function. 1310 ByteCodeField numResults = read(); 1311 ByteCodeRewriteResultList results(numResults); 1312 rewriteFn(args, constParams, rewriter, results); 1313 1314 assert(results.getResults().size() == numResults && 1315 "native PDL rewrite function returned unexpected number of results"); 1316 1317 // Store the results in the bytecode memory. 1318 for (PDLValue &result : results.getResults()) { 1319 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1320 1321 // In debug mode we also verify the expected kind of the result. 1322 #ifndef NDEBUG 1323 assert(result.getKind() == read<PDLValue::Kind>() && 1324 "native PDL rewrite function returned an unexpected type of result"); 1325 #endif 1326 1327 // If the result is a range, we need to copy it over to the bytecodes 1328 // range memory. 1329 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1330 unsigned rangeIndex = read(); 1331 typeRangeMemory[rangeIndex] = *typeRange; 1332 memory[read()] = &typeRangeMemory[rangeIndex]; 1333 } else if (Optional<ValueRange> valueRange = 1334 result.dyn_cast<ValueRange>()) { 1335 unsigned rangeIndex = read(); 1336 valueRangeMemory[rangeIndex] = *valueRange; 1337 memory[read()] = &valueRangeMemory[rangeIndex]; 1338 } else { 1339 memory[read()] = result.getAsOpaquePointer(); 1340 } 1341 } 1342 1343 // Copy over any underlying storage allocated for result ranges. 1344 for (auto &it : results.getAllocatedTypeRanges()) 1345 allocatedTypeRangeMemory.push_back(std::move(it)); 1346 for (auto &it : results.getAllocatedValueRanges()) 1347 allocatedValueRangeMemory.push_back(std::move(it)); 1348 } 1349 1350 void ByteCodeExecutor::executeAreEqual() { 1351 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1352 const void *lhs = read<const void *>(); 1353 const void *rhs = read<const void *>(); 1354 1355 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1356 selectJump(lhs == rhs); 1357 } 1358 1359 void ByteCodeExecutor::executeAreRangesEqual() { 1360 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1361 PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1362 const void *lhs = read<const void *>(); 1363 const void *rhs = read<const void *>(); 1364 1365 switch (valueKind) { 1366 case PDLValue::Kind::TypeRange: { 1367 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1368 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1369 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1370 selectJump(*lhsRange == *rhsRange); 1371 break; 1372 } 1373 case PDLValue::Kind::ValueRange: { 1374 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1375 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1376 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1377 selectJump(*lhsRange == *rhsRange); 1378 break; 1379 } 1380 default: 1381 llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1382 } 1383 } 1384 1385 void ByteCodeExecutor::executeBranch() { 1386 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1387 curCodeIt = &code[read<ByteCodeAddr>()]; 1388 } 1389 1390 void ByteCodeExecutor::executeCheckOperandCount() { 1391 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1392 Operation *op = read<Operation *>(); 1393 uint32_t expectedCount = read<uint32_t>(); 1394 bool compareAtLeast = read(); 1395 1396 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1397 << " * Expected: " << expectedCount << "\n" 1398 << " * Comparator: " 1399 << (compareAtLeast ? ">=" : "==") << "\n"); 1400 if (compareAtLeast) 1401 selectJump(op->getNumOperands() >= expectedCount); 1402 else 1403 selectJump(op->getNumOperands() == expectedCount); 1404 } 1405 1406 void ByteCodeExecutor::executeCheckOperationName() { 1407 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1408 Operation *op = read<Operation *>(); 1409 OperationName expectedName = read<OperationName>(); 1410 1411 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1412 << " * Expected: \"" << expectedName << "\"\n"); 1413 selectJump(op->getName() == expectedName); 1414 } 1415 1416 void ByteCodeExecutor::executeCheckResultCount() { 1417 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1418 Operation *op = read<Operation *>(); 1419 uint32_t expectedCount = read<uint32_t>(); 1420 bool compareAtLeast = read(); 1421 1422 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1423 << " * Expected: " << expectedCount << "\n" 1424 << " * Comparator: " 1425 << (compareAtLeast ? ">=" : "==") << "\n"); 1426 if (compareAtLeast) 1427 selectJump(op->getNumResults() >= expectedCount); 1428 else 1429 selectJump(op->getNumResults() == expectedCount); 1430 } 1431 1432 void ByteCodeExecutor::executeCheckTypes() { 1433 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1434 TypeRange *lhs = read<TypeRange *>(); 1435 Attribute rhs = read<Attribute>(); 1436 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1437 1438 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1439 } 1440 1441 void ByteCodeExecutor::executeContinue() { 1442 ByteCodeField level = read(); 1443 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" 1444 << " * Level: " << level << "\n"); 1445 ++loopIndex[level]; 1446 popCodeIt(); 1447 } 1448 1449 void ByteCodeExecutor::executeCreateTypes() { 1450 LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 1451 unsigned memIndex = read(); 1452 unsigned rangeIndex = read(); 1453 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1454 1455 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1456 1457 // Allocate a buffer for this type range. 1458 llvm::OwningArrayRef<Type> storage(typesAttr.size()); 1459 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 1460 allocatedTypeRangeMemory.emplace_back(std::move(storage)); 1461 1462 // Assign this to the range slot and use the range as the value for the 1463 // memory index. 1464 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 1465 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1466 } 1467 1468 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1469 Location mainRewriteLoc) { 1470 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1471 1472 unsigned memIndex = read(); 1473 OperationState state(mainRewriteLoc, read<OperationName>()); 1474 readValueList(state.operands); 1475 for (unsigned i = 0, e = read(); i != e; ++i) { 1476 StringAttr name = read<StringAttr>(); 1477 if (Attribute attr = read<Attribute>()) 1478 state.addAttribute(name, attr); 1479 } 1480 1481 for (unsigned i = 0, e = read(); i != e; ++i) { 1482 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1483 state.types.push_back(read<Type>()); 1484 continue; 1485 } 1486 1487 // If we find a null range, this signals that the types are infered. 1488 if (TypeRange *resultTypes = read<TypeRange *>()) { 1489 state.types.append(resultTypes->begin(), resultTypes->end()); 1490 continue; 1491 } 1492 1493 // Handle the case where the operation has inferred types. 1494 InferTypeOpInterface::Concept *concept = 1495 state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>(); 1496 1497 // TODO: Handle failure. 1498 state.types.clear(); 1499 if (failed(concept->inferReturnTypes( 1500 state.getContext(), state.location, state.operands, 1501 state.attributes.getDictionary(state.getContext()), state.regions, 1502 state.types))) 1503 return; 1504 break; 1505 } 1506 1507 Operation *resultOp = rewriter.createOperation(state); 1508 memory[memIndex] = resultOp; 1509 1510 LLVM_DEBUG({ 1511 llvm::dbgs() << " * Attributes: " 1512 << state.attributes.getDictionary(state.getContext()) 1513 << "\n * Operands: "; 1514 llvm::interleaveComma(state.operands, llvm::dbgs()); 1515 llvm::dbgs() << "\n * Result Types: "; 1516 llvm::interleaveComma(state.types, llvm::dbgs()); 1517 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1518 }); 1519 } 1520 1521 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1522 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1523 Operation *op = read<Operation *>(); 1524 1525 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1526 rewriter.eraseOp(op); 1527 } 1528 1529 template <typename T, typename Range, PDLValue::Kind kind> 1530 void ByteCodeExecutor::executeExtract() { 1531 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); 1532 Range *range = read<Range *>(); 1533 unsigned index = read<uint32_t>(); 1534 unsigned memIndex = read(); 1535 1536 if (!range) { 1537 memory[memIndex] = nullptr; 1538 return; 1539 } 1540 1541 T result = index < range->size() ? (*range)[index] : T(); 1542 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" 1543 << " * Index: " << index << "\n" 1544 << " * Result: " << result << "\n"); 1545 storeToMemory(memIndex, result); 1546 } 1547 1548 void ByteCodeExecutor::executeFinalize() { 1549 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); 1550 } 1551 1552 void ByteCodeExecutor::executeForEach() { 1553 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); 1554 // Subtract 1 for the op code. 1555 const ByteCodeField *it = curCodeIt - 1; 1556 unsigned rangeIndex = read(); 1557 unsigned memIndex = read(); 1558 const void *value = nullptr; 1559 1560 switch (read<PDLValue::Kind>()) { 1561 case PDLValue::Kind::Operation: { 1562 unsigned &index = loopIndex[read()]; 1563 ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; 1564 assert(index <= array.size() && "iterated past the end"); 1565 if (index < array.size()) { 1566 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); 1567 value = array[index]; 1568 break; 1569 } 1570 1571 LLVM_DEBUG(llvm::dbgs() << " * Done\n"); 1572 index = 0; 1573 selectJump(size_t(0)); 1574 return; 1575 } 1576 default: 1577 llvm_unreachable("unexpected `ForEach` value kind"); 1578 } 1579 1580 // Store the iterate value and the stack address. 1581 memory[memIndex] = value; 1582 pushCodeIt(it); 1583 1584 // Skip over the successor (we will enter the body of the loop). 1585 read<ByteCodeAddr>(); 1586 } 1587 1588 void ByteCodeExecutor::executeGetAttribute() { 1589 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1590 unsigned memIndex = read(); 1591 Operation *op = read<Operation *>(); 1592 StringAttr attrName = read<StringAttr>(); 1593 Attribute attr = op->getAttr(attrName); 1594 1595 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1596 << " * Attribute: " << attrName << "\n" 1597 << " * Result: " << attr << "\n"); 1598 memory[memIndex] = attr.getAsOpaquePointer(); 1599 } 1600 1601 void ByteCodeExecutor::executeGetAttributeType() { 1602 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1603 unsigned memIndex = read(); 1604 Attribute attr = read<Attribute>(); 1605 Type type = attr ? attr.getType() : Type(); 1606 1607 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1608 << " * Result: " << type << "\n"); 1609 memory[memIndex] = type.getAsOpaquePointer(); 1610 } 1611 1612 void ByteCodeExecutor::executeGetDefiningOp() { 1613 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1614 unsigned memIndex = read(); 1615 Operation *op = nullptr; 1616 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1617 Value value = read<Value>(); 1618 if (value) 1619 op = value.getDefiningOp(); 1620 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1621 } else { 1622 ValueRange *values = read<ValueRange *>(); 1623 if (values && !values->empty()) { 1624 op = values->front().getDefiningOp(); 1625 } 1626 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1627 } 1628 1629 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1630 memory[memIndex] = op; 1631 } 1632 1633 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1634 Operation *op = read<Operation *>(); 1635 unsigned memIndex = read(); 1636 Value operand = 1637 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1638 1639 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1640 << " * Index: " << index << "\n" 1641 << " * Result: " << operand << "\n"); 1642 memory[memIndex] = operand.getAsOpaquePointer(); 1643 } 1644 1645 /// This function is the internal implementation of `GetResults` and 1646 /// `GetOperands` that provides support for extracting a value range from the 1647 /// given operation. 1648 template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1649 static void * 1650 executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1651 ByteCodeField rangeIndex, StringRef attrSizedSegments, 1652 MutableArrayRef<ValueRange> valueRangeMemory) { 1653 // Check for the sentinel index that signals that all values should be 1654 // returned. 1655 if (index == std::numeric_limits<uint32_t>::max()) { 1656 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1657 // `values` is already the full value range. 1658 1659 // Otherwise, check to see if this operation uses AttrSizedSegments. 1660 } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1661 LLVM_DEBUG(llvm::dbgs() 1662 << " * Extracting values from `" << attrSizedSegments << "`\n"); 1663 1664 auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 1665 if (!segmentAttr || segmentAttr.getNumElements() <= index) 1666 return nullptr; 1667 1668 auto segments = segmentAttr.getValues<int32_t>(); 1669 unsigned startIndex = 1670 std::accumulate(segments.begin(), segments.begin() + index, 0); 1671 values = values.slice(startIndex, *std::next(segments.begin(), index)); 1672 1673 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1674 << *std::next(segments.begin(), index) << "]\n"); 1675 1676 // Otherwise, assume this is the last operand group of the operation. 1677 // FIXME: We currently don't support operations with 1678 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1679 // have a way to detect it's presence. 1680 } else if (values.size() >= index) { 1681 LLVM_DEBUG(llvm::dbgs() 1682 << " * Treating values as trailing variadic range\n"); 1683 values = values.drop_front(index); 1684 1685 // If we couldn't detect a way to compute the values, bail out. 1686 } else { 1687 return nullptr; 1688 } 1689 1690 // If the range index is valid, we are returning a range. 1691 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1692 valueRangeMemory[rangeIndex] = values; 1693 return &valueRangeMemory[rangeIndex]; 1694 } 1695 1696 // If a range index wasn't provided, the range is required to be non-variadic. 1697 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1698 } 1699 1700 void ByteCodeExecutor::executeGetOperands() { 1701 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1702 unsigned index = read<uint32_t>(); 1703 Operation *op = read<Operation *>(); 1704 ByteCodeField rangeIndex = read(); 1705 1706 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1707 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1708 valueRangeMemory); 1709 if (!result) 1710 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1711 memory[read()] = result; 1712 } 1713 1714 void ByteCodeExecutor::executeGetResult(unsigned index) { 1715 Operation *op = read<Operation *>(); 1716 unsigned memIndex = read(); 1717 OpResult result = 1718 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1719 1720 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1721 << " * Index: " << index << "\n" 1722 << " * Result: " << result << "\n"); 1723 memory[memIndex] = result.getAsOpaquePointer(); 1724 } 1725 1726 void ByteCodeExecutor::executeGetResults() { 1727 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1728 unsigned index = read<uint32_t>(); 1729 Operation *op = read<Operation *>(); 1730 ByteCodeField rangeIndex = read(); 1731 1732 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1733 op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1734 valueRangeMemory); 1735 if (!result) 1736 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1737 memory[read()] = result; 1738 } 1739 1740 void ByteCodeExecutor::executeGetUsers() { 1741 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); 1742 unsigned memIndex = read(); 1743 unsigned rangeIndex = read(); 1744 OwningOpRange &range = opRangeMemory[rangeIndex]; 1745 memory[memIndex] = ⦥ 1746 1747 range = OwningOpRange(); 1748 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1749 // Read the value. 1750 Value value = read<Value>(); 1751 if (!value) 1752 return; 1753 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1754 1755 // Extract the users of a single value. 1756 range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); 1757 llvm::copy(value.getUsers(), range.begin()); 1758 } else { 1759 // Read a range of values. 1760 ValueRange *values = read<ValueRange *>(); 1761 if (!values) 1762 return; 1763 LLVM_DEBUG({ 1764 llvm::dbgs() << " * Values (" << values->size() << "): "; 1765 llvm::interleaveComma(*values, llvm::dbgs()); 1766 llvm::dbgs() << "\n"; 1767 }); 1768 1769 // Extract all the users of a range of values. 1770 SmallVector<Operation *> users; 1771 for (Value value : *values) 1772 users.append(value.user_begin(), value.user_end()); 1773 range = OwningOpRange(users.size()); 1774 llvm::copy(users, range.begin()); 1775 } 1776 1777 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); 1778 } 1779 1780 void ByteCodeExecutor::executeGetValueType() { 1781 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1782 unsigned memIndex = read(); 1783 Value value = read<Value>(); 1784 Type type = value ? value.getType() : Type(); 1785 1786 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1787 << " * Result: " << type << "\n"); 1788 memory[memIndex] = type.getAsOpaquePointer(); 1789 } 1790 1791 void ByteCodeExecutor::executeGetValueRangeTypes() { 1792 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1793 unsigned memIndex = read(); 1794 unsigned rangeIndex = read(); 1795 ValueRange *values = read<ValueRange *>(); 1796 if (!values) { 1797 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1798 memory[memIndex] = nullptr; 1799 return; 1800 } 1801 1802 LLVM_DEBUG({ 1803 llvm::dbgs() << " * Values (" << values->size() << "): "; 1804 llvm::interleaveComma(*values, llvm::dbgs()); 1805 llvm::dbgs() << "\n * Result: "; 1806 llvm::interleaveComma(values->getType(), llvm::dbgs()); 1807 llvm::dbgs() << "\n"; 1808 }); 1809 typeRangeMemory[rangeIndex] = values->getType(); 1810 memory[memIndex] = &typeRangeMemory[rangeIndex]; 1811 } 1812 1813 void ByteCodeExecutor::executeIsNotNull() { 1814 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1815 const void *value = read<const void *>(); 1816 1817 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1818 selectJump(value != nullptr); 1819 } 1820 1821 void ByteCodeExecutor::executeRecordMatch( 1822 PatternRewriter &rewriter, 1823 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1824 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1825 unsigned patternIndex = read(); 1826 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1827 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1828 1829 // If the benefit of the pattern is impossible, skip the processing of the 1830 // rest of the pattern. 1831 if (benefit.isImpossibleToMatch()) { 1832 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1833 curCodeIt = dest; 1834 return; 1835 } 1836 1837 // Create a fused location containing the locations of each of the 1838 // operations used in the match. This will be used as the location for 1839 // created operations during the rewrite that don't already have an 1840 // explicit location set. 1841 unsigned numMatchLocs = read(); 1842 SmallVector<Location, 4> matchLocs; 1843 matchLocs.reserve(numMatchLocs); 1844 for (unsigned i = 0; i != numMatchLocs; ++i) 1845 matchLocs.push_back(read<Operation *>()->getLoc()); 1846 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1847 1848 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1849 << " * Location: " << matchLoc << "\n"); 1850 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1851 PDLByteCode::MatchResult &match = matches.back(); 1852 1853 // Record all of the inputs to the match. If any of the inputs are ranges, we 1854 // will also need to remap the range pointer to memory stored in the match 1855 // state. 1856 unsigned numInputs = read(); 1857 match.values.reserve(numInputs); 1858 match.typeRangeValues.reserve(numInputs); 1859 match.valueRangeValues.reserve(numInputs); 1860 for (unsigned i = 0; i < numInputs; ++i) { 1861 switch (read<PDLValue::Kind>()) { 1862 case PDLValue::Kind::TypeRange: 1863 match.typeRangeValues.push_back(*read<TypeRange *>()); 1864 match.values.push_back(&match.typeRangeValues.back()); 1865 break; 1866 case PDLValue::Kind::ValueRange: 1867 match.valueRangeValues.push_back(*read<ValueRange *>()); 1868 match.values.push_back(&match.valueRangeValues.back()); 1869 break; 1870 default: 1871 match.values.push_back(read<const void *>()); 1872 break; 1873 } 1874 } 1875 curCodeIt = dest; 1876 } 1877 1878 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1879 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1880 Operation *op = read<Operation *>(); 1881 SmallVector<Value, 16> args; 1882 readValueList(args); 1883 1884 LLVM_DEBUG({ 1885 llvm::dbgs() << " * Operation: " << *op << "\n" 1886 << " * Values: "; 1887 llvm::interleaveComma(args, llvm::dbgs()); 1888 llvm::dbgs() << "\n"; 1889 }); 1890 rewriter.replaceOp(op, args); 1891 } 1892 1893 void ByteCodeExecutor::executeSwitchAttribute() { 1894 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1895 Attribute value = read<Attribute>(); 1896 ArrayAttr cases = read<ArrayAttr>(); 1897 handleSwitch(value, cases); 1898 } 1899 1900 void ByteCodeExecutor::executeSwitchOperandCount() { 1901 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1902 Operation *op = read<Operation *>(); 1903 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1904 1905 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1906 handleSwitch(op->getNumOperands(), cases); 1907 } 1908 1909 void ByteCodeExecutor::executeSwitchOperationName() { 1910 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1911 OperationName value = read<Operation *>()->getName(); 1912 size_t caseCount = read(); 1913 1914 // The operation names are stored in-line, so to print them out for 1915 // debugging purposes we need to read the array before executing the 1916 // switch so that we can display all of the possible values. 1917 LLVM_DEBUG({ 1918 const ByteCodeField *prevCodeIt = curCodeIt; 1919 llvm::dbgs() << " * Value: " << value << "\n" 1920 << " * Cases: "; 1921 llvm::interleaveComma( 1922 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1923 [&](size_t) { return read<OperationName>(); }), 1924 llvm::dbgs()); 1925 llvm::dbgs() << "\n"; 1926 curCodeIt = prevCodeIt; 1927 }); 1928 1929 // Try to find the switch value within any of the cases. 1930 for (size_t i = 0; i != caseCount; ++i) { 1931 if (read<OperationName>() == value) { 1932 curCodeIt += (caseCount - i - 1); 1933 return selectJump(i + 1); 1934 } 1935 } 1936 selectJump(size_t(0)); 1937 } 1938 1939 void ByteCodeExecutor::executeSwitchResultCount() { 1940 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1941 Operation *op = read<Operation *>(); 1942 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1943 1944 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1945 handleSwitch(op->getNumResults(), cases); 1946 } 1947 1948 void ByteCodeExecutor::executeSwitchType() { 1949 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 1950 Type value = read<Type>(); 1951 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 1952 handleSwitch(value, cases); 1953 } 1954 1955 void ByteCodeExecutor::executeSwitchTypes() { 1956 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 1957 TypeRange *value = read<TypeRange *>(); 1958 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 1959 if (!value) { 1960 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 1961 return selectJump(size_t(0)); 1962 } 1963 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 1964 return value == caseValue.getAsValueRange<TypeAttr>(); 1965 }); 1966 } 1967 1968 void ByteCodeExecutor::execute( 1969 PatternRewriter &rewriter, 1970 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 1971 Optional<Location> mainRewriteLoc) { 1972 while (true) { 1973 OpCode opCode = static_cast<OpCode>(read()); 1974 switch (opCode) { 1975 case ApplyConstraint: 1976 executeApplyConstraint(rewriter); 1977 break; 1978 case ApplyRewrite: 1979 executeApplyRewrite(rewriter); 1980 break; 1981 case AreEqual: 1982 executeAreEqual(); 1983 break; 1984 case AreRangesEqual: 1985 executeAreRangesEqual(); 1986 break; 1987 case Branch: 1988 executeBranch(); 1989 break; 1990 case CheckOperandCount: 1991 executeCheckOperandCount(); 1992 break; 1993 case CheckOperationName: 1994 executeCheckOperationName(); 1995 break; 1996 case CheckResultCount: 1997 executeCheckResultCount(); 1998 break; 1999 case CheckTypes: 2000 executeCheckTypes(); 2001 break; 2002 case Continue: 2003 executeContinue(); 2004 break; 2005 case CreateOperation: 2006 executeCreateOperation(rewriter, *mainRewriteLoc); 2007 break; 2008 case CreateTypes: 2009 executeCreateTypes(); 2010 break; 2011 case EraseOp: 2012 executeEraseOp(rewriter); 2013 break; 2014 case ExtractOp: 2015 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); 2016 break; 2017 case ExtractType: 2018 executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); 2019 break; 2020 case ExtractValue: 2021 executeExtract<Value, ValueRange, PDLValue::Kind::Value>(); 2022 break; 2023 case Finalize: 2024 executeFinalize(); 2025 LLVM_DEBUG(llvm::dbgs() << "\n"); 2026 return; 2027 case ForEach: 2028 executeForEach(); 2029 break; 2030 case GetAttribute: 2031 executeGetAttribute(); 2032 break; 2033 case GetAttributeType: 2034 executeGetAttributeType(); 2035 break; 2036 case GetDefiningOp: 2037 executeGetDefiningOp(); 2038 break; 2039 case GetOperand0: 2040 case GetOperand1: 2041 case GetOperand2: 2042 case GetOperand3: { 2043 unsigned index = opCode - GetOperand0; 2044 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 2045 executeGetOperand(index); 2046 break; 2047 } 2048 case GetOperandN: 2049 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 2050 executeGetOperand(read<uint32_t>()); 2051 break; 2052 case GetOperands: 2053 executeGetOperands(); 2054 break; 2055 case GetResult0: 2056 case GetResult1: 2057 case GetResult2: 2058 case GetResult3: { 2059 unsigned index = opCode - GetResult0; 2060 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 2061 executeGetResult(index); 2062 break; 2063 } 2064 case GetResultN: 2065 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 2066 executeGetResult(read<uint32_t>()); 2067 break; 2068 case GetResults: 2069 executeGetResults(); 2070 break; 2071 case GetUsers: 2072 executeGetUsers(); 2073 break; 2074 case GetValueType: 2075 executeGetValueType(); 2076 break; 2077 case GetValueRangeTypes: 2078 executeGetValueRangeTypes(); 2079 break; 2080 case IsNotNull: 2081 executeIsNotNull(); 2082 break; 2083 case RecordMatch: 2084 assert(matches && 2085 "expected matches to be provided when executing the matcher"); 2086 executeRecordMatch(rewriter, *matches); 2087 break; 2088 case ReplaceOp: 2089 executeReplaceOp(rewriter); 2090 break; 2091 case SwitchAttribute: 2092 executeSwitchAttribute(); 2093 break; 2094 case SwitchOperandCount: 2095 executeSwitchOperandCount(); 2096 break; 2097 case SwitchOperationName: 2098 executeSwitchOperationName(); 2099 break; 2100 case SwitchResultCount: 2101 executeSwitchResultCount(); 2102 break; 2103 case SwitchType: 2104 executeSwitchType(); 2105 break; 2106 case SwitchTypes: 2107 executeSwitchTypes(); 2108 break; 2109 } 2110 LLVM_DEBUG(llvm::dbgs() << "\n"); 2111 } 2112 } 2113 2114 /// Run the pattern matcher on the given root operation, collecting the matched 2115 /// patterns in `matches`. 2116 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 2117 SmallVectorImpl<MatchResult> &matches, 2118 PDLByteCodeMutableState &state) const { 2119 // The first memory slot is always the root operation. 2120 state.memory[0] = op; 2121 2122 // The matcher function always starts at code address 0. 2123 ByteCodeExecutor executor( 2124 matcherByteCode.data(), state.memory, state.opRangeMemory, 2125 state.typeRangeMemory, state.allocatedTypeRangeMemory, 2126 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, 2127 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, 2128 constraintFunctions, rewriteFunctions); 2129 executor.execute(rewriter, &matches); 2130 2131 // Order the found matches by benefit. 2132 std::stable_sort(matches.begin(), matches.end(), 2133 [](const MatchResult &lhs, const MatchResult &rhs) { 2134 return lhs.benefit > rhs.benefit; 2135 }); 2136 } 2137 2138 /// Run the rewriter of the given pattern on the root operation `op`. 2139 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 2140 PDLByteCodeMutableState &state) const { 2141 // The arguments of the rewrite function are stored at the start of the 2142 // memory buffer. 2143 llvm::copy(match.values, state.memory.begin()); 2144 2145 ByteCodeExecutor executor( 2146 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 2147 state.opRangeMemory, state.typeRangeMemory, 2148 state.allocatedTypeRangeMemory, state.valueRangeMemory, 2149 state.allocatedValueRangeMemory, state.loopIndex, uniquedData, 2150 rewriterByteCode, state.currentPatternBenefits, patterns, 2151 constraintFunctions, rewriteFunctions); 2152 executor.execute(rewriter, /*matches=*/nullptr, match.location); 2153 } 2154