1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements MLIR to byte-code generation and the interpreter. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "ByteCode.h" 14 #include "mlir/Analysis/Liveness.h" 15 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/RegionGraphTraits.h" 19 #include "llvm/ADT/IntervalMap.h" 20 #include "llvm/ADT/PostOrderIterator.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 24 #define DEBUG_TYPE "pdl-bytecode" 25 26 using namespace mlir; 27 using namespace mlir::detail; 28 29 //===----------------------------------------------------------------------===// 30 // PDLByteCodePattern 31 //===----------------------------------------------------------------------===// 32 33 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 34 ByteCodeAddr rewriterAddr) { 35 SmallVector<StringRef, 8> generatedOps; 36 if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 37 generatedOps = 38 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 39 40 PatternBenefit benefit = matchOp.benefit(); 41 MLIRContext *ctx = matchOp.getContext(); 42 43 // Check to see if this is pattern matches a specific operation type. 44 if (Optional<StringRef> rootKind = matchOp.rootKind()) 45 return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit, 46 ctx); 47 return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx, 48 MatchAnyOpTypeTag()); 49 } 50 51 //===----------------------------------------------------------------------===// 52 // PDLByteCodeMutableState 53 //===----------------------------------------------------------------------===// 54 55 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 56 /// to the position of the pattern within the range returned by 57 /// `PDLByteCode::getPatterns`. 58 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 59 PatternBenefit benefit) { 60 currentPatternBenefits[patternIndex] = benefit; 61 } 62 63 //===----------------------------------------------------------------------===// 64 // Bytecode OpCodes 65 //===----------------------------------------------------------------------===// 66 67 namespace { 68 enum OpCode : ByteCodeField { 69 /// Apply an externally registered constraint. 70 ApplyConstraint, 71 /// Apply an externally registered rewrite. 72 ApplyRewrite, 73 /// Check if two generic values are equal. 74 AreEqual, 75 /// Unconditional branch. 76 Branch, 77 /// Compare the operand count of an operation with a constant. 78 CheckOperandCount, 79 /// Compare the name of an operation with a constant. 80 CheckOperationName, 81 /// Compare the result count of an operation with a constant. 82 CheckResultCount, 83 /// Invoke a native creation method. 84 CreateNative, 85 /// Create an operation. 86 CreateOperation, 87 /// Erase an operation. 88 EraseOp, 89 /// Terminate a matcher or rewrite sequence. 90 Finalize, 91 /// Get a specific attribute of an operation. 92 GetAttribute, 93 /// Get the type of an attribute. 94 GetAttributeType, 95 /// Get the defining operation of a value. 96 GetDefiningOp, 97 /// Get a specific operand of an operation. 98 GetOperand0, 99 GetOperand1, 100 GetOperand2, 101 GetOperand3, 102 GetOperandN, 103 /// Get a specific result of an operation. 104 GetResult0, 105 GetResult1, 106 GetResult2, 107 GetResult3, 108 GetResultN, 109 /// Get the type of a value. 110 GetValueType, 111 /// Check if a generic value is not null. 112 IsNotNull, 113 /// Record a successful pattern match. 114 RecordMatch, 115 /// Replace an operation. 116 ReplaceOp, 117 /// Compare an attribute with a set of constants. 118 SwitchAttribute, 119 /// Compare the operand count of an operation with a set of constants. 120 SwitchOperandCount, 121 /// Compare the name of an operation with a set of constants. 122 SwitchOperationName, 123 /// Compare the result count of an operation with a set of constants. 124 SwitchResultCount, 125 /// Compare a type with a set of constants. 126 SwitchType, 127 }; 128 129 enum class PDLValueKind { Attribute, Operation, Type, Value }; 130 } // end anonymous namespace 131 132 //===----------------------------------------------------------------------===// 133 // ByteCode Generation 134 //===----------------------------------------------------------------------===// 135 136 //===----------------------------------------------------------------------===// 137 // Generator 138 139 namespace { 140 struct ByteCodeWriter; 141 142 /// This class represents the main generator for the pattern bytecode. 143 class Generator { 144 public: 145 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 146 SmallVectorImpl<ByteCodeField> &matcherByteCode, 147 SmallVectorImpl<ByteCodeField> &rewriterByteCode, 148 SmallVectorImpl<PDLByteCodePattern> &patterns, 149 ByteCodeField &maxValueMemoryIndex, 150 llvm::StringMap<PDLConstraintFunction> &constraintFns, 151 llvm::StringMap<PDLCreateFunction> &createFns, 152 llvm::StringMap<PDLRewriteFunction> &rewriteFns) 153 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 154 rewriterByteCode(rewriterByteCode), patterns(patterns), 155 maxValueMemoryIndex(maxValueMemoryIndex) { 156 for (auto it : llvm::enumerate(constraintFns)) 157 constraintToMemIndex.try_emplace(it.value().first(), it.index()); 158 for (auto it : llvm::enumerate(createFns)) 159 nativeCreateToMemIndex.try_emplace(it.value().first(), it.index()); 160 for (auto it : llvm::enumerate(rewriteFns)) 161 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 162 } 163 164 /// Generate the bytecode for the given PDL interpreter module. 165 void generate(ModuleOp module); 166 167 /// Return the memory index to use for the given value. 168 ByteCodeField &getMemIndex(Value value) { 169 assert(valueToMemIndex.count(value) && 170 "expected memory index to be assigned"); 171 return valueToMemIndex[value]; 172 } 173 174 /// Return an index to use when referring to the given data that is uniqued in 175 /// the MLIR context. 176 template <typename T> 177 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 178 getMemIndex(T val) { 179 const void *opaqueVal = val.getAsOpaquePointer(); 180 181 // Get or insert a reference to this value. 182 auto it = uniquedDataToMemIndex.try_emplace( 183 opaqueVal, maxValueMemoryIndex + uniquedData.size()); 184 if (it.second) 185 uniquedData.push_back(opaqueVal); 186 return it.first->second; 187 } 188 189 private: 190 /// Allocate memory indices for the results of operations within the matcher 191 /// and rewriters. 192 void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 193 194 /// Generate the bytecode for the given operation. 195 void generate(Operation *op, ByteCodeWriter &writer); 196 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 197 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 198 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 199 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 200 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 201 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 202 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 203 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 204 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 205 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 206 void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer); 207 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 208 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 209 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 210 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 211 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 212 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 213 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 214 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 215 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 216 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 217 void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer); 218 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 219 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 220 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 221 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 222 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 223 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 224 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 225 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 226 227 /// Mapping from value to its corresponding memory index. 228 DenseMap<Value, ByteCodeField> valueToMemIndex; 229 230 /// Mapping from the name of an externally registered rewrite to its index in 231 /// the bytecode registry. 232 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 233 234 /// Mapping from the name of an externally registered constraint to its index 235 /// in the bytecode registry. 236 llvm::StringMap<ByteCodeField> constraintToMemIndex; 237 238 /// Mapping from the name of an externally registered creation method to its 239 /// index in the bytecode registry. 240 llvm::StringMap<ByteCodeField> nativeCreateToMemIndex; 241 242 /// Mapping from rewriter function name to the bytecode address of the 243 /// rewriter function in byte. 244 llvm::StringMap<ByteCodeAddr> rewriterToAddr; 245 246 /// Mapping from a uniqued storage object to its memory index within 247 /// `uniquedData`. 248 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 249 250 /// The current MLIR context. 251 MLIRContext *ctx; 252 253 /// Data of the ByteCode class to be populated. 254 std::vector<const void *> &uniquedData; 255 SmallVectorImpl<ByteCodeField> &matcherByteCode; 256 SmallVectorImpl<ByteCodeField> &rewriterByteCode; 257 SmallVectorImpl<PDLByteCodePattern> &patterns; 258 ByteCodeField &maxValueMemoryIndex; 259 }; 260 261 /// This class provides utilities for writing a bytecode stream. 262 struct ByteCodeWriter { 263 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 264 : bytecode(bytecode), generator(generator) {} 265 266 /// Append a field to the bytecode. 267 void append(ByteCodeField field) { bytecode.push_back(field); } 268 void append(OpCode opCode) { bytecode.push_back(opCode); } 269 270 /// Append an address to the bytecode. 271 void append(ByteCodeAddr field) { 272 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 273 "unexpected ByteCode address size"); 274 275 ByteCodeField fieldParts[2]; 276 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 277 bytecode.append({fieldParts[0], fieldParts[1]}); 278 } 279 280 /// Append a successor range to the bytecode, the exact address will need to 281 /// be resolved later. 282 void append(SuccessorRange successors) { 283 // Add back references to the any successors so that the address can be 284 // resolved later. 285 for (Block *successor : successors) { 286 unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 287 append(ByteCodeAddr(0)); 288 } 289 } 290 291 /// Append a range of values that will be read as generic PDLValues. 292 void appendPDLValueList(OperandRange values) { 293 bytecode.push_back(values.size()); 294 for (Value value : values) { 295 // Append the type of the value in addition to the value itself. 296 PDLValueKind kind = 297 TypeSwitch<Type, PDLValueKind>(value.getType()) 298 .Case<pdl::AttributeType>( 299 [](Type) { return PDLValueKind::Attribute; }) 300 .Case<pdl::OperationType>( 301 [](Type) { return PDLValueKind::Operation; }) 302 .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; }) 303 .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; }); 304 bytecode.push_back(static_cast<ByteCodeField>(kind)); 305 append(value); 306 } 307 } 308 309 /// Check if the given class `T` has an iterator type. 310 template <typename T, typename... Args> 311 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 312 313 /// Append a value that will be stored in a memory slot and not inline within 314 /// the bytecode. 315 template <typename T> 316 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 317 std::is_pointer<T>::value> 318 append(T value) { 319 bytecode.push_back(generator.getMemIndex(value)); 320 } 321 322 /// Append a range of values. 323 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 324 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 325 append(T range) { 326 bytecode.push_back(llvm::size(range)); 327 for (auto it : range) 328 append(it); 329 } 330 331 /// Append a variadic number of fields to the bytecode. 332 template <typename FieldTy, typename Field2Ty, typename... FieldTys> 333 void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 334 append(field); 335 append(field2, fields...); 336 } 337 338 /// Successor references in the bytecode that have yet to be resolved. 339 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 340 341 /// The underlying bytecode buffer. 342 SmallVectorImpl<ByteCodeField> &bytecode; 343 344 /// The main generator producing PDL. 345 Generator &generator; 346 }; 347 } // end anonymous namespace 348 349 void Generator::generate(ModuleOp module) { 350 FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 351 pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 352 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 353 pdl_interp::PDLInterpDialect::getRewriterModuleName()); 354 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 355 356 // Allocate memory indices for the results of operations within the matcher 357 // and rewriters. 358 allocateMemoryIndices(matcherFunc, rewriterModule); 359 360 // Generate code for the rewriter functions. 361 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 362 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 363 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 364 for (Operation &op : rewriterFunc.getOps()) 365 generate(&op, rewriterByteCodeWriter); 366 } 367 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 368 "unexpected branches in rewriter function"); 369 370 // Generate code for the matcher function. 371 DenseMap<Block *, ByteCodeAddr> blockToAddr; 372 llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody()); 373 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 374 for (Block *block : rpot) { 375 // Keep track of where this block begins within the matcher function. 376 blockToAddr.try_emplace(block, matcherByteCode.size()); 377 for (Operation &op : *block) 378 generate(&op, matcherByteCodeWriter); 379 } 380 381 // Resolve successor references in the matcher. 382 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 383 ByteCodeAddr addr = blockToAddr[it.first]; 384 for (unsigned offsetToFix : it.second) 385 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 386 } 387 } 388 389 void Generator::allocateMemoryIndices(FuncOp matcherFunc, 390 ModuleOp rewriterModule) { 391 // Rewriters use simplistic allocation scheme that simply assigns an index to 392 // each result. 393 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 394 ByteCodeField index = 0; 395 for (BlockArgument arg : rewriterFunc.getArguments()) 396 valueToMemIndex.try_emplace(arg, index++); 397 rewriterFunc.getBody().walk([&](Operation *op) { 398 for (Value result : op->getResults()) 399 valueToMemIndex.try_emplace(result, index++); 400 }); 401 if (index > maxValueMemoryIndex) 402 maxValueMemoryIndex = index; 403 } 404 405 // The matcher function uses a more sophisticated numbering that tries to 406 // minimize the number of memory indices assigned. This is done by determining 407 // a live range of the values within the matcher, then the allocation is just 408 // finding the minimal number of overlapping live ranges. This is essentially 409 // a simplified form of register allocation where we don't necessarily have a 410 // limited number of registers, but we still want to minimize the number used. 411 DenseMap<Operation *, ByteCodeField> opToIndex; 412 matcherFunc.getBody().walk([&](Operation *op) { 413 opToIndex.insert(std::make_pair(op, opToIndex.size())); 414 }); 415 416 // Liveness info for each of the defs within the matcher. 417 using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>; 418 LivenessSet::Allocator allocator; 419 DenseMap<Value, LivenessSet> valueDefRanges; 420 421 // Assign the root operation being matched to slot 0. 422 BlockArgument rootOpArg = matcherFunc.getArgument(0); 423 valueToMemIndex[rootOpArg] = 0; 424 425 // Walk each of the blocks, computing the def interval that the value is used. 426 Liveness matcherLiveness(matcherFunc); 427 for (Block &block : matcherFunc.getBody()) { 428 const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); 429 assert(info && "expected liveness info for block"); 430 auto processValue = [&](Value value, Operation *firstUseOrDef) { 431 // We don't need to process the root op argument, this value is always 432 // assigned to the first memory slot. 433 if (value == rootOpArg) 434 return; 435 436 // Set indices for the range of this block that the value is used. 437 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 438 defRangeIt->second.insert( 439 opToIndex[firstUseOrDef], 440 opToIndex[info->getEndOperation(value, firstUseOrDef)], 441 /*dummyValue*/ 0); 442 }; 443 444 // Process the live-ins of this block. 445 for (Value liveIn : info->in()) 446 processValue(liveIn, &block.front()); 447 448 // Process any new defs within this block. 449 for (Operation &op : block) 450 for (Value result : op.getResults()) 451 processValue(result, &op); 452 } 453 454 // Greedily allocate memory slots using the computed def live ranges. 455 std::vector<LivenessSet> allocatedIndices; 456 for (auto &defIt : valueDefRanges) { 457 ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 458 LivenessSet &defSet = defIt.second; 459 460 // Try to allocate to an existing index. 461 for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { 462 LivenessSet &existingIndex = existingIndexIt.value(); 463 llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps( 464 defIt.second, existingIndex); 465 if (overlaps.valid()) 466 continue; 467 // Union the range of the def within the existing index. 468 for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) 469 existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0); 470 memIndex = existingIndexIt.index() + 1; 471 } 472 473 // If no existing index could be used, add a new one. 474 if (memIndex == 0) { 475 allocatedIndices.emplace_back(allocator); 476 for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) 477 allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0); 478 memIndex = allocatedIndices.size(); 479 } 480 } 481 482 // Update the max number of indices. 483 ByteCodeField numMatcherIndices = allocatedIndices.size() + 1; 484 if (numMatcherIndices > maxValueMemoryIndex) 485 maxValueMemoryIndex = numMatcherIndices; 486 } 487 488 void Generator::generate(Operation *op, ByteCodeWriter &writer) { 489 TypeSwitch<Operation *>(op) 490 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 491 pdl_interp::AreEqualOp, pdl_interp::BranchOp, 492 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 493 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 494 pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, 495 pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp, 496 pdl_interp::CreateTypeOp, pdl_interp::EraseOp, 497 pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp, 498 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 499 pdl_interp::GetOperandOp, pdl_interp::GetResultOp, 500 pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp, 501 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, 502 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, 503 pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp, 504 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 505 [&](auto interpOp) { this->generate(interpOp, writer); }) 506 .Default([](Operation *) { 507 llvm_unreachable("unknown `pdl_interp` operation"); 508 }); 509 } 510 511 void Generator::generate(pdl_interp::ApplyConstraintOp op, 512 ByteCodeWriter &writer) { 513 assert(constraintToMemIndex.count(op.name()) && 514 "expected index for constraint function"); 515 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 516 op.constParamsAttr()); 517 writer.appendPDLValueList(op.args()); 518 writer.append(op.getSuccessors()); 519 } 520 void Generator::generate(pdl_interp::ApplyRewriteOp op, 521 ByteCodeWriter &writer) { 522 assert(externalRewriterToMemIndex.count(op.name()) && 523 "expected index for rewrite function"); 524 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 525 op.constParamsAttr(), op.root()); 526 writer.appendPDLValueList(op.args()); 527 } 528 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 529 writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); 530 } 531 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 532 writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 533 } 534 void Generator::generate(pdl_interp::CheckAttributeOp op, 535 ByteCodeWriter &writer) { 536 writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 537 op.getSuccessors()); 538 } 539 void Generator::generate(pdl_interp::CheckOperandCountOp op, 540 ByteCodeWriter &writer) { 541 writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 542 op.getSuccessors()); 543 } 544 void Generator::generate(pdl_interp::CheckOperationNameOp op, 545 ByteCodeWriter &writer) { 546 writer.append(OpCode::CheckOperationName, op.operation(), 547 OperationName(op.name(), ctx), op.getSuccessors()); 548 } 549 void Generator::generate(pdl_interp::CheckResultCountOp op, 550 ByteCodeWriter &writer) { 551 writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 552 op.getSuccessors()); 553 } 554 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 555 writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 556 } 557 void Generator::generate(pdl_interp::CreateAttributeOp op, 558 ByteCodeWriter &writer) { 559 // Simply repoint the memory index of the result to the constant. 560 getMemIndex(op.attribute()) = getMemIndex(op.value()); 561 } 562 void Generator::generate(pdl_interp::CreateNativeOp op, 563 ByteCodeWriter &writer) { 564 assert(nativeCreateToMemIndex.count(op.name()) && 565 "expected index for creation function"); 566 writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()], 567 op.result(), op.constParamsAttr()); 568 writer.appendPDLValueList(op.args()); 569 } 570 void Generator::generate(pdl_interp::CreateOperationOp op, 571 ByteCodeWriter &writer) { 572 writer.append(OpCode::CreateOperation, op.operation(), 573 OperationName(op.name(), ctx), op.operands()); 574 575 // Add the attributes. 576 OperandRange attributes = op.attributes(); 577 writer.append(static_cast<ByteCodeField>(attributes.size())); 578 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 579 writer.append( 580 Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), 581 std::get<1>(it)); 582 } 583 writer.append(op.types()); 584 } 585 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 586 // Simply repoint the memory index of the result to the constant. 587 getMemIndex(op.result()) = getMemIndex(op.value()); 588 } 589 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 590 writer.append(OpCode::EraseOp, op.operation()); 591 } 592 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 593 writer.append(OpCode::Finalize); 594 } 595 void Generator::generate(pdl_interp::GetAttributeOp op, 596 ByteCodeWriter &writer) { 597 writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 598 Identifier::get(op.name(), ctx)); 599 } 600 void Generator::generate(pdl_interp::GetAttributeTypeOp op, 601 ByteCodeWriter &writer) { 602 writer.append(OpCode::GetAttributeType, op.result(), op.value()); 603 } 604 void Generator::generate(pdl_interp::GetDefiningOpOp op, 605 ByteCodeWriter &writer) { 606 writer.append(OpCode::GetDefiningOp, op.operation(), op.value()); 607 } 608 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 609 uint32_t index = op.index(); 610 if (index < 4) 611 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 612 else 613 writer.append(OpCode::GetOperandN, index); 614 writer.append(op.operation(), op.value()); 615 } 616 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 617 uint32_t index = op.index(); 618 if (index < 4) 619 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 620 else 621 writer.append(OpCode::GetResultN, index); 622 writer.append(op.operation(), op.value()); 623 } 624 void Generator::generate(pdl_interp::GetValueTypeOp op, 625 ByteCodeWriter &writer) { 626 writer.append(OpCode::GetValueType, op.result(), op.value()); 627 } 628 void Generator::generate(pdl_interp::InferredTypeOp op, 629 ByteCodeWriter &writer) { 630 // InferType maps to a null type as a marker for inferring a result type. 631 getMemIndex(op.type()) = getMemIndex(Type()); 632 } 633 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 634 writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 635 } 636 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 637 ByteCodeField patternIndex = patterns.size(); 638 patterns.emplace_back(PDLByteCodePattern::create( 639 op, rewriterToAddr[op.rewriter().getLeafReference()])); 640 writer.append(OpCode::RecordMatch, patternIndex, 641 SuccessorRange(op.getOperation()), op.matchedOps(), 642 op.inputs()); 643 } 644 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 645 writer.append(OpCode::ReplaceOp, op.operation(), op.replValues()); 646 } 647 void Generator::generate(pdl_interp::SwitchAttributeOp op, 648 ByteCodeWriter &writer) { 649 writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 650 op.getSuccessors()); 651 } 652 void Generator::generate(pdl_interp::SwitchOperandCountOp op, 653 ByteCodeWriter &writer) { 654 writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 655 op.getSuccessors()); 656 } 657 void Generator::generate(pdl_interp::SwitchOperationNameOp op, 658 ByteCodeWriter &writer) { 659 auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 660 return OperationName(attr.cast<StringAttr>().getValue(), ctx); 661 }); 662 writer.append(OpCode::SwitchOperationName, op.operation(), cases, 663 op.getSuccessors()); 664 } 665 void Generator::generate(pdl_interp::SwitchResultCountOp op, 666 ByteCodeWriter &writer) { 667 writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 668 op.getSuccessors()); 669 } 670 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 671 writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 672 op.getSuccessors()); 673 } 674 675 //===----------------------------------------------------------------------===// 676 // PDLByteCode 677 //===----------------------------------------------------------------------===// 678 679 PDLByteCode::PDLByteCode(ModuleOp module, 680 llvm::StringMap<PDLConstraintFunction> constraintFns, 681 llvm::StringMap<PDLCreateFunction> createFns, 682 llvm::StringMap<PDLRewriteFunction> rewriteFns) { 683 Generator generator(module.getContext(), uniquedData, matcherByteCode, 684 rewriterByteCode, patterns, maxValueMemoryIndex, 685 constraintFns, createFns, rewriteFns); 686 generator.generate(module); 687 688 // Initialize the external functions. 689 for (auto &it : constraintFns) 690 constraintFunctions.push_back(std::move(it.second)); 691 for (auto &it : createFns) 692 createFunctions.push_back(std::move(it.second)); 693 for (auto &it : rewriteFns) 694 rewriteFunctions.push_back(std::move(it.second)); 695 } 696 697 /// Initialize the given state such that it can be used to execute the current 698 /// bytecode. 699 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 700 state.memory.resize(maxValueMemoryIndex, nullptr); 701 state.currentPatternBenefits.reserve(patterns.size()); 702 for (const PDLByteCodePattern &pattern : patterns) 703 state.currentPatternBenefits.push_back(pattern.getBenefit()); 704 } 705 706 //===----------------------------------------------------------------------===// 707 // ByteCode Execution 708 709 namespace { 710 /// This class provides support for executing a bytecode stream. 711 class ByteCodeExecutor { 712 public: 713 ByteCodeExecutor(const ByteCodeField *curCodeIt, 714 MutableArrayRef<const void *> memory, 715 ArrayRef<const void *> uniquedMemory, 716 ArrayRef<ByteCodeField> code, 717 ArrayRef<PatternBenefit> currentPatternBenefits, 718 ArrayRef<PDLByteCodePattern> patterns, 719 ArrayRef<PDLConstraintFunction> constraintFunctions, 720 ArrayRef<PDLCreateFunction> createFunctions, 721 ArrayRef<PDLRewriteFunction> rewriteFunctions) 722 : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), 723 code(code), currentPatternBenefits(currentPatternBenefits), 724 patterns(patterns), constraintFunctions(constraintFunctions), 725 createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {} 726 727 /// Start executing the code at the current bytecode index. `matches` is an 728 /// optional field provided when this function is executed in a matching 729 /// context. 730 void execute(PatternRewriter &rewriter, 731 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 732 Optional<Location> mainRewriteLoc = {}); 733 734 private: 735 /// Internal implementation of executing each of the bytecode commands. 736 void executeApplyConstraint(PatternRewriter &rewriter); 737 void executeApplyRewrite(PatternRewriter &rewriter); 738 void executeAreEqual(); 739 void executeBranch(); 740 void executeCheckOperandCount(); 741 void executeCheckOperationName(); 742 void executeCheckResultCount(); 743 void executeCreateNative(PatternRewriter &rewriter); 744 void executeCreateOperation(PatternRewriter &rewriter, 745 Location mainRewriteLoc); 746 void executeEraseOp(PatternRewriter &rewriter); 747 void executeGetAttribute(); 748 void executeGetAttributeType(); 749 void executeGetDefiningOp(); 750 void executeGetOperand(unsigned index); 751 void executeGetResult(unsigned index); 752 void executeGetValueType(); 753 void executeIsNotNull(); 754 void executeRecordMatch(PatternRewriter &rewriter, 755 SmallVectorImpl<PDLByteCode::MatchResult> &matches); 756 void executeReplaceOp(PatternRewriter &rewriter); 757 void executeSwitchAttribute(); 758 void executeSwitchOperandCount(); 759 void executeSwitchOperationName(); 760 void executeSwitchResultCount(); 761 void executeSwitchType(); 762 763 /// Read a value from the bytecode buffer, optionally skipping a certain 764 /// number of prefix values. These methods always update the buffer to point 765 /// to the next field after the read data. 766 template <typename T = ByteCodeField> 767 T read(size_t skipN = 0) { 768 curCodeIt += skipN; 769 return readImpl<T>(); 770 } 771 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 772 773 /// Read a list of values from the bytecode buffer. 774 template <typename ValueT, typename T> 775 void readList(SmallVectorImpl<T> &list) { 776 list.clear(); 777 for (unsigned i = 0, e = read(); i != e; ++i) 778 list.push_back(read<ValueT>()); 779 } 780 781 /// Jump to a specific successor based on a predicate value. 782 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 783 /// Jump to a specific successor based on a destination index. 784 void selectJump(size_t destIndex) { 785 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 786 } 787 788 /// Handle a switch operation with the provided value and cases. 789 template <typename T, typename RangeT> 790 void handleSwitch(const T &value, RangeT &&cases) { 791 LLVM_DEBUG({ 792 llvm::dbgs() << " * Value: " << value << "\n" 793 << " * Cases: "; 794 llvm::interleaveComma(cases, llvm::dbgs()); 795 llvm::dbgs() << "\n"; 796 }); 797 798 // Check to see if the attribute value is within the case list. Jump to 799 // the correct successor index based on the result. 800 for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 801 if (*it == value) 802 return selectJump(size_t((it - cases.begin()) + 1)); 803 selectJump(size_t(0)); 804 } 805 806 /// Internal implementation of reading various data types from the bytecode 807 /// stream. 808 template <typename T> 809 const void *readFromMemory() { 810 size_t index = *curCodeIt++; 811 812 // If this type is an SSA value, it can only be stored in non-const memory. 813 if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size()) 814 return memory[index]; 815 816 // Otherwise, if this index is not inbounds it is uniqued. 817 return uniquedMemory[index - memory.size()]; 818 } 819 template <typename T> 820 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 821 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 822 } 823 template <typename T> 824 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 825 T> 826 readImpl() { 827 return T(T::getFromOpaquePointer(readFromMemory<T>())); 828 } 829 template <typename T> 830 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 831 switch (static_cast<PDLValueKind>(read())) { 832 case PDLValueKind::Attribute: 833 return read<Attribute>(); 834 case PDLValueKind::Operation: 835 return read<Operation *>(); 836 case PDLValueKind::Type: 837 return read<Type>(); 838 case PDLValueKind::Value: 839 return read<Value>(); 840 } 841 llvm_unreachable("unhandled PDLValueKind"); 842 } 843 template <typename T> 844 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 845 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 846 "unexpected ByteCode address size"); 847 ByteCodeAddr result; 848 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 849 curCodeIt += 2; 850 return result; 851 } 852 template <typename T> 853 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 854 return *curCodeIt++; 855 } 856 857 /// The underlying bytecode buffer. 858 const ByteCodeField *curCodeIt; 859 860 /// The current execution memory. 861 MutableArrayRef<const void *> memory; 862 863 /// References to ByteCode data necessary for execution. 864 ArrayRef<const void *> uniquedMemory; 865 ArrayRef<ByteCodeField> code; 866 ArrayRef<PatternBenefit> currentPatternBenefits; 867 ArrayRef<PDLByteCodePattern> patterns; 868 ArrayRef<PDLConstraintFunction> constraintFunctions; 869 ArrayRef<PDLCreateFunction> createFunctions; 870 ArrayRef<PDLRewriteFunction> rewriteFunctions; 871 }; 872 } // end anonymous namespace 873 874 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 875 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 876 const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 877 ArrayAttr constParams = read<ArrayAttr>(); 878 SmallVector<PDLValue, 16> args; 879 readList<PDLValue>(args); 880 881 LLVM_DEBUG({ 882 llvm::dbgs() << " * Arguments: "; 883 llvm::interleaveComma(args, llvm::dbgs()); 884 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 885 }); 886 887 // Invoke the constraint and jump to the proper destination. 888 selectJump(succeeded(constraintFn(args, constParams, rewriter))); 889 } 890 891 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 892 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 893 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 894 ArrayAttr constParams = read<ArrayAttr>(); 895 Operation *root = read<Operation *>(); 896 SmallVector<PDLValue, 16> args; 897 readList<PDLValue>(args); 898 899 LLVM_DEBUG({ 900 llvm::dbgs() << " * Root: " << *root << "\n * Arguments: "; 901 llvm::interleaveComma(args, llvm::dbgs()); 902 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 903 }); 904 905 // Invoke the native rewrite function. 906 rewriteFn(root, args, constParams, rewriter); 907 } 908 909 void ByteCodeExecutor::executeAreEqual() { 910 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 911 const void *lhs = read<const void *>(); 912 const void *rhs = read<const void *>(); 913 914 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 915 selectJump(lhs == rhs); 916 } 917 918 void ByteCodeExecutor::executeBranch() { 919 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 920 curCodeIt = &code[read<ByteCodeAddr>()]; 921 } 922 923 void ByteCodeExecutor::executeCheckOperandCount() { 924 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 925 Operation *op = read<Operation *>(); 926 uint32_t expectedCount = read<uint32_t>(); 927 928 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 929 << " * Expected: " << expectedCount << "\n"); 930 selectJump(op->getNumOperands() == expectedCount); 931 } 932 933 void ByteCodeExecutor::executeCheckOperationName() { 934 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 935 Operation *op = read<Operation *>(); 936 OperationName expectedName = read<OperationName>(); 937 938 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 939 << " * Expected: \"" << expectedName << "\"\n"); 940 selectJump(op->getName() == expectedName); 941 } 942 943 void ByteCodeExecutor::executeCheckResultCount() { 944 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 945 Operation *op = read<Operation *>(); 946 uint32_t expectedCount = read<uint32_t>(); 947 948 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 949 << " * Expected: " << expectedCount << "\n"); 950 selectJump(op->getNumResults() == expectedCount); 951 } 952 953 void ByteCodeExecutor::executeCreateNative(PatternRewriter &rewriter) { 954 LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n"); 955 const PDLCreateFunction &createFn = createFunctions[read()]; 956 ByteCodeField resultIndex = read(); 957 ArrayAttr constParams = read<ArrayAttr>(); 958 SmallVector<PDLValue, 16> args; 959 readList<PDLValue>(args); 960 961 LLVM_DEBUG({ 962 llvm::dbgs() << " * Arguments: "; 963 llvm::interleaveComma(args, llvm::dbgs()); 964 llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 965 }); 966 967 PDLValue result = createFn(args, constParams, rewriter); 968 memory[resultIndex] = result.getAsOpaquePointer(); 969 970 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 971 } 972 973 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 974 Location mainRewriteLoc) { 975 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 976 977 unsigned memIndex = read(); 978 OperationState state(mainRewriteLoc, read<OperationName>()); 979 readList<Value>(state.operands); 980 for (unsigned i = 0, e = read(); i != e; ++i) { 981 Identifier name = read<Identifier>(); 982 if (Attribute attr = read<Attribute>()) 983 state.addAttribute(name, attr); 984 } 985 986 bool hasInferredTypes = false; 987 for (unsigned i = 0, e = read(); i != e; ++i) { 988 Type resultType = read<Type>(); 989 hasInferredTypes |= !resultType; 990 state.types.push_back(resultType); 991 } 992 993 // Handle the case where the operation has inferred types. 994 if (hasInferredTypes) { 995 InferTypeOpInterface::Concept *concept = 996 state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>(); 997 998 // TODO: Handle failure. 999 SmallVector<Type, 2> inferredTypes; 1000 if (failed(concept->inferReturnTypes( 1001 state.getContext(), state.location, state.operands, 1002 state.attributes.getDictionary(state.getContext()), state.regions, 1003 inferredTypes))) 1004 return; 1005 1006 for (unsigned i = 0, e = state.types.size(); i != e; ++i) 1007 if (!state.types[i]) 1008 state.types[i] = inferredTypes[i]; 1009 } 1010 Operation *resultOp = rewriter.createOperation(state); 1011 memory[memIndex] = resultOp; 1012 1013 LLVM_DEBUG({ 1014 llvm::dbgs() << " * Attributes: " 1015 << state.attributes.getDictionary(state.getContext()) 1016 << "\n * Operands: "; 1017 llvm::interleaveComma(state.operands, llvm::dbgs()); 1018 llvm::dbgs() << "\n * Result Types: "; 1019 llvm::interleaveComma(state.types, llvm::dbgs()); 1020 llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1021 }); 1022 } 1023 1024 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1025 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1026 Operation *op = read<Operation *>(); 1027 1028 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1029 rewriter.eraseOp(op); 1030 } 1031 1032 void ByteCodeExecutor::executeGetAttribute() { 1033 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1034 unsigned memIndex = read(); 1035 Operation *op = read<Operation *>(); 1036 Identifier attrName = read<Identifier>(); 1037 Attribute attr = op->getAttr(attrName); 1038 1039 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1040 << " * Attribute: " << attrName << "\n" 1041 << " * Result: " << attr << "\n"); 1042 memory[memIndex] = attr.getAsOpaquePointer(); 1043 } 1044 1045 void ByteCodeExecutor::executeGetAttributeType() { 1046 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1047 unsigned memIndex = read(); 1048 Attribute attr = read<Attribute>(); 1049 Type type = attr ? attr.getType() : Type(); 1050 1051 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1052 << " * Result: " << type << "\n"); 1053 memory[memIndex] = type.getAsOpaquePointer(); 1054 } 1055 1056 void ByteCodeExecutor::executeGetDefiningOp() { 1057 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1058 unsigned memIndex = read(); 1059 Value value = read<Value>(); 1060 Operation *op = value ? value.getDefiningOp() : nullptr; 1061 1062 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1063 << " * Result: " << *op << "\n"); 1064 memory[memIndex] = op; 1065 } 1066 1067 void ByteCodeExecutor::executeGetOperand(unsigned index) { 1068 Operation *op = read<Operation *>(); 1069 unsigned memIndex = read(); 1070 Value operand = 1071 index < op->getNumOperands() ? op->getOperand(index) : Value(); 1072 1073 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1074 << " * Index: " << index << "\n" 1075 << " * Result: " << operand << "\n"); 1076 memory[memIndex] = operand.getAsOpaquePointer(); 1077 } 1078 1079 void ByteCodeExecutor::executeGetResult(unsigned index) { 1080 Operation *op = read<Operation *>(); 1081 unsigned memIndex = read(); 1082 OpResult result = 1083 index < op->getNumResults() ? op->getResult(index) : OpResult(); 1084 1085 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1086 << " * Index: " << index << "\n" 1087 << " * Result: " << result << "\n"); 1088 memory[memIndex] = result.getAsOpaquePointer(); 1089 } 1090 1091 void ByteCodeExecutor::executeGetValueType() { 1092 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1093 unsigned memIndex = read(); 1094 Value value = read<Value>(); 1095 Type type = value ? value.getType() : Type(); 1096 1097 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1098 << " * Result: " << type << "\n"); 1099 memory[memIndex] = type.getAsOpaquePointer(); 1100 } 1101 1102 void ByteCodeExecutor::executeIsNotNull() { 1103 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1104 const void *value = read<const void *>(); 1105 1106 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1107 selectJump(value != nullptr); 1108 } 1109 1110 void ByteCodeExecutor::executeRecordMatch( 1111 PatternRewriter &rewriter, 1112 SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1113 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1114 unsigned patternIndex = read(); 1115 PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1116 const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1117 1118 // If the benefit of the pattern is impossible, skip the processing of the 1119 // rest of the pattern. 1120 if (benefit.isImpossibleToMatch()) { 1121 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1122 curCodeIt = dest; 1123 return; 1124 } 1125 1126 // Create a fused location containing the locations of each of the 1127 // operations used in the match. This will be used as the location for 1128 // created operations during the rewrite that don't already have an 1129 // explicit location set. 1130 unsigned numMatchLocs = read(); 1131 SmallVector<Location, 4> matchLocs; 1132 matchLocs.reserve(numMatchLocs); 1133 for (unsigned i = 0; i != numMatchLocs; ++i) 1134 matchLocs.push_back(read<Operation *>()->getLoc()); 1135 Location matchLoc = rewriter.getFusedLoc(matchLocs); 1136 1137 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1138 << " * Location: " << matchLoc << "\n"); 1139 matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1140 readList<const void *>(matches.back().values); 1141 curCodeIt = dest; 1142 } 1143 1144 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1145 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1146 Operation *op = read<Operation *>(); 1147 SmallVector<Value, 16> args; 1148 readList<Value>(args); 1149 1150 LLVM_DEBUG({ 1151 llvm::dbgs() << " * Operation: " << *op << "\n" 1152 << " * Values: "; 1153 llvm::interleaveComma(args, llvm::dbgs()); 1154 llvm::dbgs() << "\n"; 1155 }); 1156 rewriter.replaceOp(op, args); 1157 } 1158 1159 void ByteCodeExecutor::executeSwitchAttribute() { 1160 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1161 Attribute value = read<Attribute>(); 1162 ArrayAttr cases = read<ArrayAttr>(); 1163 handleSwitch(value, cases); 1164 } 1165 1166 void ByteCodeExecutor::executeSwitchOperandCount() { 1167 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1168 Operation *op = read<Operation *>(); 1169 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1170 1171 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1172 handleSwitch(op->getNumOperands(), cases); 1173 } 1174 1175 void ByteCodeExecutor::executeSwitchOperationName() { 1176 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1177 OperationName value = read<Operation *>()->getName(); 1178 size_t caseCount = read(); 1179 1180 // The operation names are stored in-line, so to print them out for 1181 // debugging purposes we need to read the array before executing the 1182 // switch so that we can display all of the possible values. 1183 LLVM_DEBUG({ 1184 const ByteCodeField *prevCodeIt = curCodeIt; 1185 llvm::dbgs() << " * Value: " << value << "\n" 1186 << " * Cases: "; 1187 llvm::interleaveComma( 1188 llvm::map_range(llvm::seq<size_t>(0, caseCount), 1189 [&](size_t) { return read<OperationName>(); }), 1190 llvm::dbgs()); 1191 llvm::dbgs() << "\n"; 1192 curCodeIt = prevCodeIt; 1193 }); 1194 1195 // Try to find the switch value within any of the cases. 1196 for (size_t i = 0; i != caseCount; ++i) { 1197 if (read<OperationName>() == value) { 1198 curCodeIt += (caseCount - i - 1); 1199 return selectJump(i + 1); 1200 } 1201 } 1202 selectJump(size_t(0)); 1203 } 1204 1205 void ByteCodeExecutor::executeSwitchResultCount() { 1206 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1207 Operation *op = read<Operation *>(); 1208 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1209 1210 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1211 handleSwitch(op->getNumResults(), cases); 1212 } 1213 1214 void ByteCodeExecutor::executeSwitchType() { 1215 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 1216 Type value = read<Type>(); 1217 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 1218 handleSwitch(value, cases); 1219 } 1220 1221 void ByteCodeExecutor::execute( 1222 PatternRewriter &rewriter, 1223 SmallVectorImpl<PDLByteCode::MatchResult> *matches, 1224 Optional<Location> mainRewriteLoc) { 1225 while (true) { 1226 OpCode opCode = static_cast<OpCode>(read()); 1227 switch (opCode) { 1228 case ApplyConstraint: 1229 executeApplyConstraint(rewriter); 1230 break; 1231 case ApplyRewrite: 1232 executeApplyRewrite(rewriter); 1233 break; 1234 case AreEqual: 1235 executeAreEqual(); 1236 break; 1237 case Branch: 1238 executeBranch(); 1239 break; 1240 case CheckOperandCount: 1241 executeCheckOperandCount(); 1242 break; 1243 case CheckOperationName: 1244 executeCheckOperationName(); 1245 break; 1246 case CheckResultCount: 1247 executeCheckResultCount(); 1248 break; 1249 case CreateNative: 1250 executeCreateNative(rewriter); 1251 break; 1252 case CreateOperation: 1253 executeCreateOperation(rewriter, *mainRewriteLoc); 1254 break; 1255 case EraseOp: 1256 executeEraseOp(rewriter); 1257 break; 1258 case Finalize: 1259 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); 1260 return; 1261 case GetAttribute: 1262 executeGetAttribute(); 1263 break; 1264 case GetAttributeType: 1265 executeGetAttributeType(); 1266 break; 1267 case GetDefiningOp: 1268 executeGetDefiningOp(); 1269 break; 1270 case GetOperand0: 1271 case GetOperand1: 1272 case GetOperand2: 1273 case GetOperand3: { 1274 unsigned index = opCode - GetOperand0; 1275 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 1276 executeGetOperand(index); 1277 break; 1278 } 1279 case GetOperandN: 1280 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 1281 executeGetOperand(read<uint32_t>()); 1282 break; 1283 case GetResult0: 1284 case GetResult1: 1285 case GetResult2: 1286 case GetResult3: { 1287 unsigned index = opCode - GetResult0; 1288 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 1289 executeGetResult(index); 1290 break; 1291 } 1292 case GetResultN: 1293 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 1294 executeGetResult(read<uint32_t>()); 1295 break; 1296 case GetValueType: 1297 executeGetValueType(); 1298 break; 1299 case IsNotNull: 1300 executeIsNotNull(); 1301 break; 1302 case RecordMatch: 1303 assert(matches && 1304 "expected matches to be provided when executing the matcher"); 1305 executeRecordMatch(rewriter, *matches); 1306 break; 1307 case ReplaceOp: 1308 executeReplaceOp(rewriter); 1309 break; 1310 case SwitchAttribute: 1311 executeSwitchAttribute(); 1312 break; 1313 case SwitchOperandCount: 1314 executeSwitchOperandCount(); 1315 break; 1316 case SwitchOperationName: 1317 executeSwitchOperationName(); 1318 break; 1319 case SwitchResultCount: 1320 executeSwitchResultCount(); 1321 break; 1322 case SwitchType: 1323 executeSwitchType(); 1324 break; 1325 } 1326 LLVM_DEBUG(llvm::dbgs() << "\n"); 1327 } 1328 } 1329 1330 /// Run the pattern matcher on the given root operation, collecting the matched 1331 /// patterns in `matches`. 1332 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 1333 SmallVectorImpl<MatchResult> &matches, 1334 PDLByteCodeMutableState &state) const { 1335 // The first memory slot is always the root operation. 1336 state.memory[0] = op; 1337 1338 // The matcher function always starts at code address 0. 1339 ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, 1340 matcherByteCode, state.currentPatternBenefits, 1341 patterns, constraintFunctions, createFunctions, 1342 rewriteFunctions); 1343 executor.execute(rewriter, &matches); 1344 1345 // Order the found matches by benefit. 1346 std::stable_sort(matches.begin(), matches.end(), 1347 [](const MatchResult &lhs, const MatchResult &rhs) { 1348 return lhs.benefit > rhs.benefit; 1349 }); 1350 } 1351 1352 /// Run the rewriter of the given pattern on the root operation `op`. 1353 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 1354 PDLByteCodeMutableState &state) const { 1355 // The arguments of the rewrite function are stored at the start of the 1356 // memory buffer. 1357 llvm::copy(match.values, state.memory.begin()); 1358 1359 ByteCodeExecutor executor( 1360 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 1361 uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, 1362 constraintFunctions, createFunctions, rewriteFunctions); 1363 executor.execute(rewriter, /*matches=*/nullptr, match.location); 1364 } 1365