1abfd1a8bSRiver Riddle //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2abfd1a8bSRiver Riddle // 3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6abfd1a8bSRiver Riddle // 7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 8abfd1a8bSRiver Riddle // 9abfd1a8bSRiver Riddle // This file implements MLIR to byte-code generation and the interpreter. 10abfd1a8bSRiver Riddle // 11abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 12abfd1a8bSRiver Riddle 13abfd1a8bSRiver Riddle #include "ByteCode.h" 14abfd1a8bSRiver Riddle #include "mlir/Analysis/Liveness.h" 15abfd1a8bSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16abfd1a8bSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17e66c2e25SRiver Riddle #include "mlir/IR/BuiltinOps.h" 18abfd1a8bSRiver Riddle #include "mlir/IR/RegionGraphTraits.h" 19abfd1a8bSRiver Riddle #include "llvm/ADT/IntervalMap.h" 20abfd1a8bSRiver Riddle #include "llvm/ADT/PostOrderIterator.h" 21abfd1a8bSRiver Riddle #include "llvm/ADT/TypeSwitch.h" 22abfd1a8bSRiver Riddle #include "llvm/Support/Debug.h" 2385ab413bSRiver Riddle #include "llvm/Support/Format.h" 2485ab413bSRiver Riddle #include "llvm/Support/FormatVariadic.h" 2585ab413bSRiver Riddle #include <numeric> 26abfd1a8bSRiver Riddle 27abfd1a8bSRiver Riddle #define DEBUG_TYPE "pdl-bytecode" 28abfd1a8bSRiver Riddle 29abfd1a8bSRiver Riddle using namespace mlir; 30abfd1a8bSRiver Riddle using namespace mlir::detail; 31abfd1a8bSRiver Riddle 32abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 33abfd1a8bSRiver Riddle // PDLByteCodePattern 34abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 35abfd1a8bSRiver Riddle 36abfd1a8bSRiver Riddle PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37abfd1a8bSRiver Riddle ByteCodeAddr rewriterAddr) { 38abfd1a8bSRiver Riddle SmallVector<StringRef, 8> generatedOps; 39abfd1a8bSRiver Riddle if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 40abfd1a8bSRiver Riddle generatedOps = 41abfd1a8bSRiver Riddle llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42abfd1a8bSRiver Riddle 43abfd1a8bSRiver Riddle PatternBenefit benefit = matchOp.benefit(); 44abfd1a8bSRiver Riddle MLIRContext *ctx = matchOp.getContext(); 45abfd1a8bSRiver Riddle 46abfd1a8bSRiver Riddle // Check to see if this is pattern matches a specific operation type. 47abfd1a8bSRiver Riddle if (Optional<StringRef> rootKind = matchOp.rootKind()) 4876f3c2f3SRiver Riddle return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, 4976f3c2f3SRiver Riddle generatedOps); 5076f3c2f3SRiver Riddle return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, 5176f3c2f3SRiver Riddle generatedOps); 52abfd1a8bSRiver Riddle } 53abfd1a8bSRiver Riddle 54abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 55abfd1a8bSRiver Riddle // PDLByteCodeMutableState 56abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 57abfd1a8bSRiver Riddle 58abfd1a8bSRiver Riddle /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59abfd1a8bSRiver Riddle /// to the position of the pattern within the range returned by 60abfd1a8bSRiver Riddle /// `PDLByteCode::getPatterns`. 61abfd1a8bSRiver Riddle void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62abfd1a8bSRiver Riddle PatternBenefit benefit) { 63abfd1a8bSRiver Riddle currentPatternBenefits[patternIndex] = benefit; 64abfd1a8bSRiver Riddle } 65abfd1a8bSRiver Riddle 6685ab413bSRiver Riddle /// Cleanup any allocated state after a full match/rewrite has been completed. 6785ab413bSRiver Riddle /// This method should be called irregardless of whether the match+rewrite was a 6885ab413bSRiver Riddle /// success or not. 6985ab413bSRiver Riddle void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 7085ab413bSRiver Riddle allocatedTypeRangeMemory.clear(); 7185ab413bSRiver Riddle allocatedValueRangeMemory.clear(); 7285ab413bSRiver Riddle } 7385ab413bSRiver Riddle 74abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 75abfd1a8bSRiver Riddle // Bytecode OpCodes 76abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 77abfd1a8bSRiver Riddle 78abfd1a8bSRiver Riddle namespace { 79abfd1a8bSRiver Riddle enum OpCode : ByteCodeField { 80abfd1a8bSRiver Riddle /// Apply an externally registered constraint. 81abfd1a8bSRiver Riddle ApplyConstraint, 82abfd1a8bSRiver Riddle /// Apply an externally registered rewrite. 83abfd1a8bSRiver Riddle ApplyRewrite, 84abfd1a8bSRiver Riddle /// Check if two generic values are equal. 85abfd1a8bSRiver Riddle AreEqual, 8685ab413bSRiver Riddle /// Check if two ranges are equal. 8785ab413bSRiver Riddle AreRangesEqual, 88abfd1a8bSRiver Riddle /// Unconditional branch. 89abfd1a8bSRiver Riddle Branch, 90abfd1a8bSRiver Riddle /// Compare the operand count of an operation with a constant. 91abfd1a8bSRiver Riddle CheckOperandCount, 92abfd1a8bSRiver Riddle /// Compare the name of an operation with a constant. 93abfd1a8bSRiver Riddle CheckOperationName, 94abfd1a8bSRiver Riddle /// Compare the result count of an operation with a constant. 95abfd1a8bSRiver Riddle CheckResultCount, 9685ab413bSRiver Riddle /// Compare a range of types to a constant range of types. 9785ab413bSRiver Riddle CheckTypes, 98abfd1a8bSRiver Riddle /// Create an operation. 99abfd1a8bSRiver Riddle CreateOperation, 10085ab413bSRiver Riddle /// Create a range of types. 10185ab413bSRiver Riddle CreateTypes, 102abfd1a8bSRiver Riddle /// Erase an operation. 103abfd1a8bSRiver Riddle EraseOp, 104abfd1a8bSRiver Riddle /// Terminate a matcher or rewrite sequence. 105abfd1a8bSRiver Riddle Finalize, 106abfd1a8bSRiver Riddle /// Get a specific attribute of an operation. 107abfd1a8bSRiver Riddle GetAttribute, 108abfd1a8bSRiver Riddle /// Get the type of an attribute. 109abfd1a8bSRiver Riddle GetAttributeType, 110abfd1a8bSRiver Riddle /// Get the defining operation of a value. 111abfd1a8bSRiver Riddle GetDefiningOp, 112abfd1a8bSRiver Riddle /// Get a specific operand of an operation. 113abfd1a8bSRiver Riddle GetOperand0, 114abfd1a8bSRiver Riddle GetOperand1, 115abfd1a8bSRiver Riddle GetOperand2, 116abfd1a8bSRiver Riddle GetOperand3, 117abfd1a8bSRiver Riddle GetOperandN, 11885ab413bSRiver Riddle /// Get a specific operand group of an operation. 11985ab413bSRiver Riddle GetOperands, 120abfd1a8bSRiver Riddle /// Get a specific result of an operation. 121abfd1a8bSRiver Riddle GetResult0, 122abfd1a8bSRiver Riddle GetResult1, 123abfd1a8bSRiver Riddle GetResult2, 124abfd1a8bSRiver Riddle GetResult3, 125abfd1a8bSRiver Riddle GetResultN, 12685ab413bSRiver Riddle /// Get a specific result group of an operation. 12785ab413bSRiver Riddle GetResults, 128abfd1a8bSRiver Riddle /// Get the type of a value. 129abfd1a8bSRiver Riddle GetValueType, 13085ab413bSRiver Riddle /// Get the types of a value range. 13185ab413bSRiver Riddle GetValueRangeTypes, 132abfd1a8bSRiver Riddle /// Check if a generic value is not null. 133abfd1a8bSRiver Riddle IsNotNull, 134abfd1a8bSRiver Riddle /// Record a successful pattern match. 135abfd1a8bSRiver Riddle RecordMatch, 136abfd1a8bSRiver Riddle /// Replace an operation. 137abfd1a8bSRiver Riddle ReplaceOp, 138abfd1a8bSRiver Riddle /// Compare an attribute with a set of constants. 139abfd1a8bSRiver Riddle SwitchAttribute, 140abfd1a8bSRiver Riddle /// Compare the operand count of an operation with a set of constants. 141abfd1a8bSRiver Riddle SwitchOperandCount, 142abfd1a8bSRiver Riddle /// Compare the name of an operation with a set of constants. 143abfd1a8bSRiver Riddle SwitchOperationName, 144abfd1a8bSRiver Riddle /// Compare the result count of an operation with a set of constants. 145abfd1a8bSRiver Riddle SwitchResultCount, 146abfd1a8bSRiver Riddle /// Compare a type with a set of constants. 147abfd1a8bSRiver Riddle SwitchType, 14885ab413bSRiver Riddle /// Compare a range of types with a set of constants. 14985ab413bSRiver Riddle SwitchTypes, 150abfd1a8bSRiver Riddle }; 151abfd1a8bSRiver Riddle } // end anonymous namespace 152abfd1a8bSRiver Riddle 153abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 154abfd1a8bSRiver Riddle // ByteCode Generation 155abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 156abfd1a8bSRiver Riddle 157abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 158abfd1a8bSRiver Riddle // Generator 159abfd1a8bSRiver Riddle 160abfd1a8bSRiver Riddle namespace { 161abfd1a8bSRiver Riddle struct ByteCodeWriter; 162abfd1a8bSRiver Riddle 163abfd1a8bSRiver Riddle /// This class represents the main generator for the pattern bytecode. 164abfd1a8bSRiver Riddle class Generator { 165abfd1a8bSRiver Riddle public: 166abfd1a8bSRiver Riddle Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 167abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &matcherByteCode, 168abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &rewriterByteCode, 169abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCodePattern> &patterns, 170abfd1a8bSRiver Riddle ByteCodeField &maxValueMemoryIndex, 17185ab413bSRiver Riddle ByteCodeField &maxTypeRangeMemoryIndex, 17285ab413bSRiver Riddle ByteCodeField &maxValueRangeMemoryIndex, 173abfd1a8bSRiver Riddle llvm::StringMap<PDLConstraintFunction> &constraintFns, 174abfd1a8bSRiver Riddle llvm::StringMap<PDLRewriteFunction> &rewriteFns) 175abfd1a8bSRiver Riddle : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 176abfd1a8bSRiver Riddle rewriterByteCode(rewriterByteCode), patterns(patterns), 17785ab413bSRiver Riddle maxValueMemoryIndex(maxValueMemoryIndex), 17885ab413bSRiver Riddle maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 17985ab413bSRiver Riddle maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) { 180abfd1a8bSRiver Riddle for (auto it : llvm::enumerate(constraintFns)) 181abfd1a8bSRiver Riddle constraintToMemIndex.try_emplace(it.value().first(), it.index()); 182abfd1a8bSRiver Riddle for (auto it : llvm::enumerate(rewriteFns)) 183abfd1a8bSRiver Riddle externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 184abfd1a8bSRiver Riddle } 185abfd1a8bSRiver Riddle 186abfd1a8bSRiver Riddle /// Generate the bytecode for the given PDL interpreter module. 187abfd1a8bSRiver Riddle void generate(ModuleOp module); 188abfd1a8bSRiver Riddle 189abfd1a8bSRiver Riddle /// Return the memory index to use for the given value. 190abfd1a8bSRiver Riddle ByteCodeField &getMemIndex(Value value) { 191abfd1a8bSRiver Riddle assert(valueToMemIndex.count(value) && 192abfd1a8bSRiver Riddle "expected memory index to be assigned"); 193abfd1a8bSRiver Riddle return valueToMemIndex[value]; 194abfd1a8bSRiver Riddle } 195abfd1a8bSRiver Riddle 19685ab413bSRiver Riddle /// Return the range memory index used to store the given range value. 19785ab413bSRiver Riddle ByteCodeField &getRangeStorageIndex(Value value) { 19885ab413bSRiver Riddle assert(valueToRangeIndex.count(value) && 19985ab413bSRiver Riddle "expected range index to be assigned"); 20085ab413bSRiver Riddle return valueToRangeIndex[value]; 20185ab413bSRiver Riddle } 20285ab413bSRiver Riddle 203abfd1a8bSRiver Riddle /// Return an index to use when referring to the given data that is uniqued in 204abfd1a8bSRiver Riddle /// the MLIR context. 205abfd1a8bSRiver Riddle template <typename T> 206abfd1a8bSRiver Riddle std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 207abfd1a8bSRiver Riddle getMemIndex(T val) { 208abfd1a8bSRiver Riddle const void *opaqueVal = val.getAsOpaquePointer(); 209abfd1a8bSRiver Riddle 210abfd1a8bSRiver Riddle // Get or insert a reference to this value. 211abfd1a8bSRiver Riddle auto it = uniquedDataToMemIndex.try_emplace( 212abfd1a8bSRiver Riddle opaqueVal, maxValueMemoryIndex + uniquedData.size()); 213abfd1a8bSRiver Riddle if (it.second) 214abfd1a8bSRiver Riddle uniquedData.push_back(opaqueVal); 215abfd1a8bSRiver Riddle return it.first->second; 216abfd1a8bSRiver Riddle } 217abfd1a8bSRiver Riddle 218abfd1a8bSRiver Riddle private: 219abfd1a8bSRiver Riddle /// Allocate memory indices for the results of operations within the matcher 220abfd1a8bSRiver Riddle /// and rewriters. 221abfd1a8bSRiver Riddle void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 222abfd1a8bSRiver Riddle 223abfd1a8bSRiver Riddle /// Generate the bytecode for the given operation. 224abfd1a8bSRiver Riddle void generate(Operation *op, ByteCodeWriter &writer); 225abfd1a8bSRiver Riddle void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 226abfd1a8bSRiver Riddle void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 227abfd1a8bSRiver Riddle void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 228abfd1a8bSRiver Riddle void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 229abfd1a8bSRiver Riddle void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 230abfd1a8bSRiver Riddle void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 231abfd1a8bSRiver Riddle void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 232abfd1a8bSRiver Riddle void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 233abfd1a8bSRiver Riddle void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 23485ab413bSRiver Riddle void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 235abfd1a8bSRiver Riddle void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 236abfd1a8bSRiver Riddle void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 237abfd1a8bSRiver Riddle void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 23885ab413bSRiver Riddle void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 239abfd1a8bSRiver Riddle void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 240abfd1a8bSRiver Riddle void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 241abfd1a8bSRiver Riddle void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 242abfd1a8bSRiver Riddle void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 243abfd1a8bSRiver Riddle void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 244abfd1a8bSRiver Riddle void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 24585ab413bSRiver Riddle void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 246abfd1a8bSRiver Riddle void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 24785ab413bSRiver Riddle void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 248abfd1a8bSRiver Riddle void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 2493a833a0eSRiver Riddle void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); 250abfd1a8bSRiver Riddle void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 251abfd1a8bSRiver Riddle void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 252abfd1a8bSRiver Riddle void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 253abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 254abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 25585ab413bSRiver Riddle void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 256abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 257abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 258abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 259abfd1a8bSRiver Riddle 260abfd1a8bSRiver Riddle /// Mapping from value to its corresponding memory index. 261abfd1a8bSRiver Riddle DenseMap<Value, ByteCodeField> valueToMemIndex; 262abfd1a8bSRiver Riddle 26385ab413bSRiver Riddle /// Mapping from a range value to its corresponding range storage index. 26485ab413bSRiver Riddle DenseMap<Value, ByteCodeField> valueToRangeIndex; 26585ab413bSRiver Riddle 266abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered rewrite to its index in 267abfd1a8bSRiver Riddle /// the bytecode registry. 268abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 269abfd1a8bSRiver Riddle 270abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered constraint to its index 271abfd1a8bSRiver Riddle /// in the bytecode registry. 272abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> constraintToMemIndex; 273abfd1a8bSRiver Riddle 274abfd1a8bSRiver Riddle /// Mapping from rewriter function name to the bytecode address of the 275abfd1a8bSRiver Riddle /// rewriter function in byte. 276abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeAddr> rewriterToAddr; 277abfd1a8bSRiver Riddle 278abfd1a8bSRiver Riddle /// Mapping from a uniqued storage object to its memory index within 279abfd1a8bSRiver Riddle /// `uniquedData`. 280abfd1a8bSRiver Riddle DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 281abfd1a8bSRiver Riddle 282abfd1a8bSRiver Riddle /// The current MLIR context. 283abfd1a8bSRiver Riddle MLIRContext *ctx; 284abfd1a8bSRiver Riddle 285abfd1a8bSRiver Riddle /// Data of the ByteCode class to be populated. 286abfd1a8bSRiver Riddle std::vector<const void *> &uniquedData; 287abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &matcherByteCode; 288abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &rewriterByteCode; 289abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCodePattern> &patterns; 290abfd1a8bSRiver Riddle ByteCodeField &maxValueMemoryIndex; 29185ab413bSRiver Riddle ByteCodeField &maxTypeRangeMemoryIndex; 29285ab413bSRiver Riddle ByteCodeField &maxValueRangeMemoryIndex; 293abfd1a8bSRiver Riddle }; 294abfd1a8bSRiver Riddle 295abfd1a8bSRiver Riddle /// This class provides utilities for writing a bytecode stream. 296abfd1a8bSRiver Riddle struct ByteCodeWriter { 297abfd1a8bSRiver Riddle ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 298abfd1a8bSRiver Riddle : bytecode(bytecode), generator(generator) {} 299abfd1a8bSRiver Riddle 300abfd1a8bSRiver Riddle /// Append a field to the bytecode. 301abfd1a8bSRiver Riddle void append(ByteCodeField field) { bytecode.push_back(field); } 302fa20ab7bSRiver Riddle void append(OpCode opCode) { bytecode.push_back(opCode); } 303abfd1a8bSRiver Riddle 304abfd1a8bSRiver Riddle /// Append an address to the bytecode. 305abfd1a8bSRiver Riddle void append(ByteCodeAddr field) { 306abfd1a8bSRiver Riddle static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 307abfd1a8bSRiver Riddle "unexpected ByteCode address size"); 308abfd1a8bSRiver Riddle 309abfd1a8bSRiver Riddle ByteCodeField fieldParts[2]; 310abfd1a8bSRiver Riddle std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 311abfd1a8bSRiver Riddle bytecode.append({fieldParts[0], fieldParts[1]}); 312abfd1a8bSRiver Riddle } 313abfd1a8bSRiver Riddle 314abfd1a8bSRiver Riddle /// Append a successor range to the bytecode, the exact address will need to 315abfd1a8bSRiver Riddle /// be resolved later. 316abfd1a8bSRiver Riddle void append(SuccessorRange successors) { 317abfd1a8bSRiver Riddle // Add back references to the any successors so that the address can be 318abfd1a8bSRiver Riddle // resolved later. 319abfd1a8bSRiver Riddle for (Block *successor : successors) { 320abfd1a8bSRiver Riddle unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 321abfd1a8bSRiver Riddle append(ByteCodeAddr(0)); 322abfd1a8bSRiver Riddle } 323abfd1a8bSRiver Riddle } 324abfd1a8bSRiver Riddle 325abfd1a8bSRiver Riddle /// Append a range of values that will be read as generic PDLValues. 326abfd1a8bSRiver Riddle void appendPDLValueList(OperandRange values) { 327abfd1a8bSRiver Riddle bytecode.push_back(values.size()); 32885ab413bSRiver Riddle for (Value value : values) 32985ab413bSRiver Riddle appendPDLValue(value); 33085ab413bSRiver Riddle } 33185ab413bSRiver Riddle 33285ab413bSRiver Riddle /// Append a value as a PDLValue. 33385ab413bSRiver Riddle void appendPDLValue(Value value) { 33485ab413bSRiver Riddle appendPDLValueKind(value); 335abfd1a8bSRiver Riddle append(value); 336abfd1a8bSRiver Riddle } 33785ab413bSRiver Riddle 33885ab413bSRiver Riddle /// Append the PDLValue::Kind of the given value. 33985ab413bSRiver Riddle void appendPDLValueKind(Value value) { 34085ab413bSRiver Riddle // Append the type of the value in addition to the value itself. 34185ab413bSRiver Riddle PDLValue::Kind kind = 34285ab413bSRiver Riddle TypeSwitch<Type, PDLValue::Kind>(value.getType()) 34385ab413bSRiver Riddle .Case<pdl::AttributeType>( 34485ab413bSRiver Riddle [](Type) { return PDLValue::Kind::Attribute; }) 34585ab413bSRiver Riddle .Case<pdl::OperationType>( 34685ab413bSRiver Riddle [](Type) { return PDLValue::Kind::Operation; }) 34785ab413bSRiver Riddle .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 34885ab413bSRiver Riddle if (rangeTy.getElementType().isa<pdl::TypeType>()) 34985ab413bSRiver Riddle return PDLValue::Kind::TypeRange; 35085ab413bSRiver Riddle return PDLValue::Kind::ValueRange; 35185ab413bSRiver Riddle }) 35285ab413bSRiver Riddle .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 35385ab413bSRiver Riddle .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 35485ab413bSRiver Riddle bytecode.push_back(static_cast<ByteCodeField>(kind)); 355abfd1a8bSRiver Riddle } 356abfd1a8bSRiver Riddle 357abfd1a8bSRiver Riddle /// Check if the given class `T` has an iterator type. 358abfd1a8bSRiver Riddle template <typename T, typename... Args> 359abfd1a8bSRiver Riddle using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 360abfd1a8bSRiver Riddle 361abfd1a8bSRiver Riddle /// Append a value that will be stored in a memory slot and not inline within 362abfd1a8bSRiver Riddle /// the bytecode. 363abfd1a8bSRiver Riddle template <typename T> 364abfd1a8bSRiver Riddle std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 365abfd1a8bSRiver Riddle std::is_pointer<T>::value> 366abfd1a8bSRiver Riddle append(T value) { 367abfd1a8bSRiver Riddle bytecode.push_back(generator.getMemIndex(value)); 368abfd1a8bSRiver Riddle } 369abfd1a8bSRiver Riddle 370abfd1a8bSRiver Riddle /// Append a range of values. 371abfd1a8bSRiver Riddle template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 372abfd1a8bSRiver Riddle std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 373abfd1a8bSRiver Riddle append(T range) { 374abfd1a8bSRiver Riddle bytecode.push_back(llvm::size(range)); 375abfd1a8bSRiver Riddle for (auto it : range) 376abfd1a8bSRiver Riddle append(it); 377abfd1a8bSRiver Riddle } 378abfd1a8bSRiver Riddle 379abfd1a8bSRiver Riddle /// Append a variadic number of fields to the bytecode. 380abfd1a8bSRiver Riddle template <typename FieldTy, typename Field2Ty, typename... FieldTys> 381abfd1a8bSRiver Riddle void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 382abfd1a8bSRiver Riddle append(field); 383abfd1a8bSRiver Riddle append(field2, fields...); 384abfd1a8bSRiver Riddle } 385abfd1a8bSRiver Riddle 386abfd1a8bSRiver Riddle /// Successor references in the bytecode that have yet to be resolved. 387abfd1a8bSRiver Riddle DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 388abfd1a8bSRiver Riddle 389abfd1a8bSRiver Riddle /// The underlying bytecode buffer. 390abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &bytecode; 391abfd1a8bSRiver Riddle 392abfd1a8bSRiver Riddle /// The main generator producing PDL. 393abfd1a8bSRiver Riddle Generator &generator; 394abfd1a8bSRiver Riddle }; 39585ab413bSRiver Riddle 39685ab413bSRiver Riddle /// This class represents a live range of PDL Interpreter values, containing 39785ab413bSRiver Riddle /// information about when values are live within a match/rewrite. 39885ab413bSRiver Riddle struct ByteCodeLiveRange { 39985ab413bSRiver Riddle using Set = llvm::IntervalMap<ByteCodeField, char, 16>; 40085ab413bSRiver Riddle using Allocator = Set::Allocator; 40185ab413bSRiver Riddle 40285ab413bSRiver Riddle ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {} 40385ab413bSRiver Riddle 40485ab413bSRiver Riddle /// Union this live range with the one provided. 40585ab413bSRiver Riddle void unionWith(const ByteCodeLiveRange &rhs) { 40685ab413bSRiver Riddle for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it) 40785ab413bSRiver Riddle liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0); 40885ab413bSRiver Riddle } 40985ab413bSRiver Riddle 41085ab413bSRiver Riddle /// Returns true if this range overlaps with the one provided. 41185ab413bSRiver Riddle bool overlaps(const ByteCodeLiveRange &rhs) const { 41285ab413bSRiver Riddle return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid(); 41385ab413bSRiver Riddle } 41485ab413bSRiver Riddle 41585ab413bSRiver Riddle /// A map representing the ranges of the match/rewrite that a value is live in 41685ab413bSRiver Riddle /// the interpreter. 41785ab413bSRiver Riddle llvm::IntervalMap<ByteCodeField, char, 16> liveness; 41885ab413bSRiver Riddle 41985ab413bSRiver Riddle /// The type range storage index for this range. 42085ab413bSRiver Riddle Optional<unsigned> typeRangeIndex; 42185ab413bSRiver Riddle 42285ab413bSRiver Riddle /// The value range storage index for this range. 42385ab413bSRiver Riddle Optional<unsigned> valueRangeIndex; 42485ab413bSRiver Riddle }; 425abfd1a8bSRiver Riddle } // end anonymous namespace 426abfd1a8bSRiver Riddle 427abfd1a8bSRiver Riddle void Generator::generate(ModuleOp module) { 428abfd1a8bSRiver Riddle FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 429abfd1a8bSRiver Riddle pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 430abfd1a8bSRiver Riddle ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 431abfd1a8bSRiver Riddle pdl_interp::PDLInterpDialect::getRewriterModuleName()); 432abfd1a8bSRiver Riddle assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 433abfd1a8bSRiver Riddle 434abfd1a8bSRiver Riddle // Allocate memory indices for the results of operations within the matcher 435abfd1a8bSRiver Riddle // and rewriters. 436abfd1a8bSRiver Riddle allocateMemoryIndices(matcherFunc, rewriterModule); 437abfd1a8bSRiver Riddle 438abfd1a8bSRiver Riddle // Generate code for the rewriter functions. 439abfd1a8bSRiver Riddle ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 440abfd1a8bSRiver Riddle for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 441abfd1a8bSRiver Riddle rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 442abfd1a8bSRiver Riddle for (Operation &op : rewriterFunc.getOps()) 443abfd1a8bSRiver Riddle generate(&op, rewriterByteCodeWriter); 444abfd1a8bSRiver Riddle } 445abfd1a8bSRiver Riddle assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 446abfd1a8bSRiver Riddle "unexpected branches in rewriter function"); 447abfd1a8bSRiver Riddle 448abfd1a8bSRiver Riddle // Generate code for the matcher function. 449abfd1a8bSRiver Riddle DenseMap<Block *, ByteCodeAddr> blockToAddr; 450abfd1a8bSRiver Riddle llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody()); 451abfd1a8bSRiver Riddle ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 452abfd1a8bSRiver Riddle for (Block *block : rpot) { 453abfd1a8bSRiver Riddle // Keep track of where this block begins within the matcher function. 454abfd1a8bSRiver Riddle blockToAddr.try_emplace(block, matcherByteCode.size()); 455abfd1a8bSRiver Riddle for (Operation &op : *block) 456abfd1a8bSRiver Riddle generate(&op, matcherByteCodeWriter); 457abfd1a8bSRiver Riddle } 458abfd1a8bSRiver Riddle 459abfd1a8bSRiver Riddle // Resolve successor references in the matcher. 460abfd1a8bSRiver Riddle for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 461abfd1a8bSRiver Riddle ByteCodeAddr addr = blockToAddr[it.first]; 462abfd1a8bSRiver Riddle for (unsigned offsetToFix : it.second) 463abfd1a8bSRiver Riddle std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 464abfd1a8bSRiver Riddle } 465abfd1a8bSRiver Riddle } 466abfd1a8bSRiver Riddle 467abfd1a8bSRiver Riddle void Generator::allocateMemoryIndices(FuncOp matcherFunc, 468abfd1a8bSRiver Riddle ModuleOp rewriterModule) { 469abfd1a8bSRiver Riddle // Rewriters use simplistic allocation scheme that simply assigns an index to 470abfd1a8bSRiver Riddle // each result. 471abfd1a8bSRiver Riddle for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 47285ab413bSRiver Riddle ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 47385ab413bSRiver Riddle auto processRewriterValue = [&](Value val) { 47485ab413bSRiver Riddle valueToMemIndex.try_emplace(val, index++); 47585ab413bSRiver Riddle if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 47685ab413bSRiver Riddle Type elementTy = rangeType.getElementType(); 47785ab413bSRiver Riddle if (elementTy.isa<pdl::TypeType>()) 47885ab413bSRiver Riddle valueToRangeIndex.try_emplace(val, typeRangeIndex++); 47985ab413bSRiver Riddle else if (elementTy.isa<pdl::ValueType>()) 48085ab413bSRiver Riddle valueToRangeIndex.try_emplace(val, valueRangeIndex++); 48185ab413bSRiver Riddle } 48285ab413bSRiver Riddle }; 48385ab413bSRiver Riddle 484abfd1a8bSRiver Riddle for (BlockArgument arg : rewriterFunc.getArguments()) 48585ab413bSRiver Riddle processRewriterValue(arg); 486abfd1a8bSRiver Riddle rewriterFunc.getBody().walk([&](Operation *op) { 487abfd1a8bSRiver Riddle for (Value result : op->getResults()) 48885ab413bSRiver Riddle processRewriterValue(result); 489abfd1a8bSRiver Riddle }); 490abfd1a8bSRiver Riddle if (index > maxValueMemoryIndex) 491abfd1a8bSRiver Riddle maxValueMemoryIndex = index; 49285ab413bSRiver Riddle if (typeRangeIndex > maxTypeRangeMemoryIndex) 49385ab413bSRiver Riddle maxTypeRangeMemoryIndex = typeRangeIndex; 49485ab413bSRiver Riddle if (valueRangeIndex > maxValueRangeMemoryIndex) 49585ab413bSRiver Riddle maxValueRangeMemoryIndex = valueRangeIndex; 496abfd1a8bSRiver Riddle } 497abfd1a8bSRiver Riddle 498abfd1a8bSRiver Riddle // The matcher function uses a more sophisticated numbering that tries to 499abfd1a8bSRiver Riddle // minimize the number of memory indices assigned. This is done by determining 500abfd1a8bSRiver Riddle // a live range of the values within the matcher, then the allocation is just 501abfd1a8bSRiver Riddle // finding the minimal number of overlapping live ranges. This is essentially 502abfd1a8bSRiver Riddle // a simplified form of register allocation where we don't necessarily have a 503abfd1a8bSRiver Riddle // limited number of registers, but we still want to minimize the number used. 504abfd1a8bSRiver Riddle DenseMap<Operation *, ByteCodeField> opToIndex; 505abfd1a8bSRiver Riddle matcherFunc.getBody().walk([&](Operation *op) { 506abfd1a8bSRiver Riddle opToIndex.insert(std::make_pair(op, opToIndex.size())); 507abfd1a8bSRiver Riddle }); 508abfd1a8bSRiver Riddle 509abfd1a8bSRiver Riddle // Liveness info for each of the defs within the matcher. 51085ab413bSRiver Riddle ByteCodeLiveRange::Allocator allocator; 51185ab413bSRiver Riddle DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 512abfd1a8bSRiver Riddle 513abfd1a8bSRiver Riddle // Assign the root operation being matched to slot 0. 514abfd1a8bSRiver Riddle BlockArgument rootOpArg = matcherFunc.getArgument(0); 515abfd1a8bSRiver Riddle valueToMemIndex[rootOpArg] = 0; 516abfd1a8bSRiver Riddle 517abfd1a8bSRiver Riddle // Walk each of the blocks, computing the def interval that the value is used. 518abfd1a8bSRiver Riddle Liveness matcherLiveness(matcherFunc); 519abfd1a8bSRiver Riddle for (Block &block : matcherFunc.getBody()) { 520abfd1a8bSRiver Riddle const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); 521abfd1a8bSRiver Riddle assert(info && "expected liveness info for block"); 522abfd1a8bSRiver Riddle auto processValue = [&](Value value, Operation *firstUseOrDef) { 523abfd1a8bSRiver Riddle // We don't need to process the root op argument, this value is always 524abfd1a8bSRiver Riddle // assigned to the first memory slot. 525abfd1a8bSRiver Riddle if (value == rootOpArg) 526abfd1a8bSRiver Riddle return; 527abfd1a8bSRiver Riddle 528abfd1a8bSRiver Riddle // Set indices for the range of this block that the value is used. 529abfd1a8bSRiver Riddle auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 53085ab413bSRiver Riddle defRangeIt->second.liveness.insert( 531abfd1a8bSRiver Riddle opToIndex[firstUseOrDef], 532abfd1a8bSRiver Riddle opToIndex[info->getEndOperation(value, firstUseOrDef)], 533abfd1a8bSRiver Riddle /*dummyValue*/ 0); 53485ab413bSRiver Riddle 53585ab413bSRiver Riddle // Check to see if this value is a range type. 53685ab413bSRiver Riddle if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 53785ab413bSRiver Riddle Type eleType = rangeTy.getElementType(); 53885ab413bSRiver Riddle if (eleType.isa<pdl::TypeType>()) 53985ab413bSRiver Riddle defRangeIt->second.typeRangeIndex = 0; 54085ab413bSRiver Riddle else if (eleType.isa<pdl::ValueType>()) 54185ab413bSRiver Riddle defRangeIt->second.valueRangeIndex = 0; 54285ab413bSRiver Riddle } 543abfd1a8bSRiver Riddle }; 544abfd1a8bSRiver Riddle 545abfd1a8bSRiver Riddle // Process the live-ins of this block. 546abfd1a8bSRiver Riddle for (Value liveIn : info->in()) 547abfd1a8bSRiver Riddle processValue(liveIn, &block.front()); 548abfd1a8bSRiver Riddle 549abfd1a8bSRiver Riddle // Process any new defs within this block. 550abfd1a8bSRiver Riddle for (Operation &op : block) 551abfd1a8bSRiver Riddle for (Value result : op.getResults()) 552abfd1a8bSRiver Riddle processValue(result, &op); 553abfd1a8bSRiver Riddle } 554abfd1a8bSRiver Riddle 555abfd1a8bSRiver Riddle // Greedily allocate memory slots using the computed def live ranges. 55685ab413bSRiver Riddle std::vector<ByteCodeLiveRange> allocatedIndices; 55785ab413bSRiver Riddle ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0; 558abfd1a8bSRiver Riddle for (auto &defIt : valueDefRanges) { 559abfd1a8bSRiver Riddle ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 56085ab413bSRiver Riddle ByteCodeLiveRange &defRange = defIt.second; 561abfd1a8bSRiver Riddle 562abfd1a8bSRiver Riddle // Try to allocate to an existing index. 563abfd1a8bSRiver Riddle for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { 56485ab413bSRiver Riddle ByteCodeLiveRange &existingRange = existingIndexIt.value(); 56585ab413bSRiver Riddle if (!defRange.overlaps(existingRange)) { 56685ab413bSRiver Riddle existingRange.unionWith(defRange); 567abfd1a8bSRiver Riddle memIndex = existingIndexIt.index() + 1; 56885ab413bSRiver Riddle 56985ab413bSRiver Riddle if (defRange.typeRangeIndex) { 57085ab413bSRiver Riddle if (!existingRange.typeRangeIndex) 57185ab413bSRiver Riddle existingRange.typeRangeIndex = numTypeRanges++; 57285ab413bSRiver Riddle valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 57385ab413bSRiver Riddle } else if (defRange.valueRangeIndex) { 57485ab413bSRiver Riddle if (!existingRange.valueRangeIndex) 57585ab413bSRiver Riddle existingRange.valueRangeIndex = numValueRanges++; 57685ab413bSRiver Riddle valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 57785ab413bSRiver Riddle } 57885ab413bSRiver Riddle break; 57985ab413bSRiver Riddle } 580abfd1a8bSRiver Riddle } 581abfd1a8bSRiver Riddle 582abfd1a8bSRiver Riddle // If no existing index could be used, add a new one. 583abfd1a8bSRiver Riddle if (memIndex == 0) { 584abfd1a8bSRiver Riddle allocatedIndices.emplace_back(allocator); 58585ab413bSRiver Riddle ByteCodeLiveRange &newRange = allocatedIndices.back(); 58685ab413bSRiver Riddle newRange.unionWith(defRange); 58785ab413bSRiver Riddle 58885ab413bSRiver Riddle // Allocate an index for type/value ranges. 58985ab413bSRiver Riddle if (defRange.typeRangeIndex) { 59085ab413bSRiver Riddle newRange.typeRangeIndex = numTypeRanges; 59185ab413bSRiver Riddle valueToRangeIndex[defIt.first] = numTypeRanges++; 59285ab413bSRiver Riddle } else if (defRange.valueRangeIndex) { 59385ab413bSRiver Riddle newRange.valueRangeIndex = numValueRanges; 59485ab413bSRiver Riddle valueToRangeIndex[defIt.first] = numValueRanges++; 59585ab413bSRiver Riddle } 59685ab413bSRiver Riddle 597abfd1a8bSRiver Riddle memIndex = allocatedIndices.size(); 59885ab413bSRiver Riddle ++numIndices; 599abfd1a8bSRiver Riddle } 600abfd1a8bSRiver Riddle } 601abfd1a8bSRiver Riddle 602abfd1a8bSRiver Riddle // Update the max number of indices. 60385ab413bSRiver Riddle if (numIndices > maxValueMemoryIndex) 60485ab413bSRiver Riddle maxValueMemoryIndex = numIndices; 60585ab413bSRiver Riddle if (numTypeRanges > maxTypeRangeMemoryIndex) 60685ab413bSRiver Riddle maxTypeRangeMemoryIndex = numTypeRanges; 60785ab413bSRiver Riddle if (numValueRanges > maxValueRangeMemoryIndex) 60885ab413bSRiver Riddle maxValueRangeMemoryIndex = numValueRanges; 609abfd1a8bSRiver Riddle } 610abfd1a8bSRiver Riddle 611abfd1a8bSRiver Riddle void Generator::generate(Operation *op, ByteCodeWriter &writer) { 612abfd1a8bSRiver Riddle TypeSwitch<Operation *>(op) 613abfd1a8bSRiver Riddle .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 614abfd1a8bSRiver Riddle pdl_interp::AreEqualOp, pdl_interp::BranchOp, 615abfd1a8bSRiver Riddle pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 616abfd1a8bSRiver Riddle pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 61785ab413bSRiver Riddle pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 61885ab413bSRiver Riddle pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp, 61985ab413bSRiver Riddle pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, 62002c4c0d5SRiver Riddle pdl_interp::EraseOp, pdl_interp::FinalizeOp, 62102c4c0d5SRiver Riddle pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, 62202c4c0d5SRiver Riddle pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, 62385ab413bSRiver Riddle pdl_interp::GetOperandsOp, pdl_interp::GetResultOp, 62485ab413bSRiver Riddle pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp, 6253a833a0eSRiver Riddle pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, 62602c4c0d5SRiver Riddle pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, 62702c4c0d5SRiver Riddle pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, 62885ab413bSRiver Riddle pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, 62985ab413bSRiver Riddle pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 630abfd1a8bSRiver Riddle [&](auto interpOp) { this->generate(interpOp, writer); }) 631abfd1a8bSRiver Riddle .Default([](Operation *) { 632abfd1a8bSRiver Riddle llvm_unreachable("unknown `pdl_interp` operation"); 633abfd1a8bSRiver Riddle }); 634abfd1a8bSRiver Riddle } 635abfd1a8bSRiver Riddle 636abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyConstraintOp op, 637abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 638abfd1a8bSRiver Riddle assert(constraintToMemIndex.count(op.name()) && 639abfd1a8bSRiver Riddle "expected index for constraint function"); 640abfd1a8bSRiver Riddle writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 641abfd1a8bSRiver Riddle op.constParamsAttr()); 642abfd1a8bSRiver Riddle writer.appendPDLValueList(op.args()); 643abfd1a8bSRiver Riddle writer.append(op.getSuccessors()); 644abfd1a8bSRiver Riddle } 645abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyRewriteOp op, 646abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 647abfd1a8bSRiver Riddle assert(externalRewriterToMemIndex.count(op.name()) && 648abfd1a8bSRiver Riddle "expected index for rewrite function"); 649abfd1a8bSRiver Riddle writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 65002c4c0d5SRiver Riddle op.constParamsAttr()); 651abfd1a8bSRiver Riddle writer.appendPDLValueList(op.args()); 65202c4c0d5SRiver Riddle 65385ab413bSRiver Riddle ResultRange results = op.results(); 65485ab413bSRiver Riddle writer.append(ByteCodeField(results.size())); 65585ab413bSRiver Riddle for (Value result : results) { 65685ab413bSRiver Riddle // In debug mode we also record the expected kind of the result, so that we 65785ab413bSRiver Riddle // can provide extra verification of the native rewrite function. 65802c4c0d5SRiver Riddle #ifndef NDEBUG 65985ab413bSRiver Riddle writer.appendPDLValueKind(result); 66002c4c0d5SRiver Riddle #endif 66185ab413bSRiver Riddle 66285ab413bSRiver Riddle // Range results also need to append the range storage index. 66385ab413bSRiver Riddle if (result.getType().isa<pdl::RangeType>()) 66485ab413bSRiver Riddle writer.append(getRangeStorageIndex(result)); 66502c4c0d5SRiver Riddle writer.append(result); 666abfd1a8bSRiver Riddle } 66785ab413bSRiver Riddle } 668abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 66985ab413bSRiver Riddle Value lhs = op.lhs(); 67085ab413bSRiver Riddle if (lhs.getType().isa<pdl::RangeType>()) { 67185ab413bSRiver Riddle writer.append(OpCode::AreRangesEqual); 67285ab413bSRiver Riddle writer.appendPDLValueKind(lhs); 67385ab413bSRiver Riddle writer.append(op.lhs(), op.rhs(), op.getSuccessors()); 67485ab413bSRiver Riddle return; 67585ab413bSRiver Riddle } 67685ab413bSRiver Riddle 67785ab413bSRiver Riddle writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors()); 678abfd1a8bSRiver Riddle } 679abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 6808affe881SRiver Riddle writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 681abfd1a8bSRiver Riddle } 682abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckAttributeOp op, 683abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 684abfd1a8bSRiver Riddle writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 685abfd1a8bSRiver Riddle op.getSuccessors()); 686abfd1a8bSRiver Riddle } 687abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperandCountOp op, 688abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 689abfd1a8bSRiver Riddle writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 69085ab413bSRiver Riddle static_cast<ByteCodeField>(op.compareAtLeast()), 691abfd1a8bSRiver Riddle op.getSuccessors()); 692abfd1a8bSRiver Riddle } 693abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperationNameOp op, 694abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 695abfd1a8bSRiver Riddle writer.append(OpCode::CheckOperationName, op.operation(), 696abfd1a8bSRiver Riddle OperationName(op.name(), ctx), op.getSuccessors()); 697abfd1a8bSRiver Riddle } 698abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckResultCountOp op, 699abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 700abfd1a8bSRiver Riddle writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 70185ab413bSRiver Riddle static_cast<ByteCodeField>(op.compareAtLeast()), 702abfd1a8bSRiver Riddle op.getSuccessors()); 703abfd1a8bSRiver Riddle } 704abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 705abfd1a8bSRiver Riddle writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 706abfd1a8bSRiver Riddle } 70785ab413bSRiver Riddle void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 70885ab413bSRiver Riddle writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); 70985ab413bSRiver Riddle } 710abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateAttributeOp op, 711abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 712abfd1a8bSRiver Riddle // Simply repoint the memory index of the result to the constant. 713abfd1a8bSRiver Riddle getMemIndex(op.attribute()) = getMemIndex(op.value()); 714abfd1a8bSRiver Riddle } 715abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateOperationOp op, 716abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 717abfd1a8bSRiver Riddle writer.append(OpCode::CreateOperation, op.operation(), 71885ab413bSRiver Riddle OperationName(op.name(), ctx)); 71985ab413bSRiver Riddle writer.appendPDLValueList(op.operands()); 720abfd1a8bSRiver Riddle 721abfd1a8bSRiver Riddle // Add the attributes. 722abfd1a8bSRiver Riddle OperandRange attributes = op.attributes(); 723abfd1a8bSRiver Riddle writer.append(static_cast<ByteCodeField>(attributes.size())); 724abfd1a8bSRiver Riddle for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 725abfd1a8bSRiver Riddle writer.append( 726abfd1a8bSRiver Riddle Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), 727abfd1a8bSRiver Riddle std::get<1>(it)); 728abfd1a8bSRiver Riddle } 72985ab413bSRiver Riddle writer.appendPDLValueList(op.types()); 730abfd1a8bSRiver Riddle } 731abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 732abfd1a8bSRiver Riddle // Simply repoint the memory index of the result to the constant. 733abfd1a8bSRiver Riddle getMemIndex(op.result()) = getMemIndex(op.value()); 734abfd1a8bSRiver Riddle } 73585ab413bSRiver Riddle void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 73685ab413bSRiver Riddle writer.append(OpCode::CreateTypes, op.result(), 73785ab413bSRiver Riddle getRangeStorageIndex(op.result()), op.value()); 73885ab413bSRiver Riddle } 739abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 740abfd1a8bSRiver Riddle writer.append(OpCode::EraseOp, op.operation()); 741abfd1a8bSRiver Riddle } 742abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 743abfd1a8bSRiver Riddle writer.append(OpCode::Finalize); 744abfd1a8bSRiver Riddle } 745abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeOp op, 746abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 747abfd1a8bSRiver Riddle writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 748abfd1a8bSRiver Riddle Identifier::get(op.name(), ctx)); 749abfd1a8bSRiver Riddle } 750abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeTypeOp op, 751abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 752abfd1a8bSRiver Riddle writer.append(OpCode::GetAttributeType, op.result(), op.value()); 753abfd1a8bSRiver Riddle } 754abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetDefiningOpOp op, 755abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 75685ab413bSRiver Riddle writer.append(OpCode::GetDefiningOp, op.operation()); 75785ab413bSRiver Riddle writer.appendPDLValue(op.value()); 758abfd1a8bSRiver Riddle } 759abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 760abfd1a8bSRiver Riddle uint32_t index = op.index(); 761abfd1a8bSRiver Riddle if (index < 4) 762abfd1a8bSRiver Riddle writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 763abfd1a8bSRiver Riddle else 764abfd1a8bSRiver Riddle writer.append(OpCode::GetOperandN, index); 765abfd1a8bSRiver Riddle writer.append(op.operation(), op.value()); 766abfd1a8bSRiver Riddle } 76785ab413bSRiver Riddle void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 76885ab413bSRiver Riddle Value result = op.value(); 76985ab413bSRiver Riddle Optional<uint32_t> index = op.index(); 77085ab413bSRiver Riddle writer.append(OpCode::GetOperands, 77185ab413bSRiver Riddle index.getValueOr(std::numeric_limits<uint32_t>::max()), 77285ab413bSRiver Riddle op.operation()); 77385ab413bSRiver Riddle if (result.getType().isa<pdl::RangeType>()) 77485ab413bSRiver Riddle writer.append(getRangeStorageIndex(result)); 77585ab413bSRiver Riddle else 77685ab413bSRiver Riddle writer.append(std::numeric_limits<ByteCodeField>::max()); 77785ab413bSRiver Riddle writer.append(result); 77885ab413bSRiver Riddle } 779abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 780abfd1a8bSRiver Riddle uint32_t index = op.index(); 781abfd1a8bSRiver Riddle if (index < 4) 782abfd1a8bSRiver Riddle writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 783abfd1a8bSRiver Riddle else 784abfd1a8bSRiver Riddle writer.append(OpCode::GetResultN, index); 785abfd1a8bSRiver Riddle writer.append(op.operation(), op.value()); 786abfd1a8bSRiver Riddle } 78785ab413bSRiver Riddle void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 78885ab413bSRiver Riddle Value result = op.value(); 78985ab413bSRiver Riddle Optional<uint32_t> index = op.index(); 79085ab413bSRiver Riddle writer.append(OpCode::GetResults, 79185ab413bSRiver Riddle index.getValueOr(std::numeric_limits<uint32_t>::max()), 79285ab413bSRiver Riddle op.operation()); 79385ab413bSRiver Riddle if (result.getType().isa<pdl::RangeType>()) 79485ab413bSRiver Riddle writer.append(getRangeStorageIndex(result)); 79585ab413bSRiver Riddle else 79685ab413bSRiver Riddle writer.append(std::numeric_limits<ByteCodeField>::max()); 79785ab413bSRiver Riddle writer.append(result); 79885ab413bSRiver Riddle } 799abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetValueTypeOp op, 800abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 80185ab413bSRiver Riddle if (op.getType().isa<pdl::RangeType>()) { 80285ab413bSRiver Riddle Value result = op.result(); 80385ab413bSRiver Riddle writer.append(OpCode::GetValueRangeTypes, result, 80485ab413bSRiver Riddle getRangeStorageIndex(result), op.value()); 80585ab413bSRiver Riddle } else { 806abfd1a8bSRiver Riddle writer.append(OpCode::GetValueType, op.result(), op.value()); 807abfd1a8bSRiver Riddle } 80885ab413bSRiver Riddle } 80985ab413bSRiver Riddle 8103a833a0eSRiver Riddle void Generator::generate(pdl_interp::InferredTypesOp op, 811abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 8123a833a0eSRiver Riddle // InferType maps to a null type as a marker for inferring result types. 813abfd1a8bSRiver Riddle getMemIndex(op.type()) = getMemIndex(Type()); 814abfd1a8bSRiver Riddle } 815abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 816abfd1a8bSRiver Riddle writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 817abfd1a8bSRiver Riddle } 818abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 819abfd1a8bSRiver Riddle ByteCodeField patternIndex = patterns.size(); 820abfd1a8bSRiver Riddle patterns.emplace_back(PDLByteCodePattern::create( 821*41d4aa7dSChris Lattner op, rewriterToAddr[op.rewriter().getLeafReference().getValue()])); 8228affe881SRiver Riddle writer.append(OpCode::RecordMatch, patternIndex, 82385ab413bSRiver Riddle SuccessorRange(op.getOperation()), op.matchedOps()); 82485ab413bSRiver Riddle writer.appendPDLValueList(op.inputs()); 825abfd1a8bSRiver Riddle } 826abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 82785ab413bSRiver Riddle writer.append(OpCode::ReplaceOp, op.operation()); 82885ab413bSRiver Riddle writer.appendPDLValueList(op.replValues()); 829abfd1a8bSRiver Riddle } 830abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchAttributeOp op, 831abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 832abfd1a8bSRiver Riddle writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 833abfd1a8bSRiver Riddle op.getSuccessors()); 834abfd1a8bSRiver Riddle } 835abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperandCountOp op, 836abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 837abfd1a8bSRiver Riddle writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 838abfd1a8bSRiver Riddle op.getSuccessors()); 839abfd1a8bSRiver Riddle } 840abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperationNameOp op, 841abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 842abfd1a8bSRiver Riddle auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 843abfd1a8bSRiver Riddle return OperationName(attr.cast<StringAttr>().getValue(), ctx); 844abfd1a8bSRiver Riddle }); 845abfd1a8bSRiver Riddle writer.append(OpCode::SwitchOperationName, op.operation(), cases, 846abfd1a8bSRiver Riddle op.getSuccessors()); 847abfd1a8bSRiver Riddle } 848abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchResultCountOp op, 849abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 850abfd1a8bSRiver Riddle writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 851abfd1a8bSRiver Riddle op.getSuccessors()); 852abfd1a8bSRiver Riddle } 853abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 854abfd1a8bSRiver Riddle writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 855abfd1a8bSRiver Riddle op.getSuccessors()); 856abfd1a8bSRiver Riddle } 85785ab413bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 85885ab413bSRiver Riddle writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(), 85985ab413bSRiver Riddle op.getSuccessors()); 86085ab413bSRiver Riddle } 861abfd1a8bSRiver Riddle 862abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 863abfd1a8bSRiver Riddle // PDLByteCode 864abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 865abfd1a8bSRiver Riddle 866abfd1a8bSRiver Riddle PDLByteCode::PDLByteCode(ModuleOp module, 867abfd1a8bSRiver Riddle llvm::StringMap<PDLConstraintFunction> constraintFns, 868abfd1a8bSRiver Riddle llvm::StringMap<PDLRewriteFunction> rewriteFns) { 869abfd1a8bSRiver Riddle Generator generator(module.getContext(), uniquedData, matcherByteCode, 870abfd1a8bSRiver Riddle rewriterByteCode, patterns, maxValueMemoryIndex, 87185ab413bSRiver Riddle maxTypeRangeCount, maxValueRangeCount, constraintFns, 87285ab413bSRiver Riddle rewriteFns); 873abfd1a8bSRiver Riddle generator.generate(module); 874abfd1a8bSRiver Riddle 875abfd1a8bSRiver Riddle // Initialize the external functions. 876abfd1a8bSRiver Riddle for (auto &it : constraintFns) 877abfd1a8bSRiver Riddle constraintFunctions.push_back(std::move(it.second)); 878abfd1a8bSRiver Riddle for (auto &it : rewriteFns) 879abfd1a8bSRiver Riddle rewriteFunctions.push_back(std::move(it.second)); 880abfd1a8bSRiver Riddle } 881abfd1a8bSRiver Riddle 882abfd1a8bSRiver Riddle /// Initialize the given state such that it can be used to execute the current 883abfd1a8bSRiver Riddle /// bytecode. 884abfd1a8bSRiver Riddle void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 885abfd1a8bSRiver Riddle state.memory.resize(maxValueMemoryIndex, nullptr); 88685ab413bSRiver Riddle state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 88785ab413bSRiver Riddle state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 888abfd1a8bSRiver Riddle state.currentPatternBenefits.reserve(patterns.size()); 889abfd1a8bSRiver Riddle for (const PDLByteCodePattern &pattern : patterns) 890abfd1a8bSRiver Riddle state.currentPatternBenefits.push_back(pattern.getBenefit()); 891abfd1a8bSRiver Riddle } 892abfd1a8bSRiver Riddle 893abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 894abfd1a8bSRiver Riddle // ByteCode Execution 895abfd1a8bSRiver Riddle 896abfd1a8bSRiver Riddle namespace { 897abfd1a8bSRiver Riddle /// This class provides support for executing a bytecode stream. 898abfd1a8bSRiver Riddle class ByteCodeExecutor { 899abfd1a8bSRiver Riddle public: 90085ab413bSRiver Riddle ByteCodeExecutor( 90185ab413bSRiver Riddle const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 90285ab413bSRiver Riddle MutableArrayRef<TypeRange> typeRangeMemory, 90385ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 90485ab413bSRiver Riddle MutableArrayRef<ValueRange> valueRangeMemory, 90585ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 90685ab413bSRiver Riddle ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code, 907abfd1a8bSRiver Riddle ArrayRef<PatternBenefit> currentPatternBenefits, 908abfd1a8bSRiver Riddle ArrayRef<PDLByteCodePattern> patterns, 909abfd1a8bSRiver Riddle ArrayRef<PDLConstraintFunction> constraintFunctions, 910abfd1a8bSRiver Riddle ArrayRef<PDLRewriteFunction> rewriteFunctions) 91185ab413bSRiver Riddle : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory), 91285ab413bSRiver Riddle allocatedTypeRangeMemory(allocatedTypeRangeMemory), 91385ab413bSRiver Riddle valueRangeMemory(valueRangeMemory), 91485ab413bSRiver Riddle allocatedValueRangeMemory(allocatedValueRangeMemory), 91585ab413bSRiver Riddle uniquedMemory(uniquedMemory), code(code), 91685ab413bSRiver Riddle currentPatternBenefits(currentPatternBenefits), patterns(patterns), 91785ab413bSRiver Riddle constraintFunctions(constraintFunctions), 91802c4c0d5SRiver Riddle rewriteFunctions(rewriteFunctions) {} 919abfd1a8bSRiver Riddle 920abfd1a8bSRiver Riddle /// Start executing the code at the current bytecode index. `matches` is an 921abfd1a8bSRiver Riddle /// optional field provided when this function is executed in a matching 922abfd1a8bSRiver Riddle /// context. 923abfd1a8bSRiver Riddle void execute(PatternRewriter &rewriter, 924abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 925abfd1a8bSRiver Riddle Optional<Location> mainRewriteLoc = {}); 926abfd1a8bSRiver Riddle 927abfd1a8bSRiver Riddle private: 928154cabe7SRiver Riddle /// Internal implementation of executing each of the bytecode commands. 929154cabe7SRiver Riddle void executeApplyConstraint(PatternRewriter &rewriter); 930154cabe7SRiver Riddle void executeApplyRewrite(PatternRewriter &rewriter); 931154cabe7SRiver Riddle void executeAreEqual(); 93285ab413bSRiver Riddle void executeAreRangesEqual(); 933154cabe7SRiver Riddle void executeBranch(); 934154cabe7SRiver Riddle void executeCheckOperandCount(); 935154cabe7SRiver Riddle void executeCheckOperationName(); 936154cabe7SRiver Riddle void executeCheckResultCount(); 93785ab413bSRiver Riddle void executeCheckTypes(); 938154cabe7SRiver Riddle void executeCreateOperation(PatternRewriter &rewriter, 939154cabe7SRiver Riddle Location mainRewriteLoc); 94085ab413bSRiver Riddle void executeCreateTypes(); 941154cabe7SRiver Riddle void executeEraseOp(PatternRewriter &rewriter); 942154cabe7SRiver Riddle void executeGetAttribute(); 943154cabe7SRiver Riddle void executeGetAttributeType(); 944154cabe7SRiver Riddle void executeGetDefiningOp(); 945154cabe7SRiver Riddle void executeGetOperand(unsigned index); 94685ab413bSRiver Riddle void executeGetOperands(); 947154cabe7SRiver Riddle void executeGetResult(unsigned index); 94885ab413bSRiver Riddle void executeGetResults(); 949154cabe7SRiver Riddle void executeGetValueType(); 95085ab413bSRiver Riddle void executeGetValueRangeTypes(); 951154cabe7SRiver Riddle void executeIsNotNull(); 952154cabe7SRiver Riddle void executeRecordMatch(PatternRewriter &rewriter, 953154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> &matches); 954154cabe7SRiver Riddle void executeReplaceOp(PatternRewriter &rewriter); 955154cabe7SRiver Riddle void executeSwitchAttribute(); 956154cabe7SRiver Riddle void executeSwitchOperandCount(); 957154cabe7SRiver Riddle void executeSwitchOperationName(); 958154cabe7SRiver Riddle void executeSwitchResultCount(); 959154cabe7SRiver Riddle void executeSwitchType(); 96085ab413bSRiver Riddle void executeSwitchTypes(); 961154cabe7SRiver Riddle 962abfd1a8bSRiver Riddle /// Read a value from the bytecode buffer, optionally skipping a certain 963abfd1a8bSRiver Riddle /// number of prefix values. These methods always update the buffer to point 964abfd1a8bSRiver Riddle /// to the next field after the read data. 965abfd1a8bSRiver Riddle template <typename T = ByteCodeField> 966abfd1a8bSRiver Riddle T read(size_t skipN = 0) { 967abfd1a8bSRiver Riddle curCodeIt += skipN; 968abfd1a8bSRiver Riddle return readImpl<T>(); 969abfd1a8bSRiver Riddle } 970abfd1a8bSRiver Riddle ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 971abfd1a8bSRiver Riddle 972abfd1a8bSRiver Riddle /// Read a list of values from the bytecode buffer. 973abfd1a8bSRiver Riddle template <typename ValueT, typename T> 974abfd1a8bSRiver Riddle void readList(SmallVectorImpl<T> &list) { 975abfd1a8bSRiver Riddle list.clear(); 976abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) 977abfd1a8bSRiver Riddle list.push_back(read<ValueT>()); 978abfd1a8bSRiver Riddle } 979abfd1a8bSRiver Riddle 98085ab413bSRiver Riddle /// Read a list of values from the bytecode buffer. The values may be encoded 98185ab413bSRiver Riddle /// as either Value or ValueRange elements. 98285ab413bSRiver Riddle void readValueList(SmallVectorImpl<Value> &list) { 98385ab413bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 98485ab413bSRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 98585ab413bSRiver Riddle list.push_back(read<Value>()); 98685ab413bSRiver Riddle } else { 98785ab413bSRiver Riddle ValueRange *values = read<ValueRange *>(); 98885ab413bSRiver Riddle list.append(values->begin(), values->end()); 98985ab413bSRiver Riddle } 99085ab413bSRiver Riddle } 99185ab413bSRiver Riddle } 99285ab413bSRiver Riddle 993abfd1a8bSRiver Riddle /// Jump to a specific successor based on a predicate value. 994abfd1a8bSRiver Riddle void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 995abfd1a8bSRiver Riddle /// Jump to a specific successor based on a destination index. 996abfd1a8bSRiver Riddle void selectJump(size_t destIndex) { 997abfd1a8bSRiver Riddle curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 998abfd1a8bSRiver Riddle } 999abfd1a8bSRiver Riddle 1000abfd1a8bSRiver Riddle /// Handle a switch operation with the provided value and cases. 100185ab413bSRiver Riddle template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 100285ab413bSRiver Riddle void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1003abfd1a8bSRiver Riddle LLVM_DEBUG({ 1004abfd1a8bSRiver Riddle llvm::dbgs() << " * Value: " << value << "\n" 1005abfd1a8bSRiver Riddle << " * Cases: "; 1006abfd1a8bSRiver Riddle llvm::interleaveComma(cases, llvm::dbgs()); 1007154cabe7SRiver Riddle llvm::dbgs() << "\n"; 1008abfd1a8bSRiver Riddle }); 1009abfd1a8bSRiver Riddle 1010abfd1a8bSRiver Riddle // Check to see if the attribute value is within the case list. Jump to 1011abfd1a8bSRiver Riddle // the correct successor index based on the result. 1012f80b6304SRiver Riddle for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 101385ab413bSRiver Riddle if (cmp(*it, value)) 1014f80b6304SRiver Riddle return selectJump(size_t((it - cases.begin()) + 1)); 1015f80b6304SRiver Riddle selectJump(size_t(0)); 1016abfd1a8bSRiver Riddle } 1017abfd1a8bSRiver Riddle 1018abfd1a8bSRiver Riddle /// Internal implementation of reading various data types from the bytecode 1019abfd1a8bSRiver Riddle /// stream. 1020abfd1a8bSRiver Riddle template <typename T> 1021abfd1a8bSRiver Riddle const void *readFromMemory() { 1022abfd1a8bSRiver Riddle size_t index = *curCodeIt++; 1023abfd1a8bSRiver Riddle 1024abfd1a8bSRiver Riddle // If this type is an SSA value, it can only be stored in non-const memory. 102585ab413bSRiver Riddle if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 102685ab413bSRiver Riddle Value>::value || 102785ab413bSRiver Riddle index < memory.size()) 1028abfd1a8bSRiver Riddle return memory[index]; 1029abfd1a8bSRiver Riddle 1030abfd1a8bSRiver Riddle // Otherwise, if this index is not inbounds it is uniqued. 1031abfd1a8bSRiver Riddle return uniquedMemory[index - memory.size()]; 1032abfd1a8bSRiver Riddle } 1033abfd1a8bSRiver Riddle template <typename T> 1034abfd1a8bSRiver Riddle std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1035abfd1a8bSRiver Riddle return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1036abfd1a8bSRiver Riddle } 1037abfd1a8bSRiver Riddle template <typename T> 1038abfd1a8bSRiver Riddle std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1039abfd1a8bSRiver Riddle T> 1040abfd1a8bSRiver Riddle readImpl() { 1041abfd1a8bSRiver Riddle return T(T::getFromOpaquePointer(readFromMemory<T>())); 1042abfd1a8bSRiver Riddle } 1043abfd1a8bSRiver Riddle template <typename T> 1044abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 104585ab413bSRiver Riddle switch (read<PDLValue::Kind>()) { 104685ab413bSRiver Riddle case PDLValue::Kind::Attribute: 1047abfd1a8bSRiver Riddle return read<Attribute>(); 104885ab413bSRiver Riddle case PDLValue::Kind::Operation: 1049abfd1a8bSRiver Riddle return read<Operation *>(); 105085ab413bSRiver Riddle case PDLValue::Kind::Type: 1051abfd1a8bSRiver Riddle return read<Type>(); 105285ab413bSRiver Riddle case PDLValue::Kind::Value: 1053abfd1a8bSRiver Riddle return read<Value>(); 105485ab413bSRiver Riddle case PDLValue::Kind::TypeRange: 105585ab413bSRiver Riddle return read<TypeRange *>(); 105685ab413bSRiver Riddle case PDLValue::Kind::ValueRange: 105785ab413bSRiver Riddle return read<ValueRange *>(); 1058abfd1a8bSRiver Riddle } 105985ab413bSRiver Riddle llvm_unreachable("unhandled PDLValue::Kind"); 1060abfd1a8bSRiver Riddle } 1061abfd1a8bSRiver Riddle template <typename T> 1062abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1063abfd1a8bSRiver Riddle static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1064abfd1a8bSRiver Riddle "unexpected ByteCode address size"); 1065abfd1a8bSRiver Riddle ByteCodeAddr result; 1066abfd1a8bSRiver Riddle std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1067abfd1a8bSRiver Riddle curCodeIt += 2; 1068abfd1a8bSRiver Riddle return result; 1069abfd1a8bSRiver Riddle } 1070abfd1a8bSRiver Riddle template <typename T> 1071abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1072abfd1a8bSRiver Riddle return *curCodeIt++; 1073abfd1a8bSRiver Riddle } 107485ab413bSRiver Riddle template <typename T> 107585ab413bSRiver Riddle std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 107685ab413bSRiver Riddle return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 107785ab413bSRiver Riddle } 1078abfd1a8bSRiver Riddle 1079abfd1a8bSRiver Riddle /// The underlying bytecode buffer. 1080abfd1a8bSRiver Riddle const ByteCodeField *curCodeIt; 1081abfd1a8bSRiver Riddle 1082abfd1a8bSRiver Riddle /// The current execution memory. 1083abfd1a8bSRiver Riddle MutableArrayRef<const void *> memory; 108485ab413bSRiver Riddle MutableArrayRef<TypeRange> typeRangeMemory; 108585ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 108685ab413bSRiver Riddle MutableArrayRef<ValueRange> valueRangeMemory; 108785ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1088abfd1a8bSRiver Riddle 1089abfd1a8bSRiver Riddle /// References to ByteCode data necessary for execution. 1090abfd1a8bSRiver Riddle ArrayRef<const void *> uniquedMemory; 1091abfd1a8bSRiver Riddle ArrayRef<ByteCodeField> code; 1092abfd1a8bSRiver Riddle ArrayRef<PatternBenefit> currentPatternBenefits; 1093abfd1a8bSRiver Riddle ArrayRef<PDLByteCodePattern> patterns; 1094abfd1a8bSRiver Riddle ArrayRef<PDLConstraintFunction> constraintFunctions; 1095abfd1a8bSRiver Riddle ArrayRef<PDLRewriteFunction> rewriteFunctions; 1096abfd1a8bSRiver Riddle }; 109702c4c0d5SRiver Riddle 109802c4c0d5SRiver Riddle /// This class is an instantiation of the PDLResultList that provides access to 109902c4c0d5SRiver Riddle /// the returned results. This API is not on `PDLResultList` to avoid 110002c4c0d5SRiver Riddle /// overexposing access to information specific solely to the ByteCode. 110102c4c0d5SRiver Riddle class ByteCodeRewriteResultList : public PDLResultList { 110202c4c0d5SRiver Riddle public: 110385ab413bSRiver Riddle ByteCodeRewriteResultList(unsigned maxNumResults) 110485ab413bSRiver Riddle : PDLResultList(maxNumResults) {} 110585ab413bSRiver Riddle 110602c4c0d5SRiver Riddle /// Return the list of PDL results. 110702c4c0d5SRiver Riddle MutableArrayRef<PDLValue> getResults() { return results; } 110885ab413bSRiver Riddle 110985ab413bSRiver Riddle /// Return the type ranges allocated by this list. 111085ab413bSRiver Riddle MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 111185ab413bSRiver Riddle return allocatedTypeRanges; 111285ab413bSRiver Riddle } 111385ab413bSRiver Riddle 111485ab413bSRiver Riddle /// Return the value ranges allocated by this list. 111585ab413bSRiver Riddle MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 111685ab413bSRiver Riddle return allocatedValueRanges; 111785ab413bSRiver Riddle } 111802c4c0d5SRiver Riddle }; 1119abfd1a8bSRiver Riddle } // end anonymous namespace 1120abfd1a8bSRiver Riddle 1121154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1122abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1123abfd1a8bSRiver Riddle const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1124abfd1a8bSRiver Riddle ArrayAttr constParams = read<ArrayAttr>(); 1125abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args; 1126abfd1a8bSRiver Riddle readList<PDLValue>(args); 1127154cabe7SRiver Riddle 1128abfd1a8bSRiver Riddle LLVM_DEBUG({ 1129abfd1a8bSRiver Riddle llvm::dbgs() << " * Arguments: "; 1130abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 1131154cabe7SRiver Riddle llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1132abfd1a8bSRiver Riddle }); 1133abfd1a8bSRiver Riddle 1134abfd1a8bSRiver Riddle // Invoke the constraint and jump to the proper destination. 1135abfd1a8bSRiver Riddle selectJump(succeeded(constraintFn(args, constParams, rewriter))); 1136abfd1a8bSRiver Riddle } 1137154cabe7SRiver Riddle 1138154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1139abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1140abfd1a8bSRiver Riddle const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1141abfd1a8bSRiver Riddle ArrayAttr constParams = read<ArrayAttr>(); 1142abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args; 1143abfd1a8bSRiver Riddle readList<PDLValue>(args); 1144abfd1a8bSRiver Riddle 1145abfd1a8bSRiver Riddle LLVM_DEBUG({ 114602c4c0d5SRiver Riddle llvm::dbgs() << " * Arguments: "; 1147abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 1148154cabe7SRiver Riddle llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1149abfd1a8bSRiver Riddle }); 115085ab413bSRiver Riddle 115185ab413bSRiver Riddle // Execute the rewrite function. 115285ab413bSRiver Riddle ByteCodeField numResults = read(); 115385ab413bSRiver Riddle ByteCodeRewriteResultList results(numResults); 115402c4c0d5SRiver Riddle rewriteFn(args, constParams, rewriter, results); 1155154cabe7SRiver Riddle 115685ab413bSRiver Riddle assert(results.getResults().size() == numResults && 115702c4c0d5SRiver Riddle "native PDL rewrite function returned unexpected number of results"); 115802c4c0d5SRiver Riddle 115902c4c0d5SRiver Riddle // Store the results in the bytecode memory. 116002c4c0d5SRiver Riddle for (PDLValue &result : results.getResults()) { 116102c4c0d5SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 116285ab413bSRiver Riddle 116385ab413bSRiver Riddle // In debug mode we also verify the expected kind of the result. 116485ab413bSRiver Riddle #ifndef NDEBUG 116585ab413bSRiver Riddle assert(result.getKind() == read<PDLValue::Kind>() && 116685ab413bSRiver Riddle "native PDL rewrite function returned an unexpected type of result"); 116785ab413bSRiver Riddle #endif 116885ab413bSRiver Riddle 116985ab413bSRiver Riddle // If the result is a range, we need to copy it over to the bytecodes 117085ab413bSRiver Riddle // range memory. 117185ab413bSRiver Riddle if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 117285ab413bSRiver Riddle unsigned rangeIndex = read(); 117385ab413bSRiver Riddle typeRangeMemory[rangeIndex] = *typeRange; 117485ab413bSRiver Riddle memory[read()] = &typeRangeMemory[rangeIndex]; 117585ab413bSRiver Riddle } else if (Optional<ValueRange> valueRange = 117685ab413bSRiver Riddle result.dyn_cast<ValueRange>()) { 117785ab413bSRiver Riddle unsigned rangeIndex = read(); 117885ab413bSRiver Riddle valueRangeMemory[rangeIndex] = *valueRange; 117985ab413bSRiver Riddle memory[read()] = &valueRangeMemory[rangeIndex]; 118085ab413bSRiver Riddle } else { 118102c4c0d5SRiver Riddle memory[read()] = result.getAsOpaquePointer(); 118202c4c0d5SRiver Riddle } 1183abfd1a8bSRiver Riddle } 1184154cabe7SRiver Riddle 118585ab413bSRiver Riddle // Copy over any underlying storage allocated for result ranges. 118685ab413bSRiver Riddle for (auto &it : results.getAllocatedTypeRanges()) 118785ab413bSRiver Riddle allocatedTypeRangeMemory.push_back(std::move(it)); 118885ab413bSRiver Riddle for (auto &it : results.getAllocatedValueRanges()) 118985ab413bSRiver Riddle allocatedValueRangeMemory.push_back(std::move(it)); 119085ab413bSRiver Riddle } 119185ab413bSRiver Riddle 1192154cabe7SRiver Riddle void ByteCodeExecutor::executeAreEqual() { 1193abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1194abfd1a8bSRiver Riddle const void *lhs = read<const void *>(); 1195abfd1a8bSRiver Riddle const void *rhs = read<const void *>(); 1196abfd1a8bSRiver Riddle 1197154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1198abfd1a8bSRiver Riddle selectJump(lhs == rhs); 1199abfd1a8bSRiver Riddle } 1200154cabe7SRiver Riddle 120185ab413bSRiver Riddle void ByteCodeExecutor::executeAreRangesEqual() { 120285ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 120385ab413bSRiver Riddle PDLValue::Kind valueKind = read<PDLValue::Kind>(); 120485ab413bSRiver Riddle const void *lhs = read<const void *>(); 120585ab413bSRiver Riddle const void *rhs = read<const void *>(); 120685ab413bSRiver Riddle 120785ab413bSRiver Riddle switch (valueKind) { 120885ab413bSRiver Riddle case PDLValue::Kind::TypeRange: { 120985ab413bSRiver Riddle const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 121085ab413bSRiver Riddle const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 121185ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 121285ab413bSRiver Riddle selectJump(*lhsRange == *rhsRange); 121385ab413bSRiver Riddle break; 121485ab413bSRiver Riddle } 121585ab413bSRiver Riddle case PDLValue::Kind::ValueRange: { 121685ab413bSRiver Riddle const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 121785ab413bSRiver Riddle const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 121885ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 121985ab413bSRiver Riddle selectJump(*lhsRange == *rhsRange); 122085ab413bSRiver Riddle break; 122185ab413bSRiver Riddle } 122285ab413bSRiver Riddle default: 122385ab413bSRiver Riddle llvm_unreachable("unexpected `AreRangesEqual` value kind"); 122485ab413bSRiver Riddle } 122585ab413bSRiver Riddle } 122685ab413bSRiver Riddle 1227154cabe7SRiver Riddle void ByteCodeExecutor::executeBranch() { 1228154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1229abfd1a8bSRiver Riddle curCodeIt = &code[read<ByteCodeAddr>()]; 1230abfd1a8bSRiver Riddle } 1231154cabe7SRiver Riddle 1232154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperandCount() { 1233abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1234abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1235abfd1a8bSRiver Riddle uint32_t expectedCount = read<uint32_t>(); 123685ab413bSRiver Riddle bool compareAtLeast = read(); 1237abfd1a8bSRiver Riddle 1238abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 123985ab413bSRiver Riddle << " * Expected: " << expectedCount << "\n" 124085ab413bSRiver Riddle << " * Comparator: " 124185ab413bSRiver Riddle << (compareAtLeast ? ">=" : "==") << "\n"); 124285ab413bSRiver Riddle if (compareAtLeast) 124385ab413bSRiver Riddle selectJump(op->getNumOperands() >= expectedCount); 124485ab413bSRiver Riddle else 1245abfd1a8bSRiver Riddle selectJump(op->getNumOperands() == expectedCount); 1246abfd1a8bSRiver Riddle } 1247154cabe7SRiver Riddle 1248154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperationName() { 1249abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1250abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1251abfd1a8bSRiver Riddle OperationName expectedName = read<OperationName>(); 1252abfd1a8bSRiver Riddle 1253154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1254154cabe7SRiver Riddle << " * Expected: \"" << expectedName << "\"\n"); 1255abfd1a8bSRiver Riddle selectJump(op->getName() == expectedName); 1256abfd1a8bSRiver Riddle } 1257154cabe7SRiver Riddle 1258154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckResultCount() { 1259abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1260abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1261abfd1a8bSRiver Riddle uint32_t expectedCount = read<uint32_t>(); 126285ab413bSRiver Riddle bool compareAtLeast = read(); 1263abfd1a8bSRiver Riddle 1264abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 126585ab413bSRiver Riddle << " * Expected: " << expectedCount << "\n" 126685ab413bSRiver Riddle << " * Comparator: " 126785ab413bSRiver Riddle << (compareAtLeast ? ">=" : "==") << "\n"); 126885ab413bSRiver Riddle if (compareAtLeast) 126985ab413bSRiver Riddle selectJump(op->getNumResults() >= expectedCount); 127085ab413bSRiver Riddle else 1271abfd1a8bSRiver Riddle selectJump(op->getNumResults() == expectedCount); 1272abfd1a8bSRiver Riddle } 1273154cabe7SRiver Riddle 127485ab413bSRiver Riddle void ByteCodeExecutor::executeCheckTypes() { 127585ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 127685ab413bSRiver Riddle TypeRange *lhs = read<TypeRange *>(); 127785ab413bSRiver Riddle Attribute rhs = read<Attribute>(); 127885ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 127985ab413bSRiver Riddle 128085ab413bSRiver Riddle selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 128185ab413bSRiver Riddle } 128285ab413bSRiver Riddle 128385ab413bSRiver Riddle void ByteCodeExecutor::executeCreateTypes() { 128485ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 128585ab413bSRiver Riddle unsigned memIndex = read(); 128685ab413bSRiver Riddle unsigned rangeIndex = read(); 128785ab413bSRiver Riddle ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 128885ab413bSRiver Riddle 128985ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 129085ab413bSRiver Riddle 129185ab413bSRiver Riddle // Allocate a buffer for this type range. 129285ab413bSRiver Riddle llvm::OwningArrayRef<Type> storage(typesAttr.size()); 129385ab413bSRiver Riddle llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 129485ab413bSRiver Riddle allocatedTypeRangeMemory.emplace_back(std::move(storage)); 129585ab413bSRiver Riddle 129685ab413bSRiver Riddle // Assign this to the range slot and use the range as the value for the 129785ab413bSRiver Riddle // memory index. 129885ab413bSRiver Riddle typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 129985ab413bSRiver Riddle memory[memIndex] = &typeRangeMemory[rangeIndex]; 130085ab413bSRiver Riddle } 130185ab413bSRiver Riddle 1302154cabe7SRiver Riddle void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1303154cabe7SRiver Riddle Location mainRewriteLoc) { 1304abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1305abfd1a8bSRiver Riddle 1306abfd1a8bSRiver Riddle unsigned memIndex = read(); 1307154cabe7SRiver Riddle OperationState state(mainRewriteLoc, read<OperationName>()); 130885ab413bSRiver Riddle readValueList(state.operands); 1309abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 1310abfd1a8bSRiver Riddle Identifier name = read<Identifier>(); 1311abfd1a8bSRiver Riddle if (Attribute attr = read<Attribute>()) 1312abfd1a8bSRiver Riddle state.addAttribute(name, attr); 1313abfd1a8bSRiver Riddle } 1314abfd1a8bSRiver Riddle 1315abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 131685ab413bSRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 131785ab413bSRiver Riddle state.types.push_back(read<Type>()); 131885ab413bSRiver Riddle continue; 131985ab413bSRiver Riddle } 132085ab413bSRiver Riddle 132185ab413bSRiver Riddle // If we find a null range, this signals that the types are infered. 132285ab413bSRiver Riddle if (TypeRange *resultTypes = read<TypeRange *>()) { 132385ab413bSRiver Riddle state.types.append(resultTypes->begin(), resultTypes->end()); 132485ab413bSRiver Riddle continue; 1325abfd1a8bSRiver Riddle } 1326abfd1a8bSRiver Riddle 1327abfd1a8bSRiver Riddle // Handle the case where the operation has inferred types. 1328abfd1a8bSRiver Riddle InferTypeOpInterface::Concept *concept = 1329154cabe7SRiver Riddle state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>(); 1330abfd1a8bSRiver Riddle 1331abfd1a8bSRiver Riddle // TODO: Handle failure. 13323a833a0eSRiver Riddle state.types.clear(); 1333abfd1a8bSRiver Riddle if (failed(concept->inferReturnTypes( 1334abfd1a8bSRiver Riddle state.getContext(), state.location, state.operands, 1335154cabe7SRiver Riddle state.attributes.getDictionary(state.getContext()), state.regions, 13363a833a0eSRiver Riddle state.types))) 1337abfd1a8bSRiver Riddle return; 133885ab413bSRiver Riddle break; 1339abfd1a8bSRiver Riddle } 134085ab413bSRiver Riddle 1341abfd1a8bSRiver Riddle Operation *resultOp = rewriter.createOperation(state); 1342abfd1a8bSRiver Riddle memory[memIndex] = resultOp; 1343abfd1a8bSRiver Riddle 1344abfd1a8bSRiver Riddle LLVM_DEBUG({ 1345abfd1a8bSRiver Riddle llvm::dbgs() << " * Attributes: " 1346abfd1a8bSRiver Riddle << state.attributes.getDictionary(state.getContext()) 1347abfd1a8bSRiver Riddle << "\n * Operands: "; 1348abfd1a8bSRiver Riddle llvm::interleaveComma(state.operands, llvm::dbgs()); 1349abfd1a8bSRiver Riddle llvm::dbgs() << "\n * Result Types: "; 1350abfd1a8bSRiver Riddle llvm::interleaveComma(state.types, llvm::dbgs()); 1351154cabe7SRiver Riddle llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1352abfd1a8bSRiver Riddle }); 1353abfd1a8bSRiver Riddle } 1354154cabe7SRiver Riddle 1355154cabe7SRiver Riddle void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1356abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1357abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1358abfd1a8bSRiver Riddle 1359154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1360abfd1a8bSRiver Riddle rewriter.eraseOp(op); 1361abfd1a8bSRiver Riddle } 1362154cabe7SRiver Riddle 1363154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttribute() { 1364abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1365abfd1a8bSRiver Riddle unsigned memIndex = read(); 1366abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1367abfd1a8bSRiver Riddle Identifier attrName = read<Identifier>(); 1368abfd1a8bSRiver Riddle Attribute attr = op->getAttr(attrName); 1369abfd1a8bSRiver Riddle 1370abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1371abfd1a8bSRiver Riddle << " * Attribute: " << attrName << "\n" 1372154cabe7SRiver Riddle << " * Result: " << attr << "\n"); 1373abfd1a8bSRiver Riddle memory[memIndex] = attr.getAsOpaquePointer(); 1374abfd1a8bSRiver Riddle } 1375154cabe7SRiver Riddle 1376154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttributeType() { 1377abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1378abfd1a8bSRiver Riddle unsigned memIndex = read(); 1379abfd1a8bSRiver Riddle Attribute attr = read<Attribute>(); 1380154cabe7SRiver Riddle Type type = attr ? attr.getType() : Type(); 1381abfd1a8bSRiver Riddle 1382abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1383154cabe7SRiver Riddle << " * Result: " << type << "\n"); 1384154cabe7SRiver Riddle memory[memIndex] = type.getAsOpaquePointer(); 1385abfd1a8bSRiver Riddle } 1386154cabe7SRiver Riddle 1387154cabe7SRiver Riddle void ByteCodeExecutor::executeGetDefiningOp() { 1388abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1389abfd1a8bSRiver Riddle unsigned memIndex = read(); 139085ab413bSRiver Riddle Operation *op = nullptr; 139185ab413bSRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1392abfd1a8bSRiver Riddle Value value = read<Value>(); 139385ab413bSRiver Riddle if (value) 139485ab413bSRiver Riddle op = value.getDefiningOp(); 139585ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 139685ab413bSRiver Riddle } else { 139785ab413bSRiver Riddle ValueRange *values = read<ValueRange *>(); 139885ab413bSRiver Riddle if (values && !values->empty()) { 139985ab413bSRiver Riddle op = values->front().getDefiningOp(); 140085ab413bSRiver Riddle } 140185ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 140285ab413bSRiver Riddle } 1403abfd1a8bSRiver Riddle 140485ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1405abfd1a8bSRiver Riddle memory[memIndex] = op; 1406abfd1a8bSRiver Riddle } 1407154cabe7SRiver Riddle 1408154cabe7SRiver Riddle void ByteCodeExecutor::executeGetOperand(unsigned index) { 1409abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1410abfd1a8bSRiver Riddle unsigned memIndex = read(); 1411abfd1a8bSRiver Riddle Value operand = 1412abfd1a8bSRiver Riddle index < op->getNumOperands() ? op->getOperand(index) : Value(); 1413abfd1a8bSRiver Riddle 1414abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1415abfd1a8bSRiver Riddle << " * Index: " << index << "\n" 1416154cabe7SRiver Riddle << " * Result: " << operand << "\n"); 1417abfd1a8bSRiver Riddle memory[memIndex] = operand.getAsOpaquePointer(); 1418abfd1a8bSRiver Riddle } 1419154cabe7SRiver Riddle 142085ab413bSRiver Riddle /// This function is the internal implementation of `GetResults` and 142185ab413bSRiver Riddle /// `GetOperands` that provides support for extracting a value range from the 142285ab413bSRiver Riddle /// given operation. 142385ab413bSRiver Riddle template <template <typename> class AttrSizedSegmentsT, typename RangeT> 142485ab413bSRiver Riddle static void * 142585ab413bSRiver Riddle executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 142685ab413bSRiver Riddle ByteCodeField rangeIndex, StringRef attrSizedSegments, 142785ab413bSRiver Riddle MutableArrayRef<ValueRange> &valueRangeMemory) { 142885ab413bSRiver Riddle // Check for the sentinel index that signals that all values should be 142985ab413bSRiver Riddle // returned. 143085ab413bSRiver Riddle if (index == std::numeric_limits<uint32_t>::max()) { 143185ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 143285ab413bSRiver Riddle // `values` is already the full value range. 143385ab413bSRiver Riddle 143485ab413bSRiver Riddle // Otherwise, check to see if this operation uses AttrSizedSegments. 143585ab413bSRiver Riddle } else if (op->hasTrait<AttrSizedSegmentsT>()) { 143685ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() 143785ab413bSRiver Riddle << " * Extracting values from `" << attrSizedSegments << "`\n"); 143885ab413bSRiver Riddle 143985ab413bSRiver Riddle auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 144085ab413bSRiver Riddle if (!segmentAttr || segmentAttr.getNumElements() <= index) 144185ab413bSRiver Riddle return nullptr; 144285ab413bSRiver Riddle 144385ab413bSRiver Riddle auto segments = segmentAttr.getValues<int32_t>(); 144485ab413bSRiver Riddle unsigned startIndex = 144585ab413bSRiver Riddle std::accumulate(segments.begin(), segments.begin() + index, 0); 144685ab413bSRiver Riddle values = values.slice(startIndex, *std::next(segments.begin(), index)); 144785ab413bSRiver Riddle 144885ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 144985ab413bSRiver Riddle << *std::next(segments.begin(), index) << "]\n"); 145085ab413bSRiver Riddle 145185ab413bSRiver Riddle // Otherwise, assume this is the last operand group of the operation. 145285ab413bSRiver Riddle // FIXME: We currently don't support operations with 145385ab413bSRiver Riddle // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 145485ab413bSRiver Riddle // have a way to detect it's presence. 145585ab413bSRiver Riddle } else if (values.size() >= index) { 145685ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() 145785ab413bSRiver Riddle << " * Treating values as trailing variadic range\n"); 145885ab413bSRiver Riddle values = values.drop_front(index); 145985ab413bSRiver Riddle 146085ab413bSRiver Riddle // If we couldn't detect a way to compute the values, bail out. 146185ab413bSRiver Riddle } else { 146285ab413bSRiver Riddle return nullptr; 146385ab413bSRiver Riddle } 146485ab413bSRiver Riddle 146585ab413bSRiver Riddle // If the range index is valid, we are returning a range. 146685ab413bSRiver Riddle if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 146785ab413bSRiver Riddle valueRangeMemory[rangeIndex] = values; 146885ab413bSRiver Riddle return &valueRangeMemory[rangeIndex]; 146985ab413bSRiver Riddle } 147085ab413bSRiver Riddle 147185ab413bSRiver Riddle // If a range index wasn't provided, the range is required to be non-variadic. 147285ab413bSRiver Riddle return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 147385ab413bSRiver Riddle } 147485ab413bSRiver Riddle 147585ab413bSRiver Riddle void ByteCodeExecutor::executeGetOperands() { 147685ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 147785ab413bSRiver Riddle unsigned index = read<uint32_t>(); 147885ab413bSRiver Riddle Operation *op = read<Operation *>(); 147985ab413bSRiver Riddle ByteCodeField rangeIndex = read(); 148085ab413bSRiver Riddle 148185ab413bSRiver Riddle void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 148285ab413bSRiver Riddle op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 148385ab413bSRiver Riddle valueRangeMemory); 148485ab413bSRiver Riddle if (!result) 148585ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 148685ab413bSRiver Riddle memory[read()] = result; 148785ab413bSRiver Riddle } 148885ab413bSRiver Riddle 1489154cabe7SRiver Riddle void ByteCodeExecutor::executeGetResult(unsigned index) { 1490abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1491abfd1a8bSRiver Riddle unsigned memIndex = read(); 1492abfd1a8bSRiver Riddle OpResult result = 1493abfd1a8bSRiver Riddle index < op->getNumResults() ? op->getResult(index) : OpResult(); 1494abfd1a8bSRiver Riddle 1495abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1496abfd1a8bSRiver Riddle << " * Index: " << index << "\n" 1497154cabe7SRiver Riddle << " * Result: " << result << "\n"); 1498abfd1a8bSRiver Riddle memory[memIndex] = result.getAsOpaquePointer(); 1499abfd1a8bSRiver Riddle } 1500154cabe7SRiver Riddle 150185ab413bSRiver Riddle void ByteCodeExecutor::executeGetResults() { 150285ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 150385ab413bSRiver Riddle unsigned index = read<uint32_t>(); 150485ab413bSRiver Riddle Operation *op = read<Operation *>(); 150585ab413bSRiver Riddle ByteCodeField rangeIndex = read(); 150685ab413bSRiver Riddle 150785ab413bSRiver Riddle void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 150885ab413bSRiver Riddle op->getResults(), op, index, rangeIndex, "result_segment_sizes", 150985ab413bSRiver Riddle valueRangeMemory); 151085ab413bSRiver Riddle if (!result) 151185ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 151285ab413bSRiver Riddle memory[read()] = result; 151385ab413bSRiver Riddle } 151485ab413bSRiver Riddle 1515154cabe7SRiver Riddle void ByteCodeExecutor::executeGetValueType() { 1516abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1517abfd1a8bSRiver Riddle unsigned memIndex = read(); 1518abfd1a8bSRiver Riddle Value value = read<Value>(); 1519154cabe7SRiver Riddle Type type = value ? value.getType() : Type(); 1520abfd1a8bSRiver Riddle 1521abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1522154cabe7SRiver Riddle << " * Result: " << type << "\n"); 1523154cabe7SRiver Riddle memory[memIndex] = type.getAsOpaquePointer(); 1524abfd1a8bSRiver Riddle } 1525154cabe7SRiver Riddle 152685ab413bSRiver Riddle void ByteCodeExecutor::executeGetValueRangeTypes() { 152785ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 152885ab413bSRiver Riddle unsigned memIndex = read(); 152985ab413bSRiver Riddle unsigned rangeIndex = read(); 153085ab413bSRiver Riddle ValueRange *values = read<ValueRange *>(); 153185ab413bSRiver Riddle if (!values) { 153285ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 153385ab413bSRiver Riddle memory[memIndex] = nullptr; 153485ab413bSRiver Riddle return; 153585ab413bSRiver Riddle } 153685ab413bSRiver Riddle 153785ab413bSRiver Riddle LLVM_DEBUG({ 153885ab413bSRiver Riddle llvm::dbgs() << " * Values (" << values->size() << "): "; 153985ab413bSRiver Riddle llvm::interleaveComma(*values, llvm::dbgs()); 154085ab413bSRiver Riddle llvm::dbgs() << "\n * Result: "; 154185ab413bSRiver Riddle llvm::interleaveComma(values->getType(), llvm::dbgs()); 154285ab413bSRiver Riddle llvm::dbgs() << "\n"; 154385ab413bSRiver Riddle }); 154485ab413bSRiver Riddle typeRangeMemory[rangeIndex] = values->getType(); 154585ab413bSRiver Riddle memory[memIndex] = &typeRangeMemory[rangeIndex]; 154685ab413bSRiver Riddle } 154785ab413bSRiver Riddle 1548154cabe7SRiver Riddle void ByteCodeExecutor::executeIsNotNull() { 1549abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1550abfd1a8bSRiver Riddle const void *value = read<const void *>(); 1551abfd1a8bSRiver Riddle 1552154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1553abfd1a8bSRiver Riddle selectJump(value != nullptr); 1554abfd1a8bSRiver Riddle } 1555154cabe7SRiver Riddle 1556154cabe7SRiver Riddle void ByteCodeExecutor::executeRecordMatch( 1557154cabe7SRiver Riddle PatternRewriter &rewriter, 1558154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1559abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1560abfd1a8bSRiver Riddle unsigned patternIndex = read(); 1561abfd1a8bSRiver Riddle PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1562abfd1a8bSRiver Riddle const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1563abfd1a8bSRiver Riddle 1564abfd1a8bSRiver Riddle // If the benefit of the pattern is impossible, skip the processing of the 1565abfd1a8bSRiver Riddle // rest of the pattern. 1566abfd1a8bSRiver Riddle if (benefit.isImpossibleToMatch()) { 1567154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1568abfd1a8bSRiver Riddle curCodeIt = dest; 1569154cabe7SRiver Riddle return; 1570abfd1a8bSRiver Riddle } 1571abfd1a8bSRiver Riddle 1572abfd1a8bSRiver Riddle // Create a fused location containing the locations of each of the 1573abfd1a8bSRiver Riddle // operations used in the match. This will be used as the location for 1574abfd1a8bSRiver Riddle // created operations during the rewrite that don't already have an 1575abfd1a8bSRiver Riddle // explicit location set. 1576abfd1a8bSRiver Riddle unsigned numMatchLocs = read(); 1577abfd1a8bSRiver Riddle SmallVector<Location, 4> matchLocs; 1578abfd1a8bSRiver Riddle matchLocs.reserve(numMatchLocs); 1579abfd1a8bSRiver Riddle for (unsigned i = 0; i != numMatchLocs; ++i) 1580abfd1a8bSRiver Riddle matchLocs.push_back(read<Operation *>()->getLoc()); 1581abfd1a8bSRiver Riddle Location matchLoc = rewriter.getFusedLoc(matchLocs); 1582abfd1a8bSRiver Riddle 1583abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1584154cabe7SRiver Riddle << " * Location: " << matchLoc << "\n"); 1585154cabe7SRiver Riddle matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 158685ab413bSRiver Riddle PDLByteCode::MatchResult &match = matches.back(); 158785ab413bSRiver Riddle 158885ab413bSRiver Riddle // Record all of the inputs to the match. If any of the inputs are ranges, we 158985ab413bSRiver Riddle // will also need to remap the range pointer to memory stored in the match 159085ab413bSRiver Riddle // state. 159185ab413bSRiver Riddle unsigned numInputs = read(); 159285ab413bSRiver Riddle match.values.reserve(numInputs); 159385ab413bSRiver Riddle match.typeRangeValues.reserve(numInputs); 159485ab413bSRiver Riddle match.valueRangeValues.reserve(numInputs); 159585ab413bSRiver Riddle for (unsigned i = 0; i < numInputs; ++i) { 159685ab413bSRiver Riddle switch (read<PDLValue::Kind>()) { 159785ab413bSRiver Riddle case PDLValue::Kind::TypeRange: 159885ab413bSRiver Riddle match.typeRangeValues.push_back(*read<TypeRange *>()); 159985ab413bSRiver Riddle match.values.push_back(&match.typeRangeValues.back()); 160085ab413bSRiver Riddle break; 160185ab413bSRiver Riddle case PDLValue::Kind::ValueRange: 160285ab413bSRiver Riddle match.valueRangeValues.push_back(*read<ValueRange *>()); 160385ab413bSRiver Riddle match.values.push_back(&match.valueRangeValues.back()); 160485ab413bSRiver Riddle break; 160585ab413bSRiver Riddle default: 160685ab413bSRiver Riddle match.values.push_back(read<const void *>()); 160785ab413bSRiver Riddle break; 160885ab413bSRiver Riddle } 160985ab413bSRiver Riddle } 1610abfd1a8bSRiver Riddle curCodeIt = dest; 1611abfd1a8bSRiver Riddle } 1612154cabe7SRiver Riddle 1613154cabe7SRiver Riddle void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1614abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1615abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1616abfd1a8bSRiver Riddle SmallVector<Value, 16> args; 161785ab413bSRiver Riddle readValueList(args); 1618abfd1a8bSRiver Riddle 1619abfd1a8bSRiver Riddle LLVM_DEBUG({ 1620abfd1a8bSRiver Riddle llvm::dbgs() << " * Operation: " << *op << "\n" 1621abfd1a8bSRiver Riddle << " * Values: "; 1622abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 1623154cabe7SRiver Riddle llvm::dbgs() << "\n"; 1624abfd1a8bSRiver Riddle }); 1625abfd1a8bSRiver Riddle rewriter.replaceOp(op, args); 1626abfd1a8bSRiver Riddle } 1627154cabe7SRiver Riddle 1628154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchAttribute() { 1629abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1630abfd1a8bSRiver Riddle Attribute value = read<Attribute>(); 1631abfd1a8bSRiver Riddle ArrayAttr cases = read<ArrayAttr>(); 1632abfd1a8bSRiver Riddle handleSwitch(value, cases); 1633abfd1a8bSRiver Riddle } 1634154cabe7SRiver Riddle 1635154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperandCount() { 1636abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1637abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1638abfd1a8bSRiver Riddle auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1639abfd1a8bSRiver Riddle 1640abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1641abfd1a8bSRiver Riddle handleSwitch(op->getNumOperands(), cases); 1642abfd1a8bSRiver Riddle } 1643154cabe7SRiver Riddle 1644154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperationName() { 1645abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1646abfd1a8bSRiver Riddle OperationName value = read<Operation *>()->getName(); 1647abfd1a8bSRiver Riddle size_t caseCount = read(); 1648abfd1a8bSRiver Riddle 1649abfd1a8bSRiver Riddle // The operation names are stored in-line, so to print them out for 1650abfd1a8bSRiver Riddle // debugging purposes we need to read the array before executing the 1651abfd1a8bSRiver Riddle // switch so that we can display all of the possible values. 1652abfd1a8bSRiver Riddle LLVM_DEBUG({ 1653abfd1a8bSRiver Riddle const ByteCodeField *prevCodeIt = curCodeIt; 1654abfd1a8bSRiver Riddle llvm::dbgs() << " * Value: " << value << "\n" 1655abfd1a8bSRiver Riddle << " * Cases: "; 1656abfd1a8bSRiver Riddle llvm::interleaveComma( 1657abfd1a8bSRiver Riddle llvm::map_range(llvm::seq<size_t>(0, caseCount), 1658154cabe7SRiver Riddle [&](size_t) { return read<OperationName>(); }), 1659abfd1a8bSRiver Riddle llvm::dbgs()); 1660154cabe7SRiver Riddle llvm::dbgs() << "\n"; 1661abfd1a8bSRiver Riddle curCodeIt = prevCodeIt; 1662abfd1a8bSRiver Riddle }); 1663abfd1a8bSRiver Riddle 1664abfd1a8bSRiver Riddle // Try to find the switch value within any of the cases. 1665abfd1a8bSRiver Riddle for (size_t i = 0; i != caseCount; ++i) { 1666abfd1a8bSRiver Riddle if (read<OperationName>() == value) { 1667abfd1a8bSRiver Riddle curCodeIt += (caseCount - i - 1); 1668154cabe7SRiver Riddle return selectJump(i + 1); 1669abfd1a8bSRiver Riddle } 1670abfd1a8bSRiver Riddle } 1671154cabe7SRiver Riddle selectJump(size_t(0)); 1672abfd1a8bSRiver Riddle } 1673154cabe7SRiver Riddle 1674154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchResultCount() { 1675abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1676abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1677abfd1a8bSRiver Riddle auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1678abfd1a8bSRiver Riddle 1679abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1680abfd1a8bSRiver Riddle handleSwitch(op->getNumResults(), cases); 1681abfd1a8bSRiver Riddle } 1682154cabe7SRiver Riddle 1683154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchType() { 1684abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 1685abfd1a8bSRiver Riddle Type value = read<Type>(); 1686abfd1a8bSRiver Riddle auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 1687abfd1a8bSRiver Riddle handleSwitch(value, cases); 1688154cabe7SRiver Riddle } 1689154cabe7SRiver Riddle 169085ab413bSRiver Riddle void ByteCodeExecutor::executeSwitchTypes() { 169185ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 169285ab413bSRiver Riddle TypeRange *value = read<TypeRange *>(); 169385ab413bSRiver Riddle auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 169485ab413bSRiver Riddle if (!value) { 169585ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 169685ab413bSRiver Riddle return selectJump(size_t(0)); 169785ab413bSRiver Riddle } 169885ab413bSRiver Riddle handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 169985ab413bSRiver Riddle return value == caseValue.getAsValueRange<TypeAttr>(); 170085ab413bSRiver Riddle }); 170185ab413bSRiver Riddle } 170285ab413bSRiver Riddle 1703154cabe7SRiver Riddle void ByteCodeExecutor::execute( 1704154cabe7SRiver Riddle PatternRewriter &rewriter, 1705154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> *matches, 1706154cabe7SRiver Riddle Optional<Location> mainRewriteLoc) { 1707154cabe7SRiver Riddle while (true) { 1708154cabe7SRiver Riddle OpCode opCode = static_cast<OpCode>(read()); 1709154cabe7SRiver Riddle switch (opCode) { 1710154cabe7SRiver Riddle case ApplyConstraint: 1711154cabe7SRiver Riddle executeApplyConstraint(rewriter); 1712154cabe7SRiver Riddle break; 1713154cabe7SRiver Riddle case ApplyRewrite: 1714154cabe7SRiver Riddle executeApplyRewrite(rewriter); 1715154cabe7SRiver Riddle break; 1716154cabe7SRiver Riddle case AreEqual: 1717154cabe7SRiver Riddle executeAreEqual(); 1718154cabe7SRiver Riddle break; 171985ab413bSRiver Riddle case AreRangesEqual: 172085ab413bSRiver Riddle executeAreRangesEqual(); 172185ab413bSRiver Riddle break; 1722154cabe7SRiver Riddle case Branch: 1723154cabe7SRiver Riddle executeBranch(); 1724154cabe7SRiver Riddle break; 1725154cabe7SRiver Riddle case CheckOperandCount: 1726154cabe7SRiver Riddle executeCheckOperandCount(); 1727154cabe7SRiver Riddle break; 1728154cabe7SRiver Riddle case CheckOperationName: 1729154cabe7SRiver Riddle executeCheckOperationName(); 1730154cabe7SRiver Riddle break; 1731154cabe7SRiver Riddle case CheckResultCount: 1732154cabe7SRiver Riddle executeCheckResultCount(); 1733154cabe7SRiver Riddle break; 173485ab413bSRiver Riddle case CheckTypes: 173585ab413bSRiver Riddle executeCheckTypes(); 173685ab413bSRiver Riddle break; 1737154cabe7SRiver Riddle case CreateOperation: 1738154cabe7SRiver Riddle executeCreateOperation(rewriter, *mainRewriteLoc); 1739154cabe7SRiver Riddle break; 174085ab413bSRiver Riddle case CreateTypes: 174185ab413bSRiver Riddle executeCreateTypes(); 174285ab413bSRiver Riddle break; 1743154cabe7SRiver Riddle case EraseOp: 1744154cabe7SRiver Riddle executeEraseOp(rewriter); 1745154cabe7SRiver Riddle break; 1746154cabe7SRiver Riddle case Finalize: 1747154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); 1748154cabe7SRiver Riddle return; 1749154cabe7SRiver Riddle case GetAttribute: 1750154cabe7SRiver Riddle executeGetAttribute(); 1751154cabe7SRiver Riddle break; 1752154cabe7SRiver Riddle case GetAttributeType: 1753154cabe7SRiver Riddle executeGetAttributeType(); 1754154cabe7SRiver Riddle break; 1755154cabe7SRiver Riddle case GetDefiningOp: 1756154cabe7SRiver Riddle executeGetDefiningOp(); 1757154cabe7SRiver Riddle break; 1758154cabe7SRiver Riddle case GetOperand0: 1759154cabe7SRiver Riddle case GetOperand1: 1760154cabe7SRiver Riddle case GetOperand2: 1761154cabe7SRiver Riddle case GetOperand3: { 1762154cabe7SRiver Riddle unsigned index = opCode - GetOperand0; 1763154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 17641fff7c89SFrederik Gossen executeGetOperand(index); 1765abfd1a8bSRiver Riddle break; 1766abfd1a8bSRiver Riddle } 1767154cabe7SRiver Riddle case GetOperandN: 1768154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 1769154cabe7SRiver Riddle executeGetOperand(read<uint32_t>()); 1770154cabe7SRiver Riddle break; 177185ab413bSRiver Riddle case GetOperands: 177285ab413bSRiver Riddle executeGetOperands(); 177385ab413bSRiver Riddle break; 1774154cabe7SRiver Riddle case GetResult0: 1775154cabe7SRiver Riddle case GetResult1: 1776154cabe7SRiver Riddle case GetResult2: 1777154cabe7SRiver Riddle case GetResult3: { 1778154cabe7SRiver Riddle unsigned index = opCode - GetResult0; 1779154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 17801fff7c89SFrederik Gossen executeGetResult(index); 1781154cabe7SRiver Riddle break; 1782abfd1a8bSRiver Riddle } 1783154cabe7SRiver Riddle case GetResultN: 1784154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 1785154cabe7SRiver Riddle executeGetResult(read<uint32_t>()); 1786154cabe7SRiver Riddle break; 178785ab413bSRiver Riddle case GetResults: 178885ab413bSRiver Riddle executeGetResults(); 178985ab413bSRiver Riddle break; 1790154cabe7SRiver Riddle case GetValueType: 1791154cabe7SRiver Riddle executeGetValueType(); 1792154cabe7SRiver Riddle break; 179385ab413bSRiver Riddle case GetValueRangeTypes: 179485ab413bSRiver Riddle executeGetValueRangeTypes(); 179585ab413bSRiver Riddle break; 1796154cabe7SRiver Riddle case IsNotNull: 1797154cabe7SRiver Riddle executeIsNotNull(); 1798154cabe7SRiver Riddle break; 1799154cabe7SRiver Riddle case RecordMatch: 1800154cabe7SRiver Riddle assert(matches && 1801154cabe7SRiver Riddle "expected matches to be provided when executing the matcher"); 1802154cabe7SRiver Riddle executeRecordMatch(rewriter, *matches); 1803154cabe7SRiver Riddle break; 1804154cabe7SRiver Riddle case ReplaceOp: 1805154cabe7SRiver Riddle executeReplaceOp(rewriter); 1806154cabe7SRiver Riddle break; 1807154cabe7SRiver Riddle case SwitchAttribute: 1808154cabe7SRiver Riddle executeSwitchAttribute(); 1809154cabe7SRiver Riddle break; 1810154cabe7SRiver Riddle case SwitchOperandCount: 1811154cabe7SRiver Riddle executeSwitchOperandCount(); 1812154cabe7SRiver Riddle break; 1813154cabe7SRiver Riddle case SwitchOperationName: 1814154cabe7SRiver Riddle executeSwitchOperationName(); 1815154cabe7SRiver Riddle break; 1816154cabe7SRiver Riddle case SwitchResultCount: 1817154cabe7SRiver Riddle executeSwitchResultCount(); 1818154cabe7SRiver Riddle break; 1819154cabe7SRiver Riddle case SwitchType: 1820154cabe7SRiver Riddle executeSwitchType(); 1821154cabe7SRiver Riddle break; 182285ab413bSRiver Riddle case SwitchTypes: 182385ab413bSRiver Riddle executeSwitchTypes(); 182485ab413bSRiver Riddle break; 1825154cabe7SRiver Riddle } 1826154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "\n"); 1827abfd1a8bSRiver Riddle } 1828abfd1a8bSRiver Riddle } 1829abfd1a8bSRiver Riddle 1830abfd1a8bSRiver Riddle /// Run the pattern matcher on the given root operation, collecting the matched 1831abfd1a8bSRiver Riddle /// patterns in `matches`. 1832abfd1a8bSRiver Riddle void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 1833abfd1a8bSRiver Riddle SmallVectorImpl<MatchResult> &matches, 1834abfd1a8bSRiver Riddle PDLByteCodeMutableState &state) const { 1835abfd1a8bSRiver Riddle // The first memory slot is always the root operation. 1836abfd1a8bSRiver Riddle state.memory[0] = op; 1837abfd1a8bSRiver Riddle 1838abfd1a8bSRiver Riddle // The matcher function always starts at code address 0. 183985ab413bSRiver Riddle ByteCodeExecutor executor( 184085ab413bSRiver Riddle matcherByteCode.data(), state.memory, state.typeRangeMemory, 184185ab413bSRiver Riddle state.allocatedTypeRangeMemory, state.valueRangeMemory, 184285ab413bSRiver Riddle state.allocatedValueRangeMemory, uniquedData, matcherByteCode, 184385ab413bSRiver Riddle state.currentPatternBenefits, patterns, constraintFunctions, 184485ab413bSRiver Riddle rewriteFunctions); 1845abfd1a8bSRiver Riddle executor.execute(rewriter, &matches); 1846abfd1a8bSRiver Riddle 1847abfd1a8bSRiver Riddle // Order the found matches by benefit. 1848abfd1a8bSRiver Riddle std::stable_sort(matches.begin(), matches.end(), 1849abfd1a8bSRiver Riddle [](const MatchResult &lhs, const MatchResult &rhs) { 1850abfd1a8bSRiver Riddle return lhs.benefit > rhs.benefit; 1851abfd1a8bSRiver Riddle }); 1852abfd1a8bSRiver Riddle } 1853abfd1a8bSRiver Riddle 1854abfd1a8bSRiver Riddle /// Run the rewriter of the given pattern on the root operation `op`. 1855abfd1a8bSRiver Riddle void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 1856abfd1a8bSRiver Riddle PDLByteCodeMutableState &state) const { 1857abfd1a8bSRiver Riddle // The arguments of the rewrite function are stored at the start of the 1858abfd1a8bSRiver Riddle // memory buffer. 1859abfd1a8bSRiver Riddle llvm::copy(match.values, state.memory.begin()); 1860abfd1a8bSRiver Riddle 186185ab413bSRiver Riddle ByteCodeExecutor executor( 186285ab413bSRiver Riddle &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 186385ab413bSRiver Riddle state.typeRangeMemory, state.allocatedTypeRangeMemory, 186485ab413bSRiver Riddle state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData, 186585ab413bSRiver Riddle rewriterByteCode, state.currentPatternBenefits, patterns, 186602c4c0d5SRiver Riddle constraintFunctions, rewriteFunctions); 1867abfd1a8bSRiver Riddle executor.execute(rewriter, /*matches=*/nullptr, match.location); 1868abfd1a8bSRiver Riddle } 1869