1abfd1a8bSRiver Riddle //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2abfd1a8bSRiver Riddle // 3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6abfd1a8bSRiver Riddle // 7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 8abfd1a8bSRiver Riddle // 9abfd1a8bSRiver Riddle // This file implements MLIR to byte-code generation and the interpreter. 10abfd1a8bSRiver Riddle // 11abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 12abfd1a8bSRiver Riddle 13abfd1a8bSRiver Riddle #include "ByteCode.h" 14abfd1a8bSRiver Riddle #include "mlir/Analysis/Liveness.h" 15abfd1a8bSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16abfd1a8bSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17e66c2e25SRiver Riddle #include "mlir/IR/BuiltinOps.h" 18abfd1a8bSRiver Riddle #include "mlir/IR/RegionGraphTraits.h" 19abfd1a8bSRiver Riddle #include "llvm/ADT/IntervalMap.h" 20abfd1a8bSRiver Riddle #include "llvm/ADT/PostOrderIterator.h" 21abfd1a8bSRiver Riddle #include "llvm/ADT/TypeSwitch.h" 22abfd1a8bSRiver Riddle #include "llvm/Support/Debug.h" 23*85ab413bSRiver Riddle #include "llvm/Support/Format.h" 24*85ab413bSRiver Riddle #include "llvm/Support/FormatVariadic.h" 25*85ab413bSRiver Riddle #include <numeric> 26abfd1a8bSRiver Riddle 27abfd1a8bSRiver Riddle #define DEBUG_TYPE "pdl-bytecode" 28abfd1a8bSRiver Riddle 29abfd1a8bSRiver Riddle using namespace mlir; 30abfd1a8bSRiver Riddle using namespace mlir::detail; 31abfd1a8bSRiver Riddle 32abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 33abfd1a8bSRiver Riddle // PDLByteCodePattern 34abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 35abfd1a8bSRiver Riddle 36abfd1a8bSRiver Riddle PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 37abfd1a8bSRiver Riddle ByteCodeAddr rewriterAddr) { 38abfd1a8bSRiver Riddle SmallVector<StringRef, 8> generatedOps; 39abfd1a8bSRiver Riddle if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 40abfd1a8bSRiver Riddle generatedOps = 41abfd1a8bSRiver Riddle llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 42abfd1a8bSRiver Riddle 43abfd1a8bSRiver Riddle PatternBenefit benefit = matchOp.benefit(); 44abfd1a8bSRiver Riddle MLIRContext *ctx = matchOp.getContext(); 45abfd1a8bSRiver Riddle 46abfd1a8bSRiver Riddle // Check to see if this is pattern matches a specific operation type. 47abfd1a8bSRiver Riddle if (Optional<StringRef> rootKind = matchOp.rootKind()) 48abfd1a8bSRiver Riddle return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit, 49abfd1a8bSRiver Riddle ctx); 50abfd1a8bSRiver Riddle return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx, 51abfd1a8bSRiver Riddle MatchAnyOpTypeTag()); 52abfd1a8bSRiver Riddle } 53abfd1a8bSRiver Riddle 54abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 55abfd1a8bSRiver Riddle // PDLByteCodeMutableState 56abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 57abfd1a8bSRiver Riddle 58abfd1a8bSRiver Riddle /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 59abfd1a8bSRiver Riddle /// to the position of the pattern within the range returned by 60abfd1a8bSRiver Riddle /// `PDLByteCode::getPatterns`. 61abfd1a8bSRiver Riddle void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 62abfd1a8bSRiver Riddle PatternBenefit benefit) { 63abfd1a8bSRiver Riddle currentPatternBenefits[patternIndex] = benefit; 64abfd1a8bSRiver Riddle } 65abfd1a8bSRiver Riddle 66*85ab413bSRiver Riddle /// Cleanup any allocated state after a full match/rewrite has been completed. 67*85ab413bSRiver Riddle /// This method should be called irregardless of whether the match+rewrite was a 68*85ab413bSRiver Riddle /// success or not. 69*85ab413bSRiver Riddle void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { 70*85ab413bSRiver Riddle allocatedTypeRangeMemory.clear(); 71*85ab413bSRiver Riddle allocatedValueRangeMemory.clear(); 72*85ab413bSRiver Riddle } 73*85ab413bSRiver Riddle 74abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 75abfd1a8bSRiver Riddle // Bytecode OpCodes 76abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 77abfd1a8bSRiver Riddle 78abfd1a8bSRiver Riddle namespace { 79abfd1a8bSRiver Riddle enum OpCode : ByteCodeField { 80abfd1a8bSRiver Riddle /// Apply an externally registered constraint. 81abfd1a8bSRiver Riddle ApplyConstraint, 82abfd1a8bSRiver Riddle /// Apply an externally registered rewrite. 83abfd1a8bSRiver Riddle ApplyRewrite, 84abfd1a8bSRiver Riddle /// Check if two generic values are equal. 85abfd1a8bSRiver Riddle AreEqual, 86*85ab413bSRiver Riddle /// Check if two ranges are equal. 87*85ab413bSRiver Riddle AreRangesEqual, 88abfd1a8bSRiver Riddle /// Unconditional branch. 89abfd1a8bSRiver Riddle Branch, 90abfd1a8bSRiver Riddle /// Compare the operand count of an operation with a constant. 91abfd1a8bSRiver Riddle CheckOperandCount, 92abfd1a8bSRiver Riddle /// Compare the name of an operation with a constant. 93abfd1a8bSRiver Riddle CheckOperationName, 94abfd1a8bSRiver Riddle /// Compare the result count of an operation with a constant. 95abfd1a8bSRiver Riddle CheckResultCount, 96*85ab413bSRiver Riddle /// Compare a range of types to a constant range of types. 97*85ab413bSRiver Riddle CheckTypes, 98abfd1a8bSRiver Riddle /// Create an operation. 99abfd1a8bSRiver Riddle CreateOperation, 100*85ab413bSRiver Riddle /// Create a range of types. 101*85ab413bSRiver Riddle CreateTypes, 102abfd1a8bSRiver Riddle /// Erase an operation. 103abfd1a8bSRiver Riddle EraseOp, 104abfd1a8bSRiver Riddle /// Terminate a matcher or rewrite sequence. 105abfd1a8bSRiver Riddle Finalize, 106abfd1a8bSRiver Riddle /// Get a specific attribute of an operation. 107abfd1a8bSRiver Riddle GetAttribute, 108abfd1a8bSRiver Riddle /// Get the type of an attribute. 109abfd1a8bSRiver Riddle GetAttributeType, 110abfd1a8bSRiver Riddle /// Get the defining operation of a value. 111abfd1a8bSRiver Riddle GetDefiningOp, 112abfd1a8bSRiver Riddle /// Get a specific operand of an operation. 113abfd1a8bSRiver Riddle GetOperand0, 114abfd1a8bSRiver Riddle GetOperand1, 115abfd1a8bSRiver Riddle GetOperand2, 116abfd1a8bSRiver Riddle GetOperand3, 117abfd1a8bSRiver Riddle GetOperandN, 118*85ab413bSRiver Riddle /// Get a specific operand group of an operation. 119*85ab413bSRiver Riddle GetOperands, 120abfd1a8bSRiver Riddle /// Get a specific result of an operation. 121abfd1a8bSRiver Riddle GetResult0, 122abfd1a8bSRiver Riddle GetResult1, 123abfd1a8bSRiver Riddle GetResult2, 124abfd1a8bSRiver Riddle GetResult3, 125abfd1a8bSRiver Riddle GetResultN, 126*85ab413bSRiver Riddle /// Get a specific result group of an operation. 127*85ab413bSRiver Riddle GetResults, 128abfd1a8bSRiver Riddle /// Get the type of a value. 129abfd1a8bSRiver Riddle GetValueType, 130*85ab413bSRiver Riddle /// Get the types of a value range. 131*85ab413bSRiver Riddle GetValueRangeTypes, 132abfd1a8bSRiver Riddle /// Check if a generic value is not null. 133abfd1a8bSRiver Riddle IsNotNull, 134abfd1a8bSRiver Riddle /// Record a successful pattern match. 135abfd1a8bSRiver Riddle RecordMatch, 136abfd1a8bSRiver Riddle /// Replace an operation. 137abfd1a8bSRiver Riddle ReplaceOp, 138abfd1a8bSRiver Riddle /// Compare an attribute with a set of constants. 139abfd1a8bSRiver Riddle SwitchAttribute, 140abfd1a8bSRiver Riddle /// Compare the operand count of an operation with a set of constants. 141abfd1a8bSRiver Riddle SwitchOperandCount, 142abfd1a8bSRiver Riddle /// Compare the name of an operation with a set of constants. 143abfd1a8bSRiver Riddle SwitchOperationName, 144abfd1a8bSRiver Riddle /// Compare the result count of an operation with a set of constants. 145abfd1a8bSRiver Riddle SwitchResultCount, 146abfd1a8bSRiver Riddle /// Compare a type with a set of constants. 147abfd1a8bSRiver Riddle SwitchType, 148*85ab413bSRiver Riddle /// Compare a range of types with a set of constants. 149*85ab413bSRiver Riddle SwitchTypes, 150abfd1a8bSRiver Riddle }; 151abfd1a8bSRiver Riddle } // end anonymous namespace 152abfd1a8bSRiver Riddle 153abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 154abfd1a8bSRiver Riddle // ByteCode Generation 155abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 156abfd1a8bSRiver Riddle 157abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 158abfd1a8bSRiver Riddle // Generator 159abfd1a8bSRiver Riddle 160abfd1a8bSRiver Riddle namespace { 161abfd1a8bSRiver Riddle struct ByteCodeWriter; 162abfd1a8bSRiver Riddle 163abfd1a8bSRiver Riddle /// This class represents the main generator for the pattern bytecode. 164abfd1a8bSRiver Riddle class Generator { 165abfd1a8bSRiver Riddle public: 166abfd1a8bSRiver Riddle Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 167abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &matcherByteCode, 168abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &rewriterByteCode, 169abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCodePattern> &patterns, 170abfd1a8bSRiver Riddle ByteCodeField &maxValueMemoryIndex, 171*85ab413bSRiver Riddle ByteCodeField &maxTypeRangeMemoryIndex, 172*85ab413bSRiver Riddle ByteCodeField &maxValueRangeMemoryIndex, 173abfd1a8bSRiver Riddle llvm::StringMap<PDLConstraintFunction> &constraintFns, 174abfd1a8bSRiver Riddle llvm::StringMap<PDLRewriteFunction> &rewriteFns) 175abfd1a8bSRiver Riddle : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 176abfd1a8bSRiver Riddle rewriterByteCode(rewriterByteCode), patterns(patterns), 177*85ab413bSRiver Riddle maxValueMemoryIndex(maxValueMemoryIndex), 178*85ab413bSRiver Riddle maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), 179*85ab413bSRiver Riddle maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) { 180abfd1a8bSRiver Riddle for (auto it : llvm::enumerate(constraintFns)) 181abfd1a8bSRiver Riddle constraintToMemIndex.try_emplace(it.value().first(), it.index()); 182abfd1a8bSRiver Riddle for (auto it : llvm::enumerate(rewriteFns)) 183abfd1a8bSRiver Riddle externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 184abfd1a8bSRiver Riddle } 185abfd1a8bSRiver Riddle 186abfd1a8bSRiver Riddle /// Generate the bytecode for the given PDL interpreter module. 187abfd1a8bSRiver Riddle void generate(ModuleOp module); 188abfd1a8bSRiver Riddle 189abfd1a8bSRiver Riddle /// Return the memory index to use for the given value. 190abfd1a8bSRiver Riddle ByteCodeField &getMemIndex(Value value) { 191abfd1a8bSRiver Riddle assert(valueToMemIndex.count(value) && 192abfd1a8bSRiver Riddle "expected memory index to be assigned"); 193abfd1a8bSRiver Riddle return valueToMemIndex[value]; 194abfd1a8bSRiver Riddle } 195abfd1a8bSRiver Riddle 196*85ab413bSRiver Riddle /// Return the range memory index used to store the given range value. 197*85ab413bSRiver Riddle ByteCodeField &getRangeStorageIndex(Value value) { 198*85ab413bSRiver Riddle assert(valueToRangeIndex.count(value) && 199*85ab413bSRiver Riddle "expected range index to be assigned"); 200*85ab413bSRiver Riddle return valueToRangeIndex[value]; 201*85ab413bSRiver Riddle } 202*85ab413bSRiver Riddle 203abfd1a8bSRiver Riddle /// Return an index to use when referring to the given data that is uniqued in 204abfd1a8bSRiver Riddle /// the MLIR context. 205abfd1a8bSRiver Riddle template <typename T> 206abfd1a8bSRiver Riddle std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 207abfd1a8bSRiver Riddle getMemIndex(T val) { 208abfd1a8bSRiver Riddle const void *opaqueVal = val.getAsOpaquePointer(); 209abfd1a8bSRiver Riddle 210abfd1a8bSRiver Riddle // Get or insert a reference to this value. 211abfd1a8bSRiver Riddle auto it = uniquedDataToMemIndex.try_emplace( 212abfd1a8bSRiver Riddle opaqueVal, maxValueMemoryIndex + uniquedData.size()); 213abfd1a8bSRiver Riddle if (it.second) 214abfd1a8bSRiver Riddle uniquedData.push_back(opaqueVal); 215abfd1a8bSRiver Riddle return it.first->second; 216abfd1a8bSRiver Riddle } 217abfd1a8bSRiver Riddle 218abfd1a8bSRiver Riddle private: 219abfd1a8bSRiver Riddle /// Allocate memory indices for the results of operations within the matcher 220abfd1a8bSRiver Riddle /// and rewriters. 221abfd1a8bSRiver Riddle void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 222abfd1a8bSRiver Riddle 223abfd1a8bSRiver Riddle /// Generate the bytecode for the given operation. 224abfd1a8bSRiver Riddle void generate(Operation *op, ByteCodeWriter &writer); 225abfd1a8bSRiver Riddle void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 226abfd1a8bSRiver Riddle void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 227abfd1a8bSRiver Riddle void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 228abfd1a8bSRiver Riddle void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 229abfd1a8bSRiver Riddle void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 230abfd1a8bSRiver Riddle void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 231abfd1a8bSRiver Riddle void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 232abfd1a8bSRiver Riddle void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 233abfd1a8bSRiver Riddle void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 234*85ab413bSRiver Riddle void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); 235abfd1a8bSRiver Riddle void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 236abfd1a8bSRiver Riddle void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 237abfd1a8bSRiver Riddle void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 238*85ab413bSRiver Riddle void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); 239abfd1a8bSRiver Riddle void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 240abfd1a8bSRiver Riddle void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 241abfd1a8bSRiver Riddle void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 242abfd1a8bSRiver Riddle void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 243abfd1a8bSRiver Riddle void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 244abfd1a8bSRiver Riddle void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 245*85ab413bSRiver Riddle void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); 246abfd1a8bSRiver Riddle void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 247*85ab413bSRiver Riddle void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); 248abfd1a8bSRiver Riddle void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 2493a833a0eSRiver Riddle void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); 250abfd1a8bSRiver Riddle void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 251abfd1a8bSRiver Riddle void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 252abfd1a8bSRiver Riddle void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 253abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 254abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 255*85ab413bSRiver Riddle void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); 256abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 257abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 258abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 259abfd1a8bSRiver Riddle 260abfd1a8bSRiver Riddle /// Mapping from value to its corresponding memory index. 261abfd1a8bSRiver Riddle DenseMap<Value, ByteCodeField> valueToMemIndex; 262abfd1a8bSRiver Riddle 263*85ab413bSRiver Riddle /// Mapping from a range value to its corresponding range storage index. 264*85ab413bSRiver Riddle DenseMap<Value, ByteCodeField> valueToRangeIndex; 265*85ab413bSRiver Riddle 266abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered rewrite to its index in 267abfd1a8bSRiver Riddle /// the bytecode registry. 268abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 269abfd1a8bSRiver Riddle 270abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered constraint to its index 271abfd1a8bSRiver Riddle /// in the bytecode registry. 272abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> constraintToMemIndex; 273abfd1a8bSRiver Riddle 274abfd1a8bSRiver Riddle /// Mapping from rewriter function name to the bytecode address of the 275abfd1a8bSRiver Riddle /// rewriter function in byte. 276abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeAddr> rewriterToAddr; 277abfd1a8bSRiver Riddle 278abfd1a8bSRiver Riddle /// Mapping from a uniqued storage object to its memory index within 279abfd1a8bSRiver Riddle /// `uniquedData`. 280abfd1a8bSRiver Riddle DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 281abfd1a8bSRiver Riddle 282abfd1a8bSRiver Riddle /// The current MLIR context. 283abfd1a8bSRiver Riddle MLIRContext *ctx; 284abfd1a8bSRiver Riddle 285abfd1a8bSRiver Riddle /// Data of the ByteCode class to be populated. 286abfd1a8bSRiver Riddle std::vector<const void *> &uniquedData; 287abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &matcherByteCode; 288abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &rewriterByteCode; 289abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCodePattern> &patterns; 290abfd1a8bSRiver Riddle ByteCodeField &maxValueMemoryIndex; 291*85ab413bSRiver Riddle ByteCodeField &maxTypeRangeMemoryIndex; 292*85ab413bSRiver Riddle ByteCodeField &maxValueRangeMemoryIndex; 293abfd1a8bSRiver Riddle }; 294abfd1a8bSRiver Riddle 295abfd1a8bSRiver Riddle /// This class provides utilities for writing a bytecode stream. 296abfd1a8bSRiver Riddle struct ByteCodeWriter { 297abfd1a8bSRiver Riddle ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 298abfd1a8bSRiver Riddle : bytecode(bytecode), generator(generator) {} 299abfd1a8bSRiver Riddle 300abfd1a8bSRiver Riddle /// Append a field to the bytecode. 301abfd1a8bSRiver Riddle void append(ByteCodeField field) { bytecode.push_back(field); } 302fa20ab7bSRiver Riddle void append(OpCode opCode) { bytecode.push_back(opCode); } 303abfd1a8bSRiver Riddle 304abfd1a8bSRiver Riddle /// Append an address to the bytecode. 305abfd1a8bSRiver Riddle void append(ByteCodeAddr field) { 306abfd1a8bSRiver Riddle static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 307abfd1a8bSRiver Riddle "unexpected ByteCode address size"); 308abfd1a8bSRiver Riddle 309abfd1a8bSRiver Riddle ByteCodeField fieldParts[2]; 310abfd1a8bSRiver Riddle std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 311abfd1a8bSRiver Riddle bytecode.append({fieldParts[0], fieldParts[1]}); 312abfd1a8bSRiver Riddle } 313abfd1a8bSRiver Riddle 314abfd1a8bSRiver Riddle /// Append a successor range to the bytecode, the exact address will need to 315abfd1a8bSRiver Riddle /// be resolved later. 316abfd1a8bSRiver Riddle void append(SuccessorRange successors) { 317abfd1a8bSRiver Riddle // Add back references to the any successors so that the address can be 318abfd1a8bSRiver Riddle // resolved later. 319abfd1a8bSRiver Riddle for (Block *successor : successors) { 320abfd1a8bSRiver Riddle unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 321abfd1a8bSRiver Riddle append(ByteCodeAddr(0)); 322abfd1a8bSRiver Riddle } 323abfd1a8bSRiver Riddle } 324abfd1a8bSRiver Riddle 325abfd1a8bSRiver Riddle /// Append a range of values that will be read as generic PDLValues. 326abfd1a8bSRiver Riddle void appendPDLValueList(OperandRange values) { 327abfd1a8bSRiver Riddle bytecode.push_back(values.size()); 328*85ab413bSRiver Riddle for (Value value : values) 329*85ab413bSRiver Riddle appendPDLValue(value); 330*85ab413bSRiver Riddle } 331*85ab413bSRiver Riddle 332*85ab413bSRiver Riddle /// Append a value as a PDLValue. 333*85ab413bSRiver Riddle void appendPDLValue(Value value) { 334*85ab413bSRiver Riddle appendPDLValueKind(value); 335abfd1a8bSRiver Riddle append(value); 336abfd1a8bSRiver Riddle } 337*85ab413bSRiver Riddle 338*85ab413bSRiver Riddle /// Append the PDLValue::Kind of the given value. 339*85ab413bSRiver Riddle void appendPDLValueKind(Value value) { 340*85ab413bSRiver Riddle // Append the type of the value in addition to the value itself. 341*85ab413bSRiver Riddle PDLValue::Kind kind = 342*85ab413bSRiver Riddle TypeSwitch<Type, PDLValue::Kind>(value.getType()) 343*85ab413bSRiver Riddle .Case<pdl::AttributeType>( 344*85ab413bSRiver Riddle [](Type) { return PDLValue::Kind::Attribute; }) 345*85ab413bSRiver Riddle .Case<pdl::OperationType>( 346*85ab413bSRiver Riddle [](Type) { return PDLValue::Kind::Operation; }) 347*85ab413bSRiver Riddle .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { 348*85ab413bSRiver Riddle if (rangeTy.getElementType().isa<pdl::TypeType>()) 349*85ab413bSRiver Riddle return PDLValue::Kind::TypeRange; 350*85ab413bSRiver Riddle return PDLValue::Kind::ValueRange; 351*85ab413bSRiver Riddle }) 352*85ab413bSRiver Riddle .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) 353*85ab413bSRiver Riddle .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); 354*85ab413bSRiver Riddle bytecode.push_back(static_cast<ByteCodeField>(kind)); 355abfd1a8bSRiver Riddle } 356abfd1a8bSRiver Riddle 357abfd1a8bSRiver Riddle /// Check if the given class `T` has an iterator type. 358abfd1a8bSRiver Riddle template <typename T, typename... Args> 359abfd1a8bSRiver Riddle using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 360abfd1a8bSRiver Riddle 361abfd1a8bSRiver Riddle /// Append a value that will be stored in a memory slot and not inline within 362abfd1a8bSRiver Riddle /// the bytecode. 363abfd1a8bSRiver Riddle template <typename T> 364abfd1a8bSRiver Riddle std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 365abfd1a8bSRiver Riddle std::is_pointer<T>::value> 366abfd1a8bSRiver Riddle append(T value) { 367abfd1a8bSRiver Riddle bytecode.push_back(generator.getMemIndex(value)); 368abfd1a8bSRiver Riddle } 369abfd1a8bSRiver Riddle 370abfd1a8bSRiver Riddle /// Append a range of values. 371abfd1a8bSRiver Riddle template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 372abfd1a8bSRiver Riddle std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 373abfd1a8bSRiver Riddle append(T range) { 374abfd1a8bSRiver Riddle bytecode.push_back(llvm::size(range)); 375abfd1a8bSRiver Riddle for (auto it : range) 376abfd1a8bSRiver Riddle append(it); 377abfd1a8bSRiver Riddle } 378abfd1a8bSRiver Riddle 379abfd1a8bSRiver Riddle /// Append a variadic number of fields to the bytecode. 380abfd1a8bSRiver Riddle template <typename FieldTy, typename Field2Ty, typename... FieldTys> 381abfd1a8bSRiver Riddle void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 382abfd1a8bSRiver Riddle append(field); 383abfd1a8bSRiver Riddle append(field2, fields...); 384abfd1a8bSRiver Riddle } 385abfd1a8bSRiver Riddle 386abfd1a8bSRiver Riddle /// Successor references in the bytecode that have yet to be resolved. 387abfd1a8bSRiver Riddle DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 388abfd1a8bSRiver Riddle 389abfd1a8bSRiver Riddle /// The underlying bytecode buffer. 390abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &bytecode; 391abfd1a8bSRiver Riddle 392abfd1a8bSRiver Riddle /// The main generator producing PDL. 393abfd1a8bSRiver Riddle Generator &generator; 394abfd1a8bSRiver Riddle }; 395*85ab413bSRiver Riddle 396*85ab413bSRiver Riddle /// This class represents a live range of PDL Interpreter values, containing 397*85ab413bSRiver Riddle /// information about when values are live within a match/rewrite. 398*85ab413bSRiver Riddle struct ByteCodeLiveRange { 399*85ab413bSRiver Riddle using Set = llvm::IntervalMap<ByteCodeField, char, 16>; 400*85ab413bSRiver Riddle using Allocator = Set::Allocator; 401*85ab413bSRiver Riddle 402*85ab413bSRiver Riddle ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {} 403*85ab413bSRiver Riddle 404*85ab413bSRiver Riddle /// Union this live range with the one provided. 405*85ab413bSRiver Riddle void unionWith(const ByteCodeLiveRange &rhs) { 406*85ab413bSRiver Riddle for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it) 407*85ab413bSRiver Riddle liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0); 408*85ab413bSRiver Riddle } 409*85ab413bSRiver Riddle 410*85ab413bSRiver Riddle /// Returns true if this range overlaps with the one provided. 411*85ab413bSRiver Riddle bool overlaps(const ByteCodeLiveRange &rhs) const { 412*85ab413bSRiver Riddle return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid(); 413*85ab413bSRiver Riddle } 414*85ab413bSRiver Riddle 415*85ab413bSRiver Riddle /// A map representing the ranges of the match/rewrite that a value is live in 416*85ab413bSRiver Riddle /// the interpreter. 417*85ab413bSRiver Riddle llvm::IntervalMap<ByteCodeField, char, 16> liveness; 418*85ab413bSRiver Riddle 419*85ab413bSRiver Riddle /// The type range storage index for this range. 420*85ab413bSRiver Riddle Optional<unsigned> typeRangeIndex; 421*85ab413bSRiver Riddle 422*85ab413bSRiver Riddle /// The value range storage index for this range. 423*85ab413bSRiver Riddle Optional<unsigned> valueRangeIndex; 424*85ab413bSRiver Riddle }; 425abfd1a8bSRiver Riddle } // end anonymous namespace 426abfd1a8bSRiver Riddle 427abfd1a8bSRiver Riddle void Generator::generate(ModuleOp module) { 428abfd1a8bSRiver Riddle FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 429abfd1a8bSRiver Riddle pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 430abfd1a8bSRiver Riddle ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 431abfd1a8bSRiver Riddle pdl_interp::PDLInterpDialect::getRewriterModuleName()); 432abfd1a8bSRiver Riddle assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 433abfd1a8bSRiver Riddle 434abfd1a8bSRiver Riddle // Allocate memory indices for the results of operations within the matcher 435abfd1a8bSRiver Riddle // and rewriters. 436abfd1a8bSRiver Riddle allocateMemoryIndices(matcherFunc, rewriterModule); 437abfd1a8bSRiver Riddle 438abfd1a8bSRiver Riddle // Generate code for the rewriter functions. 439abfd1a8bSRiver Riddle ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 440abfd1a8bSRiver Riddle for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 441abfd1a8bSRiver Riddle rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 442abfd1a8bSRiver Riddle for (Operation &op : rewriterFunc.getOps()) 443abfd1a8bSRiver Riddle generate(&op, rewriterByteCodeWriter); 444abfd1a8bSRiver Riddle } 445abfd1a8bSRiver Riddle assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 446abfd1a8bSRiver Riddle "unexpected branches in rewriter function"); 447abfd1a8bSRiver Riddle 448abfd1a8bSRiver Riddle // Generate code for the matcher function. 449abfd1a8bSRiver Riddle DenseMap<Block *, ByteCodeAddr> blockToAddr; 450abfd1a8bSRiver Riddle llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody()); 451abfd1a8bSRiver Riddle ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 452abfd1a8bSRiver Riddle for (Block *block : rpot) { 453abfd1a8bSRiver Riddle // Keep track of where this block begins within the matcher function. 454abfd1a8bSRiver Riddle blockToAddr.try_emplace(block, matcherByteCode.size()); 455abfd1a8bSRiver Riddle for (Operation &op : *block) 456abfd1a8bSRiver Riddle generate(&op, matcherByteCodeWriter); 457abfd1a8bSRiver Riddle } 458abfd1a8bSRiver Riddle 459abfd1a8bSRiver Riddle // Resolve successor references in the matcher. 460abfd1a8bSRiver Riddle for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 461abfd1a8bSRiver Riddle ByteCodeAddr addr = blockToAddr[it.first]; 462abfd1a8bSRiver Riddle for (unsigned offsetToFix : it.second) 463abfd1a8bSRiver Riddle std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 464abfd1a8bSRiver Riddle } 465abfd1a8bSRiver Riddle } 466abfd1a8bSRiver Riddle 467abfd1a8bSRiver Riddle void Generator::allocateMemoryIndices(FuncOp matcherFunc, 468abfd1a8bSRiver Riddle ModuleOp rewriterModule) { 469abfd1a8bSRiver Riddle // Rewriters use simplistic allocation scheme that simply assigns an index to 470abfd1a8bSRiver Riddle // each result. 471abfd1a8bSRiver Riddle for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 472*85ab413bSRiver Riddle ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; 473*85ab413bSRiver Riddle auto processRewriterValue = [&](Value val) { 474*85ab413bSRiver Riddle valueToMemIndex.try_emplace(val, index++); 475*85ab413bSRiver Riddle if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { 476*85ab413bSRiver Riddle Type elementTy = rangeType.getElementType(); 477*85ab413bSRiver Riddle if (elementTy.isa<pdl::TypeType>()) 478*85ab413bSRiver Riddle valueToRangeIndex.try_emplace(val, typeRangeIndex++); 479*85ab413bSRiver Riddle else if (elementTy.isa<pdl::ValueType>()) 480*85ab413bSRiver Riddle valueToRangeIndex.try_emplace(val, valueRangeIndex++); 481*85ab413bSRiver Riddle } 482*85ab413bSRiver Riddle }; 483*85ab413bSRiver Riddle 484abfd1a8bSRiver Riddle for (BlockArgument arg : rewriterFunc.getArguments()) 485*85ab413bSRiver Riddle processRewriterValue(arg); 486abfd1a8bSRiver Riddle rewriterFunc.getBody().walk([&](Operation *op) { 487abfd1a8bSRiver Riddle for (Value result : op->getResults()) 488*85ab413bSRiver Riddle processRewriterValue(result); 489abfd1a8bSRiver Riddle }); 490abfd1a8bSRiver Riddle if (index > maxValueMemoryIndex) 491abfd1a8bSRiver Riddle maxValueMemoryIndex = index; 492*85ab413bSRiver Riddle if (typeRangeIndex > maxTypeRangeMemoryIndex) 493*85ab413bSRiver Riddle maxTypeRangeMemoryIndex = typeRangeIndex; 494*85ab413bSRiver Riddle if (valueRangeIndex > maxValueRangeMemoryIndex) 495*85ab413bSRiver Riddle maxValueRangeMemoryIndex = valueRangeIndex; 496abfd1a8bSRiver Riddle } 497abfd1a8bSRiver Riddle 498abfd1a8bSRiver Riddle // The matcher function uses a more sophisticated numbering that tries to 499abfd1a8bSRiver Riddle // minimize the number of memory indices assigned. This is done by determining 500abfd1a8bSRiver Riddle // a live range of the values within the matcher, then the allocation is just 501abfd1a8bSRiver Riddle // finding the minimal number of overlapping live ranges. This is essentially 502abfd1a8bSRiver Riddle // a simplified form of register allocation where we don't necessarily have a 503abfd1a8bSRiver Riddle // limited number of registers, but we still want to minimize the number used. 504abfd1a8bSRiver Riddle DenseMap<Operation *, ByteCodeField> opToIndex; 505abfd1a8bSRiver Riddle matcherFunc.getBody().walk([&](Operation *op) { 506abfd1a8bSRiver Riddle opToIndex.insert(std::make_pair(op, opToIndex.size())); 507abfd1a8bSRiver Riddle }); 508abfd1a8bSRiver Riddle 509abfd1a8bSRiver Riddle // Liveness info for each of the defs within the matcher. 510*85ab413bSRiver Riddle ByteCodeLiveRange::Allocator allocator; 511*85ab413bSRiver Riddle DenseMap<Value, ByteCodeLiveRange> valueDefRanges; 512abfd1a8bSRiver Riddle 513abfd1a8bSRiver Riddle // Assign the root operation being matched to slot 0. 514abfd1a8bSRiver Riddle BlockArgument rootOpArg = matcherFunc.getArgument(0); 515abfd1a8bSRiver Riddle valueToMemIndex[rootOpArg] = 0; 516abfd1a8bSRiver Riddle 517abfd1a8bSRiver Riddle // Walk each of the blocks, computing the def interval that the value is used. 518abfd1a8bSRiver Riddle Liveness matcherLiveness(matcherFunc); 519abfd1a8bSRiver Riddle for (Block &block : matcherFunc.getBody()) { 520abfd1a8bSRiver Riddle const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); 521abfd1a8bSRiver Riddle assert(info && "expected liveness info for block"); 522abfd1a8bSRiver Riddle auto processValue = [&](Value value, Operation *firstUseOrDef) { 523abfd1a8bSRiver Riddle // We don't need to process the root op argument, this value is always 524abfd1a8bSRiver Riddle // assigned to the first memory slot. 525abfd1a8bSRiver Riddle if (value == rootOpArg) 526abfd1a8bSRiver Riddle return; 527abfd1a8bSRiver Riddle 528abfd1a8bSRiver Riddle // Set indices for the range of this block that the value is used. 529abfd1a8bSRiver Riddle auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 530*85ab413bSRiver Riddle defRangeIt->second.liveness.insert( 531abfd1a8bSRiver Riddle opToIndex[firstUseOrDef], 532abfd1a8bSRiver Riddle opToIndex[info->getEndOperation(value, firstUseOrDef)], 533abfd1a8bSRiver Riddle /*dummyValue*/ 0); 534*85ab413bSRiver Riddle 535*85ab413bSRiver Riddle // Check to see if this value is a range type. 536*85ab413bSRiver Riddle if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { 537*85ab413bSRiver Riddle Type eleType = rangeTy.getElementType(); 538*85ab413bSRiver Riddle if (eleType.isa<pdl::TypeType>()) 539*85ab413bSRiver Riddle defRangeIt->second.typeRangeIndex = 0; 540*85ab413bSRiver Riddle else if (eleType.isa<pdl::ValueType>()) 541*85ab413bSRiver Riddle defRangeIt->second.valueRangeIndex = 0; 542*85ab413bSRiver Riddle } 543abfd1a8bSRiver Riddle }; 544abfd1a8bSRiver Riddle 545abfd1a8bSRiver Riddle // Process the live-ins of this block. 546abfd1a8bSRiver Riddle for (Value liveIn : info->in()) 547abfd1a8bSRiver Riddle processValue(liveIn, &block.front()); 548abfd1a8bSRiver Riddle 549abfd1a8bSRiver Riddle // Process any new defs within this block. 550abfd1a8bSRiver Riddle for (Operation &op : block) 551abfd1a8bSRiver Riddle for (Value result : op.getResults()) 552abfd1a8bSRiver Riddle processValue(result, &op); 553abfd1a8bSRiver Riddle } 554abfd1a8bSRiver Riddle 555abfd1a8bSRiver Riddle // Greedily allocate memory slots using the computed def live ranges. 556*85ab413bSRiver Riddle std::vector<ByteCodeLiveRange> allocatedIndices; 557*85ab413bSRiver Riddle ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0; 558abfd1a8bSRiver Riddle for (auto &defIt : valueDefRanges) { 559abfd1a8bSRiver Riddle ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 560*85ab413bSRiver Riddle ByteCodeLiveRange &defRange = defIt.second; 561abfd1a8bSRiver Riddle 562abfd1a8bSRiver Riddle // Try to allocate to an existing index. 563abfd1a8bSRiver Riddle for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { 564*85ab413bSRiver Riddle ByteCodeLiveRange &existingRange = existingIndexIt.value(); 565*85ab413bSRiver Riddle if (!defRange.overlaps(existingRange)) { 566*85ab413bSRiver Riddle existingRange.unionWith(defRange); 567abfd1a8bSRiver Riddle memIndex = existingIndexIt.index() + 1; 568*85ab413bSRiver Riddle 569*85ab413bSRiver Riddle if (defRange.typeRangeIndex) { 570*85ab413bSRiver Riddle if (!existingRange.typeRangeIndex) 571*85ab413bSRiver Riddle existingRange.typeRangeIndex = numTypeRanges++; 572*85ab413bSRiver Riddle valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; 573*85ab413bSRiver Riddle } else if (defRange.valueRangeIndex) { 574*85ab413bSRiver Riddle if (!existingRange.valueRangeIndex) 575*85ab413bSRiver Riddle existingRange.valueRangeIndex = numValueRanges++; 576*85ab413bSRiver Riddle valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; 577*85ab413bSRiver Riddle } 578*85ab413bSRiver Riddle break; 579*85ab413bSRiver Riddle } 580abfd1a8bSRiver Riddle } 581abfd1a8bSRiver Riddle 582abfd1a8bSRiver Riddle // If no existing index could be used, add a new one. 583abfd1a8bSRiver Riddle if (memIndex == 0) { 584abfd1a8bSRiver Riddle allocatedIndices.emplace_back(allocator); 585*85ab413bSRiver Riddle ByteCodeLiveRange &newRange = allocatedIndices.back(); 586*85ab413bSRiver Riddle newRange.unionWith(defRange); 587*85ab413bSRiver Riddle 588*85ab413bSRiver Riddle // Allocate an index for type/value ranges. 589*85ab413bSRiver Riddle if (defRange.typeRangeIndex) { 590*85ab413bSRiver Riddle newRange.typeRangeIndex = numTypeRanges; 591*85ab413bSRiver Riddle valueToRangeIndex[defIt.first] = numTypeRanges++; 592*85ab413bSRiver Riddle } else if (defRange.valueRangeIndex) { 593*85ab413bSRiver Riddle newRange.valueRangeIndex = numValueRanges; 594*85ab413bSRiver Riddle valueToRangeIndex[defIt.first] = numValueRanges++; 595*85ab413bSRiver Riddle } 596*85ab413bSRiver Riddle 597abfd1a8bSRiver Riddle memIndex = allocatedIndices.size(); 598*85ab413bSRiver Riddle ++numIndices; 599abfd1a8bSRiver Riddle } 600abfd1a8bSRiver Riddle } 601abfd1a8bSRiver Riddle 602abfd1a8bSRiver Riddle // Update the max number of indices. 603*85ab413bSRiver Riddle if (numIndices > maxValueMemoryIndex) 604*85ab413bSRiver Riddle maxValueMemoryIndex = numIndices; 605*85ab413bSRiver Riddle if (numTypeRanges > maxTypeRangeMemoryIndex) 606*85ab413bSRiver Riddle maxTypeRangeMemoryIndex = numTypeRanges; 607*85ab413bSRiver Riddle if (numValueRanges > maxValueRangeMemoryIndex) 608*85ab413bSRiver Riddle maxValueRangeMemoryIndex = numValueRanges; 609abfd1a8bSRiver Riddle } 610abfd1a8bSRiver Riddle 611abfd1a8bSRiver Riddle void Generator::generate(Operation *op, ByteCodeWriter &writer) { 612abfd1a8bSRiver Riddle TypeSwitch<Operation *>(op) 613abfd1a8bSRiver Riddle .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 614abfd1a8bSRiver Riddle pdl_interp::AreEqualOp, pdl_interp::BranchOp, 615abfd1a8bSRiver Riddle pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 616abfd1a8bSRiver Riddle pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 617*85ab413bSRiver Riddle pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, 618*85ab413bSRiver Riddle pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp, 619*85ab413bSRiver Riddle pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, 62002c4c0d5SRiver Riddle pdl_interp::EraseOp, pdl_interp::FinalizeOp, 62102c4c0d5SRiver Riddle pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, 62202c4c0d5SRiver Riddle pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, 623*85ab413bSRiver Riddle pdl_interp::GetOperandsOp, pdl_interp::GetResultOp, 624*85ab413bSRiver Riddle pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp, 6253a833a0eSRiver Riddle pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, 62602c4c0d5SRiver Riddle pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, 62702c4c0d5SRiver Riddle pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, 628*85ab413bSRiver Riddle pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, 629*85ab413bSRiver Riddle pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 630abfd1a8bSRiver Riddle [&](auto interpOp) { this->generate(interpOp, writer); }) 631abfd1a8bSRiver Riddle .Default([](Operation *) { 632abfd1a8bSRiver Riddle llvm_unreachable("unknown `pdl_interp` operation"); 633abfd1a8bSRiver Riddle }); 634abfd1a8bSRiver Riddle } 635abfd1a8bSRiver Riddle 636abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyConstraintOp op, 637abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 638abfd1a8bSRiver Riddle assert(constraintToMemIndex.count(op.name()) && 639abfd1a8bSRiver Riddle "expected index for constraint function"); 640abfd1a8bSRiver Riddle writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 641abfd1a8bSRiver Riddle op.constParamsAttr()); 642abfd1a8bSRiver Riddle writer.appendPDLValueList(op.args()); 643abfd1a8bSRiver Riddle writer.append(op.getSuccessors()); 644abfd1a8bSRiver Riddle } 645abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyRewriteOp op, 646abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 647abfd1a8bSRiver Riddle assert(externalRewriterToMemIndex.count(op.name()) && 648abfd1a8bSRiver Riddle "expected index for rewrite function"); 649abfd1a8bSRiver Riddle writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 65002c4c0d5SRiver Riddle op.constParamsAttr()); 651abfd1a8bSRiver Riddle writer.appendPDLValueList(op.args()); 65202c4c0d5SRiver Riddle 653*85ab413bSRiver Riddle ResultRange results = op.results(); 654*85ab413bSRiver Riddle writer.append(ByteCodeField(results.size())); 655*85ab413bSRiver Riddle for (Value result : results) { 656*85ab413bSRiver Riddle // In debug mode we also record the expected kind of the result, so that we 657*85ab413bSRiver Riddle // can provide extra verification of the native rewrite function. 65802c4c0d5SRiver Riddle #ifndef NDEBUG 659*85ab413bSRiver Riddle writer.appendPDLValueKind(result); 66002c4c0d5SRiver Riddle #endif 661*85ab413bSRiver Riddle 662*85ab413bSRiver Riddle // Range results also need to append the range storage index. 663*85ab413bSRiver Riddle if (result.getType().isa<pdl::RangeType>()) 664*85ab413bSRiver Riddle writer.append(getRangeStorageIndex(result)); 66502c4c0d5SRiver Riddle writer.append(result); 666abfd1a8bSRiver Riddle } 667*85ab413bSRiver Riddle } 668abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 669*85ab413bSRiver Riddle Value lhs = op.lhs(); 670*85ab413bSRiver Riddle if (lhs.getType().isa<pdl::RangeType>()) { 671*85ab413bSRiver Riddle writer.append(OpCode::AreRangesEqual); 672*85ab413bSRiver Riddle writer.appendPDLValueKind(lhs); 673*85ab413bSRiver Riddle writer.append(op.lhs(), op.rhs(), op.getSuccessors()); 674*85ab413bSRiver Riddle return; 675*85ab413bSRiver Riddle } 676*85ab413bSRiver Riddle 677*85ab413bSRiver Riddle writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors()); 678abfd1a8bSRiver Riddle } 679abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 6808affe881SRiver Riddle writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 681abfd1a8bSRiver Riddle } 682abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckAttributeOp op, 683abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 684abfd1a8bSRiver Riddle writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 685abfd1a8bSRiver Riddle op.getSuccessors()); 686abfd1a8bSRiver Riddle } 687abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperandCountOp op, 688abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 689abfd1a8bSRiver Riddle writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 690*85ab413bSRiver Riddle static_cast<ByteCodeField>(op.compareAtLeast()), 691abfd1a8bSRiver Riddle op.getSuccessors()); 692abfd1a8bSRiver Riddle } 693abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperationNameOp op, 694abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 695abfd1a8bSRiver Riddle writer.append(OpCode::CheckOperationName, op.operation(), 696abfd1a8bSRiver Riddle OperationName(op.name(), ctx), op.getSuccessors()); 697abfd1a8bSRiver Riddle } 698abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckResultCountOp op, 699abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 700abfd1a8bSRiver Riddle writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 701*85ab413bSRiver Riddle static_cast<ByteCodeField>(op.compareAtLeast()), 702abfd1a8bSRiver Riddle op.getSuccessors()); 703abfd1a8bSRiver Riddle } 704abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 705abfd1a8bSRiver Riddle writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 706abfd1a8bSRiver Riddle } 707*85ab413bSRiver Riddle void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { 708*85ab413bSRiver Riddle writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); 709*85ab413bSRiver Riddle } 710abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateAttributeOp op, 711abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 712abfd1a8bSRiver Riddle // Simply repoint the memory index of the result to the constant. 713abfd1a8bSRiver Riddle getMemIndex(op.attribute()) = getMemIndex(op.value()); 714abfd1a8bSRiver Riddle } 715abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateOperationOp op, 716abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 717abfd1a8bSRiver Riddle writer.append(OpCode::CreateOperation, op.operation(), 718*85ab413bSRiver Riddle OperationName(op.name(), ctx)); 719*85ab413bSRiver Riddle writer.appendPDLValueList(op.operands()); 720abfd1a8bSRiver Riddle 721abfd1a8bSRiver Riddle // Add the attributes. 722abfd1a8bSRiver Riddle OperandRange attributes = op.attributes(); 723abfd1a8bSRiver Riddle writer.append(static_cast<ByteCodeField>(attributes.size())); 724abfd1a8bSRiver Riddle for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 725abfd1a8bSRiver Riddle writer.append( 726abfd1a8bSRiver Riddle Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), 727abfd1a8bSRiver Riddle std::get<1>(it)); 728abfd1a8bSRiver Riddle } 729*85ab413bSRiver Riddle writer.appendPDLValueList(op.types()); 730abfd1a8bSRiver Riddle } 731abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 732abfd1a8bSRiver Riddle // Simply repoint the memory index of the result to the constant. 733abfd1a8bSRiver Riddle getMemIndex(op.result()) = getMemIndex(op.value()); 734abfd1a8bSRiver Riddle } 735*85ab413bSRiver Riddle void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { 736*85ab413bSRiver Riddle writer.append(OpCode::CreateTypes, op.result(), 737*85ab413bSRiver Riddle getRangeStorageIndex(op.result()), op.value()); 738*85ab413bSRiver Riddle } 739abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 740abfd1a8bSRiver Riddle writer.append(OpCode::EraseOp, op.operation()); 741abfd1a8bSRiver Riddle } 742abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 743abfd1a8bSRiver Riddle writer.append(OpCode::Finalize); 744abfd1a8bSRiver Riddle } 745abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeOp op, 746abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 747abfd1a8bSRiver Riddle writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 748abfd1a8bSRiver Riddle Identifier::get(op.name(), ctx)); 749abfd1a8bSRiver Riddle } 750abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeTypeOp op, 751abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 752abfd1a8bSRiver Riddle writer.append(OpCode::GetAttributeType, op.result(), op.value()); 753abfd1a8bSRiver Riddle } 754abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetDefiningOpOp op, 755abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 756*85ab413bSRiver Riddle writer.append(OpCode::GetDefiningOp, op.operation()); 757*85ab413bSRiver Riddle writer.appendPDLValue(op.value()); 758abfd1a8bSRiver Riddle } 759abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 760abfd1a8bSRiver Riddle uint32_t index = op.index(); 761abfd1a8bSRiver Riddle if (index < 4) 762abfd1a8bSRiver Riddle writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 763abfd1a8bSRiver Riddle else 764abfd1a8bSRiver Riddle writer.append(OpCode::GetOperandN, index); 765abfd1a8bSRiver Riddle writer.append(op.operation(), op.value()); 766abfd1a8bSRiver Riddle } 767*85ab413bSRiver Riddle void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { 768*85ab413bSRiver Riddle Value result = op.value(); 769*85ab413bSRiver Riddle Optional<uint32_t> index = op.index(); 770*85ab413bSRiver Riddle writer.append(OpCode::GetOperands, 771*85ab413bSRiver Riddle index.getValueOr(std::numeric_limits<uint32_t>::max()), 772*85ab413bSRiver Riddle op.operation()); 773*85ab413bSRiver Riddle if (result.getType().isa<pdl::RangeType>()) 774*85ab413bSRiver Riddle writer.append(getRangeStorageIndex(result)); 775*85ab413bSRiver Riddle else 776*85ab413bSRiver Riddle writer.append(std::numeric_limits<ByteCodeField>::max()); 777*85ab413bSRiver Riddle writer.append(result); 778*85ab413bSRiver Riddle } 779abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 780abfd1a8bSRiver Riddle uint32_t index = op.index(); 781abfd1a8bSRiver Riddle if (index < 4) 782abfd1a8bSRiver Riddle writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 783abfd1a8bSRiver Riddle else 784abfd1a8bSRiver Riddle writer.append(OpCode::GetResultN, index); 785abfd1a8bSRiver Riddle writer.append(op.operation(), op.value()); 786abfd1a8bSRiver Riddle } 787*85ab413bSRiver Riddle void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { 788*85ab413bSRiver Riddle Value result = op.value(); 789*85ab413bSRiver Riddle Optional<uint32_t> index = op.index(); 790*85ab413bSRiver Riddle writer.append(OpCode::GetResults, 791*85ab413bSRiver Riddle index.getValueOr(std::numeric_limits<uint32_t>::max()), 792*85ab413bSRiver Riddle op.operation()); 793*85ab413bSRiver Riddle if (result.getType().isa<pdl::RangeType>()) 794*85ab413bSRiver Riddle writer.append(getRangeStorageIndex(result)); 795*85ab413bSRiver Riddle else 796*85ab413bSRiver Riddle writer.append(std::numeric_limits<ByteCodeField>::max()); 797*85ab413bSRiver Riddle writer.append(result); 798*85ab413bSRiver Riddle } 799abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetValueTypeOp op, 800abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 801*85ab413bSRiver Riddle if (op.getType().isa<pdl::RangeType>()) { 802*85ab413bSRiver Riddle Value result = op.result(); 803*85ab413bSRiver Riddle writer.append(OpCode::GetValueRangeTypes, result, 804*85ab413bSRiver Riddle getRangeStorageIndex(result), op.value()); 805*85ab413bSRiver Riddle } else { 806abfd1a8bSRiver Riddle writer.append(OpCode::GetValueType, op.result(), op.value()); 807abfd1a8bSRiver Riddle } 808*85ab413bSRiver Riddle } 809*85ab413bSRiver Riddle 8103a833a0eSRiver Riddle void Generator::generate(pdl_interp::InferredTypesOp op, 811abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 8123a833a0eSRiver Riddle // InferType maps to a null type as a marker for inferring result types. 813abfd1a8bSRiver Riddle getMemIndex(op.type()) = getMemIndex(Type()); 814abfd1a8bSRiver Riddle } 815abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 816abfd1a8bSRiver Riddle writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 817abfd1a8bSRiver Riddle } 818abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 819abfd1a8bSRiver Riddle ByteCodeField patternIndex = patterns.size(); 820abfd1a8bSRiver Riddle patterns.emplace_back(PDLByteCodePattern::create( 821abfd1a8bSRiver Riddle op, rewriterToAddr[op.rewriter().getLeafReference()])); 8228affe881SRiver Riddle writer.append(OpCode::RecordMatch, patternIndex, 823*85ab413bSRiver Riddle SuccessorRange(op.getOperation()), op.matchedOps()); 824*85ab413bSRiver Riddle writer.appendPDLValueList(op.inputs()); 825abfd1a8bSRiver Riddle } 826abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 827*85ab413bSRiver Riddle writer.append(OpCode::ReplaceOp, op.operation()); 828*85ab413bSRiver Riddle writer.appendPDLValueList(op.replValues()); 829abfd1a8bSRiver Riddle } 830abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchAttributeOp op, 831abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 832abfd1a8bSRiver Riddle writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 833abfd1a8bSRiver Riddle op.getSuccessors()); 834abfd1a8bSRiver Riddle } 835abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperandCountOp op, 836abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 837abfd1a8bSRiver Riddle writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 838abfd1a8bSRiver Riddle op.getSuccessors()); 839abfd1a8bSRiver Riddle } 840abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperationNameOp op, 841abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 842abfd1a8bSRiver Riddle auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 843abfd1a8bSRiver Riddle return OperationName(attr.cast<StringAttr>().getValue(), ctx); 844abfd1a8bSRiver Riddle }); 845abfd1a8bSRiver Riddle writer.append(OpCode::SwitchOperationName, op.operation(), cases, 846abfd1a8bSRiver Riddle op.getSuccessors()); 847abfd1a8bSRiver Riddle } 848abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchResultCountOp op, 849abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 850abfd1a8bSRiver Riddle writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 851abfd1a8bSRiver Riddle op.getSuccessors()); 852abfd1a8bSRiver Riddle } 853abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 854abfd1a8bSRiver Riddle writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 855abfd1a8bSRiver Riddle op.getSuccessors()); 856abfd1a8bSRiver Riddle } 857*85ab413bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { 858*85ab413bSRiver Riddle writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(), 859*85ab413bSRiver Riddle op.getSuccessors()); 860*85ab413bSRiver Riddle } 861abfd1a8bSRiver Riddle 862abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 863abfd1a8bSRiver Riddle // PDLByteCode 864abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 865abfd1a8bSRiver Riddle 866abfd1a8bSRiver Riddle PDLByteCode::PDLByteCode(ModuleOp module, 867abfd1a8bSRiver Riddle llvm::StringMap<PDLConstraintFunction> constraintFns, 868abfd1a8bSRiver Riddle llvm::StringMap<PDLRewriteFunction> rewriteFns) { 869abfd1a8bSRiver Riddle Generator generator(module.getContext(), uniquedData, matcherByteCode, 870abfd1a8bSRiver Riddle rewriterByteCode, patterns, maxValueMemoryIndex, 871*85ab413bSRiver Riddle maxTypeRangeCount, maxValueRangeCount, constraintFns, 872*85ab413bSRiver Riddle rewriteFns); 873abfd1a8bSRiver Riddle generator.generate(module); 874abfd1a8bSRiver Riddle 875abfd1a8bSRiver Riddle // Initialize the external functions. 876abfd1a8bSRiver Riddle for (auto &it : constraintFns) 877abfd1a8bSRiver Riddle constraintFunctions.push_back(std::move(it.second)); 878abfd1a8bSRiver Riddle for (auto &it : rewriteFns) 879abfd1a8bSRiver Riddle rewriteFunctions.push_back(std::move(it.second)); 880abfd1a8bSRiver Riddle } 881abfd1a8bSRiver Riddle 882abfd1a8bSRiver Riddle /// Initialize the given state such that it can be used to execute the current 883abfd1a8bSRiver Riddle /// bytecode. 884abfd1a8bSRiver Riddle void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 885abfd1a8bSRiver Riddle state.memory.resize(maxValueMemoryIndex, nullptr); 886*85ab413bSRiver Riddle state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); 887*85ab413bSRiver Riddle state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); 888abfd1a8bSRiver Riddle state.currentPatternBenefits.reserve(patterns.size()); 889abfd1a8bSRiver Riddle for (const PDLByteCodePattern &pattern : patterns) 890abfd1a8bSRiver Riddle state.currentPatternBenefits.push_back(pattern.getBenefit()); 891abfd1a8bSRiver Riddle } 892abfd1a8bSRiver Riddle 893abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 894abfd1a8bSRiver Riddle // ByteCode Execution 895abfd1a8bSRiver Riddle 896abfd1a8bSRiver Riddle namespace { 897abfd1a8bSRiver Riddle /// This class provides support for executing a bytecode stream. 898abfd1a8bSRiver Riddle class ByteCodeExecutor { 899abfd1a8bSRiver Riddle public: 900*85ab413bSRiver Riddle ByteCodeExecutor( 901*85ab413bSRiver Riddle const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, 902*85ab413bSRiver Riddle MutableArrayRef<TypeRange> typeRangeMemory, 903*85ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, 904*85ab413bSRiver Riddle MutableArrayRef<ValueRange> valueRangeMemory, 905*85ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, 906*85ab413bSRiver Riddle ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code, 907abfd1a8bSRiver Riddle ArrayRef<PatternBenefit> currentPatternBenefits, 908abfd1a8bSRiver Riddle ArrayRef<PDLByteCodePattern> patterns, 909abfd1a8bSRiver Riddle ArrayRef<PDLConstraintFunction> constraintFunctions, 910abfd1a8bSRiver Riddle ArrayRef<PDLRewriteFunction> rewriteFunctions) 911*85ab413bSRiver Riddle : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory), 912*85ab413bSRiver Riddle allocatedTypeRangeMemory(allocatedTypeRangeMemory), 913*85ab413bSRiver Riddle valueRangeMemory(valueRangeMemory), 914*85ab413bSRiver Riddle allocatedValueRangeMemory(allocatedValueRangeMemory), 915*85ab413bSRiver Riddle uniquedMemory(uniquedMemory), code(code), 916*85ab413bSRiver Riddle currentPatternBenefits(currentPatternBenefits), patterns(patterns), 917*85ab413bSRiver Riddle constraintFunctions(constraintFunctions), 91802c4c0d5SRiver Riddle rewriteFunctions(rewriteFunctions) {} 919abfd1a8bSRiver Riddle 920abfd1a8bSRiver Riddle /// Start executing the code at the current bytecode index. `matches` is an 921abfd1a8bSRiver Riddle /// optional field provided when this function is executed in a matching 922abfd1a8bSRiver Riddle /// context. 923abfd1a8bSRiver Riddle void execute(PatternRewriter &rewriter, 924abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 925abfd1a8bSRiver Riddle Optional<Location> mainRewriteLoc = {}); 926abfd1a8bSRiver Riddle 927abfd1a8bSRiver Riddle private: 928154cabe7SRiver Riddle /// Internal implementation of executing each of the bytecode commands. 929154cabe7SRiver Riddle void executeApplyConstraint(PatternRewriter &rewriter); 930154cabe7SRiver Riddle void executeApplyRewrite(PatternRewriter &rewriter); 931154cabe7SRiver Riddle void executeAreEqual(); 932*85ab413bSRiver Riddle void executeAreRangesEqual(); 933154cabe7SRiver Riddle void executeBranch(); 934154cabe7SRiver Riddle void executeCheckOperandCount(); 935154cabe7SRiver Riddle void executeCheckOperationName(); 936154cabe7SRiver Riddle void executeCheckResultCount(); 937*85ab413bSRiver Riddle void executeCheckTypes(); 938154cabe7SRiver Riddle void executeCreateOperation(PatternRewriter &rewriter, 939154cabe7SRiver Riddle Location mainRewriteLoc); 940*85ab413bSRiver Riddle void executeCreateTypes(); 941154cabe7SRiver Riddle void executeEraseOp(PatternRewriter &rewriter); 942154cabe7SRiver Riddle void executeGetAttribute(); 943154cabe7SRiver Riddle void executeGetAttributeType(); 944154cabe7SRiver Riddle void executeGetDefiningOp(); 945154cabe7SRiver Riddle void executeGetOperand(unsigned index); 946*85ab413bSRiver Riddle void executeGetOperands(); 947154cabe7SRiver Riddle void executeGetResult(unsigned index); 948*85ab413bSRiver Riddle void executeGetResults(); 949154cabe7SRiver Riddle void executeGetValueType(); 950*85ab413bSRiver Riddle void executeGetValueRangeTypes(); 951154cabe7SRiver Riddle void executeIsNotNull(); 952154cabe7SRiver Riddle void executeRecordMatch(PatternRewriter &rewriter, 953154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> &matches); 954154cabe7SRiver Riddle void executeReplaceOp(PatternRewriter &rewriter); 955154cabe7SRiver Riddle void executeSwitchAttribute(); 956154cabe7SRiver Riddle void executeSwitchOperandCount(); 957154cabe7SRiver Riddle void executeSwitchOperationName(); 958154cabe7SRiver Riddle void executeSwitchResultCount(); 959154cabe7SRiver Riddle void executeSwitchType(); 960*85ab413bSRiver Riddle void executeSwitchTypes(); 961154cabe7SRiver Riddle 962abfd1a8bSRiver Riddle /// Read a value from the bytecode buffer, optionally skipping a certain 963abfd1a8bSRiver Riddle /// number of prefix values. These methods always update the buffer to point 964abfd1a8bSRiver Riddle /// to the next field after the read data. 965abfd1a8bSRiver Riddle template <typename T = ByteCodeField> 966abfd1a8bSRiver Riddle T read(size_t skipN = 0) { 967abfd1a8bSRiver Riddle curCodeIt += skipN; 968abfd1a8bSRiver Riddle return readImpl<T>(); 969abfd1a8bSRiver Riddle } 970abfd1a8bSRiver Riddle ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 971abfd1a8bSRiver Riddle 972abfd1a8bSRiver Riddle /// Read a list of values from the bytecode buffer. 973abfd1a8bSRiver Riddle template <typename ValueT, typename T> 974abfd1a8bSRiver Riddle void readList(SmallVectorImpl<T> &list) { 975abfd1a8bSRiver Riddle list.clear(); 976abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) 977abfd1a8bSRiver Riddle list.push_back(read<ValueT>()); 978abfd1a8bSRiver Riddle } 979abfd1a8bSRiver Riddle 980*85ab413bSRiver Riddle /// Read a list of values from the bytecode buffer. The values may be encoded 981*85ab413bSRiver Riddle /// as either Value or ValueRange elements. 982*85ab413bSRiver Riddle void readValueList(SmallVectorImpl<Value> &list) { 983*85ab413bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 984*85ab413bSRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 985*85ab413bSRiver Riddle list.push_back(read<Value>()); 986*85ab413bSRiver Riddle } else { 987*85ab413bSRiver Riddle ValueRange *values = read<ValueRange *>(); 988*85ab413bSRiver Riddle list.append(values->begin(), values->end()); 989*85ab413bSRiver Riddle } 990*85ab413bSRiver Riddle } 991*85ab413bSRiver Riddle } 992*85ab413bSRiver Riddle 993abfd1a8bSRiver Riddle /// Jump to a specific successor based on a predicate value. 994abfd1a8bSRiver Riddle void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 995abfd1a8bSRiver Riddle /// Jump to a specific successor based on a destination index. 996abfd1a8bSRiver Riddle void selectJump(size_t destIndex) { 997abfd1a8bSRiver Riddle curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 998abfd1a8bSRiver Riddle } 999abfd1a8bSRiver Riddle 1000abfd1a8bSRiver Riddle /// Handle a switch operation with the provided value and cases. 1001*85ab413bSRiver Riddle template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> 1002*85ab413bSRiver Riddle void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { 1003abfd1a8bSRiver Riddle LLVM_DEBUG({ 1004abfd1a8bSRiver Riddle llvm::dbgs() << " * Value: " << value << "\n" 1005abfd1a8bSRiver Riddle << " * Cases: "; 1006abfd1a8bSRiver Riddle llvm::interleaveComma(cases, llvm::dbgs()); 1007154cabe7SRiver Riddle llvm::dbgs() << "\n"; 1008abfd1a8bSRiver Riddle }); 1009abfd1a8bSRiver Riddle 1010abfd1a8bSRiver Riddle // Check to see if the attribute value is within the case list. Jump to 1011abfd1a8bSRiver Riddle // the correct successor index based on the result. 1012f80b6304SRiver Riddle for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 1013*85ab413bSRiver Riddle if (cmp(*it, value)) 1014f80b6304SRiver Riddle return selectJump(size_t((it - cases.begin()) + 1)); 1015f80b6304SRiver Riddle selectJump(size_t(0)); 1016abfd1a8bSRiver Riddle } 1017abfd1a8bSRiver Riddle 1018abfd1a8bSRiver Riddle /// Internal implementation of reading various data types from the bytecode 1019abfd1a8bSRiver Riddle /// stream. 1020abfd1a8bSRiver Riddle template <typename T> 1021abfd1a8bSRiver Riddle const void *readFromMemory() { 1022abfd1a8bSRiver Riddle size_t index = *curCodeIt++; 1023abfd1a8bSRiver Riddle 1024abfd1a8bSRiver Riddle // If this type is an SSA value, it can only be stored in non-const memory. 1025*85ab413bSRiver Riddle if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, 1026*85ab413bSRiver Riddle Value>::value || 1027*85ab413bSRiver Riddle index < memory.size()) 1028abfd1a8bSRiver Riddle return memory[index]; 1029abfd1a8bSRiver Riddle 1030abfd1a8bSRiver Riddle // Otherwise, if this index is not inbounds it is uniqued. 1031abfd1a8bSRiver Riddle return uniquedMemory[index - memory.size()]; 1032abfd1a8bSRiver Riddle } 1033abfd1a8bSRiver Riddle template <typename T> 1034abfd1a8bSRiver Riddle std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 1035abfd1a8bSRiver Riddle return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 1036abfd1a8bSRiver Riddle } 1037abfd1a8bSRiver Riddle template <typename T> 1038abfd1a8bSRiver Riddle std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 1039abfd1a8bSRiver Riddle T> 1040abfd1a8bSRiver Riddle readImpl() { 1041abfd1a8bSRiver Riddle return T(T::getFromOpaquePointer(readFromMemory<T>())); 1042abfd1a8bSRiver Riddle } 1043abfd1a8bSRiver Riddle template <typename T> 1044abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 1045*85ab413bSRiver Riddle switch (read<PDLValue::Kind>()) { 1046*85ab413bSRiver Riddle case PDLValue::Kind::Attribute: 1047abfd1a8bSRiver Riddle return read<Attribute>(); 1048*85ab413bSRiver Riddle case PDLValue::Kind::Operation: 1049abfd1a8bSRiver Riddle return read<Operation *>(); 1050*85ab413bSRiver Riddle case PDLValue::Kind::Type: 1051abfd1a8bSRiver Riddle return read<Type>(); 1052*85ab413bSRiver Riddle case PDLValue::Kind::Value: 1053abfd1a8bSRiver Riddle return read<Value>(); 1054*85ab413bSRiver Riddle case PDLValue::Kind::TypeRange: 1055*85ab413bSRiver Riddle return read<TypeRange *>(); 1056*85ab413bSRiver Riddle case PDLValue::Kind::ValueRange: 1057*85ab413bSRiver Riddle return read<ValueRange *>(); 1058abfd1a8bSRiver Riddle } 1059*85ab413bSRiver Riddle llvm_unreachable("unhandled PDLValue::Kind"); 1060abfd1a8bSRiver Riddle } 1061abfd1a8bSRiver Riddle template <typename T> 1062abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 1063abfd1a8bSRiver Riddle static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 1064abfd1a8bSRiver Riddle "unexpected ByteCode address size"); 1065abfd1a8bSRiver Riddle ByteCodeAddr result; 1066abfd1a8bSRiver Riddle std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 1067abfd1a8bSRiver Riddle curCodeIt += 2; 1068abfd1a8bSRiver Riddle return result; 1069abfd1a8bSRiver Riddle } 1070abfd1a8bSRiver Riddle template <typename T> 1071abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 1072abfd1a8bSRiver Riddle return *curCodeIt++; 1073abfd1a8bSRiver Riddle } 1074*85ab413bSRiver Riddle template <typename T> 1075*85ab413bSRiver Riddle std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { 1076*85ab413bSRiver Riddle return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); 1077*85ab413bSRiver Riddle } 1078abfd1a8bSRiver Riddle 1079abfd1a8bSRiver Riddle /// The underlying bytecode buffer. 1080abfd1a8bSRiver Riddle const ByteCodeField *curCodeIt; 1081abfd1a8bSRiver Riddle 1082abfd1a8bSRiver Riddle /// The current execution memory. 1083abfd1a8bSRiver Riddle MutableArrayRef<const void *> memory; 1084*85ab413bSRiver Riddle MutableArrayRef<TypeRange> typeRangeMemory; 1085*85ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; 1086*85ab413bSRiver Riddle MutableArrayRef<ValueRange> valueRangeMemory; 1087*85ab413bSRiver Riddle std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; 1088abfd1a8bSRiver Riddle 1089abfd1a8bSRiver Riddle /// References to ByteCode data necessary for execution. 1090abfd1a8bSRiver Riddle ArrayRef<const void *> uniquedMemory; 1091abfd1a8bSRiver Riddle ArrayRef<ByteCodeField> code; 1092abfd1a8bSRiver Riddle ArrayRef<PatternBenefit> currentPatternBenefits; 1093abfd1a8bSRiver Riddle ArrayRef<PDLByteCodePattern> patterns; 1094abfd1a8bSRiver Riddle ArrayRef<PDLConstraintFunction> constraintFunctions; 1095abfd1a8bSRiver Riddle ArrayRef<PDLRewriteFunction> rewriteFunctions; 1096abfd1a8bSRiver Riddle }; 109702c4c0d5SRiver Riddle 109802c4c0d5SRiver Riddle /// This class is an instantiation of the PDLResultList that provides access to 109902c4c0d5SRiver Riddle /// the returned results. This API is not on `PDLResultList` to avoid 110002c4c0d5SRiver Riddle /// overexposing access to information specific solely to the ByteCode. 110102c4c0d5SRiver Riddle class ByteCodeRewriteResultList : public PDLResultList { 110202c4c0d5SRiver Riddle public: 1103*85ab413bSRiver Riddle ByteCodeRewriteResultList(unsigned maxNumResults) 1104*85ab413bSRiver Riddle : PDLResultList(maxNumResults) {} 1105*85ab413bSRiver Riddle 110602c4c0d5SRiver Riddle /// Return the list of PDL results. 110702c4c0d5SRiver Riddle MutableArrayRef<PDLValue> getResults() { return results; } 1108*85ab413bSRiver Riddle 1109*85ab413bSRiver Riddle /// Return the type ranges allocated by this list. 1110*85ab413bSRiver Riddle MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { 1111*85ab413bSRiver Riddle return allocatedTypeRanges; 1112*85ab413bSRiver Riddle } 1113*85ab413bSRiver Riddle 1114*85ab413bSRiver Riddle /// Return the value ranges allocated by this list. 1115*85ab413bSRiver Riddle MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { 1116*85ab413bSRiver Riddle return allocatedValueRanges; 1117*85ab413bSRiver Riddle } 111802c4c0d5SRiver Riddle }; 1119abfd1a8bSRiver Riddle } // end anonymous namespace 1120abfd1a8bSRiver Riddle 1121154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 1122abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 1123abfd1a8bSRiver Riddle const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 1124abfd1a8bSRiver Riddle ArrayAttr constParams = read<ArrayAttr>(); 1125abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args; 1126abfd1a8bSRiver Riddle readList<PDLValue>(args); 1127154cabe7SRiver Riddle 1128abfd1a8bSRiver Riddle LLVM_DEBUG({ 1129abfd1a8bSRiver Riddle llvm::dbgs() << " * Arguments: "; 1130abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 1131154cabe7SRiver Riddle llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1132abfd1a8bSRiver Riddle }); 1133abfd1a8bSRiver Riddle 1134abfd1a8bSRiver Riddle // Invoke the constraint and jump to the proper destination. 1135abfd1a8bSRiver Riddle selectJump(succeeded(constraintFn(args, constParams, rewriter))); 1136abfd1a8bSRiver Riddle } 1137154cabe7SRiver Riddle 1138154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 1139abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 1140abfd1a8bSRiver Riddle const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 1141abfd1a8bSRiver Riddle ArrayAttr constParams = read<ArrayAttr>(); 1142abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args; 1143abfd1a8bSRiver Riddle readList<PDLValue>(args); 1144abfd1a8bSRiver Riddle 1145abfd1a8bSRiver Riddle LLVM_DEBUG({ 114602c4c0d5SRiver Riddle llvm::dbgs() << " * Arguments: "; 1147abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 1148154cabe7SRiver Riddle llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 1149abfd1a8bSRiver Riddle }); 1150*85ab413bSRiver Riddle 1151*85ab413bSRiver Riddle // Execute the rewrite function. 1152*85ab413bSRiver Riddle ByteCodeField numResults = read(); 1153*85ab413bSRiver Riddle ByteCodeRewriteResultList results(numResults); 115402c4c0d5SRiver Riddle rewriteFn(args, constParams, rewriter, results); 1155154cabe7SRiver Riddle 1156*85ab413bSRiver Riddle assert(results.getResults().size() == numResults && 115702c4c0d5SRiver Riddle "native PDL rewrite function returned unexpected number of results"); 115802c4c0d5SRiver Riddle 115902c4c0d5SRiver Riddle // Store the results in the bytecode memory. 116002c4c0d5SRiver Riddle for (PDLValue &result : results.getResults()) { 116102c4c0d5SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 1162*85ab413bSRiver Riddle 1163*85ab413bSRiver Riddle // In debug mode we also verify the expected kind of the result. 1164*85ab413bSRiver Riddle #ifndef NDEBUG 1165*85ab413bSRiver Riddle assert(result.getKind() == read<PDLValue::Kind>() && 1166*85ab413bSRiver Riddle "native PDL rewrite function returned an unexpected type of result"); 1167*85ab413bSRiver Riddle #endif 1168*85ab413bSRiver Riddle 1169*85ab413bSRiver Riddle // If the result is a range, we need to copy it over to the bytecodes 1170*85ab413bSRiver Riddle // range memory. 1171*85ab413bSRiver Riddle if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { 1172*85ab413bSRiver Riddle unsigned rangeIndex = read(); 1173*85ab413bSRiver Riddle typeRangeMemory[rangeIndex] = *typeRange; 1174*85ab413bSRiver Riddle memory[read()] = &typeRangeMemory[rangeIndex]; 1175*85ab413bSRiver Riddle } else if (Optional<ValueRange> valueRange = 1176*85ab413bSRiver Riddle result.dyn_cast<ValueRange>()) { 1177*85ab413bSRiver Riddle unsigned rangeIndex = read(); 1178*85ab413bSRiver Riddle valueRangeMemory[rangeIndex] = *valueRange; 1179*85ab413bSRiver Riddle memory[read()] = &valueRangeMemory[rangeIndex]; 1180*85ab413bSRiver Riddle } else { 118102c4c0d5SRiver Riddle memory[read()] = result.getAsOpaquePointer(); 118202c4c0d5SRiver Riddle } 1183abfd1a8bSRiver Riddle } 1184154cabe7SRiver Riddle 1185*85ab413bSRiver Riddle // Copy over any underlying storage allocated for result ranges. 1186*85ab413bSRiver Riddle for (auto &it : results.getAllocatedTypeRanges()) 1187*85ab413bSRiver Riddle allocatedTypeRangeMemory.push_back(std::move(it)); 1188*85ab413bSRiver Riddle for (auto &it : results.getAllocatedValueRanges()) 1189*85ab413bSRiver Riddle allocatedValueRangeMemory.push_back(std::move(it)); 1190*85ab413bSRiver Riddle } 1191*85ab413bSRiver Riddle 1192154cabe7SRiver Riddle void ByteCodeExecutor::executeAreEqual() { 1193abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1194abfd1a8bSRiver Riddle const void *lhs = read<const void *>(); 1195abfd1a8bSRiver Riddle const void *rhs = read<const void *>(); 1196abfd1a8bSRiver Riddle 1197154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 1198abfd1a8bSRiver Riddle selectJump(lhs == rhs); 1199abfd1a8bSRiver Riddle } 1200154cabe7SRiver Riddle 1201*85ab413bSRiver Riddle void ByteCodeExecutor::executeAreRangesEqual() { 1202*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); 1203*85ab413bSRiver Riddle PDLValue::Kind valueKind = read<PDLValue::Kind>(); 1204*85ab413bSRiver Riddle const void *lhs = read<const void *>(); 1205*85ab413bSRiver Riddle const void *rhs = read<const void *>(); 1206*85ab413bSRiver Riddle 1207*85ab413bSRiver Riddle switch (valueKind) { 1208*85ab413bSRiver Riddle case PDLValue::Kind::TypeRange: { 1209*85ab413bSRiver Riddle const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); 1210*85ab413bSRiver Riddle const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); 1211*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1212*85ab413bSRiver Riddle selectJump(*lhsRange == *rhsRange); 1213*85ab413bSRiver Riddle break; 1214*85ab413bSRiver Riddle } 1215*85ab413bSRiver Riddle case PDLValue::Kind::ValueRange: { 1216*85ab413bSRiver Riddle const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); 1217*85ab413bSRiver Riddle const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); 1218*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1219*85ab413bSRiver Riddle selectJump(*lhsRange == *rhsRange); 1220*85ab413bSRiver Riddle break; 1221*85ab413bSRiver Riddle } 1222*85ab413bSRiver Riddle default: 1223*85ab413bSRiver Riddle llvm_unreachable("unexpected `AreRangesEqual` value kind"); 1224*85ab413bSRiver Riddle } 1225*85ab413bSRiver Riddle } 1226*85ab413bSRiver Riddle 1227154cabe7SRiver Riddle void ByteCodeExecutor::executeBranch() { 1228154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 1229abfd1a8bSRiver Riddle curCodeIt = &code[read<ByteCodeAddr>()]; 1230abfd1a8bSRiver Riddle } 1231154cabe7SRiver Riddle 1232154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperandCount() { 1233abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 1234abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1235abfd1a8bSRiver Riddle uint32_t expectedCount = read<uint32_t>(); 1236*85ab413bSRiver Riddle bool compareAtLeast = read(); 1237abfd1a8bSRiver Riddle 1238abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 1239*85ab413bSRiver Riddle << " * Expected: " << expectedCount << "\n" 1240*85ab413bSRiver Riddle << " * Comparator: " 1241*85ab413bSRiver Riddle << (compareAtLeast ? ">=" : "==") << "\n"); 1242*85ab413bSRiver Riddle if (compareAtLeast) 1243*85ab413bSRiver Riddle selectJump(op->getNumOperands() >= expectedCount); 1244*85ab413bSRiver Riddle else 1245abfd1a8bSRiver Riddle selectJump(op->getNumOperands() == expectedCount); 1246abfd1a8bSRiver Riddle } 1247154cabe7SRiver Riddle 1248154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperationName() { 1249abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 1250abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1251abfd1a8bSRiver Riddle OperationName expectedName = read<OperationName>(); 1252abfd1a8bSRiver Riddle 1253154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 1254154cabe7SRiver Riddle << " * Expected: \"" << expectedName << "\"\n"); 1255abfd1a8bSRiver Riddle selectJump(op->getName() == expectedName); 1256abfd1a8bSRiver Riddle } 1257154cabe7SRiver Riddle 1258154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckResultCount() { 1259abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 1260abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1261abfd1a8bSRiver Riddle uint32_t expectedCount = read<uint32_t>(); 1262*85ab413bSRiver Riddle bool compareAtLeast = read(); 1263abfd1a8bSRiver Riddle 1264abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 1265*85ab413bSRiver Riddle << " * Expected: " << expectedCount << "\n" 1266*85ab413bSRiver Riddle << " * Comparator: " 1267*85ab413bSRiver Riddle << (compareAtLeast ? ">=" : "==") << "\n"); 1268*85ab413bSRiver Riddle if (compareAtLeast) 1269*85ab413bSRiver Riddle selectJump(op->getNumResults() >= expectedCount); 1270*85ab413bSRiver Riddle else 1271abfd1a8bSRiver Riddle selectJump(op->getNumResults() == expectedCount); 1272abfd1a8bSRiver Riddle } 1273154cabe7SRiver Riddle 1274*85ab413bSRiver Riddle void ByteCodeExecutor::executeCheckTypes() { 1275*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 1276*85ab413bSRiver Riddle TypeRange *lhs = read<TypeRange *>(); 1277*85ab413bSRiver Riddle Attribute rhs = read<Attribute>(); 1278*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 1279*85ab413bSRiver Riddle 1280*85ab413bSRiver Riddle selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); 1281*85ab413bSRiver Riddle } 1282*85ab413bSRiver Riddle 1283*85ab413bSRiver Riddle void ByteCodeExecutor::executeCreateTypes() { 1284*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); 1285*85ab413bSRiver Riddle unsigned memIndex = read(); 1286*85ab413bSRiver Riddle unsigned rangeIndex = read(); 1287*85ab413bSRiver Riddle ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); 1288*85ab413bSRiver Riddle 1289*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); 1290*85ab413bSRiver Riddle 1291*85ab413bSRiver Riddle // Allocate a buffer for this type range. 1292*85ab413bSRiver Riddle llvm::OwningArrayRef<Type> storage(typesAttr.size()); 1293*85ab413bSRiver Riddle llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); 1294*85ab413bSRiver Riddle allocatedTypeRangeMemory.emplace_back(std::move(storage)); 1295*85ab413bSRiver Riddle 1296*85ab413bSRiver Riddle // Assign this to the range slot and use the range as the value for the 1297*85ab413bSRiver Riddle // memory index. 1298*85ab413bSRiver Riddle typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); 1299*85ab413bSRiver Riddle memory[memIndex] = &typeRangeMemory[rangeIndex]; 1300*85ab413bSRiver Riddle } 1301*85ab413bSRiver Riddle 1302154cabe7SRiver Riddle void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 1303154cabe7SRiver Riddle Location mainRewriteLoc) { 1304abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 1305abfd1a8bSRiver Riddle 1306abfd1a8bSRiver Riddle unsigned memIndex = read(); 1307154cabe7SRiver Riddle OperationState state(mainRewriteLoc, read<OperationName>()); 1308*85ab413bSRiver Riddle readValueList(state.operands); 1309abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 1310abfd1a8bSRiver Riddle Identifier name = read<Identifier>(); 1311abfd1a8bSRiver Riddle if (Attribute attr = read<Attribute>()) 1312abfd1a8bSRiver Riddle state.addAttribute(name, attr); 1313abfd1a8bSRiver Riddle } 1314abfd1a8bSRiver Riddle 1315abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 1316*85ab413bSRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { 1317*85ab413bSRiver Riddle state.types.push_back(read<Type>()); 1318*85ab413bSRiver Riddle continue; 1319*85ab413bSRiver Riddle } 1320*85ab413bSRiver Riddle 1321*85ab413bSRiver Riddle // If we find a null range, this signals that the types are infered. 1322*85ab413bSRiver Riddle if (TypeRange *resultTypes = read<TypeRange *>()) { 1323*85ab413bSRiver Riddle state.types.append(resultTypes->begin(), resultTypes->end()); 1324*85ab413bSRiver Riddle continue; 1325abfd1a8bSRiver Riddle } 1326abfd1a8bSRiver Riddle 1327abfd1a8bSRiver Riddle // Handle the case where the operation has inferred types. 1328abfd1a8bSRiver Riddle InferTypeOpInterface::Concept *concept = 1329154cabe7SRiver Riddle state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>(); 1330abfd1a8bSRiver Riddle 1331abfd1a8bSRiver Riddle // TODO: Handle failure. 13323a833a0eSRiver Riddle state.types.clear(); 1333abfd1a8bSRiver Riddle if (failed(concept->inferReturnTypes( 1334abfd1a8bSRiver Riddle state.getContext(), state.location, state.operands, 1335154cabe7SRiver Riddle state.attributes.getDictionary(state.getContext()), state.regions, 13363a833a0eSRiver Riddle state.types))) 1337abfd1a8bSRiver Riddle return; 1338*85ab413bSRiver Riddle break; 1339abfd1a8bSRiver Riddle } 1340*85ab413bSRiver Riddle 1341abfd1a8bSRiver Riddle Operation *resultOp = rewriter.createOperation(state); 1342abfd1a8bSRiver Riddle memory[memIndex] = resultOp; 1343abfd1a8bSRiver Riddle 1344abfd1a8bSRiver Riddle LLVM_DEBUG({ 1345abfd1a8bSRiver Riddle llvm::dbgs() << " * Attributes: " 1346abfd1a8bSRiver Riddle << state.attributes.getDictionary(state.getContext()) 1347abfd1a8bSRiver Riddle << "\n * Operands: "; 1348abfd1a8bSRiver Riddle llvm::interleaveComma(state.operands, llvm::dbgs()); 1349abfd1a8bSRiver Riddle llvm::dbgs() << "\n * Result Types: "; 1350abfd1a8bSRiver Riddle llvm::interleaveComma(state.types, llvm::dbgs()); 1351154cabe7SRiver Riddle llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1352abfd1a8bSRiver Riddle }); 1353abfd1a8bSRiver Riddle } 1354154cabe7SRiver Riddle 1355154cabe7SRiver Riddle void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1356abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1357abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1358abfd1a8bSRiver Riddle 1359154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1360abfd1a8bSRiver Riddle rewriter.eraseOp(op); 1361abfd1a8bSRiver Riddle } 1362154cabe7SRiver Riddle 1363154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttribute() { 1364abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1365abfd1a8bSRiver Riddle unsigned memIndex = read(); 1366abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1367abfd1a8bSRiver Riddle Identifier attrName = read<Identifier>(); 1368abfd1a8bSRiver Riddle Attribute attr = op->getAttr(attrName); 1369abfd1a8bSRiver Riddle 1370abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1371abfd1a8bSRiver Riddle << " * Attribute: " << attrName << "\n" 1372154cabe7SRiver Riddle << " * Result: " << attr << "\n"); 1373abfd1a8bSRiver Riddle memory[memIndex] = attr.getAsOpaquePointer(); 1374abfd1a8bSRiver Riddle } 1375154cabe7SRiver Riddle 1376154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttributeType() { 1377abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1378abfd1a8bSRiver Riddle unsigned memIndex = read(); 1379abfd1a8bSRiver Riddle Attribute attr = read<Attribute>(); 1380154cabe7SRiver Riddle Type type = attr ? attr.getType() : Type(); 1381abfd1a8bSRiver Riddle 1382abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1383154cabe7SRiver Riddle << " * Result: " << type << "\n"); 1384154cabe7SRiver Riddle memory[memIndex] = type.getAsOpaquePointer(); 1385abfd1a8bSRiver Riddle } 1386154cabe7SRiver Riddle 1387154cabe7SRiver Riddle void ByteCodeExecutor::executeGetDefiningOp() { 1388abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1389abfd1a8bSRiver Riddle unsigned memIndex = read(); 1390*85ab413bSRiver Riddle Operation *op = nullptr; 1391*85ab413bSRiver Riddle if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { 1392abfd1a8bSRiver Riddle Value value = read<Value>(); 1393*85ab413bSRiver Riddle if (value) 1394*85ab413bSRiver Riddle op = value.getDefiningOp(); 1395*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1396*85ab413bSRiver Riddle } else { 1397*85ab413bSRiver Riddle ValueRange *values = read<ValueRange *>(); 1398*85ab413bSRiver Riddle if (values && !values->empty()) { 1399*85ab413bSRiver Riddle op = values->front().getDefiningOp(); 1400*85ab413bSRiver Riddle } 1401*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); 1402*85ab413bSRiver Riddle } 1403abfd1a8bSRiver Riddle 1404*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); 1405abfd1a8bSRiver Riddle memory[memIndex] = op; 1406abfd1a8bSRiver Riddle } 1407154cabe7SRiver Riddle 1408154cabe7SRiver Riddle void ByteCodeExecutor::executeGetOperand(unsigned index) { 1409abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1410abfd1a8bSRiver Riddle unsigned memIndex = read(); 1411abfd1a8bSRiver Riddle Value operand = 1412abfd1a8bSRiver Riddle index < op->getNumOperands() ? op->getOperand(index) : Value(); 1413abfd1a8bSRiver Riddle 1414abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1415abfd1a8bSRiver Riddle << " * Index: " << index << "\n" 1416154cabe7SRiver Riddle << " * Result: " << operand << "\n"); 1417abfd1a8bSRiver Riddle memory[memIndex] = operand.getAsOpaquePointer(); 1418abfd1a8bSRiver Riddle } 1419154cabe7SRiver Riddle 1420*85ab413bSRiver Riddle /// This function is the internal implementation of `GetResults` and 1421*85ab413bSRiver Riddle /// `GetOperands` that provides support for extracting a value range from the 1422*85ab413bSRiver Riddle /// given operation. 1423*85ab413bSRiver Riddle template <template <typename> class AttrSizedSegmentsT, typename RangeT> 1424*85ab413bSRiver Riddle static void * 1425*85ab413bSRiver Riddle executeGetOperandsResults(RangeT values, Operation *op, unsigned index, 1426*85ab413bSRiver Riddle ByteCodeField rangeIndex, StringRef attrSizedSegments, 1427*85ab413bSRiver Riddle MutableArrayRef<ValueRange> &valueRangeMemory) { 1428*85ab413bSRiver Riddle // Check for the sentinel index that signals that all values should be 1429*85ab413bSRiver Riddle // returned. 1430*85ab413bSRiver Riddle if (index == std::numeric_limits<uint32_t>::max()) { 1431*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); 1432*85ab413bSRiver Riddle // `values` is already the full value range. 1433*85ab413bSRiver Riddle 1434*85ab413bSRiver Riddle // Otherwise, check to see if this operation uses AttrSizedSegments. 1435*85ab413bSRiver Riddle } else if (op->hasTrait<AttrSizedSegmentsT>()) { 1436*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() 1437*85ab413bSRiver Riddle << " * Extracting values from `" << attrSizedSegments << "`\n"); 1438*85ab413bSRiver Riddle 1439*85ab413bSRiver Riddle auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); 1440*85ab413bSRiver Riddle if (!segmentAttr || segmentAttr.getNumElements() <= index) 1441*85ab413bSRiver Riddle return nullptr; 1442*85ab413bSRiver Riddle 1443*85ab413bSRiver Riddle auto segments = segmentAttr.getValues<int32_t>(); 1444*85ab413bSRiver Riddle unsigned startIndex = 1445*85ab413bSRiver Riddle std::accumulate(segments.begin(), segments.begin() + index, 0); 1446*85ab413bSRiver Riddle values = values.slice(startIndex, *std::next(segments.begin(), index)); 1447*85ab413bSRiver Riddle 1448*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " 1449*85ab413bSRiver Riddle << *std::next(segments.begin(), index) << "]\n"); 1450*85ab413bSRiver Riddle 1451*85ab413bSRiver Riddle // Otherwise, assume this is the last operand group of the operation. 1452*85ab413bSRiver Riddle // FIXME: We currently don't support operations with 1453*85ab413bSRiver Riddle // SameVariadicOperandSize/SameVariadicResultSize here given that we don't 1454*85ab413bSRiver Riddle // have a way to detect it's presence. 1455*85ab413bSRiver Riddle } else if (values.size() >= index) { 1456*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() 1457*85ab413bSRiver Riddle << " * Treating values as trailing variadic range\n"); 1458*85ab413bSRiver Riddle values = values.drop_front(index); 1459*85ab413bSRiver Riddle 1460*85ab413bSRiver Riddle // If we couldn't detect a way to compute the values, bail out. 1461*85ab413bSRiver Riddle } else { 1462*85ab413bSRiver Riddle return nullptr; 1463*85ab413bSRiver Riddle } 1464*85ab413bSRiver Riddle 1465*85ab413bSRiver Riddle // If the range index is valid, we are returning a range. 1466*85ab413bSRiver Riddle if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { 1467*85ab413bSRiver Riddle valueRangeMemory[rangeIndex] = values; 1468*85ab413bSRiver Riddle return &valueRangeMemory[rangeIndex]; 1469*85ab413bSRiver Riddle } 1470*85ab413bSRiver Riddle 1471*85ab413bSRiver Riddle // If a range index wasn't provided, the range is required to be non-variadic. 1472*85ab413bSRiver Riddle return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); 1473*85ab413bSRiver Riddle } 1474*85ab413bSRiver Riddle 1475*85ab413bSRiver Riddle void ByteCodeExecutor::executeGetOperands() { 1476*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); 1477*85ab413bSRiver Riddle unsigned index = read<uint32_t>(); 1478*85ab413bSRiver Riddle Operation *op = read<Operation *>(); 1479*85ab413bSRiver Riddle ByteCodeField rangeIndex = read(); 1480*85ab413bSRiver Riddle 1481*85ab413bSRiver Riddle void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( 1482*85ab413bSRiver Riddle op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", 1483*85ab413bSRiver Riddle valueRangeMemory); 1484*85ab413bSRiver Riddle if (!result) 1485*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); 1486*85ab413bSRiver Riddle memory[read()] = result; 1487*85ab413bSRiver Riddle } 1488*85ab413bSRiver Riddle 1489154cabe7SRiver Riddle void ByteCodeExecutor::executeGetResult(unsigned index) { 1490abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1491abfd1a8bSRiver Riddle unsigned memIndex = read(); 1492abfd1a8bSRiver Riddle OpResult result = 1493abfd1a8bSRiver Riddle index < op->getNumResults() ? op->getResult(index) : OpResult(); 1494abfd1a8bSRiver Riddle 1495abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1496abfd1a8bSRiver Riddle << " * Index: " << index << "\n" 1497154cabe7SRiver Riddle << " * Result: " << result << "\n"); 1498abfd1a8bSRiver Riddle memory[memIndex] = result.getAsOpaquePointer(); 1499abfd1a8bSRiver Riddle } 1500154cabe7SRiver Riddle 1501*85ab413bSRiver Riddle void ByteCodeExecutor::executeGetResults() { 1502*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); 1503*85ab413bSRiver Riddle unsigned index = read<uint32_t>(); 1504*85ab413bSRiver Riddle Operation *op = read<Operation *>(); 1505*85ab413bSRiver Riddle ByteCodeField rangeIndex = read(); 1506*85ab413bSRiver Riddle 1507*85ab413bSRiver Riddle void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( 1508*85ab413bSRiver Riddle op->getResults(), op, index, rangeIndex, "result_segment_sizes", 1509*85ab413bSRiver Riddle valueRangeMemory); 1510*85ab413bSRiver Riddle if (!result) 1511*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); 1512*85ab413bSRiver Riddle memory[read()] = result; 1513*85ab413bSRiver Riddle } 1514*85ab413bSRiver Riddle 1515154cabe7SRiver Riddle void ByteCodeExecutor::executeGetValueType() { 1516abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1517abfd1a8bSRiver Riddle unsigned memIndex = read(); 1518abfd1a8bSRiver Riddle Value value = read<Value>(); 1519154cabe7SRiver Riddle Type type = value ? value.getType() : Type(); 1520abfd1a8bSRiver Riddle 1521abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1522154cabe7SRiver Riddle << " * Result: " << type << "\n"); 1523154cabe7SRiver Riddle memory[memIndex] = type.getAsOpaquePointer(); 1524abfd1a8bSRiver Riddle } 1525154cabe7SRiver Riddle 1526*85ab413bSRiver Riddle void ByteCodeExecutor::executeGetValueRangeTypes() { 1527*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); 1528*85ab413bSRiver Riddle unsigned memIndex = read(); 1529*85ab413bSRiver Riddle unsigned rangeIndex = read(); 1530*85ab413bSRiver Riddle ValueRange *values = read<ValueRange *>(); 1531*85ab413bSRiver Riddle if (!values) { 1532*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); 1533*85ab413bSRiver Riddle memory[memIndex] = nullptr; 1534*85ab413bSRiver Riddle return; 1535*85ab413bSRiver Riddle } 1536*85ab413bSRiver Riddle 1537*85ab413bSRiver Riddle LLVM_DEBUG({ 1538*85ab413bSRiver Riddle llvm::dbgs() << " * Values (" << values->size() << "): "; 1539*85ab413bSRiver Riddle llvm::interleaveComma(*values, llvm::dbgs()); 1540*85ab413bSRiver Riddle llvm::dbgs() << "\n * Result: "; 1541*85ab413bSRiver Riddle llvm::interleaveComma(values->getType(), llvm::dbgs()); 1542*85ab413bSRiver Riddle llvm::dbgs() << "\n"; 1543*85ab413bSRiver Riddle }); 1544*85ab413bSRiver Riddle typeRangeMemory[rangeIndex] = values->getType(); 1545*85ab413bSRiver Riddle memory[memIndex] = &typeRangeMemory[rangeIndex]; 1546*85ab413bSRiver Riddle } 1547*85ab413bSRiver Riddle 1548154cabe7SRiver Riddle void ByteCodeExecutor::executeIsNotNull() { 1549abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1550abfd1a8bSRiver Riddle const void *value = read<const void *>(); 1551abfd1a8bSRiver Riddle 1552154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1553abfd1a8bSRiver Riddle selectJump(value != nullptr); 1554abfd1a8bSRiver Riddle } 1555154cabe7SRiver Riddle 1556154cabe7SRiver Riddle void ByteCodeExecutor::executeRecordMatch( 1557154cabe7SRiver Riddle PatternRewriter &rewriter, 1558154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1559abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1560abfd1a8bSRiver Riddle unsigned patternIndex = read(); 1561abfd1a8bSRiver Riddle PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1562abfd1a8bSRiver Riddle const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1563abfd1a8bSRiver Riddle 1564abfd1a8bSRiver Riddle // If the benefit of the pattern is impossible, skip the processing of the 1565abfd1a8bSRiver Riddle // rest of the pattern. 1566abfd1a8bSRiver Riddle if (benefit.isImpossibleToMatch()) { 1567154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1568abfd1a8bSRiver Riddle curCodeIt = dest; 1569154cabe7SRiver Riddle return; 1570abfd1a8bSRiver Riddle } 1571abfd1a8bSRiver Riddle 1572abfd1a8bSRiver Riddle // Create a fused location containing the locations of each of the 1573abfd1a8bSRiver Riddle // operations used in the match. This will be used as the location for 1574abfd1a8bSRiver Riddle // created operations during the rewrite that don't already have an 1575abfd1a8bSRiver Riddle // explicit location set. 1576abfd1a8bSRiver Riddle unsigned numMatchLocs = read(); 1577abfd1a8bSRiver Riddle SmallVector<Location, 4> matchLocs; 1578abfd1a8bSRiver Riddle matchLocs.reserve(numMatchLocs); 1579abfd1a8bSRiver Riddle for (unsigned i = 0; i != numMatchLocs; ++i) 1580abfd1a8bSRiver Riddle matchLocs.push_back(read<Operation *>()->getLoc()); 1581abfd1a8bSRiver Riddle Location matchLoc = rewriter.getFusedLoc(matchLocs); 1582abfd1a8bSRiver Riddle 1583abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1584154cabe7SRiver Riddle << " * Location: " << matchLoc << "\n"); 1585154cabe7SRiver Riddle matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1586*85ab413bSRiver Riddle PDLByteCode::MatchResult &match = matches.back(); 1587*85ab413bSRiver Riddle 1588*85ab413bSRiver Riddle // Record all of the inputs to the match. If any of the inputs are ranges, we 1589*85ab413bSRiver Riddle // will also need to remap the range pointer to memory stored in the match 1590*85ab413bSRiver Riddle // state. 1591*85ab413bSRiver Riddle unsigned numInputs = read(); 1592*85ab413bSRiver Riddle match.values.reserve(numInputs); 1593*85ab413bSRiver Riddle match.typeRangeValues.reserve(numInputs); 1594*85ab413bSRiver Riddle match.valueRangeValues.reserve(numInputs); 1595*85ab413bSRiver Riddle for (unsigned i = 0; i < numInputs; ++i) { 1596*85ab413bSRiver Riddle switch (read<PDLValue::Kind>()) { 1597*85ab413bSRiver Riddle case PDLValue::Kind::TypeRange: 1598*85ab413bSRiver Riddle match.typeRangeValues.push_back(*read<TypeRange *>()); 1599*85ab413bSRiver Riddle match.values.push_back(&match.typeRangeValues.back()); 1600*85ab413bSRiver Riddle break; 1601*85ab413bSRiver Riddle case PDLValue::Kind::ValueRange: 1602*85ab413bSRiver Riddle match.valueRangeValues.push_back(*read<ValueRange *>()); 1603*85ab413bSRiver Riddle match.values.push_back(&match.valueRangeValues.back()); 1604*85ab413bSRiver Riddle break; 1605*85ab413bSRiver Riddle default: 1606*85ab413bSRiver Riddle match.values.push_back(read<const void *>()); 1607*85ab413bSRiver Riddle break; 1608*85ab413bSRiver Riddle } 1609*85ab413bSRiver Riddle } 1610abfd1a8bSRiver Riddle curCodeIt = dest; 1611abfd1a8bSRiver Riddle } 1612154cabe7SRiver Riddle 1613154cabe7SRiver Riddle void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1614abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1615abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1616abfd1a8bSRiver Riddle SmallVector<Value, 16> args; 1617*85ab413bSRiver Riddle readValueList(args); 1618abfd1a8bSRiver Riddle 1619abfd1a8bSRiver Riddle LLVM_DEBUG({ 1620abfd1a8bSRiver Riddle llvm::dbgs() << " * Operation: " << *op << "\n" 1621abfd1a8bSRiver Riddle << " * Values: "; 1622abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 1623154cabe7SRiver Riddle llvm::dbgs() << "\n"; 1624abfd1a8bSRiver Riddle }); 1625abfd1a8bSRiver Riddle rewriter.replaceOp(op, args); 1626abfd1a8bSRiver Riddle } 1627154cabe7SRiver Riddle 1628154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchAttribute() { 1629abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1630abfd1a8bSRiver Riddle Attribute value = read<Attribute>(); 1631abfd1a8bSRiver Riddle ArrayAttr cases = read<ArrayAttr>(); 1632abfd1a8bSRiver Riddle handleSwitch(value, cases); 1633abfd1a8bSRiver Riddle } 1634154cabe7SRiver Riddle 1635154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperandCount() { 1636abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1637abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1638abfd1a8bSRiver Riddle auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1639abfd1a8bSRiver Riddle 1640abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1641abfd1a8bSRiver Riddle handleSwitch(op->getNumOperands(), cases); 1642abfd1a8bSRiver Riddle } 1643154cabe7SRiver Riddle 1644154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperationName() { 1645abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1646abfd1a8bSRiver Riddle OperationName value = read<Operation *>()->getName(); 1647abfd1a8bSRiver Riddle size_t caseCount = read(); 1648abfd1a8bSRiver Riddle 1649abfd1a8bSRiver Riddle // The operation names are stored in-line, so to print them out for 1650abfd1a8bSRiver Riddle // debugging purposes we need to read the array before executing the 1651abfd1a8bSRiver Riddle // switch so that we can display all of the possible values. 1652abfd1a8bSRiver Riddle LLVM_DEBUG({ 1653abfd1a8bSRiver Riddle const ByteCodeField *prevCodeIt = curCodeIt; 1654abfd1a8bSRiver Riddle llvm::dbgs() << " * Value: " << value << "\n" 1655abfd1a8bSRiver Riddle << " * Cases: "; 1656abfd1a8bSRiver Riddle llvm::interleaveComma( 1657abfd1a8bSRiver Riddle llvm::map_range(llvm::seq<size_t>(0, caseCount), 1658154cabe7SRiver Riddle [&](size_t) { return read<OperationName>(); }), 1659abfd1a8bSRiver Riddle llvm::dbgs()); 1660154cabe7SRiver Riddle llvm::dbgs() << "\n"; 1661abfd1a8bSRiver Riddle curCodeIt = prevCodeIt; 1662abfd1a8bSRiver Riddle }); 1663abfd1a8bSRiver Riddle 1664abfd1a8bSRiver Riddle // Try to find the switch value within any of the cases. 1665abfd1a8bSRiver Riddle for (size_t i = 0; i != caseCount; ++i) { 1666abfd1a8bSRiver Riddle if (read<OperationName>() == value) { 1667abfd1a8bSRiver Riddle curCodeIt += (caseCount - i - 1); 1668154cabe7SRiver Riddle return selectJump(i + 1); 1669abfd1a8bSRiver Riddle } 1670abfd1a8bSRiver Riddle } 1671154cabe7SRiver Riddle selectJump(size_t(0)); 1672abfd1a8bSRiver Riddle } 1673154cabe7SRiver Riddle 1674154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchResultCount() { 1675abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1676abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1677abfd1a8bSRiver Riddle auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1678abfd1a8bSRiver Riddle 1679abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1680abfd1a8bSRiver Riddle handleSwitch(op->getNumResults(), cases); 1681abfd1a8bSRiver Riddle } 1682154cabe7SRiver Riddle 1683154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchType() { 1684abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 1685abfd1a8bSRiver Riddle Type value = read<Type>(); 1686abfd1a8bSRiver Riddle auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 1687abfd1a8bSRiver Riddle handleSwitch(value, cases); 1688154cabe7SRiver Riddle } 1689154cabe7SRiver Riddle 1690*85ab413bSRiver Riddle void ByteCodeExecutor::executeSwitchTypes() { 1691*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); 1692*85ab413bSRiver Riddle TypeRange *value = read<TypeRange *>(); 1693*85ab413bSRiver Riddle auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); 1694*85ab413bSRiver Riddle if (!value) { 1695*85ab413bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); 1696*85ab413bSRiver Riddle return selectJump(size_t(0)); 1697*85ab413bSRiver Riddle } 1698*85ab413bSRiver Riddle handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { 1699*85ab413bSRiver Riddle return value == caseValue.getAsValueRange<TypeAttr>(); 1700*85ab413bSRiver Riddle }); 1701*85ab413bSRiver Riddle } 1702*85ab413bSRiver Riddle 1703154cabe7SRiver Riddle void ByteCodeExecutor::execute( 1704154cabe7SRiver Riddle PatternRewriter &rewriter, 1705154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> *matches, 1706154cabe7SRiver Riddle Optional<Location> mainRewriteLoc) { 1707154cabe7SRiver Riddle while (true) { 1708154cabe7SRiver Riddle OpCode opCode = static_cast<OpCode>(read()); 1709154cabe7SRiver Riddle switch (opCode) { 1710154cabe7SRiver Riddle case ApplyConstraint: 1711154cabe7SRiver Riddle executeApplyConstraint(rewriter); 1712154cabe7SRiver Riddle break; 1713154cabe7SRiver Riddle case ApplyRewrite: 1714154cabe7SRiver Riddle executeApplyRewrite(rewriter); 1715154cabe7SRiver Riddle break; 1716154cabe7SRiver Riddle case AreEqual: 1717154cabe7SRiver Riddle executeAreEqual(); 1718154cabe7SRiver Riddle break; 1719*85ab413bSRiver Riddle case AreRangesEqual: 1720*85ab413bSRiver Riddle executeAreRangesEqual(); 1721*85ab413bSRiver Riddle break; 1722154cabe7SRiver Riddle case Branch: 1723154cabe7SRiver Riddle executeBranch(); 1724154cabe7SRiver Riddle break; 1725154cabe7SRiver Riddle case CheckOperandCount: 1726154cabe7SRiver Riddle executeCheckOperandCount(); 1727154cabe7SRiver Riddle break; 1728154cabe7SRiver Riddle case CheckOperationName: 1729154cabe7SRiver Riddle executeCheckOperationName(); 1730154cabe7SRiver Riddle break; 1731154cabe7SRiver Riddle case CheckResultCount: 1732154cabe7SRiver Riddle executeCheckResultCount(); 1733154cabe7SRiver Riddle break; 1734*85ab413bSRiver Riddle case CheckTypes: 1735*85ab413bSRiver Riddle executeCheckTypes(); 1736*85ab413bSRiver Riddle break; 1737154cabe7SRiver Riddle case CreateOperation: 1738154cabe7SRiver Riddle executeCreateOperation(rewriter, *mainRewriteLoc); 1739154cabe7SRiver Riddle break; 1740*85ab413bSRiver Riddle case CreateTypes: 1741*85ab413bSRiver Riddle executeCreateTypes(); 1742*85ab413bSRiver Riddle break; 1743154cabe7SRiver Riddle case EraseOp: 1744154cabe7SRiver Riddle executeEraseOp(rewriter); 1745154cabe7SRiver Riddle break; 1746154cabe7SRiver Riddle case Finalize: 1747154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); 1748154cabe7SRiver Riddle return; 1749154cabe7SRiver Riddle case GetAttribute: 1750154cabe7SRiver Riddle executeGetAttribute(); 1751154cabe7SRiver Riddle break; 1752154cabe7SRiver Riddle case GetAttributeType: 1753154cabe7SRiver Riddle executeGetAttributeType(); 1754154cabe7SRiver Riddle break; 1755154cabe7SRiver Riddle case GetDefiningOp: 1756154cabe7SRiver Riddle executeGetDefiningOp(); 1757154cabe7SRiver Riddle break; 1758154cabe7SRiver Riddle case GetOperand0: 1759154cabe7SRiver Riddle case GetOperand1: 1760154cabe7SRiver Riddle case GetOperand2: 1761154cabe7SRiver Riddle case GetOperand3: { 1762154cabe7SRiver Riddle unsigned index = opCode - GetOperand0; 1763154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 17641fff7c89SFrederik Gossen executeGetOperand(index); 1765abfd1a8bSRiver Riddle break; 1766abfd1a8bSRiver Riddle } 1767154cabe7SRiver Riddle case GetOperandN: 1768154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 1769154cabe7SRiver Riddle executeGetOperand(read<uint32_t>()); 1770154cabe7SRiver Riddle break; 1771*85ab413bSRiver Riddle case GetOperands: 1772*85ab413bSRiver Riddle executeGetOperands(); 1773*85ab413bSRiver Riddle break; 1774154cabe7SRiver Riddle case GetResult0: 1775154cabe7SRiver Riddle case GetResult1: 1776154cabe7SRiver Riddle case GetResult2: 1777154cabe7SRiver Riddle case GetResult3: { 1778154cabe7SRiver Riddle unsigned index = opCode - GetResult0; 1779154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 17801fff7c89SFrederik Gossen executeGetResult(index); 1781154cabe7SRiver Riddle break; 1782abfd1a8bSRiver Riddle } 1783154cabe7SRiver Riddle case GetResultN: 1784154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 1785154cabe7SRiver Riddle executeGetResult(read<uint32_t>()); 1786154cabe7SRiver Riddle break; 1787*85ab413bSRiver Riddle case GetResults: 1788*85ab413bSRiver Riddle executeGetResults(); 1789*85ab413bSRiver Riddle break; 1790154cabe7SRiver Riddle case GetValueType: 1791154cabe7SRiver Riddle executeGetValueType(); 1792154cabe7SRiver Riddle break; 1793*85ab413bSRiver Riddle case GetValueRangeTypes: 1794*85ab413bSRiver Riddle executeGetValueRangeTypes(); 1795*85ab413bSRiver Riddle break; 1796154cabe7SRiver Riddle case IsNotNull: 1797154cabe7SRiver Riddle executeIsNotNull(); 1798154cabe7SRiver Riddle break; 1799154cabe7SRiver Riddle case RecordMatch: 1800154cabe7SRiver Riddle assert(matches && 1801154cabe7SRiver Riddle "expected matches to be provided when executing the matcher"); 1802154cabe7SRiver Riddle executeRecordMatch(rewriter, *matches); 1803154cabe7SRiver Riddle break; 1804154cabe7SRiver Riddle case ReplaceOp: 1805154cabe7SRiver Riddle executeReplaceOp(rewriter); 1806154cabe7SRiver Riddle break; 1807154cabe7SRiver Riddle case SwitchAttribute: 1808154cabe7SRiver Riddle executeSwitchAttribute(); 1809154cabe7SRiver Riddle break; 1810154cabe7SRiver Riddle case SwitchOperandCount: 1811154cabe7SRiver Riddle executeSwitchOperandCount(); 1812154cabe7SRiver Riddle break; 1813154cabe7SRiver Riddle case SwitchOperationName: 1814154cabe7SRiver Riddle executeSwitchOperationName(); 1815154cabe7SRiver Riddle break; 1816154cabe7SRiver Riddle case SwitchResultCount: 1817154cabe7SRiver Riddle executeSwitchResultCount(); 1818154cabe7SRiver Riddle break; 1819154cabe7SRiver Riddle case SwitchType: 1820154cabe7SRiver Riddle executeSwitchType(); 1821154cabe7SRiver Riddle break; 1822*85ab413bSRiver Riddle case SwitchTypes: 1823*85ab413bSRiver Riddle executeSwitchTypes(); 1824*85ab413bSRiver Riddle break; 1825154cabe7SRiver Riddle } 1826154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "\n"); 1827abfd1a8bSRiver Riddle } 1828abfd1a8bSRiver Riddle } 1829abfd1a8bSRiver Riddle 1830abfd1a8bSRiver Riddle /// Run the pattern matcher on the given root operation, collecting the matched 1831abfd1a8bSRiver Riddle /// patterns in `matches`. 1832abfd1a8bSRiver Riddle void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 1833abfd1a8bSRiver Riddle SmallVectorImpl<MatchResult> &matches, 1834abfd1a8bSRiver Riddle PDLByteCodeMutableState &state) const { 1835abfd1a8bSRiver Riddle // The first memory slot is always the root operation. 1836abfd1a8bSRiver Riddle state.memory[0] = op; 1837abfd1a8bSRiver Riddle 1838abfd1a8bSRiver Riddle // The matcher function always starts at code address 0. 1839*85ab413bSRiver Riddle ByteCodeExecutor executor( 1840*85ab413bSRiver Riddle matcherByteCode.data(), state.memory, state.typeRangeMemory, 1841*85ab413bSRiver Riddle state.allocatedTypeRangeMemory, state.valueRangeMemory, 1842*85ab413bSRiver Riddle state.allocatedValueRangeMemory, uniquedData, matcherByteCode, 1843*85ab413bSRiver Riddle state.currentPatternBenefits, patterns, constraintFunctions, 1844*85ab413bSRiver Riddle rewriteFunctions); 1845abfd1a8bSRiver Riddle executor.execute(rewriter, &matches); 1846abfd1a8bSRiver Riddle 1847abfd1a8bSRiver Riddle // Order the found matches by benefit. 1848abfd1a8bSRiver Riddle std::stable_sort(matches.begin(), matches.end(), 1849abfd1a8bSRiver Riddle [](const MatchResult &lhs, const MatchResult &rhs) { 1850abfd1a8bSRiver Riddle return lhs.benefit > rhs.benefit; 1851abfd1a8bSRiver Riddle }); 1852abfd1a8bSRiver Riddle } 1853abfd1a8bSRiver Riddle 1854abfd1a8bSRiver Riddle /// Run the rewriter of the given pattern on the root operation `op`. 1855abfd1a8bSRiver Riddle void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 1856abfd1a8bSRiver Riddle PDLByteCodeMutableState &state) const { 1857abfd1a8bSRiver Riddle // The arguments of the rewrite function are stored at the start of the 1858abfd1a8bSRiver Riddle // memory buffer. 1859abfd1a8bSRiver Riddle llvm::copy(match.values, state.memory.begin()); 1860abfd1a8bSRiver Riddle 1861*85ab413bSRiver Riddle ByteCodeExecutor executor( 1862*85ab413bSRiver Riddle &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 1863*85ab413bSRiver Riddle state.typeRangeMemory, state.allocatedTypeRangeMemory, 1864*85ab413bSRiver Riddle state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData, 1865*85ab413bSRiver Riddle rewriterByteCode, state.currentPatternBenefits, patterns, 186602c4c0d5SRiver Riddle constraintFunctions, rewriteFunctions); 1867abfd1a8bSRiver Riddle executor.execute(rewriter, /*matches=*/nullptr, match.location); 1868abfd1a8bSRiver Riddle } 1869