1abfd1a8bSRiver Riddle //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2abfd1a8bSRiver Riddle // 3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6abfd1a8bSRiver Riddle // 7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 8abfd1a8bSRiver Riddle // 9abfd1a8bSRiver Riddle // This file implements MLIR to byte-code generation and the interpreter. 10abfd1a8bSRiver Riddle // 11abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 12abfd1a8bSRiver Riddle 13abfd1a8bSRiver Riddle #include "ByteCode.h" 14abfd1a8bSRiver Riddle #include "mlir/Analysis/Liveness.h" 15abfd1a8bSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16abfd1a8bSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17e66c2e25SRiver Riddle #include "mlir/IR/BuiltinOps.h" 18abfd1a8bSRiver Riddle #include "mlir/IR/RegionGraphTraits.h" 19abfd1a8bSRiver Riddle #include "llvm/ADT/IntervalMap.h" 20abfd1a8bSRiver Riddle #include "llvm/ADT/PostOrderIterator.h" 21abfd1a8bSRiver Riddle #include "llvm/ADT/TypeSwitch.h" 22abfd1a8bSRiver Riddle #include "llvm/Support/Debug.h" 2385ab413bSRiver Riddle #include "llvm/Support/Format.h" 2485ab413bSRiver Riddle #include "llvm/Support/FormatVariadic.h" 2585ab413bSRiver Riddle #include <numeric> 26abfd1a8bSRiver Riddle 27abfd1a8bSRiver Riddle #define DEBUG_TYPE "pdl-bytecode" 28abfd1a8bSRiver Riddle 29abfd1a8bSRiver Riddle using namespace mlir; 30abfd1a8bSRiver Riddle using namespace mlir::detail; 31abfd1a8bSRiver Riddle 32abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 33abfd1a8bSRiver Riddle // PDLByteCodePattern 34abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 35abfd1a8bSRiver Riddle 36abfd1a8bSRiver Riddle PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37abfd1a8bSRiver Riddle ByteCodeAddr rewriterAddr) { 38abfd1a8bSRiver Riddle SmallVector<StringRef, 8> generatedOps; 39abfd1a8bSRiver Riddle if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 40abfd1a8bSRiver Riddle generatedOps = 41abfd1a8bSRiver Riddle llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42abfd1a8bSRiver Riddle 43abfd1a8bSRiver Riddle PatternBenefit benefit = matchOp.benefit(); 44abfd1a8bSRiver Riddle MLIRContext *ctx = matchOp.getContext(); 45abfd1a8bSRiver Riddle 46abfd1a8bSRiver Riddle // Check to see if this is pattern matches a specific operation type. 47abfd1a8bSRiver Riddle if (Optional<StringRef> rootKind = matchOp.rootKind()) 4876f3c2f3SRiver Riddle return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, 4976f3c2f3SRiver Riddle generatedOps); 5076f3c2f3SRiver Riddle return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, 5176f3c2f3SRiver Riddle generatedOps); 52abfd1a8bSRiver Riddle } 53abfd1a8bSRiver Riddle 54abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 55abfd1a8bSRiver Riddle // PDLByteCodeMutableState 56abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 57abfd1a8bSRiver Riddle 58abfd1a8bSRiver Riddle /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59abfd1a8bSRiver Riddle /// to the position of the pattern within the range returned by 60abfd1a8bSRiver Riddle /// `PDLByteCode::getPatterns`. 61abfd1a8bSRiver Riddle void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62abfd1a8bSRiver Riddle PatternBenefit benefit) { 63abfd1a8bSRiver Riddle currentPatternBenefits[patternIndex] = benefit; 64abfd1a8bSRiver Riddle } 65abfd1a8bSRiver Riddle 6685ab413bSRiver Riddle /// Cleanup any allocated state after a full match/rewrite has been completed. 6785ab413bSRiver Riddle /// This method should be called irregardless of whether the match+rewrite was a 6885ab413bSRiver Riddle /// success or not. 6985ab413bSRiver Riddle void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 7085ab413bSRiver Riddle allocatedTypeRangeMemory.clear(); 7185ab413bSRiver Riddle allocatedValueRangeMemory.clear(); 7285ab413bSRiver Riddle } 7385ab413bSRiver Riddle 74abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 75abfd1a8bSRiver Riddle // Bytecode OpCodes 76abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 77abfd1a8bSRiver Riddle 78abfd1a8bSRiver Riddle namespace { 79abfd1a8bSRiver Riddle enum OpCode : ByteCodeField { 80abfd1a8bSRiver Riddle /// Apply an externally registered constraint. 81abfd1a8bSRiver Riddle ApplyConstraint, 82abfd1a8bSRiver Riddle /// Apply an externally registered rewrite. 83abfd1a8bSRiver Riddle ApplyRewrite, 84abfd1a8bSRiver Riddle /// Check if two generic values are equal. 85abfd1a8bSRiver Riddle AreEqual, 8685ab413bSRiver Riddle /// Check if two ranges are equal. 8785ab413bSRiver Riddle AreRangesEqual, 88abfd1a8bSRiver Riddle /// Unconditional branch. 89abfd1a8bSRiver Riddle Branch, 90abfd1a8bSRiver Riddle /// Compare the operand count of an operation with a constant. 91abfd1a8bSRiver Riddle CheckOperandCount, 92abfd1a8bSRiver Riddle /// Compare the name of an operation with a constant. 93abfd1a8bSRiver Riddle CheckOperationName, 94abfd1a8bSRiver Riddle /// Compare the result count of an operation with a constant. 95abfd1a8bSRiver Riddle CheckResultCount, 9685ab413bSRiver Riddle /// Compare a range of types to a constant range of types. 9785ab413bSRiver Riddle CheckTypes, 983eb1647aSStanislav Funiak /// Continue to the next iteration of a loop. 993eb1647aSStanislav Funiak Continue, 100abfd1a8bSRiver Riddle /// Create an operation. 101abfd1a8bSRiver Riddle CreateOperation, 10285ab413bSRiver Riddle /// Create a range of types. 10385ab413bSRiver Riddle CreateTypes, 104abfd1a8bSRiver Riddle /// Erase an operation. 105abfd1a8bSRiver Riddle EraseOp, 1063eb1647aSStanislav Funiak /// Extract the op from a range at the specified index. 1073eb1647aSStanislav Funiak ExtractOp, 1083eb1647aSStanislav Funiak /// Extract the type from a range at the specified index. 1093eb1647aSStanislav Funiak ExtractType, 1103eb1647aSStanislav Funiak /// Extract the value from a range at the specified index. 1113eb1647aSStanislav Funiak ExtractValue, 112abfd1a8bSRiver Riddle /// Terminate a matcher or rewrite sequence. 113abfd1a8bSRiver Riddle Finalize, 1143eb1647aSStanislav Funiak /// Iterate over a range of values. 1153eb1647aSStanislav Funiak ForEach, 116abfd1a8bSRiver Riddle /// Get a specific attribute of an operation. 117abfd1a8bSRiver Riddle GetAttribute, 118abfd1a8bSRiver Riddle /// Get the type of an attribute. 119abfd1a8bSRiver Riddle GetAttributeType, 120abfd1a8bSRiver Riddle /// Get the defining operation of a value. 121abfd1a8bSRiver Riddle GetDefiningOp, 122abfd1a8bSRiver Riddle /// Get a specific operand of an operation. 123abfd1a8bSRiver Riddle GetOperand0, 124abfd1a8bSRiver Riddle GetOperand1, 125abfd1a8bSRiver Riddle GetOperand2, 126abfd1a8bSRiver Riddle GetOperand3, 127abfd1a8bSRiver Riddle GetOperandN, 12885ab413bSRiver Riddle /// Get a specific operand group of an operation. 12985ab413bSRiver Riddle GetOperands, 130abfd1a8bSRiver Riddle /// Get a specific result of an operation. 131abfd1a8bSRiver Riddle GetResult0, 132abfd1a8bSRiver Riddle GetResult1, 133abfd1a8bSRiver Riddle GetResult2, 134abfd1a8bSRiver Riddle GetResult3, 135abfd1a8bSRiver Riddle GetResultN, 13685ab413bSRiver Riddle /// Get a specific result group of an operation. 13785ab413bSRiver Riddle GetResults, 1383eb1647aSStanislav Funiak /// Get the users of a value or a range of values. 1393eb1647aSStanislav Funiak GetUsers, 140abfd1a8bSRiver Riddle /// Get the type of a value. 141abfd1a8bSRiver Riddle GetValueType, 14285ab413bSRiver Riddle /// Get the types of a value range. 14385ab413bSRiver Riddle GetValueRangeTypes, 144abfd1a8bSRiver Riddle /// Check if a generic value is not null. 145abfd1a8bSRiver Riddle IsNotNull, 146abfd1a8bSRiver Riddle /// Record a successful pattern match. 147abfd1a8bSRiver Riddle RecordMatch, 148abfd1a8bSRiver Riddle /// Replace an operation. 149abfd1a8bSRiver Riddle ReplaceOp, 150abfd1a8bSRiver Riddle /// Compare an attribute with a set of constants. 151abfd1a8bSRiver Riddle SwitchAttribute, 152abfd1a8bSRiver Riddle /// Compare the operand count of an operation with a set of constants. 153abfd1a8bSRiver Riddle SwitchOperandCount, 154abfd1a8bSRiver Riddle /// Compare the name of an operation with a set of constants. 155abfd1a8bSRiver Riddle SwitchOperationName, 156abfd1a8bSRiver Riddle /// Compare the result count of an operation with a set of constants. 157abfd1a8bSRiver Riddle SwitchResultCount, 158abfd1a8bSRiver Riddle /// Compare a type with a set of constants. 159abfd1a8bSRiver Riddle SwitchType, 16085ab413bSRiver Riddle /// Compare a range of types with a set of constants. 16185ab413bSRiver Riddle SwitchTypes, 162abfd1a8bSRiver Riddle }; 163be0a7e9fSMehdi Amini } // namespace 164abfd1a8bSRiver Riddle 165abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 166abfd1a8bSRiver Riddle // ByteCode Generation 167abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 168abfd1a8bSRiver Riddle 169abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 170abfd1a8bSRiver Riddle // Generator 171abfd1a8bSRiver Riddle 172abfd1a8bSRiver Riddle namespace { 1733eb1647aSStanislav Funiak struct ByteCodeLiveRange; 174abfd1a8bSRiver Riddle struct ByteCodeWriter; 175abfd1a8bSRiver Riddle 1763eb1647aSStanislav Funiak /// Check if the given class `T` can be converted to an opaque pointer. 1773eb1647aSStanislav Funiak template <typename T, typename... Args> 1783eb1647aSStanislav Funiak using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 1793eb1647aSStanislav Funiak 180abfd1a8bSRiver Riddle /// This class represents the main generator for the pattern bytecode. 181abfd1a8bSRiver Riddle class Generator { 182abfd1a8bSRiver Riddle public: 183abfd1a8bSRiver Riddle Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 184abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &matcherByteCode, 185abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &rewriterByteCode, 186abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCodePattern> &patterns, 187abfd1a8bSRiver Riddle ByteCodeField &maxValueMemoryIndex, 1883eb1647aSStanislav Funiak ByteCodeField &maxOpRangeMemoryIndex, 18985ab413bSRiver Riddle ByteCodeField &maxTypeRangeMemoryIndex, 19085ab413bSRiver Riddle ByteCodeField &maxValueRangeMemoryIndex, 1913eb1647aSStanislav Funiak ByteCodeField &maxLoopLevel, 192abfd1a8bSRiver Riddle llvm::StringMap<PDLConstraintFunction> &constraintFns, 193abfd1a8bSRiver Riddle llvm::StringMap<PDLRewriteFunction> &rewriteFns) 194abfd1a8bSRiver Riddle : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 195abfd1a8bSRiver Riddle rewriterByteCode(rewriterByteCode), patterns(patterns), 19685ab413bSRiver Riddle maxValueMemoryIndex(maxValueMemoryIndex), 1973eb1647aSStanislav Funiak maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), 19885ab413bSRiver Riddle maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 1993eb1647aSStanislav Funiak maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), 2003eb1647aSStanislav Funiak maxLoopLevel(maxLoopLevel) { 201e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(constraintFns)) 202abfd1a8bSRiver Riddle constraintToMemIndex.try_emplace(it.value().first(), it.index()); 203e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(rewriteFns)) 204abfd1a8bSRiver Riddle externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 205abfd1a8bSRiver Riddle } 206abfd1a8bSRiver Riddle 207abfd1a8bSRiver Riddle /// Generate the bytecode for the given PDL interpreter module. 208abfd1a8bSRiver Riddle void generate(ModuleOp module); 209abfd1a8bSRiver Riddle 210abfd1a8bSRiver Riddle /// Return the memory index to use for the given value. 211abfd1a8bSRiver Riddle ByteCodeField &getMemIndex(Value value) { 212abfd1a8bSRiver Riddle assert(valueToMemIndex.count(value) && 213abfd1a8bSRiver Riddle "expected memory index to be assigned"); 214abfd1a8bSRiver Riddle return valueToMemIndex[value]; 215abfd1a8bSRiver Riddle } 216abfd1a8bSRiver Riddle 21785ab413bSRiver Riddle /// Return the range memory index used to store the given range value. 21885ab413bSRiver Riddle ByteCodeField &getRangeStorageIndex(Value value) { 21985ab413bSRiver Riddle assert(valueToRangeIndex.count(value) && 22085ab413bSRiver Riddle "expected range index to be assigned"); 22185ab413bSRiver Riddle return valueToRangeIndex[value]; 22285ab413bSRiver Riddle } 22385ab413bSRiver Riddle 224abfd1a8bSRiver Riddle /// Return an index to use when referring to the given data that is uniqued in 225abfd1a8bSRiver Riddle /// the MLIR context. 226abfd1a8bSRiver Riddle template <typename T> 227abfd1a8bSRiver Riddle std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 228abfd1a8bSRiver Riddle getMemIndex(T val) { 229abfd1a8bSRiver Riddle const void *opaqueVal = val.getAsOpaquePointer(); 230abfd1a8bSRiver Riddle 231abfd1a8bSRiver Riddle // Get or insert a reference to this value. 232abfd1a8bSRiver Riddle auto it = uniquedDataToMemIndex.try_emplace( 233abfd1a8bSRiver Riddle opaqueVal, maxValueMemoryIndex + uniquedData.size()); 234abfd1a8bSRiver Riddle if (it.second) 235abfd1a8bSRiver Riddle uniquedData.push_back(opaqueVal); 236abfd1a8bSRiver Riddle return it.first->second; 237abfd1a8bSRiver Riddle } 238abfd1a8bSRiver Riddle 239abfd1a8bSRiver Riddle private: 240abfd1a8bSRiver Riddle /// Allocate memory indices for the results of operations within the matcher 241abfd1a8bSRiver Riddle /// and rewriters. 242abfd1a8bSRiver Riddle void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 243abfd1a8bSRiver Riddle 244abfd1a8bSRiver Riddle /// Generate the bytecode for the given operation. 2453eb1647aSStanislav Funiak void generate(Region *region, ByteCodeWriter &writer); 246abfd1a8bSRiver Riddle void generate(Operation *op, ByteCodeWriter &writer); 247abfd1a8bSRiver Riddle void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 248abfd1a8bSRiver Riddle void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 249abfd1a8bSRiver Riddle void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 250abfd1a8bSRiver Riddle void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 251abfd1a8bSRiver Riddle void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 252abfd1a8bSRiver Riddle void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 253abfd1a8bSRiver Riddle void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 254abfd1a8bSRiver Riddle void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 255abfd1a8bSRiver Riddle void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 25685ab413bSRiver Riddle void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 2573eb1647aSStanislav Funiak void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); 258abfd1a8bSRiver Riddle void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 259abfd1a8bSRiver Riddle void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 260abfd1a8bSRiver Riddle void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 26185ab413bSRiver Riddle void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 262abfd1a8bSRiver Riddle void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 2633eb1647aSStanislav Funiak void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); 264abfd1a8bSRiver Riddle void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 2653eb1647aSStanislav Funiak void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); 266abfd1a8bSRiver Riddle void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 267abfd1a8bSRiver Riddle void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 268abfd1a8bSRiver Riddle void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 269abfd1a8bSRiver Riddle void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 27085ab413bSRiver Riddle void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 271abfd1a8bSRiver Riddle void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 27285ab413bSRiver Riddle void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 2733eb1647aSStanislav Funiak void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); 274abfd1a8bSRiver Riddle void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 2753a833a0eSRiver Riddle void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); 276abfd1a8bSRiver Riddle void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 277abfd1a8bSRiver Riddle void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 278abfd1a8bSRiver Riddle void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 279abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 280abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 28185ab413bSRiver Riddle void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 282abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 283abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 284abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 285abfd1a8bSRiver Riddle 286abfd1a8bSRiver Riddle /// Mapping from value to its corresponding memory index. 287abfd1a8bSRiver Riddle DenseMap<Value, ByteCodeField> valueToMemIndex; 288abfd1a8bSRiver Riddle 28985ab413bSRiver Riddle /// Mapping from a range value to its corresponding range storage index. 29085ab413bSRiver Riddle DenseMap<Value, ByteCodeField> valueToRangeIndex; 29185ab413bSRiver Riddle 292abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered rewrite to its index in 293abfd1a8bSRiver Riddle /// the bytecode registry. 294abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 295abfd1a8bSRiver Riddle 296abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered constraint to its index 297abfd1a8bSRiver Riddle /// in the bytecode registry. 298abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> constraintToMemIndex; 299abfd1a8bSRiver Riddle 300abfd1a8bSRiver Riddle /// Mapping from rewriter function name to the bytecode address of the 301abfd1a8bSRiver Riddle /// rewriter function in byte. 302abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeAddr> rewriterToAddr; 303abfd1a8bSRiver Riddle 304abfd1a8bSRiver Riddle /// Mapping from a uniqued storage object to its memory index within 305abfd1a8bSRiver Riddle /// `uniquedData`. 306abfd1a8bSRiver Riddle DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 307abfd1a8bSRiver Riddle 3083eb1647aSStanislav Funiak /// The current level of the foreach loop. 3093eb1647aSStanislav Funiak ByteCodeField curLoopLevel = 0; 3103eb1647aSStanislav Funiak 311abfd1a8bSRiver Riddle /// The current MLIR context. 312abfd1a8bSRiver Riddle MLIRContext *ctx; 313abfd1a8bSRiver Riddle 3143eb1647aSStanislav Funiak /// Mapping from block to its address. 3153eb1647aSStanislav Funiak DenseMap<Block *, ByteCodeAddr> blockToAddr; 3163eb1647aSStanislav Funiak 317abfd1a8bSRiver Riddle /// Data of the ByteCode class to be populated. 318abfd1a8bSRiver Riddle std::vector<const void *> &uniquedData; 319abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &matcherByteCode; 320abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &rewriterByteCode; 321abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCodePattern> &patterns; 322abfd1a8bSRiver Riddle ByteCodeField &maxValueMemoryIndex; 3233eb1647aSStanislav Funiak ByteCodeField &maxOpRangeMemoryIndex; 32485ab413bSRiver Riddle ByteCodeField &maxTypeRangeMemoryIndex; 32585ab413bSRiver Riddle ByteCodeField &maxValueRangeMemoryIndex; 3263eb1647aSStanislav Funiak ByteCodeField &maxLoopLevel; 327abfd1a8bSRiver Riddle }; 328abfd1a8bSRiver Riddle 329abfd1a8bSRiver Riddle /// This class provides utilities for writing a bytecode stream. 330abfd1a8bSRiver Riddle struct ByteCodeWriter { 331abfd1a8bSRiver Riddle ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 332abfd1a8bSRiver Riddle : bytecode(bytecode), generator(generator) {} 333abfd1a8bSRiver Riddle 334abfd1a8bSRiver Riddle /// Append a field to the bytecode. 335abfd1a8bSRiver Riddle void append(ByteCodeField field) { bytecode.push_back(field); } 336fa20ab7bSRiver Riddle void append(OpCode opCode) { bytecode.push_back(opCode); } 337abfd1a8bSRiver Riddle 338abfd1a8bSRiver Riddle /// Append an address to the bytecode. 339abfd1a8bSRiver Riddle void append(ByteCodeAddr field) { 340abfd1a8bSRiver Riddle static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 341abfd1a8bSRiver Riddle "unexpected ByteCode address size"); 342abfd1a8bSRiver Riddle 343abfd1a8bSRiver Riddle ByteCodeField fieldParts[2]; 344abfd1a8bSRiver Riddle std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 345abfd1a8bSRiver Riddle bytecode.append({fieldParts[0], fieldParts[1]}); 346abfd1a8bSRiver Riddle } 347abfd1a8bSRiver Riddle 3483eb1647aSStanislav Funiak /// Append a single successor to the bytecode, the exact address will need to 349abfd1a8bSRiver Riddle /// be resolved later. 3503eb1647aSStanislav Funiak void append(Block *successor) { 3513eb1647aSStanislav Funiak // Add back a reference to the successor so that the address can be resolved 3523eb1647aSStanislav Funiak // later. 353abfd1a8bSRiver Riddle unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 354abfd1a8bSRiver Riddle append(ByteCodeAddr(0)); 355abfd1a8bSRiver Riddle } 3563eb1647aSStanislav Funiak 3573eb1647aSStanislav Funiak /// Append a successor range to the bytecode, the exact address will need to 3583eb1647aSStanislav Funiak /// be resolved later. 3593eb1647aSStanislav Funiak void append(SuccessorRange successors) { 3603eb1647aSStanislav Funiak for (Block *successor : successors) 3613eb1647aSStanislav Funiak append(successor); 362abfd1a8bSRiver Riddle } 363abfd1a8bSRiver Riddle 364abfd1a8bSRiver Riddle /// Append a range of values that will be read as generic PDLValues. 365abfd1a8bSRiver Riddle void appendPDLValueList(OperandRange values) { 366abfd1a8bSRiver Riddle bytecode.push_back(values.size()); 36785ab413bSRiver Riddle for (Value value : values) 36885ab413bSRiver Riddle appendPDLValue(value); 36985ab413bSRiver Riddle } 37085ab413bSRiver Riddle 37185ab413bSRiver Riddle /// Append a value as a PDLValue. 37285ab413bSRiver Riddle void appendPDLValue(Value value) { 37385ab413bSRiver Riddle appendPDLValueKind(value); 374abfd1a8bSRiver Riddle append(value); 375abfd1a8bSRiver Riddle } 37685ab413bSRiver Riddle 37785ab413bSRiver Riddle /// Append the PDLValue::Kind of the given value. 3783eb1647aSStanislav Funiak void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } 3793eb1647aSStanislav Funiak 3803eb1647aSStanislav Funiak /// Append the PDLValue::Kind of the given type. 3813eb1647aSStanislav Funiak void appendPDLValueKind(Type type) { 38285ab413bSRiver Riddle PDLValue::Kind kind = 3833eb1647aSStanislav Funiak TypeSwitch<Type, PDLValue::Kind>(type) 38485ab413bSRiver Riddle .Case<pdl::AttributeType>( 38585ab413bSRiver Riddle [](Type) { return PDLValue::Kind::Attribute; }) 38685ab413bSRiver Riddle .Case<pdl::OperationType>( 38785ab413bSRiver Riddle [](Type) { return PDLValue::Kind::Operation; }) 38885ab413bSRiver Riddle .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 38985ab413bSRiver Riddle if (rangeTy.getElementType().isa<pdl::TypeType>()) 39085ab413bSRiver Riddle return PDLValue::Kind::TypeRange; 39185ab413bSRiver Riddle return PDLValue::Kind::ValueRange; 39285ab413bSRiver Riddle }) 39385ab413bSRiver Riddle .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 39485ab413bSRiver Riddle .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 39585ab413bSRiver Riddle bytecode.push_back(static_cast<ByteCodeField>(kind)); 396abfd1a8bSRiver Riddle } 397abfd1a8bSRiver Riddle 398abfd1a8bSRiver Riddle /// Append a value that will be stored in a memory slot and not inline within 399abfd1a8bSRiver Riddle /// the bytecode. 400abfd1a8bSRiver Riddle template <typename T> 401abfd1a8bSRiver Riddle std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 402abfd1a8bSRiver Riddle std::is_pointer<T>::value> 403abfd1a8bSRiver Riddle append(T value) { 404abfd1a8bSRiver Riddle bytecode.push_back(generator.getMemIndex(value)); 405abfd1a8bSRiver Riddle } 406abfd1a8bSRiver Riddle 407abfd1a8bSRiver Riddle /// Append a range of values. 408abfd1a8bSRiver Riddle template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 409abfd1a8bSRiver Riddle std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 410abfd1a8bSRiver Riddle append(T range) { 411abfd1a8bSRiver Riddle bytecode.push_back(llvm::size(range)); 412abfd1a8bSRiver Riddle for (auto it : range) 413abfd1a8bSRiver Riddle append(it); 414abfd1a8bSRiver Riddle } 415abfd1a8bSRiver Riddle 416abfd1a8bSRiver Riddle /// Append a variadic number of fields to the bytecode. 417abfd1a8bSRiver Riddle template <typename FieldTy, typename Field2Ty, typename... FieldTys> 418abfd1a8bSRiver Riddle void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 419abfd1a8bSRiver Riddle append(field); 420abfd1a8bSRiver Riddle append(field2, fields...); 421abfd1a8bSRiver Riddle } 422abfd1a8bSRiver Riddle 423d35f1190SStanislav Funiak /// Appends a value as a pointer, stored inline within the bytecode. 424d35f1190SStanislav Funiak template <typename T> 425d35f1190SStanislav Funiak std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 426d35f1190SStanislav Funiak appendInline(T value) { 427d35f1190SStanislav Funiak constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); 428d35f1190SStanislav Funiak const void *pointer = value.getAsOpaquePointer(); 429d35f1190SStanislav Funiak ByteCodeField fieldParts[numParts]; 430d35f1190SStanislav Funiak std::memcpy(fieldParts, &pointer, sizeof(const void *)); 431d35f1190SStanislav Funiak bytecode.append(fieldParts, fieldParts + numParts); 432d35f1190SStanislav Funiak } 433d35f1190SStanislav Funiak 434abfd1a8bSRiver Riddle /// Successor references in the bytecode that have yet to be resolved. 435abfd1a8bSRiver Riddle DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 436abfd1a8bSRiver Riddle 437abfd1a8bSRiver Riddle /// The underlying bytecode buffer. 438abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &bytecode; 439abfd1a8bSRiver Riddle 440abfd1a8bSRiver Riddle /// The main generator producing PDL. 441abfd1a8bSRiver Riddle Generator &generator; 442abfd1a8bSRiver Riddle }; 44385ab413bSRiver Riddle 44485ab413bSRiver Riddle /// This class represents a live range of PDL Interpreter values, containing 44585ab413bSRiver Riddle /// information about when values are live within a match/rewrite. 44685ab413bSRiver Riddle struct ByteCodeLiveRange { 4473eb1647aSStanislav Funiak using Set = llvm::IntervalMap<uint64_t, char, 16>; 44885ab413bSRiver Riddle using Allocator = Set::Allocator; 44985ab413bSRiver Riddle 4503eb1647aSStanislav Funiak ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} 45185ab413bSRiver Riddle 45285ab413bSRiver Riddle /// Union this live range with the one provided. 45385ab413bSRiver Riddle void unionWith(const ByteCodeLiveRange &rhs) { 4543eb1647aSStanislav Funiak for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; 4553eb1647aSStanislav Funiak ++it) 4563eb1647aSStanislav Funiak liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); 45785ab413bSRiver Riddle } 45885ab413bSRiver Riddle 45985ab413bSRiver Riddle /// Returns true if this range overlaps with the one provided. 46085ab413bSRiver Riddle bool overlaps(const ByteCodeLiveRange &rhs) const { 4613eb1647aSStanislav Funiak return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness) 4623eb1647aSStanislav Funiak .valid(); 46385ab413bSRiver Riddle } 46485ab413bSRiver Riddle 46585ab413bSRiver Riddle /// A map representing the ranges of the match/rewrite that a value is live in 46685ab413bSRiver Riddle /// the interpreter. 4673eb1647aSStanislav Funiak /// 4683eb1647aSStanislav Funiak /// We use std::unique_ptr here, because IntervalMap does not provide a 4693eb1647aSStanislav Funiak /// correct copy or move constructor. We can eliminate the pointer once 4703eb1647aSStanislav Funiak /// https://reviews.llvm.org/D113240 lands. 4713eb1647aSStanislav Funiak std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness; 4723eb1647aSStanislav Funiak 4733eb1647aSStanislav Funiak /// The operation range storage index for this range. 4743eb1647aSStanislav Funiak Optional<unsigned> opRangeIndex; 47585ab413bSRiver Riddle 47685ab413bSRiver Riddle /// The type range storage index for this range. 47785ab413bSRiver Riddle Optional<unsigned> typeRangeIndex; 47885ab413bSRiver Riddle 47985ab413bSRiver Riddle /// The value range storage index for this range. 48085ab413bSRiver Riddle Optional<unsigned> valueRangeIndex; 48185ab413bSRiver Riddle }; 482be0a7e9fSMehdi Amini } // namespace 483abfd1a8bSRiver Riddle 484abfd1a8bSRiver Riddle void Generator::generate(ModuleOp module) { 485abfd1a8bSRiver Riddle FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 486abfd1a8bSRiver Riddle pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 487abfd1a8bSRiver Riddle ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 488abfd1a8bSRiver Riddle pdl_interp::PDLInterpDialect::getRewriterModuleName()); 489abfd1a8bSRiver Riddle assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 490abfd1a8bSRiver Riddle 491abfd1a8bSRiver Riddle // Allocate memory indices for the results of operations within the matcher 492abfd1a8bSRiver Riddle // and rewriters. 493abfd1a8bSRiver Riddle allocateMemoryIndices(matcherFunc, rewriterModule); 494abfd1a8bSRiver Riddle 495abfd1a8bSRiver Riddle // Generate code for the rewriter functions. 496abfd1a8bSRiver Riddle ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 497abfd1a8bSRiver Riddle for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 498abfd1a8bSRiver Riddle rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 499abfd1a8bSRiver Riddle for (Operation &op : rewriterFunc.getOps()) 500abfd1a8bSRiver Riddle generate(&op, rewriterByteCodeWriter); 501abfd1a8bSRiver Riddle } 502abfd1a8bSRiver Riddle assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 503abfd1a8bSRiver Riddle "unexpected branches in rewriter function"); 504abfd1a8bSRiver Riddle 505abfd1a8bSRiver Riddle // Generate code for the matcher function. 506abfd1a8bSRiver Riddle ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 5073eb1647aSStanislav Funiak generate(&matcherFunc.getBody(), matcherByteCodeWriter); 508abfd1a8bSRiver Riddle 509abfd1a8bSRiver Riddle // Resolve successor references in the matcher. 510abfd1a8bSRiver Riddle for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 511abfd1a8bSRiver Riddle ByteCodeAddr addr = blockToAddr[it.first]; 512abfd1a8bSRiver Riddle for (unsigned offsetToFix : it.second) 513abfd1a8bSRiver Riddle std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 514abfd1a8bSRiver Riddle } 515abfd1a8bSRiver Riddle } 516abfd1a8bSRiver Riddle 517abfd1a8bSRiver Riddle void Generator::allocateMemoryIndices(FuncOp matcherFunc, 518abfd1a8bSRiver Riddle ModuleOp rewriterModule) { 519abfd1a8bSRiver Riddle // Rewriters use simplistic allocation scheme that simply assigns an index to 520abfd1a8bSRiver Riddle // each result. 521abfd1a8bSRiver Riddle for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 52285ab413bSRiver Riddle ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 52385ab413bSRiver Riddle auto processRewriterValue = [&](Value val) { 52485ab413bSRiver Riddle valueToMemIndex.try_emplace(val, index++); 52585ab413bSRiver Riddle if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 52685ab413bSRiver Riddle Type elementTy = rangeType.getElementType(); 52785ab413bSRiver Riddle if (elementTy.isa<pdl::TypeType>()) 52885ab413bSRiver Riddle valueToRangeIndex.try_emplace(val, typeRangeIndex++); 52985ab413bSRiver Riddle else if (elementTy.isa<pdl::ValueType>()) 53085ab413bSRiver Riddle valueToRangeIndex.try_emplace(val, valueRangeIndex++); 53185ab413bSRiver Riddle } 53285ab413bSRiver Riddle }; 53385ab413bSRiver Riddle 534abfd1a8bSRiver Riddle for (BlockArgument arg : rewriterFunc.getArguments()) 53585ab413bSRiver Riddle processRewriterValue(arg); 536abfd1a8bSRiver Riddle rewriterFunc.getBody().walk([&](Operation *op) { 537abfd1a8bSRiver Riddle for (Value result : op->getResults()) 53885ab413bSRiver Riddle processRewriterValue(result); 539abfd1a8bSRiver Riddle }); 540abfd1a8bSRiver Riddle if (index > maxValueMemoryIndex) 541abfd1a8bSRiver Riddle maxValueMemoryIndex = index; 54285ab413bSRiver Riddle if (typeRangeIndex > maxTypeRangeMemoryIndex) 54385ab413bSRiver Riddle maxTypeRangeMemoryIndex = typeRangeIndex; 54485ab413bSRiver Riddle if (valueRangeIndex > maxValueRangeMemoryIndex) 54585ab413bSRiver Riddle maxValueRangeMemoryIndex = valueRangeIndex; 546abfd1a8bSRiver Riddle } 547abfd1a8bSRiver Riddle 548abfd1a8bSRiver Riddle // The matcher function uses a more sophisticated numbering that tries to 549abfd1a8bSRiver Riddle // minimize the number of memory indices assigned. This is done by determining 550abfd1a8bSRiver Riddle // a live range of the values within the matcher, then the allocation is just 551abfd1a8bSRiver Riddle // finding the minimal number of overlapping live ranges. This is essentially 552abfd1a8bSRiver Riddle // a simplified form of register allocation where we don't necessarily have a 553abfd1a8bSRiver Riddle // limited number of registers, but we still want to minimize the number used. 554b4130e9eSStanislav Funiak DenseMap<Operation *, unsigned> opToFirstIndex; 555b4130e9eSStanislav Funiak DenseMap<Operation *, unsigned> opToLastIndex; 556b4130e9eSStanislav Funiak 557b4130e9eSStanislav Funiak // A custom walk that marks the first and the last index of each operation. 558b4130e9eSStanislav Funiak // The entry marks the beginning of the liveness range for this operation, 559b4130e9eSStanislav Funiak // followed by nested operations, followed by the end of the liveness range. 560b4130e9eSStanislav Funiak unsigned index = 0; 561b4130e9eSStanislav Funiak llvm::unique_function<void(Operation *)> walk = [&](Operation *op) { 562b4130e9eSStanislav Funiak opToFirstIndex.try_emplace(op, index++); 563b4130e9eSStanislav Funiak for (Region ®ion : op->getRegions()) 564b4130e9eSStanislav Funiak for (Block &block : region.getBlocks()) 565b4130e9eSStanislav Funiak for (Operation &nested : block) 566b4130e9eSStanislav Funiak walk(&nested); 567b4130e9eSStanislav Funiak opToLastIndex.try_emplace(op, index++); 568b4130e9eSStanislav Funiak }; 569b4130e9eSStanislav Funiak walk(matcherFunc); 570abfd1a8bSRiver Riddle 571abfd1a8bSRiver Riddle // Liveness info for each of the defs within the matcher. 57285ab413bSRiver Riddle ByteCodeLiveRange::Allocator allocator; 57385ab413bSRiver Riddle DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 574abfd1a8bSRiver Riddle 575abfd1a8bSRiver Riddle // Assign the root operation being matched to slot 0. 576abfd1a8bSRiver Riddle BlockArgument rootOpArg = matcherFunc.getArgument(0); 577abfd1a8bSRiver Riddle valueToMemIndex[rootOpArg] = 0; 578abfd1a8bSRiver Riddle 579abfd1a8bSRiver Riddle // Walk each of the blocks, computing the def interval that the value is used. 580abfd1a8bSRiver Riddle Liveness matcherLiveness(matcherFunc); 5813eb1647aSStanislav Funiak matcherFunc->walk([&](Block *block) { 5823eb1647aSStanislav Funiak const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); 583abfd1a8bSRiver Riddle assert(info && "expected liveness info for block"); 584abfd1a8bSRiver Riddle auto processValue = [&](Value value, Operation *firstUseOrDef) { 585abfd1a8bSRiver Riddle // We don't need to process the root op argument, this value is always 586abfd1a8bSRiver Riddle // assigned to the first memory slot. 587abfd1a8bSRiver Riddle if (value == rootOpArg) 588abfd1a8bSRiver Riddle return; 589abfd1a8bSRiver Riddle 590abfd1a8bSRiver Riddle // Set indices for the range of this block that the value is used. 591abfd1a8bSRiver Riddle auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 5923eb1647aSStanislav Funiak defRangeIt->second.liveness->insert( 593b4130e9eSStanislav Funiak opToFirstIndex[firstUseOrDef], 594b4130e9eSStanislav Funiak opToLastIndex[info->getEndOperation(value, firstUseOrDef)], 595abfd1a8bSRiver Riddle /*dummyValue*/ 0); 59685ab413bSRiver Riddle 59785ab413bSRiver Riddle // Check to see if this value is a range type. 59885ab413bSRiver Riddle if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 59985ab413bSRiver Riddle Type eleType = rangeTy.getElementType(); 6003eb1647aSStanislav Funiak if (eleType.isa<pdl::OperationType>()) 6013eb1647aSStanislav Funiak defRangeIt->second.opRangeIndex = 0; 6023eb1647aSStanislav Funiak else if (eleType.isa<pdl::TypeType>()) 60385ab413bSRiver Riddle defRangeIt->second.typeRangeIndex = 0; 60485ab413bSRiver Riddle else if (eleType.isa<pdl::ValueType>()) 60585ab413bSRiver Riddle defRangeIt->second.valueRangeIndex = 0; 60685ab413bSRiver Riddle } 607abfd1a8bSRiver Riddle }; 608abfd1a8bSRiver Riddle 609abfd1a8bSRiver Riddle // Process the live-ins of this block. 6103eb1647aSStanislav Funiak for (Value liveIn : info->in()) { 6113eb1647aSStanislav Funiak // Only process the value if it has been defined in the current region. 6123eb1647aSStanislav Funiak // Other values that span across pdl_interp.foreach will be added higher 6133eb1647aSStanislav Funiak // up. This ensures that the we keep them alive for the entire duration 6143eb1647aSStanislav Funiak // of the loop. 6153eb1647aSStanislav Funiak if (liveIn.getParentRegion() == block->getParent()) 6163eb1647aSStanislav Funiak processValue(liveIn, &block->front()); 6173eb1647aSStanislav Funiak } 6183eb1647aSStanislav Funiak 6193eb1647aSStanislav Funiak // Process the block arguments for the entry block (those are not live-in). 6203eb1647aSStanislav Funiak if (block->isEntryBlock()) { 6213eb1647aSStanislav Funiak for (Value argument : block->getArguments()) 6223eb1647aSStanislav Funiak processValue(argument, &block->front()); 6233eb1647aSStanislav Funiak } 624abfd1a8bSRiver Riddle 625abfd1a8bSRiver Riddle // Process any new defs within this block. 6263eb1647aSStanislav Funiak for (Operation &op : *block) 627abfd1a8bSRiver Riddle for (Value result : op.getResults()) 628abfd1a8bSRiver Riddle processValue(result, &op); 6293eb1647aSStanislav Funiak }); 630abfd1a8bSRiver Riddle 631abfd1a8bSRiver Riddle // Greedily allocate memory slots using the computed def live ranges. 63285ab413bSRiver Riddle std::vector<ByteCodeLiveRange> allocatedIndices; 6333eb1647aSStanislav Funiak 6343eb1647aSStanislav Funiak // The number of memory indices currently allocated (and its next value). 6353eb1647aSStanislav Funiak // Recall that the root gets allocated memory index 0. 6363eb1647aSStanislav Funiak ByteCodeField numIndices = 1; 6373eb1647aSStanislav Funiak 6383eb1647aSStanislav Funiak // The number of memory ranges of various types (and their next values). 6393eb1647aSStanislav Funiak ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; 6403eb1647aSStanislav Funiak 641abfd1a8bSRiver Riddle for (auto &defIt : valueDefRanges) { 642abfd1a8bSRiver Riddle ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 64385ab413bSRiver Riddle ByteCodeLiveRange &defRange = defIt.second; 644abfd1a8bSRiver Riddle 645abfd1a8bSRiver Riddle // Try to allocate to an existing index. 646e4853be2SMehdi Amini for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) { 64785ab413bSRiver Riddle ByteCodeLiveRange &existingRange = existingIndexIt.value(); 64885ab413bSRiver Riddle if (!defRange.overlaps(existingRange)) { 64985ab413bSRiver Riddle existingRange.unionWith(defRange); 650abfd1a8bSRiver Riddle memIndex = existingIndexIt.index() + 1; 65185ab413bSRiver Riddle 6523eb1647aSStanislav Funiak if (defRange.opRangeIndex) { 6533eb1647aSStanislav Funiak if (!existingRange.opRangeIndex) 6543eb1647aSStanislav Funiak existingRange.opRangeIndex = numOpRanges++; 6553eb1647aSStanislav Funiak valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; 6563eb1647aSStanislav Funiak } else if (defRange.typeRangeIndex) { 65785ab413bSRiver Riddle if (!existingRange.typeRangeIndex) 65885ab413bSRiver Riddle existingRange.typeRangeIndex = numTypeRanges++; 65985ab413bSRiver Riddle valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 66085ab413bSRiver Riddle } else if (defRange.valueRangeIndex) { 66185ab413bSRiver Riddle if (!existingRange.valueRangeIndex) 66285ab413bSRiver Riddle existingRange.valueRangeIndex = numValueRanges++; 66385ab413bSRiver Riddle valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 66485ab413bSRiver Riddle } 66585ab413bSRiver Riddle break; 66685ab413bSRiver Riddle } 667abfd1a8bSRiver Riddle } 668abfd1a8bSRiver Riddle 669abfd1a8bSRiver Riddle // If no existing index could be used, add a new one. 670abfd1a8bSRiver Riddle if (memIndex == 0) { 671abfd1a8bSRiver Riddle allocatedIndices.emplace_back(allocator); 67285ab413bSRiver Riddle ByteCodeLiveRange &newRange = allocatedIndices.back(); 67385ab413bSRiver Riddle newRange.unionWith(defRange); 67485ab413bSRiver Riddle 6753eb1647aSStanislav Funiak // Allocate an index for op/type/value ranges. 6763eb1647aSStanislav Funiak if (defRange.opRangeIndex) { 6773eb1647aSStanislav Funiak newRange.opRangeIndex = numOpRanges; 6783eb1647aSStanislav Funiak valueToRangeIndex[defIt.first] = numOpRanges++; 6793eb1647aSStanislav Funiak } else if (defRange.typeRangeIndex) { 68085ab413bSRiver Riddle newRange.typeRangeIndex = numTypeRanges; 68185ab413bSRiver Riddle valueToRangeIndex[defIt.first] = numTypeRanges++; 68285ab413bSRiver Riddle } else if (defRange.valueRangeIndex) { 68385ab413bSRiver Riddle newRange.valueRangeIndex = numValueRanges; 68485ab413bSRiver Riddle valueToRangeIndex[defIt.first] = numValueRanges++; 68585ab413bSRiver Riddle } 68685ab413bSRiver Riddle 687abfd1a8bSRiver Riddle memIndex = allocatedIndices.size(); 68885ab413bSRiver Riddle ++numIndices; 689abfd1a8bSRiver Riddle } 690abfd1a8bSRiver Riddle } 691abfd1a8bSRiver Riddle 6923eb1647aSStanislav Funiak // Print the index usage and ensure that we did not run out of index space. 6933eb1647aSStanislav Funiak LLVM_DEBUG({ 6943eb1647aSStanislav Funiak llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " 6953eb1647aSStanislav Funiak << "(down from initial " << valueDefRanges.size() << ").\n"; 6963eb1647aSStanislav Funiak }); 6973eb1647aSStanislav Funiak assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && 6983eb1647aSStanislav Funiak "Ran out of memory for allocated indices"); 6993eb1647aSStanislav Funiak 700abfd1a8bSRiver Riddle // Update the max number of indices. 70185ab413bSRiver Riddle if (numIndices > maxValueMemoryIndex) 70285ab413bSRiver Riddle maxValueMemoryIndex = numIndices; 7033eb1647aSStanislav Funiak if (numOpRanges > maxOpRangeMemoryIndex) 7043eb1647aSStanislav Funiak maxOpRangeMemoryIndex = numOpRanges; 70585ab413bSRiver Riddle if (numTypeRanges > maxTypeRangeMemoryIndex) 70685ab413bSRiver Riddle maxTypeRangeMemoryIndex = numTypeRanges; 70785ab413bSRiver Riddle if (numValueRanges > maxValueRangeMemoryIndex) 70885ab413bSRiver Riddle maxValueRangeMemoryIndex = numValueRanges; 709abfd1a8bSRiver Riddle } 710abfd1a8bSRiver Riddle 7113eb1647aSStanislav Funiak void Generator::generate(Region *region, ByteCodeWriter &writer) { 7123eb1647aSStanislav Funiak llvm::ReversePostOrderTraversal<Region *> rpot(region); 7133eb1647aSStanislav Funiak for (Block *block : rpot) { 7143eb1647aSStanislav Funiak // Keep track of where this block begins within the matcher function. 7153eb1647aSStanislav Funiak blockToAddr.try_emplace(block, matcherByteCode.size()); 7163eb1647aSStanislav Funiak for (Operation &op : *block) 7173eb1647aSStanislav Funiak generate(&op, writer); 7183eb1647aSStanislav Funiak } 7193eb1647aSStanislav Funiak } 7203eb1647aSStanislav Funiak 721abfd1a8bSRiver Riddle void Generator::generate(Operation *op, ByteCodeWriter &writer) { 722d35f1190SStanislav Funiak LLVM_DEBUG({ 723d35f1190SStanislav Funiak // The following list must contain all the operations that do not 724d35f1190SStanislav Funiak // produce any bytecode. 725d35f1190SStanislav Funiak if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp, 726d35f1190SStanislav Funiak pdl_interp::InferredTypesOp>(op)) 727d35f1190SStanislav Funiak writer.appendInline(op->getLoc()); 728d35f1190SStanislav Funiak }); 729abfd1a8bSRiver Riddle TypeSwitch<Operation *>(op) 730abfd1a8bSRiver Riddle .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 731abfd1a8bSRiver Riddle pdl_interp::AreEqualOp, pdl_interp::BranchOp, 732abfd1a8bSRiver Riddle pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 733abfd1a8bSRiver Riddle pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 73485ab413bSRiver Riddle pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 7353eb1647aSStanislav Funiak pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, 7363eb1647aSStanislav Funiak pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, 7373eb1647aSStanislav Funiak pdl_interp::CreateTypesOp, pdl_interp::EraseOp, 7383eb1647aSStanislav Funiak pdl_interp::ExtractOp, pdl_interp::FinalizeOp, 7393eb1647aSStanislav Funiak pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, 7403eb1647aSStanislav Funiak pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 7413eb1647aSStanislav Funiak pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, 7423eb1647aSStanislav Funiak pdl_interp::GetResultOp, pdl_interp::GetResultsOp, 7433eb1647aSStanislav Funiak pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, 7443a833a0eSRiver Riddle pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, 74502c4c0d5SRiver Riddle pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, 74602c4c0d5SRiver Riddle pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, 74785ab413bSRiver Riddle pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, 74885ab413bSRiver Riddle pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 749abfd1a8bSRiver Riddle [&](auto interpOp) { this->generate(interpOp, writer); }) 750abfd1a8bSRiver Riddle .Default([](Operation *) { 751abfd1a8bSRiver Riddle llvm_unreachable("unknown `pdl_interp` operation"); 752abfd1a8bSRiver Riddle }); 753abfd1a8bSRiver Riddle } 754abfd1a8bSRiver Riddle 755abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyConstraintOp op, 756abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 757abfd1a8bSRiver Riddle assert(constraintToMemIndex.count(op.name()) && 758abfd1a8bSRiver Riddle "expected index for constraint function"); 759abfd1a8bSRiver Riddle writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 760abfd1a8bSRiver Riddle op.constParamsAttr()); 761abfd1a8bSRiver Riddle writer.appendPDLValueList(op.args()); 762abfd1a8bSRiver Riddle writer.append(op.getSuccessors()); 763abfd1a8bSRiver Riddle } 764abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyRewriteOp op, 765abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 766abfd1a8bSRiver Riddle assert(externalRewriterToMemIndex.count(op.name()) && 767abfd1a8bSRiver Riddle "expected index for rewrite function"); 768abfd1a8bSRiver Riddle writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 76902c4c0d5SRiver Riddle op.constParamsAttr()); 770abfd1a8bSRiver Riddle writer.appendPDLValueList(op.args()); 77102c4c0d5SRiver Riddle 77285ab413bSRiver Riddle ResultRange results = op.results(); 77385ab413bSRiver Riddle writer.append(ByteCodeField(results.size())); 77485ab413bSRiver Riddle for (Value result : results) { 77585ab413bSRiver Riddle // In debug mode we also record the expected kind of the result, so that we 77685ab413bSRiver Riddle // can provide extra verification of the native rewrite function. 77702c4c0d5SRiver Riddle #ifndef NDEBUG 77885ab413bSRiver Riddle writer.appendPDLValueKind(result); 77902c4c0d5SRiver Riddle #endif 78085ab413bSRiver Riddle 78185ab413bSRiver Riddle // Range results also need to append the range storage index. 78285ab413bSRiver Riddle if (result.getType().isa<pdl::RangeType>()) 78385ab413bSRiver Riddle writer.append(getRangeStorageIndex(result)); 78402c4c0d5SRiver Riddle writer.append(result); 785abfd1a8bSRiver Riddle } 78685ab413bSRiver Riddle } 787abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 78885ab413bSRiver Riddle Value lhs = op.lhs(); 78985ab413bSRiver Riddle if (lhs.getType().isa<pdl::RangeType>()) { 79085ab413bSRiver Riddle writer.append(OpCode::AreRangesEqual); 79185ab413bSRiver Riddle writer.appendPDLValueKind(lhs); 79285ab413bSRiver Riddle writer.append(op.lhs(), op.rhs(), op.getSuccessors()); 79385ab413bSRiver Riddle return; 79485ab413bSRiver Riddle } 79585ab413bSRiver Riddle 79685ab413bSRiver Riddle writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors()); 797abfd1a8bSRiver Riddle } 798abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 7998affe881SRiver Riddle writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 800abfd1a8bSRiver Riddle } 801abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckAttributeOp op, 802abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 803abfd1a8bSRiver Riddle writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 804abfd1a8bSRiver Riddle op.getSuccessors()); 805abfd1a8bSRiver Riddle } 806abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperandCountOp op, 807abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 808abfd1a8bSRiver Riddle writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 80985ab413bSRiver Riddle static_cast<ByteCodeField>(op.compareAtLeast()), 810abfd1a8bSRiver Riddle op.getSuccessors()); 811abfd1a8bSRiver Riddle } 812abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperationNameOp op, 813abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 814abfd1a8bSRiver Riddle writer.append(OpCode::CheckOperationName, op.operation(), 815abfd1a8bSRiver Riddle OperationName(op.name(), ctx), op.getSuccessors()); 816abfd1a8bSRiver Riddle } 817abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckResultCountOp op, 818abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 819abfd1a8bSRiver Riddle writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 82085ab413bSRiver Riddle static_cast<ByteCodeField>(op.compareAtLeast()), 821abfd1a8bSRiver Riddle op.getSuccessors()); 822abfd1a8bSRiver Riddle } 823abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 824abfd1a8bSRiver Riddle writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 825abfd1a8bSRiver Riddle } 82685ab413bSRiver Riddle void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 82785ab413bSRiver Riddle writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); 82885ab413bSRiver Riddle } 8293eb1647aSStanislav Funiak void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { 8303eb1647aSStanislav Funiak assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); 8313eb1647aSStanislav Funiak writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); 8323eb1647aSStanislav Funiak } 833abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateAttributeOp op, 834abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 835abfd1a8bSRiver Riddle // Simply repoint the memory index of the result to the constant. 836abfd1a8bSRiver Riddle getMemIndex(op.attribute()) = getMemIndex(op.value()); 837abfd1a8bSRiver Riddle } 838abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateOperationOp op, 839abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 840abfd1a8bSRiver Riddle writer.append(OpCode::CreateOperation, op.operation(), 84185ab413bSRiver Riddle OperationName(op.name(), ctx)); 84285ab413bSRiver Riddle writer.appendPDLValueList(op.operands()); 843abfd1a8bSRiver Riddle 844abfd1a8bSRiver Riddle // Add the attributes. 845abfd1a8bSRiver Riddle OperandRange attributes = op.attributes(); 846abfd1a8bSRiver Riddle writer.append(static_cast<ByteCodeField>(attributes.size())); 847195730a6SRiver Riddle for (auto it : llvm::zip(op.attributeNames(), op.attributes())) 848195730a6SRiver Riddle writer.append(std::get<0>(it), std::get<1>(it)); 84985ab413bSRiver Riddle writer.appendPDLValueList(op.types()); 850abfd1a8bSRiver Riddle } 851abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 852abfd1a8bSRiver Riddle // Simply repoint the memory index of the result to the constant. 853abfd1a8bSRiver Riddle getMemIndex(op.result()) = getMemIndex(op.value()); 854abfd1a8bSRiver Riddle } 85585ab413bSRiver Riddle void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 85685ab413bSRiver Riddle writer.append(OpCode::CreateTypes, op.result(), 85785ab413bSRiver Riddle getRangeStorageIndex(op.result()), op.value()); 85885ab413bSRiver Riddle } 859abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 860abfd1a8bSRiver Riddle writer.append(OpCode::EraseOp, op.operation()); 861abfd1a8bSRiver Riddle } 8623eb1647aSStanislav Funiak void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { 8633eb1647aSStanislav Funiak OpCode opCode = 8643eb1647aSStanislav Funiak TypeSwitch<Type, OpCode>(op.result().getType()) 8653eb1647aSStanislav Funiak .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) 8663eb1647aSStanislav Funiak .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) 8673eb1647aSStanislav Funiak .Case([](pdl::TypeType) { return OpCode::ExtractType; }) 8683eb1647aSStanislav Funiak .Default([](Type) -> OpCode { 8693eb1647aSStanislav Funiak llvm_unreachable("unsupported element type"); 8703eb1647aSStanislav Funiak }); 8713eb1647aSStanislav Funiak writer.append(opCode, op.range(), op.index(), op.result()); 8723eb1647aSStanislav Funiak } 873abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 874abfd1a8bSRiver Riddle writer.append(OpCode::Finalize); 875abfd1a8bSRiver Riddle } 8763eb1647aSStanislav Funiak void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { 8773eb1647aSStanislav Funiak BlockArgument arg = op.getLoopVariable(); 8783eb1647aSStanislav Funiak writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg); 8793eb1647aSStanislav Funiak writer.appendPDLValueKind(arg.getType()); 8803eb1647aSStanislav Funiak writer.append(curLoopLevel, op.successor()); 8813eb1647aSStanislav Funiak ++curLoopLevel; 8823eb1647aSStanislav Funiak if (curLoopLevel > maxLoopLevel) 8833eb1647aSStanislav Funiak maxLoopLevel = curLoopLevel; 8843eb1647aSStanislav Funiak generate(&op.region(), writer); 8853eb1647aSStanislav Funiak --curLoopLevel; 8863eb1647aSStanislav Funiak } 887abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeOp op, 888abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 889abfd1a8bSRiver Riddle writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 890195730a6SRiver Riddle op.nameAttr()); 891abfd1a8bSRiver Riddle } 892abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeTypeOp op, 893abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 894abfd1a8bSRiver Riddle writer.append(OpCode::GetAttributeType, op.result(), op.value()); 895abfd1a8bSRiver Riddle } 896abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetDefiningOpOp op, 897abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 89885ab413bSRiver Riddle writer.append(OpCode::GetDefiningOp, op.operation()); 89985ab413bSRiver Riddle writer.appendPDLValue(op.value()); 900abfd1a8bSRiver Riddle } 901abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 902abfd1a8bSRiver Riddle uint32_t index = op.index(); 903abfd1a8bSRiver Riddle if (index < 4) 904abfd1a8bSRiver Riddle writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 905abfd1a8bSRiver Riddle else 906abfd1a8bSRiver Riddle writer.append(OpCode::GetOperandN, index); 907abfd1a8bSRiver Riddle writer.append(op.operation(), op.value()); 908abfd1a8bSRiver Riddle } 90985ab413bSRiver Riddle void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 91085ab413bSRiver Riddle Value result = op.value(); 91185ab413bSRiver Riddle Optional<uint32_t> index = op.index(); 91285ab413bSRiver Riddle writer.append(OpCode::GetOperands, 91385ab413bSRiver Riddle index.getValueOr(std::numeric_limits<uint32_t>::max()), 91485ab413bSRiver Riddle op.operation()); 91585ab413bSRiver Riddle if (result.getType().isa<pdl::RangeType>()) 91685ab413bSRiver Riddle writer.append(getRangeStorageIndex(result)); 91785ab413bSRiver Riddle else 91885ab413bSRiver Riddle writer.append(std::numeric_limits<ByteCodeField>::max()); 91985ab413bSRiver Riddle writer.append(result); 92085ab413bSRiver Riddle } 921abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 922abfd1a8bSRiver Riddle uint32_t index = op.index(); 923abfd1a8bSRiver Riddle if (index < 4) 924abfd1a8bSRiver Riddle writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 925abfd1a8bSRiver Riddle else 926abfd1a8bSRiver Riddle writer.append(OpCode::GetResultN, index); 927abfd1a8bSRiver Riddle writer.append(op.operation(), op.value()); 928abfd1a8bSRiver Riddle } 92985ab413bSRiver Riddle void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 93085ab413bSRiver Riddle Value result = op.value(); 93185ab413bSRiver Riddle Optional<uint32_t> index = op.index(); 93285ab413bSRiver Riddle writer.append(OpCode::GetResults, 93385ab413bSRiver Riddle index.getValueOr(std::numeric_limits<uint32_t>::max()), 93485ab413bSRiver Riddle op.operation()); 93585ab413bSRiver Riddle if (result.getType().isa<pdl::RangeType>()) 93685ab413bSRiver Riddle writer.append(getRangeStorageIndex(result)); 93785ab413bSRiver Riddle else 93885ab413bSRiver Riddle writer.append(std::numeric_limits<ByteCodeField>::max()); 93985ab413bSRiver Riddle writer.append(result); 94085ab413bSRiver Riddle } 9413eb1647aSStanislav Funiak void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { 9423eb1647aSStanislav Funiak Value operations = op.operations(); 9433eb1647aSStanislav Funiak ByteCodeField rangeIndex = getRangeStorageIndex(operations); 9443eb1647aSStanislav Funiak writer.append(OpCode::GetUsers, operations, rangeIndex); 9453eb1647aSStanislav Funiak writer.appendPDLValue(op.value()); 9463eb1647aSStanislav Funiak } 947abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetValueTypeOp op, 948abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 94985ab413bSRiver Riddle if (op.getType().isa<pdl::RangeType>()) { 95085ab413bSRiver Riddle Value result = op.result(); 95185ab413bSRiver Riddle writer.append(OpCode::GetValueRangeTypes, result, 95285ab413bSRiver Riddle getRangeStorageIndex(result), op.value()); 95385ab413bSRiver Riddle } else { 954abfd1a8bSRiver Riddle writer.append(OpCode::GetValueType, op.result(), op.value()); 955abfd1a8bSRiver Riddle } 95685ab413bSRiver Riddle } 95785ab413bSRiver Riddle 9583a833a0eSRiver Riddle void Generator::generate(pdl_interp::InferredTypesOp op, 959abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 9603a833a0eSRiver Riddle // InferType maps to a null type as a marker for inferring result types. 961abfd1a8bSRiver Riddle getMemIndex(op.type()) = getMemIndex(Type()); 962abfd1a8bSRiver Riddle } 963abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 964abfd1a8bSRiver Riddle writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 965abfd1a8bSRiver Riddle } 966abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 967abfd1a8bSRiver Riddle ByteCodeField patternIndex = patterns.size(); 968abfd1a8bSRiver Riddle patterns.emplace_back(PDLByteCodePattern::create( 96941d4aa7dSChris Lattner op, rewriterToAddr[op.rewriter().getLeafReference().getValue()])); 9708affe881SRiver Riddle writer.append(OpCode::RecordMatch, patternIndex, 97185ab413bSRiver Riddle SuccessorRange(op.getOperation()), op.matchedOps()); 97285ab413bSRiver Riddle writer.appendPDLValueList(op.inputs()); 973abfd1a8bSRiver Riddle } 974abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 97585ab413bSRiver Riddle writer.append(OpCode::ReplaceOp, op.operation()); 97685ab413bSRiver Riddle writer.appendPDLValueList(op.replValues()); 977abfd1a8bSRiver Riddle } 978abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchAttributeOp op, 979abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 980abfd1a8bSRiver Riddle writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 981abfd1a8bSRiver Riddle op.getSuccessors()); 982abfd1a8bSRiver Riddle } 983abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperandCountOp op, 984abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 985abfd1a8bSRiver Riddle writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 986abfd1a8bSRiver Riddle op.getSuccessors()); 987abfd1a8bSRiver Riddle } 988abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperationNameOp op, 989abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 990abfd1a8bSRiver Riddle auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 991abfd1a8bSRiver Riddle return OperationName(attr.cast<StringAttr>().getValue(), ctx); 992abfd1a8bSRiver Riddle }); 993abfd1a8bSRiver Riddle writer.append(OpCode::SwitchOperationName, op.operation(), cases, 994abfd1a8bSRiver Riddle op.getSuccessors()); 995abfd1a8bSRiver Riddle } 996abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchResultCountOp op, 997abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 998abfd1a8bSRiver Riddle writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 999abfd1a8bSRiver Riddle op.getSuccessors()); 1000abfd1a8bSRiver Riddle } 1001abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 1002abfd1a8bSRiver Riddle writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 1003abfd1a8bSRiver Riddle op.getSuccessors()); 1004abfd1a8bSRiver Riddle } 100585ab413bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 100685ab413bSRiver Riddle writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(), 100785ab413bSRiver Riddle op.getSuccessors()); 100885ab413bSRiver Riddle } 1009abfd1a8bSRiver Riddle 1010abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 1011abfd1a8bSRiver Riddle // PDLByteCode 1012abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 1013abfd1a8bSRiver Riddle 1014abfd1a8bSRiver Riddle PDLByteCode::PDLByteCode(ModuleOp module, 1015abfd1a8bSRiver Riddle llvm::StringMap<PDLConstraintFunction> constraintFns, 1016abfd1a8bSRiver Riddle llvm::StringMap<PDLRewriteFunction> rewriteFns) { 1017abfd1a8bSRiver Riddle Generator generator(module.getContext(), uniquedData, matcherByteCode, 1018abfd1a8bSRiver Riddle rewriterByteCode, patterns, maxValueMemoryIndex, 10193eb1647aSStanislav Funiak maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, 10203eb1647aSStanislav Funiak maxLoopLevel, constraintFns, rewriteFns); 1021abfd1a8bSRiver Riddle generator.generate(module); 1022abfd1a8bSRiver Riddle 1023abfd1a8bSRiver Riddle // Initialize the external functions. 1024abfd1a8bSRiver Riddle for (auto &it : constraintFns) 1025abfd1a8bSRiver Riddle constraintFunctions.push_back(std::move(it.second)); 1026abfd1a8bSRiver Riddle for (auto &it : rewriteFns) 1027abfd1a8bSRiver Riddle rewriteFunctions.push_back(std::move(it.second)); 1028abfd1a8bSRiver Riddle } 1029abfd1a8bSRiver Riddle 1030abfd1a8bSRiver Riddle /// Initialize the given state such that it can be used to execute the current 1031abfd1a8bSRiver Riddle /// bytecode. 1032abfd1a8bSRiver Riddle void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 1033abfd1a8bSRiver Riddle state.memory.resize(maxValueMemoryIndex, nullptr); 10343eb1647aSStanislav Funiak state.opRangeMemory.resize(maxOpRangeCount); 103585ab413bSRiver Riddle state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 103685ab413bSRiver Riddle state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 10373eb1647aSStanislav Funiak state.loopIndex.resize(maxLoopLevel, 0); 1038abfd1a8bSRiver Riddle state.currentPatternBenefits.reserve(patterns.size()); 1039abfd1a8bSRiver Riddle for (const PDLByteCodePattern &pattern : patterns) 1040abfd1a8bSRiver Riddle state.currentPatternBenefits.push_back(pattern.getBenefit()); 1041abfd1a8bSRiver Riddle } 1042abfd1a8bSRiver Riddle 1043abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 1044abfd1a8bSRiver Riddle // ByteCode Execution 1045abfd1a8bSRiver Riddle 1046abfd1a8bSRiver Riddle namespace { 1047abfd1a8bSRiver Riddle /// This class provides support for executing a bytecode stream. 1048abfd1a8bSRiver Riddle class ByteCodeExecutor { 1049abfd1a8bSRiver Riddle public: 105085ab413bSRiver Riddle ByteCodeExecutor( 105185ab413bSRiver Riddle const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 10523eb1647aSStanislav Funiak MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, 105385ab413bSRiver Riddle MutableArrayRef<TypeRange> typeRangeMemory, 105485ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 105585ab413bSRiver Riddle MutableArrayRef<ValueRange> valueRangeMemory, 105685ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 10573eb1647aSStanislav Funiak MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, 10583eb1647aSStanislav Funiak ArrayRef<ByteCodeField> code, 1059abfd1a8bSRiver Riddle ArrayRef<PatternBenefit> currentPatternBenefits, 1060abfd1a8bSRiver Riddle ArrayRef<PDLByteCodePattern> patterns, 1061abfd1a8bSRiver Riddle ArrayRef<PDLConstraintFunction> constraintFunctions, 1062abfd1a8bSRiver Riddle ArrayRef<PDLRewriteFunction> rewriteFunctions) 10633eb1647aSStanislav Funiak : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), 10643eb1647aSStanislav Funiak typeRangeMemory(typeRangeMemory), 106585ab413bSRiver Riddle allocatedTypeRangeMemory(allocatedTypeRangeMemory), 106685ab413bSRiver Riddle valueRangeMemory(valueRangeMemory), 106785ab413bSRiver Riddle allocatedValueRangeMemory(allocatedValueRangeMemory), 10683eb1647aSStanislav Funiak loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), 106985ab413bSRiver Riddle currentPatternBenefits(currentPatternBenefits), patterns(patterns), 107085ab413bSRiver Riddle constraintFunctions(constraintFunctions), 107102c4c0d5SRiver Riddle rewriteFunctions(rewriteFunctions) {} 1072abfd1a8bSRiver Riddle 1073abfd1a8bSRiver Riddle /// Start executing the code at the current bytecode index. `matches` is an 1074abfd1a8bSRiver Riddle /// optional field provided when this function is executed in a matching 1075abfd1a8bSRiver Riddle /// context. 1076abfd1a8bSRiver Riddle void execute(PatternRewriter &rewriter, 1077abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 1078abfd1a8bSRiver Riddle Optional<Location> mainRewriteLoc = {}); 1079abfd1a8bSRiver Riddle 1080abfd1a8bSRiver Riddle private: 1081154cabe7SRiver Riddle /// Internal implementation of executing each of the bytecode commands. 1082154cabe7SRiver Riddle void executeApplyConstraint(PatternRewriter &rewriter); 1083154cabe7SRiver Riddle void executeApplyRewrite(PatternRewriter &rewriter); 1084154cabe7SRiver Riddle void executeAreEqual(); 108585ab413bSRiver Riddle void executeAreRangesEqual(); 1086154cabe7SRiver Riddle void executeBranch(); 1087154cabe7SRiver Riddle void executeCheckOperandCount(); 1088154cabe7SRiver Riddle void executeCheckOperationName(); 1089154cabe7SRiver Riddle void executeCheckResultCount(); 109085ab413bSRiver Riddle void executeCheckTypes(); 10913eb1647aSStanislav Funiak void executeContinue(); 1092154cabe7SRiver Riddle void executeCreateOperation(PatternRewriter &rewriter, 1093154cabe7SRiver Riddle Location mainRewriteLoc); 109485ab413bSRiver Riddle void executeCreateTypes(); 1095154cabe7SRiver Riddle void executeEraseOp(PatternRewriter &rewriter); 10963eb1647aSStanislav Funiak template <typename T, typename Range, PDLValue::Kind kind> 10973eb1647aSStanislav Funiak void executeExtract(); 10983eb1647aSStanislav Funiak void executeFinalize(); 10993eb1647aSStanislav Funiak void executeForEach(); 1100154cabe7SRiver Riddle void executeGetAttribute(); 1101154cabe7SRiver Riddle void executeGetAttributeType(); 1102154cabe7SRiver Riddle void executeGetDefiningOp(); 1103154cabe7SRiver Riddle void executeGetOperand(unsigned index); 110485ab413bSRiver Riddle void executeGetOperands(); 1105154cabe7SRiver Riddle void executeGetResult(unsigned index); 110685ab413bSRiver Riddle void executeGetResults(); 11073eb1647aSStanislav Funiak void executeGetUsers(); 1108154cabe7SRiver Riddle void executeGetValueType(); 110985ab413bSRiver Riddle void executeGetValueRangeTypes(); 1110154cabe7SRiver Riddle void executeIsNotNull(); 1111154cabe7SRiver Riddle void executeRecordMatch(PatternRewriter &rewriter, 1112154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> &matches); 1113154cabe7SRiver Riddle void executeReplaceOp(PatternRewriter &rewriter); 1114154cabe7SRiver Riddle void executeSwitchAttribute(); 1115154cabe7SRiver Riddle void executeSwitchOperandCount(); 1116154cabe7SRiver Riddle void executeSwitchOperationName(); 1117154cabe7SRiver Riddle void executeSwitchResultCount(); 1118154cabe7SRiver Riddle void executeSwitchType(); 111985ab413bSRiver Riddle void executeSwitchTypes(); 1120154cabe7SRiver Riddle 11213eb1647aSStanislav Funiak /// Pushes a code iterator to the stack. 11223eb1647aSStanislav Funiak void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } 11233eb1647aSStanislav Funiak 11243eb1647aSStanislav Funiak /// Pops a code iterator from the stack, returning true on success. 11253eb1647aSStanislav Funiak void popCodeIt() { 11263eb1647aSStanislav Funiak assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); 11273eb1647aSStanislav Funiak curCodeIt = resumeCodeIt.back(); 11283eb1647aSStanislav Funiak resumeCodeIt.pop_back(); 11293eb1647aSStanislav Funiak } 11303eb1647aSStanislav Funiak 1131d35f1190SStanislav Funiak /// Return the bytecode iterator at the start of the current op code. 1132d35f1190SStanislav Funiak const ByteCodeField *getPrevCodeIt() const { 1133d35f1190SStanislav Funiak LLVM_DEBUG({ 1134d35f1190SStanislav Funiak // Account for the op code and the Location stored inline. 1135d35f1190SStanislav Funiak return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField); 1136d35f1190SStanislav Funiak }); 1137d35f1190SStanislav Funiak 1138d35f1190SStanislav Funiak // Account for the op code only. 1139d35f1190SStanislav Funiak return curCodeIt - 1; 1140d35f1190SStanislav Funiak } 1141d35f1190SStanislav Funiak 1142abfd1a8bSRiver Riddle /// Read a value from the bytecode buffer, optionally skipping a certain 1143abfd1a8bSRiver Riddle /// number of prefix values. These methods always update the buffer to point 1144abfd1a8bSRiver Riddle /// to the next field after the read data. 1145abfd1a8bSRiver Riddle template <typename T = ByteCodeField> 1146abfd1a8bSRiver Riddle T read(size_t skipN = 0) { 1147abfd1a8bSRiver Riddle curCodeIt += skipN; 1148abfd1a8bSRiver Riddle return readImpl<T>(); 1149abfd1a8bSRiver Riddle } 1150abfd1a8bSRiver Riddle ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 1151abfd1a8bSRiver Riddle 1152abfd1a8bSRiver Riddle /// Read a list of values from the bytecode buffer. 1153abfd1a8bSRiver Riddle template <typename ValueT, typename T> 1154abfd1a8bSRiver Riddle void readList(SmallVectorImpl<T> &list) { 1155abfd1a8bSRiver Riddle list.clear(); 1156abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) 1157abfd1a8bSRiver Riddle list.push_back(read<ValueT>()); 1158abfd1a8bSRiver Riddle } 1159abfd1a8bSRiver Riddle 116085ab413bSRiver Riddle /// Read a list of values from the bytecode buffer. The values may be encoded 116185ab413bSRiver Riddle /// as either Value or ValueRange elements. 116285ab413bSRiver Riddle void readValueList(SmallVectorImpl<Value> &list) { 116385ab413bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 116485ab413bSRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 116585ab413bSRiver Riddle list.push_back(read<Value>()); 116685ab413bSRiver Riddle } else { 116785ab413bSRiver Riddle ValueRange *values = read<ValueRange *>(); 116885ab413bSRiver Riddle list.append(values->begin(), values->end()); 116985ab413bSRiver Riddle } 117085ab413bSRiver Riddle } 117185ab413bSRiver Riddle } 117285ab413bSRiver Riddle 1173d35f1190SStanislav Funiak /// Read a value stored inline as a pointer. 1174d35f1190SStanislav Funiak template <typename T> 1175d35f1190SStanislav Funiak std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T> 1176d35f1190SStanislav Funiak readInline() { 1177d35f1190SStanislav Funiak const void *pointer; 1178d35f1190SStanislav Funiak std::memcpy(&pointer, curCodeIt, sizeof(const void *)); 1179d35f1190SStanislav Funiak curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); 1180d35f1190SStanislav Funiak return T::getFromOpaquePointer(pointer); 1181d35f1190SStanislav Funiak } 1182d35f1190SStanislav Funiak 1183abfd1a8bSRiver Riddle /// Jump to a specific successor based on a predicate value. 1184abfd1a8bSRiver Riddle void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 1185abfd1a8bSRiver Riddle /// Jump to a specific successor based on a destination index. 1186abfd1a8bSRiver Riddle void selectJump(size_t destIndex) { 1187abfd1a8bSRiver Riddle curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 1188abfd1a8bSRiver Riddle } 1189abfd1a8bSRiver Riddle 1190abfd1a8bSRiver Riddle /// Handle a switch operation with the provided value and cases. 119185ab413bSRiver Riddle template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 119285ab413bSRiver Riddle void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1193abfd1a8bSRiver Riddle LLVM_DEBUG({ 1194abfd1a8bSRiver Riddle llvm::dbgs() << " * Value: " << value << "\n" 1195abfd1a8bSRiver Riddle << " * Cases: "; 1196abfd1a8bSRiver Riddle llvm::interleaveComma(cases, llvm::dbgs()); 1197154cabe7SRiver Riddle llvm::dbgs() << "\n"; 1198abfd1a8bSRiver Riddle }); 1199abfd1a8bSRiver Riddle 1200abfd1a8bSRiver Riddle // Check to see if the attribute value is within the case list. Jump to 1201abfd1a8bSRiver Riddle // the correct successor index based on the result. 1202f80b6304SRiver Riddle for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 120385ab413bSRiver Riddle if (cmp(*it, value)) 1204f80b6304SRiver Riddle return selectJump(size_t((it - cases.begin()) + 1)); 1205f80b6304SRiver Riddle selectJump(size_t(0)); 1206abfd1a8bSRiver Riddle } 1207abfd1a8bSRiver Riddle 12083eb1647aSStanislav Funiak /// Store a pointer to memory. 12093eb1647aSStanislav Funiak void storeToMemory(unsigned index, const void *value) { 12103eb1647aSStanislav Funiak memory[index] = value; 12113eb1647aSStanislav Funiak } 12123eb1647aSStanislav Funiak 12133eb1647aSStanislav Funiak /// Store a value to memory as an opaque pointer. 12143eb1647aSStanislav Funiak template <typename T> 12153eb1647aSStanislav Funiak std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value> 12163eb1647aSStanislav Funiak storeToMemory(unsigned index, T value) { 12173eb1647aSStanislav Funiak memory[index] = value.getAsOpaquePointer(); 12183eb1647aSStanislav Funiak } 12193eb1647aSStanislav Funiak 1220abfd1a8bSRiver Riddle /// Internal implementation of reading various data types from the bytecode 1221abfd1a8bSRiver Riddle /// stream. 1222abfd1a8bSRiver Riddle template <typename T> 1223abfd1a8bSRiver Riddle const void *readFromMemory() { 1224abfd1a8bSRiver Riddle size_t index = *curCodeIt++; 1225abfd1a8bSRiver Riddle 1226abfd1a8bSRiver Riddle // If this type is an SSA value, it can only be stored in non-const memory. 122785ab413bSRiver Riddle if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 122885ab413bSRiver Riddle Value>::value || 122985ab413bSRiver Riddle index < memory.size()) 1230abfd1a8bSRiver Riddle return memory[index]; 1231abfd1a8bSRiver Riddle 1232abfd1a8bSRiver Riddle // Otherwise, if this index is not inbounds it is uniqued. 1233abfd1a8bSRiver Riddle return uniquedMemory[index - memory.size()]; 1234abfd1a8bSRiver Riddle } 1235abfd1a8bSRiver Riddle template <typename T> 1236abfd1a8bSRiver Riddle std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1237abfd1a8bSRiver Riddle return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1238abfd1a8bSRiver Riddle } 1239abfd1a8bSRiver Riddle template <typename T> 1240abfd1a8bSRiver Riddle std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1241abfd1a8bSRiver Riddle T> 1242abfd1a8bSRiver Riddle readImpl() { 1243abfd1a8bSRiver Riddle return T(T::getFromOpaquePointer(readFromMemory<T>())); 1244abfd1a8bSRiver Riddle } 1245abfd1a8bSRiver Riddle template <typename T> 1246abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 124785ab413bSRiver Riddle switch (read<PDLValue::Kind>()) { 124885ab413bSRiver Riddle case PDLValue::Kind::Attribute: 1249abfd1a8bSRiver Riddle return read<Attribute>(); 125085ab413bSRiver Riddle case PDLValue::Kind::Operation: 1251abfd1a8bSRiver Riddle return read<Operation *>(); 125285ab413bSRiver Riddle case PDLValue::Kind::Type: 1253abfd1a8bSRiver Riddle return read<Type>(); 125485ab413bSRiver Riddle case PDLValue::Kind::Value: 1255abfd1a8bSRiver Riddle return read<Value>(); 125685ab413bSRiver Riddle case PDLValue::Kind::TypeRange: 125785ab413bSRiver Riddle return read<TypeRange *>(); 125885ab413bSRiver Riddle case PDLValue::Kind::ValueRange: 125985ab413bSRiver Riddle return read<ValueRange *>(); 1260abfd1a8bSRiver Riddle } 126185ab413bSRiver Riddle llvm_unreachable("unhandled PDLValue::Kind"); 1262abfd1a8bSRiver Riddle } 1263abfd1a8bSRiver Riddle template <typename T> 1264abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1265abfd1a8bSRiver Riddle static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1266abfd1a8bSRiver Riddle "unexpected ByteCode address size"); 1267abfd1a8bSRiver Riddle ByteCodeAddr result; 1268abfd1a8bSRiver Riddle std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1269abfd1a8bSRiver Riddle curCodeIt += 2; 1270abfd1a8bSRiver Riddle return result; 1271abfd1a8bSRiver Riddle } 1272abfd1a8bSRiver Riddle template <typename T> 1273abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1274abfd1a8bSRiver Riddle return *curCodeIt++; 1275abfd1a8bSRiver Riddle } 127685ab413bSRiver Riddle template <typename T> 127785ab413bSRiver Riddle std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 127885ab413bSRiver Riddle return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 127985ab413bSRiver Riddle } 1280abfd1a8bSRiver Riddle 1281abfd1a8bSRiver Riddle /// The underlying bytecode buffer. 1282abfd1a8bSRiver Riddle const ByteCodeField *curCodeIt; 1283abfd1a8bSRiver Riddle 12843eb1647aSStanislav Funiak /// The stack of bytecode positions at which to resume operation. 12853eb1647aSStanislav Funiak SmallVector<const ByteCodeField *> resumeCodeIt; 12863eb1647aSStanislav Funiak 1287abfd1a8bSRiver Riddle /// The current execution memory. 1288abfd1a8bSRiver Riddle MutableArrayRef<const void *> memory; 12893eb1647aSStanislav Funiak MutableArrayRef<OwningOpRange> opRangeMemory; 129085ab413bSRiver Riddle MutableArrayRef<TypeRange> typeRangeMemory; 129185ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 129285ab413bSRiver Riddle MutableArrayRef<ValueRange> valueRangeMemory; 129385ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1294abfd1a8bSRiver Riddle 12953eb1647aSStanislav Funiak /// The current loop indices. 12963eb1647aSStanislav Funiak MutableArrayRef<unsigned> loopIndex; 12973eb1647aSStanislav Funiak 1298abfd1a8bSRiver Riddle /// References to ByteCode data necessary for execution. 1299abfd1a8bSRiver Riddle ArrayRef<const void *> uniquedMemory; 1300abfd1a8bSRiver Riddle ArrayRef<ByteCodeField> code; 1301abfd1a8bSRiver Riddle ArrayRef<PatternBenefit> currentPatternBenefits; 1302abfd1a8bSRiver Riddle ArrayRef<PDLByteCodePattern> patterns; 1303abfd1a8bSRiver Riddle ArrayRef<PDLConstraintFunction> constraintFunctions; 1304abfd1a8bSRiver Riddle ArrayRef<PDLRewriteFunction> rewriteFunctions; 1305abfd1a8bSRiver Riddle }; 130602c4c0d5SRiver Riddle 130702c4c0d5SRiver Riddle /// This class is an instantiation of the PDLResultList that provides access to 130802c4c0d5SRiver Riddle /// the returned results. This API is not on `PDLResultList` to avoid 130902c4c0d5SRiver Riddle /// overexposing access to information specific solely to the ByteCode. 131002c4c0d5SRiver Riddle class ByteCodeRewriteResultList : public PDLResultList { 131102c4c0d5SRiver Riddle public: 131285ab413bSRiver Riddle ByteCodeRewriteResultList(unsigned maxNumResults) 131385ab413bSRiver Riddle : PDLResultList(maxNumResults) {} 131485ab413bSRiver Riddle 131502c4c0d5SRiver Riddle /// Return the list of PDL results. 131602c4c0d5SRiver Riddle MutableArrayRef<PDLValue> getResults() { return results; } 131785ab413bSRiver Riddle 131885ab413bSRiver Riddle /// Return the type ranges allocated by this list. 131985ab413bSRiver Riddle MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 132085ab413bSRiver Riddle return allocatedTypeRanges; 132185ab413bSRiver Riddle } 132285ab413bSRiver Riddle 132385ab413bSRiver Riddle /// Return the value ranges allocated by this list. 132485ab413bSRiver Riddle MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 132585ab413bSRiver Riddle return allocatedValueRanges; 132685ab413bSRiver Riddle } 132702c4c0d5SRiver Riddle }; 1328be0a7e9fSMehdi Amini } // namespace 1329abfd1a8bSRiver Riddle 1330154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1331abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1332abfd1a8bSRiver Riddle const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1333abfd1a8bSRiver Riddle ArrayAttr constParams = read<ArrayAttr>(); 1334abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args; 1335abfd1a8bSRiver Riddle readList<PDLValue>(args); 1336154cabe7SRiver Riddle 1337abfd1a8bSRiver Riddle LLVM_DEBUG({ 1338abfd1a8bSRiver Riddle llvm::dbgs() << " * Arguments: "; 1339abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 1340154cabe7SRiver Riddle llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1341abfd1a8bSRiver Riddle }); 1342abfd1a8bSRiver Riddle 1343abfd1a8bSRiver Riddle // Invoke the constraint and jump to the proper destination. 1344abfd1a8bSRiver Riddle selectJump(succeeded(constraintFn(args, constParams, rewriter))); 1345abfd1a8bSRiver Riddle } 1346154cabe7SRiver Riddle 1347154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1348abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1349abfd1a8bSRiver Riddle const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1350abfd1a8bSRiver Riddle ArrayAttr constParams = read<ArrayAttr>(); 1351abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args; 1352abfd1a8bSRiver Riddle readList<PDLValue>(args); 1353abfd1a8bSRiver Riddle 1354abfd1a8bSRiver Riddle LLVM_DEBUG({ 135502c4c0d5SRiver Riddle llvm::dbgs() << " * Arguments: "; 1356abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 1357154cabe7SRiver Riddle llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1358abfd1a8bSRiver Riddle }); 135985ab413bSRiver Riddle 136085ab413bSRiver Riddle // Execute the rewrite function. 136185ab413bSRiver Riddle ByteCodeField numResults = read(); 136285ab413bSRiver Riddle ByteCodeRewriteResultList results(numResults); 136302c4c0d5SRiver Riddle rewriteFn(args, constParams, rewriter, results); 1364154cabe7SRiver Riddle 136585ab413bSRiver Riddle assert(results.getResults().size() == numResults && 136602c4c0d5SRiver Riddle "native PDL rewrite function returned unexpected number of results"); 136702c4c0d5SRiver Riddle 136802c4c0d5SRiver Riddle // Store the results in the bytecode memory. 136902c4c0d5SRiver Riddle for (PDLValue &result : results.getResults()) { 137002c4c0d5SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 137185ab413bSRiver Riddle 137285ab413bSRiver Riddle // In debug mode we also verify the expected kind of the result. 137385ab413bSRiver Riddle #ifndef NDEBUG 137485ab413bSRiver Riddle assert(result.getKind() == read<PDLValue::Kind>() && 137585ab413bSRiver Riddle "native PDL rewrite function returned an unexpected type of result"); 137685ab413bSRiver Riddle #endif 137785ab413bSRiver Riddle 137885ab413bSRiver Riddle // If the result is a range, we need to copy it over to the bytecodes 137985ab413bSRiver Riddle // range memory. 138085ab413bSRiver Riddle if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 138185ab413bSRiver Riddle unsigned rangeIndex = read(); 138285ab413bSRiver Riddle typeRangeMemory[rangeIndex] = *typeRange; 138385ab413bSRiver Riddle memory[read()] = &typeRangeMemory[rangeIndex]; 138485ab413bSRiver Riddle } else if (Optional<ValueRange> valueRange = 138585ab413bSRiver Riddle result.dyn_cast<ValueRange>()) { 138685ab413bSRiver Riddle unsigned rangeIndex = read(); 138785ab413bSRiver Riddle valueRangeMemory[rangeIndex] = *valueRange; 138885ab413bSRiver Riddle memory[read()] = &valueRangeMemory[rangeIndex]; 138985ab413bSRiver Riddle } else { 139002c4c0d5SRiver Riddle memory[read()] = result.getAsOpaquePointer(); 139102c4c0d5SRiver Riddle } 1392abfd1a8bSRiver Riddle } 1393154cabe7SRiver Riddle 139485ab413bSRiver Riddle // Copy over any underlying storage allocated for result ranges. 139585ab413bSRiver Riddle for (auto &it : results.getAllocatedTypeRanges()) 139685ab413bSRiver Riddle allocatedTypeRangeMemory.push_back(std::move(it)); 139785ab413bSRiver Riddle for (auto &it : results.getAllocatedValueRanges()) 139885ab413bSRiver Riddle allocatedValueRangeMemory.push_back(std::move(it)); 139985ab413bSRiver Riddle } 140085ab413bSRiver Riddle 1401154cabe7SRiver Riddle void ByteCodeExecutor::executeAreEqual() { 1402abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1403abfd1a8bSRiver Riddle const void *lhs = read<const void *>(); 1404abfd1a8bSRiver Riddle const void *rhs = read<const void *>(); 1405abfd1a8bSRiver Riddle 1406154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1407abfd1a8bSRiver Riddle selectJump(lhs == rhs); 1408abfd1a8bSRiver Riddle } 1409154cabe7SRiver Riddle 141085ab413bSRiver Riddle void ByteCodeExecutor::executeAreRangesEqual() { 141185ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 141285ab413bSRiver Riddle PDLValue::Kind valueKind = read<PDLValue::Kind>(); 141385ab413bSRiver Riddle const void *lhs = read<const void *>(); 141485ab413bSRiver Riddle const void *rhs = read<const void *>(); 141585ab413bSRiver Riddle 141685ab413bSRiver Riddle switch (valueKind) { 141785ab413bSRiver Riddle case PDLValue::Kind::TypeRange: { 141885ab413bSRiver Riddle const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 141985ab413bSRiver Riddle const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 142085ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 142185ab413bSRiver Riddle selectJump(*lhsRange == *rhsRange); 142285ab413bSRiver Riddle break; 142385ab413bSRiver Riddle } 142485ab413bSRiver Riddle case PDLValue::Kind::ValueRange: { 142585ab413bSRiver Riddle const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 142685ab413bSRiver Riddle const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 142785ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 142885ab413bSRiver Riddle selectJump(*lhsRange == *rhsRange); 142985ab413bSRiver Riddle break; 143085ab413bSRiver Riddle } 143185ab413bSRiver Riddle default: 143285ab413bSRiver Riddle llvm_unreachable("unexpected `AreRangesEqual` value kind"); 143385ab413bSRiver Riddle } 143485ab413bSRiver Riddle } 143585ab413bSRiver Riddle 1436154cabe7SRiver Riddle void ByteCodeExecutor::executeBranch() { 1437154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1438abfd1a8bSRiver Riddle curCodeIt = &code[read<ByteCodeAddr>()]; 1439abfd1a8bSRiver Riddle } 1440154cabe7SRiver Riddle 1441154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperandCount() { 1442abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1443abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1444abfd1a8bSRiver Riddle uint32_t expectedCount = read<uint32_t>(); 144585ab413bSRiver Riddle bool compareAtLeast = read(); 1446abfd1a8bSRiver Riddle 1447abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 144885ab413bSRiver Riddle << " * Expected: " << expectedCount << "\n" 144985ab413bSRiver Riddle << " * Comparator: " 145085ab413bSRiver Riddle << (compareAtLeast ? ">=" : "==") << "\n"); 145185ab413bSRiver Riddle if (compareAtLeast) 145285ab413bSRiver Riddle selectJump(op->getNumOperands() >= expectedCount); 145385ab413bSRiver Riddle else 1454abfd1a8bSRiver Riddle selectJump(op->getNumOperands() == expectedCount); 1455abfd1a8bSRiver Riddle } 1456154cabe7SRiver Riddle 1457154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperationName() { 1458abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1459abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1460abfd1a8bSRiver Riddle OperationName expectedName = read<OperationName>(); 1461abfd1a8bSRiver Riddle 1462154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1463154cabe7SRiver Riddle << " * Expected: \"" << expectedName << "\"\n"); 1464abfd1a8bSRiver Riddle selectJump(op->getName() == expectedName); 1465abfd1a8bSRiver Riddle } 1466154cabe7SRiver Riddle 1467154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckResultCount() { 1468abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1469abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1470abfd1a8bSRiver Riddle uint32_t expectedCount = read<uint32_t>(); 147185ab413bSRiver Riddle bool compareAtLeast = read(); 1472abfd1a8bSRiver Riddle 1473abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 147485ab413bSRiver Riddle << " * Expected: " << expectedCount << "\n" 147585ab413bSRiver Riddle << " * Comparator: " 147685ab413bSRiver Riddle << (compareAtLeast ? ">=" : "==") << "\n"); 147785ab413bSRiver Riddle if (compareAtLeast) 147885ab413bSRiver Riddle selectJump(op->getNumResults() >= expectedCount); 147985ab413bSRiver Riddle else 1480abfd1a8bSRiver Riddle selectJump(op->getNumResults() == expectedCount); 1481abfd1a8bSRiver Riddle } 1482154cabe7SRiver Riddle 148385ab413bSRiver Riddle void ByteCodeExecutor::executeCheckTypes() { 148485ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 148585ab413bSRiver Riddle TypeRange *lhs = read<TypeRange *>(); 148685ab413bSRiver Riddle Attribute rhs = read<Attribute>(); 148785ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 148885ab413bSRiver Riddle 148985ab413bSRiver Riddle selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 149085ab413bSRiver Riddle } 149185ab413bSRiver Riddle 14923eb1647aSStanislav Funiak void ByteCodeExecutor::executeContinue() { 14933eb1647aSStanislav Funiak ByteCodeField level = read(); 14943eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" 14953eb1647aSStanislav Funiak << " * Level: " << level << "\n"); 14963eb1647aSStanislav Funiak ++loopIndex[level]; 14973eb1647aSStanislav Funiak popCodeIt(); 14983eb1647aSStanislav Funiak } 14993eb1647aSStanislav Funiak 150085ab413bSRiver Riddle void ByteCodeExecutor::executeCreateTypes() { 150185ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 150285ab413bSRiver Riddle unsigned memIndex = read(); 150385ab413bSRiver Riddle unsigned rangeIndex = read(); 150485ab413bSRiver Riddle ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 150585ab413bSRiver Riddle 150685ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 150785ab413bSRiver Riddle 150885ab413bSRiver Riddle // Allocate a buffer for this type range. 150985ab413bSRiver Riddle llvm::OwningArrayRef<Type> storage(typesAttr.size()); 151085ab413bSRiver Riddle llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 151185ab413bSRiver Riddle allocatedTypeRangeMemory.emplace_back(std::move(storage)); 151285ab413bSRiver Riddle 151385ab413bSRiver Riddle // Assign this to the range slot and use the range as the value for the 151485ab413bSRiver Riddle // memory index. 151585ab413bSRiver Riddle typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 151685ab413bSRiver Riddle memory[memIndex] = &typeRangeMemory[rangeIndex]; 151785ab413bSRiver Riddle } 151885ab413bSRiver Riddle 1519154cabe7SRiver Riddle void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1520154cabe7SRiver Riddle Location mainRewriteLoc) { 1521abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1522abfd1a8bSRiver Riddle 1523abfd1a8bSRiver Riddle unsigned memIndex = read(); 1524154cabe7SRiver Riddle OperationState state(mainRewriteLoc, read<OperationName>()); 152585ab413bSRiver Riddle readValueList(state.operands); 1526abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 1527195730a6SRiver Riddle StringAttr name = read<StringAttr>(); 1528abfd1a8bSRiver Riddle if (Attribute attr = read<Attribute>()) 1529abfd1a8bSRiver Riddle state.addAttribute(name, attr); 1530abfd1a8bSRiver Riddle } 1531abfd1a8bSRiver Riddle 1532abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 153385ab413bSRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 153485ab413bSRiver Riddle state.types.push_back(read<Type>()); 153585ab413bSRiver Riddle continue; 153685ab413bSRiver Riddle } 153785ab413bSRiver Riddle 153885ab413bSRiver Riddle // If we find a null range, this signals that the types are infered. 153985ab413bSRiver Riddle if (TypeRange *resultTypes = read<TypeRange *>()) { 154085ab413bSRiver Riddle state.types.append(resultTypes->begin(), resultTypes->end()); 154185ab413bSRiver Riddle continue; 1542abfd1a8bSRiver Riddle } 1543abfd1a8bSRiver Riddle 1544abfd1a8bSRiver Riddle // Handle the case where the operation has inferred types. 1545*ea7be7e3SBenjamin Kramer InferTypeOpInterface::Concept *inferInterface = 1546edc6c0ecSRiver Riddle state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>(); 1547abfd1a8bSRiver Riddle 1548abfd1a8bSRiver Riddle // TODO: Handle failure. 15493a833a0eSRiver Riddle state.types.clear(); 1550*ea7be7e3SBenjamin Kramer if (failed(inferInterface->inferReturnTypes( 1551abfd1a8bSRiver Riddle state.getContext(), state.location, state.operands, 1552154cabe7SRiver Riddle state.attributes.getDictionary(state.getContext()), state.regions, 15533a833a0eSRiver Riddle state.types))) 1554abfd1a8bSRiver Riddle return; 155585ab413bSRiver Riddle break; 1556abfd1a8bSRiver Riddle } 155785ab413bSRiver Riddle 1558abfd1a8bSRiver Riddle Operation *resultOp = rewriter.createOperation(state); 1559abfd1a8bSRiver Riddle memory[memIndex] = resultOp; 1560abfd1a8bSRiver Riddle 1561abfd1a8bSRiver Riddle LLVM_DEBUG({ 1562abfd1a8bSRiver Riddle llvm::dbgs() << " * Attributes: " 1563abfd1a8bSRiver Riddle << state.attributes.getDictionary(state.getContext()) 1564abfd1a8bSRiver Riddle << "\n * Operands: "; 1565abfd1a8bSRiver Riddle llvm::interleaveComma(state.operands, llvm::dbgs()); 1566abfd1a8bSRiver Riddle llvm::dbgs() << "\n * Result Types: "; 1567abfd1a8bSRiver Riddle llvm::interleaveComma(state.types, llvm::dbgs()); 1568154cabe7SRiver Riddle llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1569abfd1a8bSRiver Riddle }); 1570abfd1a8bSRiver Riddle } 1571154cabe7SRiver Riddle 1572154cabe7SRiver Riddle void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1573abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1574abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1575abfd1a8bSRiver Riddle 1576154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1577abfd1a8bSRiver Riddle rewriter.eraseOp(op); 1578abfd1a8bSRiver Riddle } 1579154cabe7SRiver Riddle 15803eb1647aSStanislav Funiak template <typename T, typename Range, PDLValue::Kind kind> 15813eb1647aSStanislav Funiak void ByteCodeExecutor::executeExtract() { 15823eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); 15833eb1647aSStanislav Funiak Range *range = read<Range *>(); 15843eb1647aSStanislav Funiak unsigned index = read<uint32_t>(); 15853eb1647aSStanislav Funiak unsigned memIndex = read(); 15863eb1647aSStanislav Funiak 15873eb1647aSStanislav Funiak if (!range) { 15883eb1647aSStanislav Funiak memory[memIndex] = nullptr; 15893eb1647aSStanislav Funiak return; 15903eb1647aSStanislav Funiak } 15913eb1647aSStanislav Funiak 15923eb1647aSStanislav Funiak T result = index < range->size() ? (*range)[index] : T(); 15933eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" 15943eb1647aSStanislav Funiak << " * Index: " << index << "\n" 15953eb1647aSStanislav Funiak << " * Result: " << result << "\n"); 15963eb1647aSStanislav Funiak storeToMemory(memIndex, result); 15973eb1647aSStanislav Funiak } 15983eb1647aSStanislav Funiak 15993eb1647aSStanislav Funiak void ByteCodeExecutor::executeFinalize() { 16003eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); 16013eb1647aSStanislav Funiak } 16023eb1647aSStanislav Funiak 16033eb1647aSStanislav Funiak void ByteCodeExecutor::executeForEach() { 16043eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); 1605d35f1190SStanislav Funiak const ByteCodeField *prevCodeIt = getPrevCodeIt(); 16063eb1647aSStanislav Funiak unsigned rangeIndex = read(); 16073eb1647aSStanislav Funiak unsigned memIndex = read(); 16083eb1647aSStanislav Funiak const void *value = nullptr; 16093eb1647aSStanislav Funiak 16103eb1647aSStanislav Funiak switch (read<PDLValue::Kind>()) { 16113eb1647aSStanislav Funiak case PDLValue::Kind::Operation: { 16123eb1647aSStanislav Funiak unsigned &index = loopIndex[read()]; 16133eb1647aSStanislav Funiak ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; 16143eb1647aSStanislav Funiak assert(index <= array.size() && "iterated past the end"); 16153eb1647aSStanislav Funiak if (index < array.size()) { 16163eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); 16173eb1647aSStanislav Funiak value = array[index]; 16183eb1647aSStanislav Funiak break; 16193eb1647aSStanislav Funiak } 16203eb1647aSStanislav Funiak 16213eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * Done\n"); 16223eb1647aSStanislav Funiak index = 0; 16233eb1647aSStanislav Funiak selectJump(size_t(0)); 16243eb1647aSStanislav Funiak return; 16253eb1647aSStanislav Funiak } 16263eb1647aSStanislav Funiak default: 16273eb1647aSStanislav Funiak llvm_unreachable("unexpected `ForEach` value kind"); 16283eb1647aSStanislav Funiak } 16293eb1647aSStanislav Funiak 16303eb1647aSStanislav Funiak // Store the iterate value and the stack address. 16313eb1647aSStanislav Funiak memory[memIndex] = value; 1632d35f1190SStanislav Funiak pushCodeIt(prevCodeIt); 16333eb1647aSStanislav Funiak 16343eb1647aSStanislav Funiak // Skip over the successor (we will enter the body of the loop). 16353eb1647aSStanislav Funiak read<ByteCodeAddr>(); 16363eb1647aSStanislav Funiak } 16373eb1647aSStanislav Funiak 1638154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttribute() { 1639abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1640abfd1a8bSRiver Riddle unsigned memIndex = read(); 1641abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1642195730a6SRiver Riddle StringAttr attrName = read<StringAttr>(); 1643abfd1a8bSRiver Riddle Attribute attr = op->getAttr(attrName); 1644abfd1a8bSRiver Riddle 1645abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1646abfd1a8bSRiver Riddle << " * Attribute: " << attrName << "\n" 1647154cabe7SRiver Riddle << " * Result: " << attr << "\n"); 1648abfd1a8bSRiver Riddle memory[memIndex] = attr.getAsOpaquePointer(); 1649abfd1a8bSRiver Riddle } 1650154cabe7SRiver Riddle 1651154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttributeType() { 1652abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1653abfd1a8bSRiver Riddle unsigned memIndex = read(); 1654abfd1a8bSRiver Riddle Attribute attr = read<Attribute>(); 1655154cabe7SRiver Riddle Type type = attr ? attr.getType() : Type(); 1656abfd1a8bSRiver Riddle 1657abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1658154cabe7SRiver Riddle << " * Result: " << type << "\n"); 1659154cabe7SRiver Riddle memory[memIndex] = type.getAsOpaquePointer(); 1660abfd1a8bSRiver Riddle } 1661154cabe7SRiver Riddle 1662154cabe7SRiver Riddle void ByteCodeExecutor::executeGetDefiningOp() { 1663abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1664abfd1a8bSRiver Riddle unsigned memIndex = read(); 166585ab413bSRiver Riddle Operation *op = nullptr; 166685ab413bSRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1667abfd1a8bSRiver Riddle Value value = read<Value>(); 166885ab413bSRiver Riddle if (value) 166985ab413bSRiver Riddle op = value.getDefiningOp(); 167085ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 167185ab413bSRiver Riddle } else { 167285ab413bSRiver Riddle ValueRange *values = read<ValueRange *>(); 167385ab413bSRiver Riddle if (values && !values->empty()) { 167485ab413bSRiver Riddle op = values->front().getDefiningOp(); 167585ab413bSRiver Riddle } 167685ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 167785ab413bSRiver Riddle } 1678abfd1a8bSRiver Riddle 167985ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1680abfd1a8bSRiver Riddle memory[memIndex] = op; 1681abfd1a8bSRiver Riddle } 1682154cabe7SRiver Riddle 1683154cabe7SRiver Riddle void ByteCodeExecutor::executeGetOperand(unsigned index) { 1684abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1685abfd1a8bSRiver Riddle unsigned memIndex = read(); 1686abfd1a8bSRiver Riddle Value operand = 1687abfd1a8bSRiver Riddle index < op->getNumOperands() ? op->getOperand(index) : Value(); 1688abfd1a8bSRiver Riddle 1689abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1690abfd1a8bSRiver Riddle << " * Index: " << index << "\n" 1691154cabe7SRiver Riddle << " * Result: " << operand << "\n"); 1692abfd1a8bSRiver Riddle memory[memIndex] = operand.getAsOpaquePointer(); 1693abfd1a8bSRiver Riddle } 1694154cabe7SRiver Riddle 169585ab413bSRiver Riddle /// This function is the internal implementation of `GetResults` and 169685ab413bSRiver Riddle /// `GetOperands` that provides support for extracting a value range from the 169785ab413bSRiver Riddle /// given operation. 169885ab413bSRiver Riddle template <template <typename> class AttrSizedSegmentsT, typename RangeT> 169985ab413bSRiver Riddle static void * 170085ab413bSRiver Riddle executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 170185ab413bSRiver Riddle ByteCodeField rangeIndex, StringRef attrSizedSegments, 17023eb1647aSStanislav Funiak MutableArrayRef<ValueRange> valueRangeMemory) { 170385ab413bSRiver Riddle // Check for the sentinel index that signals that all values should be 170485ab413bSRiver Riddle // returned. 170585ab413bSRiver Riddle if (index == std::numeric_limits<uint32_t>::max()) { 170685ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 170785ab413bSRiver Riddle // `values` is already the full value range. 170885ab413bSRiver Riddle 170985ab413bSRiver Riddle // Otherwise, check to see if this operation uses AttrSizedSegments. 171085ab413bSRiver Riddle } else if (op->hasTrait<AttrSizedSegmentsT>()) { 171185ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() 171285ab413bSRiver Riddle << " * Extracting values from `" << attrSizedSegments << "`\n"); 171385ab413bSRiver Riddle 171485ab413bSRiver Riddle auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 171585ab413bSRiver Riddle if (!segmentAttr || segmentAttr.getNumElements() <= index) 171685ab413bSRiver Riddle return nullptr; 171785ab413bSRiver Riddle 171885ab413bSRiver Riddle auto segments = segmentAttr.getValues<int32_t>(); 171985ab413bSRiver Riddle unsigned startIndex = 172085ab413bSRiver Riddle std::accumulate(segments.begin(), segments.begin() + index, 0); 172185ab413bSRiver Riddle values = values.slice(startIndex, *std::next(segments.begin(), index)); 172285ab413bSRiver Riddle 172385ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 172485ab413bSRiver Riddle << *std::next(segments.begin(), index) << "]\n"); 172585ab413bSRiver Riddle 172685ab413bSRiver Riddle // Otherwise, assume this is the last operand group of the operation. 172785ab413bSRiver Riddle // FIXME: We currently don't support operations with 172885ab413bSRiver Riddle // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 172985ab413bSRiver Riddle // have a way to detect it's presence. 173085ab413bSRiver Riddle } else if (values.size() >= index) { 173185ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() 173285ab413bSRiver Riddle << " * Treating values as trailing variadic range\n"); 173385ab413bSRiver Riddle values = values.drop_front(index); 173485ab413bSRiver Riddle 173585ab413bSRiver Riddle // If we couldn't detect a way to compute the values, bail out. 173685ab413bSRiver Riddle } else { 173785ab413bSRiver Riddle return nullptr; 173885ab413bSRiver Riddle } 173985ab413bSRiver Riddle 174085ab413bSRiver Riddle // If the range index is valid, we are returning a range. 174185ab413bSRiver Riddle if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 174285ab413bSRiver Riddle valueRangeMemory[rangeIndex] = values; 174385ab413bSRiver Riddle return &valueRangeMemory[rangeIndex]; 174485ab413bSRiver Riddle } 174585ab413bSRiver Riddle 174685ab413bSRiver Riddle // If a range index wasn't provided, the range is required to be non-variadic. 174785ab413bSRiver Riddle return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 174885ab413bSRiver Riddle } 174985ab413bSRiver Riddle 175085ab413bSRiver Riddle void ByteCodeExecutor::executeGetOperands() { 175185ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 175285ab413bSRiver Riddle unsigned index = read<uint32_t>(); 175385ab413bSRiver Riddle Operation *op = read<Operation *>(); 175485ab413bSRiver Riddle ByteCodeField rangeIndex = read(); 175585ab413bSRiver Riddle 175685ab413bSRiver Riddle void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 175785ab413bSRiver Riddle op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 175885ab413bSRiver Riddle valueRangeMemory); 175985ab413bSRiver Riddle if (!result) 176085ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 176185ab413bSRiver Riddle memory[read()] = result; 176285ab413bSRiver Riddle } 176385ab413bSRiver Riddle 1764154cabe7SRiver Riddle void ByteCodeExecutor::executeGetResult(unsigned index) { 1765abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1766abfd1a8bSRiver Riddle unsigned memIndex = read(); 1767abfd1a8bSRiver Riddle OpResult result = 1768abfd1a8bSRiver Riddle index < op->getNumResults() ? op->getResult(index) : OpResult(); 1769abfd1a8bSRiver Riddle 1770abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1771abfd1a8bSRiver Riddle << " * Index: " << index << "\n" 1772154cabe7SRiver Riddle << " * Result: " << result << "\n"); 1773abfd1a8bSRiver Riddle memory[memIndex] = result.getAsOpaquePointer(); 1774abfd1a8bSRiver Riddle } 1775154cabe7SRiver Riddle 177685ab413bSRiver Riddle void ByteCodeExecutor::executeGetResults() { 177785ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 177885ab413bSRiver Riddle unsigned index = read<uint32_t>(); 177985ab413bSRiver Riddle Operation *op = read<Operation *>(); 178085ab413bSRiver Riddle ByteCodeField rangeIndex = read(); 178185ab413bSRiver Riddle 178285ab413bSRiver Riddle void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 178385ab413bSRiver Riddle op->getResults(), op, index, rangeIndex, "result_segment_sizes", 178485ab413bSRiver Riddle valueRangeMemory); 178585ab413bSRiver Riddle if (!result) 178685ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 178785ab413bSRiver Riddle memory[read()] = result; 178885ab413bSRiver Riddle } 178985ab413bSRiver Riddle 17903eb1647aSStanislav Funiak void ByteCodeExecutor::executeGetUsers() { 17913eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); 17923eb1647aSStanislav Funiak unsigned memIndex = read(); 17933eb1647aSStanislav Funiak unsigned rangeIndex = read(); 17943eb1647aSStanislav Funiak OwningOpRange &range = opRangeMemory[rangeIndex]; 17953eb1647aSStanislav Funiak memory[memIndex] = ⦥ 17963eb1647aSStanislav Funiak 17973eb1647aSStanislav Funiak range = OwningOpRange(); 17983eb1647aSStanislav Funiak if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 17993eb1647aSStanislav Funiak // Read the value. 18003eb1647aSStanislav Funiak Value value = read<Value>(); 18013eb1647aSStanislav Funiak if (!value) 18023eb1647aSStanislav Funiak return; 18033eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 18043eb1647aSStanislav Funiak 18053eb1647aSStanislav Funiak // Extract the users of a single value. 18063eb1647aSStanislav Funiak range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); 18073eb1647aSStanislav Funiak llvm::copy(value.getUsers(), range.begin()); 18083eb1647aSStanislav Funiak } else { 18093eb1647aSStanislav Funiak // Read a range of values. 18103eb1647aSStanislav Funiak ValueRange *values = read<ValueRange *>(); 18113eb1647aSStanislav Funiak if (!values) 18123eb1647aSStanislav Funiak return; 18133eb1647aSStanislav Funiak LLVM_DEBUG({ 18143eb1647aSStanislav Funiak llvm::dbgs() << " * Values (" << values->size() << "): "; 18153eb1647aSStanislav Funiak llvm::interleaveComma(*values, llvm::dbgs()); 18163eb1647aSStanislav Funiak llvm::dbgs() << "\n"; 18173eb1647aSStanislav Funiak }); 18183eb1647aSStanislav Funiak 18193eb1647aSStanislav Funiak // Extract all the users of a range of values. 18203eb1647aSStanislav Funiak SmallVector<Operation *> users; 18213eb1647aSStanislav Funiak for (Value value : *values) 18223eb1647aSStanislav Funiak users.append(value.user_begin(), value.user_end()); 18233eb1647aSStanislav Funiak range = OwningOpRange(users.size()); 18243eb1647aSStanislav Funiak llvm::copy(users, range.begin()); 18253eb1647aSStanislav Funiak } 18263eb1647aSStanislav Funiak 18273eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); 18283eb1647aSStanislav Funiak } 18293eb1647aSStanislav Funiak 1830154cabe7SRiver Riddle void ByteCodeExecutor::executeGetValueType() { 1831abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1832abfd1a8bSRiver Riddle unsigned memIndex = read(); 1833abfd1a8bSRiver Riddle Value value = read<Value>(); 1834154cabe7SRiver Riddle Type type = value ? value.getType() : Type(); 1835abfd1a8bSRiver Riddle 1836abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1837154cabe7SRiver Riddle << " * Result: " << type << "\n"); 1838154cabe7SRiver Riddle memory[memIndex] = type.getAsOpaquePointer(); 1839abfd1a8bSRiver Riddle } 1840154cabe7SRiver Riddle 184185ab413bSRiver Riddle void ByteCodeExecutor::executeGetValueRangeTypes() { 184285ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 184385ab413bSRiver Riddle unsigned memIndex = read(); 184485ab413bSRiver Riddle unsigned rangeIndex = read(); 184585ab413bSRiver Riddle ValueRange *values = read<ValueRange *>(); 184685ab413bSRiver Riddle if (!values) { 184785ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 184885ab413bSRiver Riddle memory[memIndex] = nullptr; 184985ab413bSRiver Riddle return; 185085ab413bSRiver Riddle } 185185ab413bSRiver Riddle 185285ab413bSRiver Riddle LLVM_DEBUG({ 185385ab413bSRiver Riddle llvm::dbgs() << " * Values (" << values->size() << "): "; 185485ab413bSRiver Riddle llvm::interleaveComma(*values, llvm::dbgs()); 185585ab413bSRiver Riddle llvm::dbgs() << "\n * Result: "; 185685ab413bSRiver Riddle llvm::interleaveComma(values->getType(), llvm::dbgs()); 185785ab413bSRiver Riddle llvm::dbgs() << "\n"; 185885ab413bSRiver Riddle }); 185985ab413bSRiver Riddle typeRangeMemory[rangeIndex] = values->getType(); 186085ab413bSRiver Riddle memory[memIndex] = &typeRangeMemory[rangeIndex]; 186185ab413bSRiver Riddle } 186285ab413bSRiver Riddle 1863154cabe7SRiver Riddle void ByteCodeExecutor::executeIsNotNull() { 1864abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1865abfd1a8bSRiver Riddle const void *value = read<const void *>(); 1866abfd1a8bSRiver Riddle 1867154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1868abfd1a8bSRiver Riddle selectJump(value != nullptr); 1869abfd1a8bSRiver Riddle } 1870154cabe7SRiver Riddle 1871154cabe7SRiver Riddle void ByteCodeExecutor::executeRecordMatch( 1872154cabe7SRiver Riddle PatternRewriter &rewriter, 1873154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1874abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1875abfd1a8bSRiver Riddle unsigned patternIndex = read(); 1876abfd1a8bSRiver Riddle PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1877abfd1a8bSRiver Riddle const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1878abfd1a8bSRiver Riddle 1879abfd1a8bSRiver Riddle // If the benefit of the pattern is impossible, skip the processing of the 1880abfd1a8bSRiver Riddle // rest of the pattern. 1881abfd1a8bSRiver Riddle if (benefit.isImpossibleToMatch()) { 1882154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1883abfd1a8bSRiver Riddle curCodeIt = dest; 1884154cabe7SRiver Riddle return; 1885abfd1a8bSRiver Riddle } 1886abfd1a8bSRiver Riddle 1887abfd1a8bSRiver Riddle // Create a fused location containing the locations of each of the 1888abfd1a8bSRiver Riddle // operations used in the match. This will be used as the location for 1889abfd1a8bSRiver Riddle // created operations during the rewrite that don't already have an 1890abfd1a8bSRiver Riddle // explicit location set. 1891abfd1a8bSRiver Riddle unsigned numMatchLocs = read(); 1892abfd1a8bSRiver Riddle SmallVector<Location, 4> matchLocs; 1893abfd1a8bSRiver Riddle matchLocs.reserve(numMatchLocs); 1894abfd1a8bSRiver Riddle for (unsigned i = 0; i != numMatchLocs; ++i) 1895abfd1a8bSRiver Riddle matchLocs.push_back(read<Operation *>()->getLoc()); 1896abfd1a8bSRiver Riddle Location matchLoc = rewriter.getFusedLoc(matchLocs); 1897abfd1a8bSRiver Riddle 1898abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1899154cabe7SRiver Riddle << " * Location: " << matchLoc << "\n"); 1900154cabe7SRiver Riddle matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 190185ab413bSRiver Riddle PDLByteCode::MatchResult &match = matches.back(); 190285ab413bSRiver Riddle 190385ab413bSRiver Riddle // Record all of the inputs to the match. If any of the inputs are ranges, we 190485ab413bSRiver Riddle // will also need to remap the range pointer to memory stored in the match 190585ab413bSRiver Riddle // state. 190685ab413bSRiver Riddle unsigned numInputs = read(); 190785ab413bSRiver Riddle match.values.reserve(numInputs); 190885ab413bSRiver Riddle match.typeRangeValues.reserve(numInputs); 190985ab413bSRiver Riddle match.valueRangeValues.reserve(numInputs); 191085ab413bSRiver Riddle for (unsigned i = 0; i < numInputs; ++i) { 191185ab413bSRiver Riddle switch (read<PDLValue::Kind>()) { 191285ab413bSRiver Riddle case PDLValue::Kind::TypeRange: 191385ab413bSRiver Riddle match.typeRangeValues.push_back(*read<TypeRange *>()); 191485ab413bSRiver Riddle match.values.push_back(&match.typeRangeValues.back()); 191585ab413bSRiver Riddle break; 191685ab413bSRiver Riddle case PDLValue::Kind::ValueRange: 191785ab413bSRiver Riddle match.valueRangeValues.push_back(*read<ValueRange *>()); 191885ab413bSRiver Riddle match.values.push_back(&match.valueRangeValues.back()); 191985ab413bSRiver Riddle break; 192085ab413bSRiver Riddle default: 192185ab413bSRiver Riddle match.values.push_back(read<const void *>()); 192285ab413bSRiver Riddle break; 192385ab413bSRiver Riddle } 192485ab413bSRiver Riddle } 1925abfd1a8bSRiver Riddle curCodeIt = dest; 1926abfd1a8bSRiver Riddle } 1927154cabe7SRiver Riddle 1928154cabe7SRiver Riddle void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1929abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1930abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1931abfd1a8bSRiver Riddle SmallVector<Value, 16> args; 193285ab413bSRiver Riddle readValueList(args); 1933abfd1a8bSRiver Riddle 1934abfd1a8bSRiver Riddle LLVM_DEBUG({ 1935abfd1a8bSRiver Riddle llvm::dbgs() << " * Operation: " << *op << "\n" 1936abfd1a8bSRiver Riddle << " * Values: "; 1937abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 1938154cabe7SRiver Riddle llvm::dbgs() << "\n"; 1939abfd1a8bSRiver Riddle }); 1940abfd1a8bSRiver Riddle rewriter.replaceOp(op, args); 1941abfd1a8bSRiver Riddle } 1942154cabe7SRiver Riddle 1943154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchAttribute() { 1944abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1945abfd1a8bSRiver Riddle Attribute value = read<Attribute>(); 1946abfd1a8bSRiver Riddle ArrayAttr cases = read<ArrayAttr>(); 1947abfd1a8bSRiver Riddle handleSwitch(value, cases); 1948abfd1a8bSRiver Riddle } 1949154cabe7SRiver Riddle 1950154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperandCount() { 1951abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1952abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1953abfd1a8bSRiver Riddle auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1954abfd1a8bSRiver Riddle 1955abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1956abfd1a8bSRiver Riddle handleSwitch(op->getNumOperands(), cases); 1957abfd1a8bSRiver Riddle } 1958154cabe7SRiver Riddle 1959154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperationName() { 1960abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1961abfd1a8bSRiver Riddle OperationName value = read<Operation *>()->getName(); 1962abfd1a8bSRiver Riddle size_t caseCount = read(); 1963abfd1a8bSRiver Riddle 1964abfd1a8bSRiver Riddle // The operation names are stored in-line, so to print them out for 1965abfd1a8bSRiver Riddle // debugging purposes we need to read the array before executing the 1966abfd1a8bSRiver Riddle // switch so that we can display all of the possible values. 1967abfd1a8bSRiver Riddle LLVM_DEBUG({ 1968abfd1a8bSRiver Riddle const ByteCodeField *prevCodeIt = curCodeIt; 1969abfd1a8bSRiver Riddle llvm::dbgs() << " * Value: " << value << "\n" 1970abfd1a8bSRiver Riddle << " * Cases: "; 1971abfd1a8bSRiver Riddle llvm::interleaveComma( 1972abfd1a8bSRiver Riddle llvm::map_range(llvm::seq<size_t>(0, caseCount), 1973154cabe7SRiver Riddle [&](size_t) { return read<OperationName>(); }), 1974abfd1a8bSRiver Riddle llvm::dbgs()); 1975154cabe7SRiver Riddle llvm::dbgs() << "\n"; 1976abfd1a8bSRiver Riddle curCodeIt = prevCodeIt; 1977abfd1a8bSRiver Riddle }); 1978abfd1a8bSRiver Riddle 1979abfd1a8bSRiver Riddle // Try to find the switch value within any of the cases. 1980abfd1a8bSRiver Riddle for (size_t i = 0; i != caseCount; ++i) { 1981abfd1a8bSRiver Riddle if (read<OperationName>() == value) { 1982abfd1a8bSRiver Riddle curCodeIt += (caseCount - i - 1); 1983154cabe7SRiver Riddle return selectJump(i + 1); 1984abfd1a8bSRiver Riddle } 1985abfd1a8bSRiver Riddle } 1986154cabe7SRiver Riddle selectJump(size_t(0)); 1987abfd1a8bSRiver Riddle } 1988154cabe7SRiver Riddle 1989154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchResultCount() { 1990abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1991abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1992abfd1a8bSRiver Riddle auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1993abfd1a8bSRiver Riddle 1994abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1995abfd1a8bSRiver Riddle handleSwitch(op->getNumResults(), cases); 1996abfd1a8bSRiver Riddle } 1997154cabe7SRiver Riddle 1998154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchType() { 1999abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 2000abfd1a8bSRiver Riddle Type value = read<Type>(); 2001abfd1a8bSRiver Riddle auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 2002abfd1a8bSRiver Riddle handleSwitch(value, cases); 2003154cabe7SRiver Riddle } 2004154cabe7SRiver Riddle 200585ab413bSRiver Riddle void ByteCodeExecutor::executeSwitchTypes() { 200685ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 200785ab413bSRiver Riddle TypeRange *value = read<TypeRange *>(); 200885ab413bSRiver Riddle auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 200985ab413bSRiver Riddle if (!value) { 201085ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 201185ab413bSRiver Riddle return selectJump(size_t(0)); 201285ab413bSRiver Riddle } 201385ab413bSRiver Riddle handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 201485ab413bSRiver Riddle return value == caseValue.getAsValueRange<TypeAttr>(); 201585ab413bSRiver Riddle }); 201685ab413bSRiver Riddle } 201785ab413bSRiver Riddle 2018154cabe7SRiver Riddle void ByteCodeExecutor::execute( 2019154cabe7SRiver Riddle PatternRewriter &rewriter, 2020154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> *matches, 2021154cabe7SRiver Riddle Optional<Location> mainRewriteLoc) { 2022154cabe7SRiver Riddle while (true) { 2023d35f1190SStanislav Funiak // Print the location of the operation being executed. 2024d35f1190SStanislav Funiak LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); 2025d35f1190SStanislav Funiak 2026154cabe7SRiver Riddle OpCode opCode = static_cast<OpCode>(read()); 2027154cabe7SRiver Riddle switch (opCode) { 2028154cabe7SRiver Riddle case ApplyConstraint: 2029154cabe7SRiver Riddle executeApplyConstraint(rewriter); 2030154cabe7SRiver Riddle break; 2031154cabe7SRiver Riddle case ApplyRewrite: 2032154cabe7SRiver Riddle executeApplyRewrite(rewriter); 2033154cabe7SRiver Riddle break; 2034154cabe7SRiver Riddle case AreEqual: 2035154cabe7SRiver Riddle executeAreEqual(); 2036154cabe7SRiver Riddle break; 203785ab413bSRiver Riddle case AreRangesEqual: 203885ab413bSRiver Riddle executeAreRangesEqual(); 203985ab413bSRiver Riddle break; 2040154cabe7SRiver Riddle case Branch: 2041154cabe7SRiver Riddle executeBranch(); 2042154cabe7SRiver Riddle break; 2043154cabe7SRiver Riddle case CheckOperandCount: 2044154cabe7SRiver Riddle executeCheckOperandCount(); 2045154cabe7SRiver Riddle break; 2046154cabe7SRiver Riddle case CheckOperationName: 2047154cabe7SRiver Riddle executeCheckOperationName(); 2048154cabe7SRiver Riddle break; 2049154cabe7SRiver Riddle case CheckResultCount: 2050154cabe7SRiver Riddle executeCheckResultCount(); 2051154cabe7SRiver Riddle break; 205285ab413bSRiver Riddle case CheckTypes: 205385ab413bSRiver Riddle executeCheckTypes(); 205485ab413bSRiver Riddle break; 20553eb1647aSStanislav Funiak case Continue: 20563eb1647aSStanislav Funiak executeContinue(); 20573eb1647aSStanislav Funiak break; 2058154cabe7SRiver Riddle case CreateOperation: 2059154cabe7SRiver Riddle executeCreateOperation(rewriter, *mainRewriteLoc); 2060154cabe7SRiver Riddle break; 206185ab413bSRiver Riddle case CreateTypes: 206285ab413bSRiver Riddle executeCreateTypes(); 206385ab413bSRiver Riddle break; 2064154cabe7SRiver Riddle case EraseOp: 2065154cabe7SRiver Riddle executeEraseOp(rewriter); 2066154cabe7SRiver Riddle break; 20673eb1647aSStanislav Funiak case ExtractOp: 20683eb1647aSStanislav Funiak executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); 20693eb1647aSStanislav Funiak break; 20703eb1647aSStanislav Funiak case ExtractType: 20713eb1647aSStanislav Funiak executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); 20723eb1647aSStanislav Funiak break; 20733eb1647aSStanislav Funiak case ExtractValue: 20743eb1647aSStanislav Funiak executeExtract<Value, ValueRange, PDLValue::Kind::Value>(); 20753eb1647aSStanislav Funiak break; 2076154cabe7SRiver Riddle case Finalize: 20773eb1647aSStanislav Funiak executeFinalize(); 20783eb1647aSStanislav Funiak LLVM_DEBUG(llvm::dbgs() << "\n"); 2079154cabe7SRiver Riddle return; 20803eb1647aSStanislav Funiak case ForEach: 20813eb1647aSStanislav Funiak executeForEach(); 20823eb1647aSStanislav Funiak break; 2083154cabe7SRiver Riddle case GetAttribute: 2084154cabe7SRiver Riddle executeGetAttribute(); 2085154cabe7SRiver Riddle break; 2086154cabe7SRiver Riddle case GetAttributeType: 2087154cabe7SRiver Riddle executeGetAttributeType(); 2088154cabe7SRiver Riddle break; 2089154cabe7SRiver Riddle case GetDefiningOp: 2090154cabe7SRiver Riddle executeGetDefiningOp(); 2091154cabe7SRiver Riddle break; 2092154cabe7SRiver Riddle case GetOperand0: 2093154cabe7SRiver Riddle case GetOperand1: 2094154cabe7SRiver Riddle case GetOperand2: 2095154cabe7SRiver Riddle case GetOperand3: { 2096154cabe7SRiver Riddle unsigned index = opCode - GetOperand0; 2097154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 20981fff7c89SFrederik Gossen executeGetOperand(index); 2099abfd1a8bSRiver Riddle break; 2100abfd1a8bSRiver Riddle } 2101154cabe7SRiver Riddle case GetOperandN: 2102154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 2103154cabe7SRiver Riddle executeGetOperand(read<uint32_t>()); 2104154cabe7SRiver Riddle break; 210585ab413bSRiver Riddle case GetOperands: 210685ab413bSRiver Riddle executeGetOperands(); 210785ab413bSRiver Riddle break; 2108154cabe7SRiver Riddle case GetResult0: 2109154cabe7SRiver Riddle case GetResult1: 2110154cabe7SRiver Riddle case GetResult2: 2111154cabe7SRiver Riddle case GetResult3: { 2112154cabe7SRiver Riddle unsigned index = opCode - GetResult0; 2113154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 21141fff7c89SFrederik Gossen executeGetResult(index); 2115154cabe7SRiver Riddle break; 2116abfd1a8bSRiver Riddle } 2117154cabe7SRiver Riddle case GetResultN: 2118154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 2119154cabe7SRiver Riddle executeGetResult(read<uint32_t>()); 2120154cabe7SRiver Riddle break; 212185ab413bSRiver Riddle case GetResults: 212285ab413bSRiver Riddle executeGetResults(); 212385ab413bSRiver Riddle break; 21243eb1647aSStanislav Funiak case GetUsers: 21253eb1647aSStanislav Funiak executeGetUsers(); 21263eb1647aSStanislav Funiak break; 2127154cabe7SRiver Riddle case GetValueType: 2128154cabe7SRiver Riddle executeGetValueType(); 2129154cabe7SRiver Riddle break; 213085ab413bSRiver Riddle case GetValueRangeTypes: 213185ab413bSRiver Riddle executeGetValueRangeTypes(); 213285ab413bSRiver Riddle break; 2133154cabe7SRiver Riddle case IsNotNull: 2134154cabe7SRiver Riddle executeIsNotNull(); 2135154cabe7SRiver Riddle break; 2136154cabe7SRiver Riddle case RecordMatch: 2137154cabe7SRiver Riddle assert(matches && 2138154cabe7SRiver Riddle "expected matches to be provided when executing the matcher"); 2139154cabe7SRiver Riddle executeRecordMatch(rewriter, *matches); 2140154cabe7SRiver Riddle break; 2141154cabe7SRiver Riddle case ReplaceOp: 2142154cabe7SRiver Riddle executeReplaceOp(rewriter); 2143154cabe7SRiver Riddle break; 2144154cabe7SRiver Riddle case SwitchAttribute: 2145154cabe7SRiver Riddle executeSwitchAttribute(); 2146154cabe7SRiver Riddle break; 2147154cabe7SRiver Riddle case SwitchOperandCount: 2148154cabe7SRiver Riddle executeSwitchOperandCount(); 2149154cabe7SRiver Riddle break; 2150154cabe7SRiver Riddle case SwitchOperationName: 2151154cabe7SRiver Riddle executeSwitchOperationName(); 2152154cabe7SRiver Riddle break; 2153154cabe7SRiver Riddle case SwitchResultCount: 2154154cabe7SRiver Riddle executeSwitchResultCount(); 2155154cabe7SRiver Riddle break; 2156154cabe7SRiver Riddle case SwitchType: 2157154cabe7SRiver Riddle executeSwitchType(); 2158154cabe7SRiver Riddle break; 215985ab413bSRiver Riddle case SwitchTypes: 216085ab413bSRiver Riddle executeSwitchTypes(); 216185ab413bSRiver Riddle break; 2162154cabe7SRiver Riddle } 2163154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "\n"); 2164abfd1a8bSRiver Riddle } 2165abfd1a8bSRiver Riddle } 2166abfd1a8bSRiver Riddle 2167abfd1a8bSRiver Riddle /// Run the pattern matcher on the given root operation, collecting the matched 2168abfd1a8bSRiver Riddle /// patterns in `matches`. 2169abfd1a8bSRiver Riddle void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 2170abfd1a8bSRiver Riddle SmallVectorImpl<MatchResult> &matches, 2171abfd1a8bSRiver Riddle PDLByteCodeMutableState &state) const { 2172abfd1a8bSRiver Riddle // The first memory slot is always the root operation. 2173abfd1a8bSRiver Riddle state.memory[0] = op; 2174abfd1a8bSRiver Riddle 2175abfd1a8bSRiver Riddle // The matcher function always starts at code address 0. 217685ab413bSRiver Riddle ByteCodeExecutor executor( 21773eb1647aSStanislav Funiak matcherByteCode.data(), state.memory, state.opRangeMemory, 21783eb1647aSStanislav Funiak state.typeRangeMemory, state.allocatedTypeRangeMemory, 21793eb1647aSStanislav Funiak state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, 21803eb1647aSStanislav Funiak uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, 21813eb1647aSStanislav Funiak constraintFunctions, rewriteFunctions); 2182abfd1a8bSRiver Riddle executor.execute(rewriter, &matches); 2183abfd1a8bSRiver Riddle 2184abfd1a8bSRiver Riddle // Order the found matches by benefit. 2185abfd1a8bSRiver Riddle std::stable_sort(matches.begin(), matches.end(), 2186abfd1a8bSRiver Riddle [](const MatchResult &lhs, const MatchResult &rhs) { 2187abfd1a8bSRiver Riddle return lhs.benefit > rhs.benefit; 2188abfd1a8bSRiver Riddle }); 2189abfd1a8bSRiver Riddle } 2190abfd1a8bSRiver Riddle 2191abfd1a8bSRiver Riddle /// Run the rewriter of the given pattern on the root operation `op`. 2192abfd1a8bSRiver Riddle void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 2193abfd1a8bSRiver Riddle PDLByteCodeMutableState &state) const { 2194abfd1a8bSRiver Riddle // The arguments of the rewrite function are stored at the start of the 2195abfd1a8bSRiver Riddle // memory buffer. 2196abfd1a8bSRiver Riddle llvm::copy(match.values, state.memory.begin()); 2197abfd1a8bSRiver Riddle 219885ab413bSRiver Riddle ByteCodeExecutor executor( 219985ab413bSRiver Riddle &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 22003eb1647aSStanislav Funiak state.opRangeMemory, state.typeRangeMemory, 22013eb1647aSStanislav Funiak state.allocatedTypeRangeMemory, state.valueRangeMemory, 22023eb1647aSStanislav Funiak state.allocatedValueRangeMemory, state.loopIndex, uniquedData, 220385ab413bSRiver Riddle rewriterByteCode, state.currentPatternBenefits, patterns, 220402c4c0d5SRiver Riddle constraintFunctions, rewriteFunctions); 2205abfd1a8bSRiver Riddle executor.execute(rewriter, /*matches=*/nullptr, match.location); 2206abfd1a8bSRiver Riddle } 2207