1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 24 #define DEBUG_TYPE "pdl-bytecode" 25 26 using namespace mlir; 27 using namespace mlir::detail; 28 29 //===----------------------------------------------------------------------===// 30 // PDLByteCodePattern 31 //===----------------------------------------------------------------------===// 32 33 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 34 ByteCodeAddr rewriterAddr) { 35 SmallVector<StringRef, 8> generatedOps; 36 if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 37 generatedOps = 38 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 39 40 PatternBenefit benefit = matchOp.benefit(); 41 MLIRContext *ctx = matchOp.getContext(); 42 43 // Check to see if this is pattern matches a specific operation type. 44 if (Optional<StringRef> rootKind = matchOp.rootKind()) 45 return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit, 46 ctx); 47 return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx, 48 MatchAnyOpTypeTag()); 49 } 50 51 //===----------------------------------------------------------------------===// 52 // PDLByteCodeMutableState 53 //===----------------------------------------------------------------------===// 54 55 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 56 /// to the position of the pattern within the range returned by 57 /// `PDLByteCode::getPatterns`. 58 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 59 PatternBenefit benefit) { 60 currentPatternBenefits[patternIndex] = benefit; 61 } 62 63 //===----------------------------------------------------------------------===// 64 // Bytecode OpCodes 65 //===----------------------------------------------------------------------===// 66 67 namespace { 68 enum OpCode : ByteCodeField { 69 /// Apply an externally registered constraint. 70 ApplyConstraint, 71 /// Apply an externally registered rewrite. 72 ApplyRewrite, 73 /// Check if two generic values are equal. 74 AreEqual, 75 /// Unconditional branch. 76 Branch, 77 /// Compare the operand count of an operation with a constant. 78 CheckOperandCount, 79 /// Compare the name of an operation with a constant. 80 CheckOperationName, 81 /// Compare the result count of an operation with a constant. 82 CheckResultCount, 83 /// Invoke a native creation method. 84 CreateNative, 85 /// Create an operation. 86 CreateOperation, 87 /// Erase an operation. 88 EraseOp, 89 /// Terminate a matcher or rewrite sequence. 90 Finalize, 91 /// Get a specific attribute of an operation. 92 GetAttribute, 93 /// Get the type of an attribute. 94 GetAttributeType, 95 /// Get the defining operation of a value. 96 GetDefiningOp, 97 /// Get a specific operand of an operation. 98 GetOperand0, 99 GetOperand1, 100 GetOperand2, 101 GetOperand3, 102 GetOperandN, 103 /// Get a specific result of an operation. 104 GetResult0, 105 GetResult1, 106 GetResult2, 107 GetResult3, 108 GetResultN, 109 /// Get the type of a value. 110 GetValueType, 111 /// Check if a generic value is not null. 112 IsNotNull, 113 /// Record a successful pattern match. 114 RecordMatch, 115 /// Replace an operation. 116 ReplaceOp, 117 /// Compare an attribute with a set of constants. 118 SwitchAttribute, 119 /// Compare the operand count of an operation with a set of constants. 120 SwitchOperandCount, 121 /// Compare the name of an operation with a set of constants. 122 SwitchOperationName, 123 /// Compare the result count of an operation with a set of constants. 124 SwitchResultCount, 125 /// Compare a type with a set of constants. 126 SwitchType, 127 }; 128 129 enum class PDLValueKind { Attribute, Operation, Type, Value }; 130 } // end anonymous namespace 131 132 //===----------------------------------------------------------------------===// 133 // ByteCode Generation 134 //===----------------------------------------------------------------------===// 135 136 //===----------------------------------------------------------------------===// 137 // Generator 138 139 namespace { 140 struct ByteCodeWriter; 141 142 /// This class represents the main generator for the pattern bytecode. 143 class Generator { 144 public: 145 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 146 SmallVectorImpl<ByteCodeField> &matcherByteCode, 147 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 148 SmallVectorImpl<PDLByteCodePattern> &patterns, 149 ByteCodeField &maxValueMemoryIndex, 150 llvm::StringMap<PDLConstraintFunction> &constraintFns, 151 llvm::StringMap<PDLCreateFunction> &createFns, 152 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 153 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 154 rewriterByteCode(rewriterByteCode), patterns(patterns), 155 maxValueMemoryIndex(maxValueMemoryIndex) { 156 for (auto it : llvm::enumerate(constraintFns)) 157 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 158 for (auto it : llvm::enumerate(createFns)) 159 nativeCreateToMemIndex.try_emplace(it.value().first(), it.index()); 160 for (auto it : llvm::enumerate(rewriteFns)) 161 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 162 } 163 164 /// Generate the bytecode for the given PDL interpreter module. 165 void generate(ModuleOp module); 166 167 /// Return the memory index to use for the given value. 168 ByteCodeField &getMemIndex(Value value) { 169 assert(valueToMemIndex.count(value) && 170 "expected memory index to be assigned"); 171 return valueToMemIndex[value]; 172 } 173 174 /// Return an index to use when referring to the given data that is uniqued in 175 /// the MLIR context. 176 template <typename T> 177 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 178 getMemIndex(T val) { 179 const void *opaqueVal = val.getAsOpaquePointer(); 180 181 // Get or insert a reference to this value. 182 auto it = uniquedDataToMemIndex.try_emplace( 183 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 184 if (it.second) 185 uniquedData.push_back(opaqueVal); 186 return it.first->second; 187 } 188 189 private: 190 /// Allocate memory indices for the results of operations within the matcher 191 /// and rewriters. 192 void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 193 194 /// Generate the bytecode for the given operation. 195 void generate(Operation *op, ByteCodeWriter &writer); 196 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 197 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 198 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 199 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 200 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 201 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 202 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 203 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 204 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 205 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 206 void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer); 207 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 208 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 209 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 210 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 211 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 212 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 213 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 214 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 215 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 216 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 217 void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer); 218 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 219 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 220 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 221 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 222 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 223 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 224 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 225 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 226 227 /// Mapping from value to its corresponding memory index. 228 DenseMap<Value, ByteCodeField> valueToMemIndex; 229 230 /// Mapping from the name of an externally registered rewrite to its index in 231 /// the bytecode registry. 232 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 233 234 /// Mapping from the name of an externally registered constraint to its index 235 /// in the bytecode registry. 236 llvm::StringMap<ByteCodeField> constraintToMemIndex; 237 238 /// Mapping from the name of an externally registered creation method to its 239 /// index in the bytecode registry. 240 llvm::StringMap<ByteCodeField> nativeCreateToMemIndex; 241 242 /// Mapping from rewriter function name to the bytecode address of the 243 /// rewriter function in byte. 244 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 245 246 /// Mapping from a uniqued storage object to its memory index within 247 /// `uniquedData`. 248 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 249 250 /// The current MLIR context. 251 MLIRContext *ctx; 252 253 /// Data of the ByteCode class to be populated. 254 std::vector<const void *> &uniquedData; 255 SmallVectorImpl<ByteCodeField> &matcherByteCode; 256 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 257 SmallVectorImpl<PDLByteCodePattern> &patterns; 258 ByteCodeField &maxValueMemoryIndex; 259 }; 260 261 /// This class provides utilities for writing a bytecode stream. 262 struct ByteCodeWriter { 263 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 264 : bytecode(bytecode), generator(generator) {} 265 266 /// Append a field to the bytecode. 267 void append(ByteCodeField field) { bytecode.push_back(field); } 268 void append(OpCode opCode) { bytecode.push_back(opCode); } 269 270 /// Append an address to the bytecode. 271 void append(ByteCodeAddr field) { 272 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 273 "unexpected ByteCode address size"); 274 275 ByteCodeField fieldParts[2]; 276 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 277 bytecode.append({fieldParts[0], fieldParts[1]}); 278 } 279 280 /// Append a successor range to the bytecode, the exact address will need to 281 /// be resolved later. 282 void append(SuccessorRange successors) { 283 // Add back references to the any successors so that the address can be 284 // resolved later. 285 for (Block *successor : successors) { 286 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 287 append(ByteCodeAddr(0)); 288 } 289 } 290 291 /// Append a range of values that will be read as generic PDLValues. 292 void appendPDLValueList(OperandRange values) { 293 bytecode.push_back(values.size()); 294 for (Value value : values) { 295 // Append the type of the value in addition to the value itself. 296 PDLValueKind kind = 297 TypeSwitch<Type, PDLValueKind>(value.getType()) 298 .Case<pdl::AttributeType>( 299 [](Type) { return PDLValueKind::Attribute; }) 300 .Case<pdl::OperationType>( 301 [](Type) { return PDLValueKind::Operation; }) 302 .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; }) 303 .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; }); 304 bytecode.push_back(static_cast<ByteCodeField>(kind)); 305 append(value); 306 } 307 } 308 309 /// Check if the given class `T` has an iterator type. 310 template <typename T, typename... Args> 311 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 312 313 /// Append a value that will be stored in a memory slot and not inline within 314 /// the bytecode. 315 template <typename T> 316 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 317 std::is_pointer<T>::value> 318 append(T value) { 319 bytecode.push_back(generator.getMemIndex(value)); 320 } 321 322 /// Append a range of values. 323 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 324 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 325 append(T range) { 326 bytecode.push_back(llvm::size(range)); 327 for (auto it : range) 328 append(it); 329 } 330 331 /// Append a variadic number of fields to the bytecode. 332 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 333 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 334 append(field); 335 append(field2, fields...); 336 } 337 338 /// Successor references in the bytecode that have yet to be resolved. 339 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 340 341 /// The underlying bytecode buffer. 342 SmallVectorImpl<ByteCodeField> &bytecode; 343 344 /// The main generator producing PDL. 345 Generator &generator; 346 }; 347 } // end anonymous namespace 348 349 void Generator::generate(ModuleOp module) { 350 FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 351 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 352 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 353 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 354 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 355 356 // Allocate memory indices for the results of operations within the matcher 357 // and rewriters. 358 allocateMemoryIndices(matcherFunc, rewriterModule); 359 360 // Generate code for the rewriter functions. 361 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 362 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 363 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 364 for (Operation &op : rewriterFunc.getOps()) 365 generate(&op, rewriterByteCodeWriter); 366 } 367 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 368 "unexpected branches in rewriter function"); 369 370 // Generate code for the matcher function. 371 DenseMap<Block *, ByteCodeAddr> blockToAddr; 372 llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody()); 373 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 374 for (Block *block : rpot) { 375 // Keep track of where this block begins within the matcher function. 376 blockToAddr.try_emplace(block, matcherByteCode.size()); 377 for (Operation &op : *block) 378 generate(&op, matcherByteCodeWriter); 379 } 380 381 // Resolve successor references in the matcher. 382 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 383 ByteCodeAddr addr = blockToAddr[it.first]; 384 for (unsigned offsetToFix : it.second) 385 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 386 } 387 } 388 389 void Generator::allocateMemoryIndices(FuncOp matcherFunc, 390 ModuleOp rewriterModule) { 391 // Rewriters use simplistic allocation scheme that simply assigns an index to 392 // each result. 393 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 394 ByteCodeField index = 0; 395 for (BlockArgument arg : rewriterFunc.getArguments()) 396 valueToMemIndex.try_emplace(arg, index++); 397 rewriterFunc.getBody().walk([&](Operation *op) { 398 for (Value result : op->getResults()) 399 valueToMemIndex.try_emplace(result, index++); 400 }); 401 if (index > maxValueMemoryIndex) 402 maxValueMemoryIndex = index; 403 } 404 405 // The matcher function uses a more sophisticated numbering that tries to 406 // minimize the number of memory indices assigned. This is done by determining 407 // a live range of the values within the matcher, then the allocation is just 408 // finding the minimal number of overlapping live ranges. This is essentially 409 // a simplified form of register allocation where we don't necessarily have a 410 // limited number of registers, but we still want to minimize the number used. 411 DenseMap<Operation *, ByteCodeField> opToIndex; 412 matcherFunc.getBody().walk([&](Operation *op) { 413 opToIndex.insert(std::make_pair(op, opToIndex.size())); 414 }); 415 416 // Liveness info for each of the defs within the matcher. 417 using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>; 418 LivenessSet::Allocator allocator; 419 DenseMap<Value, LivenessSet> valueDefRanges; 420 421 // Assign the root operation being matched to slot 0. 422 BlockArgument rootOpArg = matcherFunc.getArgument(0); 423 valueToMemIndex[rootOpArg] = 0; 424 425 // Walk each of the blocks, computing the def interval that the value is used. 426 Liveness matcherLiveness(matcherFunc); 427 for (Block &block : matcherFunc.getBody()) { 428 const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); 429 assert(info && "expected liveness info for block"); 430 auto processValue = [&](Value value, Operation *firstUseOrDef) { 431 // We don't need to process the root op argument, this value is always 432 // assigned to the first memory slot. 433 if (value == rootOpArg) 434 return; 435 436 // Set indices for the range of this block that the value is used. 437 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 438 defRangeIt->second.insert( 439 opToIndex[firstUseOrDef], 440 opToIndex[info->getEndOperation(value, firstUseOrDef)], 441 /*dummyValue*/ 0); 442 }; 443 444 // Process the live-ins of this block. 445 for (Value liveIn : info->in()) 446 processValue(liveIn, &block.front()); 447 448 // Process any new defs within this block. 449 for (Operation &op : block) 450 for (Value result : op.getResults()) 451 processValue(result, &op); 452 } 453 454 // Greedily allocate memory slots using the computed def live ranges. 455 std::vector<LivenessSet> allocatedIndices; 456 for (auto &defIt : valueDefRanges) { 457 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 458 LivenessSet &defSet = defIt.second; 459 460 // Try to allocate to an existing index. 461 for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { 462 LivenessSet &existingIndex = existingIndexIt.value(); 463 llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps( 464 defIt.second, existingIndex); 465 if (overlaps.valid()) 466 continue; 467 // Union the range of the def within the existing index. 468 for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) 469 existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0); 470 memIndex = existingIndexIt.index() + 1; 471 } 472 473 // If no existing index could be used, add a new one. 474 if (memIndex == 0) { 475 allocatedIndices.emplace_back(allocator); 476 for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) 477 allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0); 478 memIndex = allocatedIndices.size(); 479 } 480 } 481 482 // Update the max number of indices. 483 ByteCodeField numMatcherIndices = allocatedIndices.size() + 1; 484 if (numMatcherIndices > maxValueMemoryIndex) 485 maxValueMemoryIndex = numMatcherIndices; 486 } 487 488 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 489 TypeSwitch<Operation *>(op) 490 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 491 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 492 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 493 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 494 pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, 495 pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp, 496 pdl_interp::CreateTypeOp, pdl_interp::EraseOp, 497 pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp, 498 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 499 pdl_interp::GetOperandOp, pdl_interp::GetResultOp, 500 pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp, 501 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, 502 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, 503 pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp, 504 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 505 [&](auto interpOp) { this->generate(interpOp, writer); }) 506 .Default([](Operation *) { 507 llvm_unreachable("unknown `pdl_interp` operation"); 508 }); 509 } 510 511 void Generator::generate(pdl_interp::ApplyConstraintOp op, 512 ByteCodeWriter &writer) { 513 assert(constraintToMemIndex.count(op.name()) && 514 "expected index for constraint function"); 515 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 516 op.constParamsAttr()); 517 writer.appendPDLValueList(op.args()); 518 writer.append(op.getSuccessors()); 519 } 520 void Generator::generate(pdl_interp::ApplyRewriteOp op, 521 ByteCodeWriter &writer) { 522 assert(externalRewriterToMemIndex.count(op.name()) && 523 "expected index for rewrite function"); 524 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 525 op.constParamsAttr(), op.root()); 526 writer.appendPDLValueList(op.args()); 527 } 528 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 529 writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); 530 } 531 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 532 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 533 } 534 void Generator::generate(pdl_interp::CheckAttributeOp op, 535 ByteCodeWriter &writer) { 536 writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 537 op.getSuccessors()); 538 } 539 void Generator::generate(pdl_interp::CheckOperandCountOp op, 540 ByteCodeWriter &writer) { 541 writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 542 op.getSuccessors()); 543 } 544 void Generator::generate(pdl_interp::CheckOperationNameOp op, 545 ByteCodeWriter &writer) { 546 writer.append(OpCode::CheckOperationName, op.operation(), 547 OperationName(op.name(), ctx), op.getSuccessors()); 548 } 549 void Generator::generate(pdl_interp::CheckResultCountOp op, 550 ByteCodeWriter &writer) { 551 writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 552 op.getSuccessors()); 553 } 554 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 555 writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 556 } 557 void Generator::generate(pdl_interp::CreateAttributeOp op, 558 ByteCodeWriter &writer) { 559 // Simply repoint the memory index of the result to the constant. 560 getMemIndex(op.attribute()) = getMemIndex(op.value()); 561 } 562 void Generator::generate(pdl_interp::CreateNativeOp op, 563 ByteCodeWriter &writer) { 564 assert(nativeCreateToMemIndex.count(op.name()) && 565 "expected index for creation function"); 566 writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()], 567 op.result(), op.constParamsAttr()); 568 writer.appendPDLValueList(op.args()); 569 } 570 void Generator::generate(pdl_interp::CreateOperationOp op, 571 ByteCodeWriter &writer) { 572 writer.append(OpCode::CreateOperation, op.operation(), 573 OperationName(op.name(), ctx), op.operands()); 574 575 // Add the attributes. 576 OperandRange attributes = op.attributes(); 577 writer.append(static_cast<ByteCodeField>(attributes.size())); 578 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 579 writer.append( 580 Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), 581 std::get<1>(it)); 582 } 583 writer.append(op.types()); 584 } 585 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 586 // Simply repoint the memory index of the result to the constant. 587 getMemIndex(op.result()) = getMemIndex(op.value()); 588 } 589 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 590 writer.append(OpCode::EraseOp, op.operation()); 591 } 592 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 593 writer.append(OpCode::Finalize); 594 } 595 void Generator::generate(pdl_interp::GetAttributeOp op, 596 ByteCodeWriter &writer) { 597 writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 598 Identifier::get(op.name(), ctx)); 599 } 600 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 601 ByteCodeWriter &writer) { 602 writer.append(OpCode::GetAttributeType, op.result(), op.value()); 603 } 604 void Generator::generate(pdl_interp::GetDefiningOpOp op, 605 ByteCodeWriter &writer) { 606 writer.append(OpCode::GetDefiningOp, op.operation(), op.value()); 607 } 608 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 609 uint32_t index = op.index(); 610 if (index < 4) 611 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 612 else 613 writer.append(OpCode::GetOperandN, index); 614 writer.append(op.operation(), op.value()); 615 } 616 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 617 uint32_t index = op.index(); 618 if (index < 4) 619 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 620 else 621 writer.append(OpCode::GetResultN, index); 622 writer.append(op.operation(), op.value()); 623 } 624 void Generator::generate(pdl_interp::GetValueTypeOp op, 625 ByteCodeWriter &writer) { 626 writer.append(OpCode::GetValueType, op.result(), op.value()); 627 } 628 void Generator::generate(pdl_interp::InferredTypeOp op, 629 ByteCodeWriter &writer) { 630 // InferType maps to a null type as a marker for inferring a result type. 631 getMemIndex(op.type()) = getMemIndex(Type()); 632 } 633 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 634 writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 635 } 636 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 637 ByteCodeField patternIndex = patterns.size(); 638 patterns.emplace_back(PDLByteCodePattern::create( 639 op, rewriterToAddr[op.rewriter().getLeafReference()])); 640 writer.append(OpCode::RecordMatch, patternIndex, 641 SuccessorRange(op.getOperation()), op.matchedOps(), 642 op.inputs()); 643 } 644 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 645 writer.append(OpCode::ReplaceOp, op.operation(), op.replValues()); 646 } 647 void Generator::generate(pdl_interp::SwitchAttributeOp op, 648 ByteCodeWriter &writer) { 649 writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 650 op.getSuccessors()); 651 } 652 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 653 ByteCodeWriter &writer) { 654 writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 655 op.getSuccessors()); 656 } 657 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 658 ByteCodeWriter &writer) { 659 auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 660 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 661 }); 662 writer.append(OpCode::SwitchOperationName, op.operation(), cases, 663 op.getSuccessors()); 664 } 665 void Generator::generate(pdl_interp::SwitchResultCountOp op, 666 ByteCodeWriter &writer) { 667 writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 668 op.getSuccessors()); 669 } 670 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 671 writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 672 op.getSuccessors()); 673 } 674 675 //===----------------------------------------------------------------------===// 676 // PDLByteCode 677 //===----------------------------------------------------------------------===// 678 679 PDLByteCode::PDLByteCode(ModuleOp module, 680 llvm::StringMap<PDLConstraintFunction> constraintFns, 681 llvm::StringMap<PDLCreateFunction> createFns, 682 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 683 Generator generator(module.getContext(), uniquedData, matcherByteCode, 684 rewriterByteCode, patterns, maxValueMemoryIndex, 685 constraintFns, createFns, rewriteFns); 686 generator.generate(module); 687 688 // Initialize the external functions. 689 for (auto &it : constraintFns) 690 constraintFunctions.push_back(std::move(it.second)); 691 for (auto &it : createFns) 692 createFunctions.push_back(std::move(it.second)); 693 for (auto &it : rewriteFns) 694 rewriteFunctions.push_back(std::move(it.second)); 695 } 696 697 /// Initialize the given state such that it can be used to execute the current 698 /// bytecode. 699 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 700 state.memory.resize(maxValueMemoryIndex, nullptr); 701 state.currentPatternBenefits.reserve(patterns.size()); 702 for (const PDLByteCodePattern &pattern : patterns) 703 state.currentPatternBenefits.push_back(pattern.getBenefit()); 704 } 705 706 //===----------------------------------------------------------------------===// 707 // ByteCode Execution 708 709 namespace { 710 /// This class provides support for executing a bytecode stream. 711 class ByteCodeExecutor { 712 public: 713 ByteCodeExecutor(const ByteCodeField *curCodeIt, 714 MutableArrayRef<const void *> memory, 715 ArrayRef<const void *> uniquedMemory, 716 ArrayRef<ByteCodeField> code, 717 ArrayRef<PatternBenefit> currentPatternBenefits, 718 ArrayRef<PDLByteCodePattern> patterns, 719 ArrayRef<PDLConstraintFunction> constraintFunctions, 720 ArrayRef<PDLCreateFunction> createFunctions, 721 ArrayRef<PDLRewriteFunction> rewriteFunctions) 722 : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), 723 code(code), currentPatternBenefits(currentPatternBenefits), 724 patterns(patterns), constraintFunctions(constraintFunctions), 725 createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {} 726 727 /// Start executing the code at the current bytecode index. `matches` is an 728 /// optional field provided when this function is executed in a matching 729 /// context. 730 void execute(PatternRewriter &rewriter, 731 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 732 Optional<Location> mainRewriteLoc = {}); 733 734 private: 735 /// Read a value from the bytecode buffer, optionally skipping a certain 736 /// number of prefix values. These methods always update the buffer to point 737 /// to the next field after the read data. 738 template <typename T = ByteCodeField> 739 T read(size_t skipN = 0) { 740 curCodeIt += skipN; 741 return readImpl<T>(); 742 } 743 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 744 745 /// Read a list of values from the bytecode buffer. 746 template <typename ValueT, typename T> 747 void readList(SmallVectorImpl<T> &list) { 748 list.clear(); 749 for (unsigned i = 0, e = read(); i != e; ++i) 750 list.push_back(read<ValueT>()); 751 } 752 753 /// Jump to a specific successor based on a predicate value. 754 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 755 /// Jump to a specific successor based on a destination index. 756 void selectJump(size_t destIndex) { 757 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 758 } 759 760 /// Handle a switch operation with the provided value and cases. 761 template <typename T, typename RangeT> 762 void handleSwitch(const T &value, RangeT &&cases) { 763 LLVM_DEBUG({ 764 llvm::dbgs() << " * Value: " << value << "\n" 765 << " * Cases: "; 766 llvm::interleaveComma(cases, llvm::dbgs()); 767 llvm::dbgs() << "\n\n"; 768 }); 769 770 // Check to see if the attribute value is within the case list. Jump to 771 // the correct successor index based on the result. 772 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 773 if (*it == value) 774 return selectJump(size_t((it - cases.begin()) + 1)); 775 selectJump(size_t(0)); 776 } 777 778 /// Internal implementation of reading various data types from the bytecode 779 /// stream. 780 template <typename T> 781 const void *readFromMemory() { 782 size_t index = *curCodeIt++; 783 784 // If this type is an SSA value, it can only be stored in non-const memory. 785 if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size()) 786 return memory[index]; 787 788 // Otherwise, if this index is not inbounds it is uniqued. 789 return uniquedMemory[index - memory.size()]; 790 } 791 template <typename T> 792 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 793 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 794 } 795 template <typename T> 796 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 797 T> 798 readImpl() { 799 return T(T::getFromOpaquePointer(readFromMemory<T>())); 800 } 801 template <typename T> 802 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 803 switch (static_cast<PDLValueKind>(read())) { 804 case PDLValueKind::Attribute: 805 return read<Attribute>(); 806 case PDLValueKind::Operation: 807 return read<Operation *>(); 808 case PDLValueKind::Type: 809 return read<Type>(); 810 case PDLValueKind::Value: 811 return read<Value>(); 812 } 813 llvm_unreachable("unhandled PDLValueKind"); 814 } 815 template <typename T> 816 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 817 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 818 "unexpected ByteCode address size"); 819 ByteCodeAddr result; 820 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 821 curCodeIt += 2; 822 return result; 823 } 824 template <typename T> 825 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 826 return *curCodeIt++; 827 } 828 829 /// The underlying bytecode buffer. 830 const ByteCodeField *curCodeIt; 831 832 /// The current execution memory. 833 MutableArrayRef<const void *> memory; 834 835 /// References to ByteCode data necessary for execution. 836 ArrayRef<const void *> uniquedMemory; 837 ArrayRef<ByteCodeField> code; 838 ArrayRef<PatternBenefit> currentPatternBenefits; 839 ArrayRef<PDLByteCodePattern> patterns; 840 ArrayRef<PDLConstraintFunction> constraintFunctions; 841 ArrayRef<PDLCreateFunction> createFunctions; 842 ArrayRef<PDLRewriteFunction> rewriteFunctions; 843 }; 844 } // end anonymous namespace 845 846 void ByteCodeExecutor::execute( 847 PatternRewriter &rewriter, 848 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 849 Optional<Location> mainRewriteLoc) { 850 while (true) { 851 OpCode opCode = static_cast<OpCode>(read()); 852 switch (opCode) { 853 case ApplyConstraint: { 854 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 855 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 856 ArrayAttr constParams = read<ArrayAttr>(); 857 SmallVector<PDLValue, 16> args; 858 readList<PDLValue>(args); 859 LLVM_DEBUG({ 860 llvm::dbgs() << " * Arguments: "; 861 llvm::interleaveComma(args, llvm::dbgs()); 862 llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; 863 }); 864 865 // Invoke the constraint and jump to the proper destination. 866 selectJump(succeeded(constraintFn(args, constParams, rewriter))); 867 break; 868 } 869 case ApplyRewrite: { 870 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 871 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 872 ArrayAttr constParams = read<ArrayAttr>(); 873 Operation *root = read<Operation *>(); 874 SmallVector<PDLValue, 16> args; 875 readList<PDLValue>(args); 876 877 LLVM_DEBUG({ 878 llvm::dbgs() << " * Root: " << *root << "\n" 879 << " * Arguments: "; 880 llvm::interleaveComma(args, llvm::dbgs()); 881 llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; 882 }); 883 rewriteFn(root, args, constParams, rewriter); 884 break; 885 } 886 case AreEqual: { 887 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 888 const void *lhs = read<const void *>(); 889 const void *rhs = read<const void *>(); 890 891 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 892 selectJump(lhs == rhs); 893 break; 894 } 895 case Branch: { 896 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n"); 897 curCodeIt = &code[read<ByteCodeAddr>()]; 898 break; 899 } 900 case CheckOperandCount: { 901 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 902 Operation *op = read<Operation *>(); 903 uint32_t expectedCount = read<uint32_t>(); 904 905 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 906 << " * Expected: " << expectedCount << "\n\n"); 907 selectJump(op->getNumOperands() == expectedCount); 908 break; 909 } 910 case CheckOperationName: { 911 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 912 Operation *op = read<Operation *>(); 913 OperationName expectedName = read<OperationName>(); 914 915 LLVM_DEBUG(llvm::dbgs() 916 << " * Found: \"" << op->getName() << "\"\n" 917 << " * Expected: \"" << expectedName << "\"\n\n"); 918 selectJump(op->getName() == expectedName); 919 break; 920 } 921 case CheckResultCount: { 922 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 923 Operation *op = read<Operation *>(); 924 uint32_t expectedCount = read<uint32_t>(); 925 926 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 927 << " * Expected: " << expectedCount << "\n\n"); 928 selectJump(op->getNumResults() == expectedCount); 929 break; 930 } 931 case CreateNative: { 932 LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n"); 933 const PDLCreateFunction &createFn = createFunctions[read()]; 934 ByteCodeField resultIndex = read(); 935 ArrayAttr constParams = read<ArrayAttr>(); 936 SmallVector<PDLValue, 16> args; 937 readList<PDLValue>(args); 938 939 LLVM_DEBUG({ 940 llvm::dbgs() << " * Arguments: "; 941 llvm::interleaveComma(args, llvm::dbgs()); 942 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 943 }); 944 945 PDLValue result = createFn(args, constParams, rewriter); 946 memory[resultIndex] = result.getAsOpaquePointer(); 947 948 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n"); 949 break; 950 } 951 case CreateOperation: { 952 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 953 assert(mainRewriteLoc && "expected rewrite loc to be provided when " 954 "executing the rewriter bytecode"); 955 956 unsigned memIndex = read(); 957 OperationState state(*mainRewriteLoc, read<OperationName>()); 958 readList<Value>(state.operands); 959 for (unsigned i = 0, e = read(); i != e; ++i) { 960 Identifier name = read<Identifier>(); 961 if (Attribute attr = read<Attribute>()) 962 state.addAttribute(name, attr); 963 } 964 965 bool hasInferredTypes = false; 966 for (unsigned i = 0, e = read(); i != e; ++i) { 967 Type resultType = read<Type>(); 968 hasInferredTypes |= !resultType; 969 state.types.push_back(resultType); 970 } 971 972 // Handle the case where the operation has inferred types. 973 if (hasInferredTypes) { 974 InferTypeOpInterface::Concept *concept = 975 state.name.getAbstractOperation() 976 ->getInterface<InferTypeOpInterface>(); 977 978 // TODO: Handle failure. 979 SmallVector<Type, 2> inferredTypes; 980 if (failed(concept->inferReturnTypes( 981 state.getContext(), state.location, state.operands, 982 state.attributes.getDictionary(state.getContext()), 983 state.regions, inferredTypes))) 984 return; 985 986 for (unsigned i = 0, e = state.types.size(); i != e; ++i) 987 if (!state.types[i]) 988 state.types[i] = inferredTypes[i]; 989 } 990 Operation *resultOp = rewriter.createOperation(state); 991 memory[memIndex] = resultOp; 992 993 LLVM_DEBUG({ 994 llvm::dbgs() << " * Attributes: " 995 << state.attributes.getDictionary(state.getContext()) 996 << "\n * Operands: "; 997 llvm::interleaveComma(state.operands, llvm::dbgs()); 998 llvm::dbgs() << "\n * Result Types: "; 999 llvm::interleaveComma(state.types, llvm::dbgs()); 1000 llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n"; 1001 }); 1002 break; 1003 } 1004 case EraseOp: { 1005 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1006 Operation *op = read<Operation *>(); 1007 1008 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n"); 1009 rewriter.eraseOp(op); 1010 break; 1011 } 1012 case Finalize: { 1013 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); 1014 return; 1015 } 1016 case GetAttribute: { 1017 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1018 unsigned memIndex = read(); 1019 Operation *op = read<Operation *>(); 1020 Identifier attrName = read<Identifier>(); 1021 Attribute attr = op->getAttr(attrName); 1022 1023 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1024 << " * Attribute: " << attrName << "\n" 1025 << " * Result: " << attr << "\n\n"); 1026 memory[memIndex] = attr.getAsOpaquePointer(); 1027 break; 1028 } 1029 case GetAttributeType: { 1030 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1031 unsigned memIndex = read(); 1032 Attribute attr = read<Attribute>(); 1033 1034 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1035 << " * Result: " << attr.getType() << "\n\n"); 1036 memory[memIndex] = attr.getType().getAsOpaquePointer(); 1037 break; 1038 } 1039 case GetDefiningOp: { 1040 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1041 unsigned memIndex = read(); 1042 Value value = read<Value>(); 1043 Operation *op = value ? value.getDefiningOp() : nullptr; 1044 1045 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1046 << " * Result: " << *op << "\n\n"); 1047 memory[memIndex] = op; 1048 break; 1049 } 1050 case GetOperand0: 1051 case GetOperand1: 1052 case GetOperand2: 1053 case GetOperand3: 1054 case GetOperandN: { 1055 LLVM_DEBUG({ 1056 llvm::dbgs() << "Executing GetOperand" 1057 << (opCode == GetOperandN ? Twine("N") 1058 : Twine(opCode - GetOperand0)) 1059 << ":\n"; 1060 }); 1061 unsigned index = 1062 opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0); 1063 Operation *op = read<Operation *>(); 1064 unsigned memIndex = read(); 1065 Value operand = 1066 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1067 1068 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1069 << " * Index: " << index << "\n" 1070 << " * Result: " << operand << "\n\n"); 1071 memory[memIndex] = operand.getAsOpaquePointer(); 1072 break; 1073 } 1074 case GetResult0: 1075 case GetResult1: 1076 case GetResult2: 1077 case GetResult3: 1078 case GetResultN: { 1079 LLVM_DEBUG({ 1080 llvm::dbgs() << "Executing GetResult" 1081 << (opCode == GetResultN ? Twine("N") 1082 : Twine(opCode - GetResult0)) 1083 << ":\n"; 1084 }); 1085 unsigned index = 1086 opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0); 1087 Operation *op = read<Operation *>(); 1088 unsigned memIndex = read(); 1089 OpResult result = 1090 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1091 1092 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1093 << " * Index: " << index << "\n" 1094 << " * Result: " << result << "\n\n"); 1095 memory[memIndex] = result.getAsOpaquePointer(); 1096 break; 1097 } 1098 case GetValueType: { 1099 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1100 unsigned memIndex = read(); 1101 Value value = read<Value>(); 1102 1103 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1104 << " * Result: " << value.getType() << "\n\n"); 1105 memory[memIndex] = value.getType().getAsOpaquePointer(); 1106 break; 1107 } 1108 case IsNotNull: { 1109 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1110 const void *value = read<const void *>(); 1111 1112 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n"); 1113 selectJump(value != nullptr); 1114 break; 1115 } 1116 case RecordMatch: { 1117 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1118 assert(matches && 1119 "expected matches to be provided when executing the matcher"); 1120 unsigned patternIndex = read(); 1121 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1122 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1123 1124 // If the benefit of the pattern is impossible, skip the processing of the 1125 // rest of the pattern. 1126 if (benefit.isImpossibleToMatch()) { 1127 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n"); 1128 curCodeIt = dest; 1129 break; 1130 } 1131 1132 // Create a fused location containing the locations of each of the 1133 // operations used in the match. This will be used as the location for 1134 // created operations during the rewrite that don't already have an 1135 // explicit location set. 1136 unsigned numMatchLocs = read(); 1137 SmallVector<Location, 4> matchLocs; 1138 matchLocs.reserve(numMatchLocs); 1139 for (unsigned i = 0; i != numMatchLocs; ++i) 1140 matchLocs.push_back(read<Operation *>()->getLoc()); 1141 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1142 1143 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1144 << " * Location: " << matchLoc << "\n\n"); 1145 matches->emplace_back(matchLoc, patterns[patternIndex], benefit); 1146 readList<const void *>(matches->back().values); 1147 curCodeIt = dest; 1148 break; 1149 } 1150 case ReplaceOp: { 1151 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1152 Operation *op = read<Operation *>(); 1153 SmallVector<Value, 16> args; 1154 readList<Value>(args); 1155 1156 LLVM_DEBUG({ 1157 llvm::dbgs() << " * Operation: " << *op << "\n" 1158 << " * Values: "; 1159 llvm::interleaveComma(args, llvm::dbgs()); 1160 llvm::dbgs() << "\n\n"; 1161 }); 1162 rewriter.replaceOp(op, args); 1163 break; 1164 } 1165 case SwitchAttribute: { 1166 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1167 Attribute value = read<Attribute>(); 1168 ArrayAttr cases = read<ArrayAttr>(); 1169 handleSwitch(value, cases); 1170 break; 1171 } 1172 case SwitchOperandCount: { 1173 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1174 Operation *op = read<Operation *>(); 1175 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1176 1177 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1178 handleSwitch(op->getNumOperands(), cases); 1179 break; 1180 } 1181 case SwitchOperationName: { 1182 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1183 OperationName value = read<Operation *>()->getName(); 1184 size_t caseCount = read(); 1185 1186 // The operation names are stored in-line, so to print them out for 1187 // debugging purposes we need to read the array before executing the 1188 // switch so that we can display all of the possible values. 1189 LLVM_DEBUG({ 1190 const ByteCodeField *prevCodeIt = curCodeIt; 1191 llvm::dbgs() << " * Value: " << value << "\n" 1192 << " * Cases: "; 1193 llvm::interleaveComma( 1194 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1195 [&](size_t i) { return read<OperationName>(); }), 1196 llvm::dbgs()); 1197 llvm::dbgs() << "\n\n"; 1198 curCodeIt = prevCodeIt; 1199 }); 1200 1201 // Try to find the switch value within any of the cases. 1202 size_t jumpDest = 0; 1203 for (size_t i = 0; i != caseCount; ++i) { 1204 if (read<OperationName>() == value) { 1205 curCodeIt += (caseCount - i - 1); 1206 jumpDest = i + 1; 1207 break; 1208 } 1209 } 1210 selectJump(jumpDest); 1211 break; 1212 } 1213 case SwitchResultCount: { 1214 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1215 Operation *op = read<Operation *>(); 1216 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1217 1218 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1219 handleSwitch(op->getNumResults(), cases); 1220 break; 1221 } 1222 case SwitchType: { 1223 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 1224 Type value = read<Type>(); 1225 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 1226 handleSwitch(value, cases); 1227 break; 1228 } 1229 } 1230 } 1231 } 1232 1233 /// Run the pattern matcher on the given root operation, collecting the matched 1234 /// patterns in `matches`. 1235 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 1236 SmallVectorImpl<MatchResult> &matches, 1237 PDLByteCodeMutableState &state) const { 1238 // The first memory slot is always the root operation. 1239 state.memory[0] = op; 1240 1241 // The matcher function always starts at code address 0. 1242 ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, 1243 matcherByteCode, state.currentPatternBenefits, 1244 patterns, constraintFunctions, createFunctions, 1245 rewriteFunctions); 1246 executor.execute(rewriter, &matches); 1247 1248 // Order the found matches by benefit. 1249 std::stable_sort(matches.begin(), matches.end(), 1250 [](const MatchResult &lhs, const MatchResult &rhs) { 1251 return lhs.benefit > rhs.benefit; 1252 }); 1253 } 1254 1255 /// Run the rewriter of the given pattern on the root operation `op`. 1256 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 1257 PDLByteCodeMutableState &state) const { 1258 // The arguments of the rewrite function are stored at the start of the 1259 // memory buffer. 1260 llvm::copy(match.values, state.memory.begin()); 1261 1262 ByteCodeExecutor executor( 1263 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 1264 uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, 1265 constraintFunctions, createFunctions, rewriteFunctions); 1266 executor.execute(rewriter, /*matches=*/nullptr, match.location); 1267 } 1268