1 //===- AsmParserState.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/AsmParser/AsmParserState.h" 10 #include "mlir/IR/Operation.h" 11 #include "mlir/IR/SymbolTable.h" 12 #include "llvm/ADT/StringExtras.h" 13 14 using namespace mlir; 15 16 //===----------------------------------------------------------------------===// 17 // AsmParserState::Impl 18 //===----------------------------------------------------------------------===// 19 20 struct AsmParserState::Impl { 21 /// A map from a SymbolRefAttr to a range of uses. 22 using SymbolUseMap = 23 DenseMap<Attribute, SmallVector<SmallVector<SMRange>, 0>>; 24 25 struct PartialOpDef { 26 explicit PartialOpDef(const OperationName &opName) { 27 if (opName.hasTrait<OpTrait::SymbolTable>()) 28 symbolTable = std::make_unique<SymbolUseMap>(); 29 } 30 31 /// Return if this operation is a symbol table. 32 bool isSymbolTable() const { return symbolTable.get(); } 33 34 /// If this operation is a symbol table, the following contains symbol uses 35 /// within this operation. 36 std::unique_ptr<SymbolUseMap> symbolTable; 37 }; 38 39 /// Resolve any symbol table uses in the IR. 40 void resolveSymbolUses(); 41 42 /// A mapping from operations in the input source file to their parser state. 43 SmallVector<std::unique_ptr<OperationDefinition>> operations; 44 DenseMap<Operation *, unsigned> operationToIdx; 45 46 /// A mapping from blocks in the input source file to their parser state. 47 SmallVector<std::unique_ptr<BlockDefinition>> blocks; 48 DenseMap<Block *, unsigned> blocksToIdx; 49 50 /// A set of value definitions that are placeholders for forward references. 51 /// This map should be empty if the parser finishes successfully. 52 DenseMap<Value, SmallVector<SMLoc>> placeholderValueUses; 53 54 /// The symbol table operations within the IR. 55 SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>> 56 symbolTableOperations; 57 58 /// A stack of partial operation definitions that have been started but not 59 /// yet finalized. 60 SmallVector<PartialOpDef> partialOperations; 61 62 /// A stack of symbol use scopes. This is used when collecting symbol table 63 /// uses during parsing. 64 SmallVector<SymbolUseMap *> symbolUseScopes; 65 66 /// A symbol table containing all of the symbol table operations in the IR. 67 SymbolTableCollection symbolTable; 68 }; 69 70 void AsmParserState::Impl::resolveSymbolUses() { 71 SmallVector<Operation *> symbolOps; 72 for (auto &opAndUseMapIt : symbolTableOperations) { 73 for (auto &it : *opAndUseMapIt.second) { 74 symbolOps.clear(); 75 if (failed(symbolTable.lookupSymbolIn( 76 opAndUseMapIt.first, it.first.cast<SymbolRefAttr>(), symbolOps))) 77 continue; 78 79 for (ArrayRef<SMRange> useRange : it.second) { 80 for (const auto &symIt : llvm::zip(symbolOps, useRange)) { 81 auto opIt = operationToIdx.find(std::get<0>(symIt)); 82 if (opIt != operationToIdx.end()) 83 operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt)); 84 } 85 } 86 } 87 } 88 } 89 90 //===----------------------------------------------------------------------===// 91 // AsmParserState 92 //===----------------------------------------------------------------------===// 93 94 AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {} 95 AsmParserState::~AsmParserState() = default; 96 AsmParserState &AsmParserState::operator=(AsmParserState &&other) { 97 impl = std::move(other.impl); 98 return *this; 99 } 100 101 //===----------------------------------------------------------------------===// 102 // Access State 103 104 auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> { 105 return llvm::make_pointee_range(llvm::makeArrayRef(impl->blocks)); 106 } 107 108 auto AsmParserState::getBlockDef(Block *block) const 109 -> const BlockDefinition * { 110 auto it = impl->blocksToIdx.find(block); 111 return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second]; 112 } 113 114 auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> { 115 return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations)); 116 } 117 118 auto AsmParserState::getOpDef(Operation *op) const 119 -> const OperationDefinition * { 120 auto it = impl->operationToIdx.find(op); 121 return it == impl->operationToIdx.end() ? nullptr 122 : &*impl->operations[it->second]; 123 } 124 125 /// Lex a string token whose contents start at the given `curPtr`. Returns the 126 /// position at the end of the string, after a terminal or invalid character 127 /// (e.g. `"` or `\0`). 128 static const char *lexLocStringTok(const char *curPtr) { 129 while (char c = *curPtr++) { 130 // Check for various terminal characters. 131 if (StringRef("\"\n\v\f").contains(c)) 132 return curPtr; 133 134 // Check for escape sequences. 135 if (c == '\\') { 136 // Check a few known escapes and \xx hex digits. 137 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't') 138 ++curPtr; 139 else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) 140 curPtr += 2; 141 else 142 return curPtr; 143 } 144 } 145 146 // If we hit this point, we've reached the end of the buffer. Update the end 147 // pointer to not point past the buffer. 148 return curPtr - 1; 149 } 150 151 SMRange AsmParserState::convertIdLocToRange(SMLoc loc) { 152 if (!loc.isValid()) 153 return SMRange(); 154 const char *curPtr = loc.getPointer(); 155 156 // Check if this is a string token. 157 if (*curPtr == '"') { 158 curPtr = lexLocStringTok(curPtr + 1); 159 160 // Otherwise, default to handling an identifier. 161 } else { 162 // Return if the given character is a valid identifier character. 163 auto isIdentifierChar = [](char c) { 164 return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-'; 165 }; 166 167 while (*curPtr && isIdentifierChar(*(++curPtr))) 168 continue; 169 } 170 171 return SMRange(loc, SMLoc::getFromPointer(curPtr)); 172 } 173 174 //===----------------------------------------------------------------------===// 175 // Populate State 176 177 void AsmParserState::initialize(Operation *topLevelOp) { 178 startOperationDefinition(topLevelOp->getName()); 179 180 // If the top-level operation is a symbol table, push a new symbol scope. 181 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back(); 182 if (partialOpDef.isSymbolTable()) 183 impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get()); 184 } 185 186 void AsmParserState::finalize(Operation *topLevelOp) { 187 assert(!impl->partialOperations.empty() && 188 "expected valid partial operation definition"); 189 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val(); 190 191 // If this operation is a symbol table, resolve any symbol uses. 192 if (partialOpDef.isSymbolTable()) { 193 impl->symbolTableOperations.emplace_back( 194 topLevelOp, std::move(partialOpDef.symbolTable)); 195 } 196 impl->resolveSymbolUses(); 197 } 198 199 void AsmParserState::startOperationDefinition(const OperationName &opName) { 200 impl->partialOperations.emplace_back(opName); 201 } 202 203 void AsmParserState::finalizeOperationDefinition( 204 Operation *op, SMRange nameLoc, SMLoc endLoc, 205 ArrayRef<std::pair<unsigned, SMLoc>> resultGroups) { 206 assert(!impl->partialOperations.empty() && 207 "expected valid partial operation definition"); 208 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val(); 209 210 // Build the full operation definition. 211 std::unique_ptr<OperationDefinition> def = 212 std::make_unique<OperationDefinition>(op, nameLoc, endLoc); 213 for (auto &resultGroup : resultGroups) 214 def->resultGroups.emplace_back(resultGroup.first, 215 convertIdLocToRange(resultGroup.second)); 216 impl->operationToIdx.try_emplace(op, impl->operations.size()); 217 impl->operations.emplace_back(std::move(def)); 218 219 // If this operation is a symbol table, resolve any symbol uses. 220 if (partialOpDef.isSymbolTable()) { 221 impl->symbolTableOperations.emplace_back( 222 op, std::move(partialOpDef.symbolTable)); 223 } 224 } 225 226 void AsmParserState::startRegionDefinition() { 227 assert(!impl->partialOperations.empty() && 228 "expected valid partial operation definition"); 229 230 // If the parent operation of this region is a symbol table, we also push a 231 // new symbol scope. 232 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back(); 233 if (partialOpDef.isSymbolTable()) 234 impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get()); 235 } 236 237 void AsmParserState::finalizeRegionDefinition() { 238 assert(!impl->partialOperations.empty() && 239 "expected valid partial operation definition"); 240 241 // If the parent operation of this region is a symbol table, pop the symbol 242 // scope for this region. 243 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back(); 244 if (partialOpDef.isSymbolTable()) 245 impl->symbolUseScopes.pop_back(); 246 } 247 248 void AsmParserState::addDefinition(Block *block, SMLoc location) { 249 auto it = impl->blocksToIdx.find(block); 250 if (it == impl->blocksToIdx.end()) { 251 impl->blocksToIdx.try_emplace(block, impl->blocks.size()); 252 impl->blocks.emplace_back(std::make_unique<BlockDefinition>( 253 block, convertIdLocToRange(location))); 254 return; 255 } 256 257 // If an entry already exists, this was a forward declaration that now has a 258 // proper definition. 259 impl->blocks[it->second]->definition.loc = convertIdLocToRange(location); 260 } 261 262 void AsmParserState::addDefinition(BlockArgument blockArg, SMLoc location) { 263 auto it = impl->blocksToIdx.find(blockArg.getOwner()); 264 assert(it != impl->blocksToIdx.end() && 265 "expected owner block to have an entry"); 266 BlockDefinition &def = *impl->blocks[it->second]; 267 unsigned argIdx = blockArg.getArgNumber(); 268 269 if (def.arguments.size() <= argIdx) 270 def.arguments.resize(argIdx + 1); 271 def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location)); 272 } 273 274 void AsmParserState::addUses(Value value, ArrayRef<SMLoc> locations) { 275 // Handle the case where the value is an operation result. 276 if (OpResult result = value.dyn_cast<OpResult>()) { 277 // Check to see if a definition for the parent operation has been recorded. 278 // If one hasn't, we treat the provided value as a placeholder value that 279 // will be refined further later. 280 Operation *parentOp = result.getOwner(); 281 auto existingIt = impl->operationToIdx.find(parentOp); 282 if (existingIt == impl->operationToIdx.end()) { 283 impl->placeholderValueUses[value].append(locations.begin(), 284 locations.end()); 285 return; 286 } 287 288 // If a definition does exist, locate the value's result group and add the 289 // use. The result groups are ordered by increasing start index, so we just 290 // need to find the last group that has a smaller/equal start index. 291 unsigned resultNo = result.getResultNumber(); 292 OperationDefinition &def = *impl->operations[existingIt->second]; 293 for (auto &resultGroup : llvm::reverse(def.resultGroups)) { 294 if (resultNo >= resultGroup.startIndex) { 295 for (SMLoc loc : locations) 296 resultGroup.definition.uses.push_back(convertIdLocToRange(loc)); 297 return; 298 } 299 } 300 llvm_unreachable("expected valid result group for value use"); 301 } 302 303 // Otherwise, this is a block argument. 304 BlockArgument arg = value.cast<BlockArgument>(); 305 auto existingIt = impl->blocksToIdx.find(arg.getOwner()); 306 assert(existingIt != impl->blocksToIdx.end() && 307 "expected valid block definition for block argument"); 308 BlockDefinition &blockDef = *impl->blocks[existingIt->second]; 309 SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()]; 310 for (SMLoc loc : locations) 311 argDef.uses.emplace_back(convertIdLocToRange(loc)); 312 } 313 314 void AsmParserState::addUses(Block *block, ArrayRef<SMLoc> locations) { 315 auto it = impl->blocksToIdx.find(block); 316 if (it == impl->blocksToIdx.end()) { 317 it = impl->blocksToIdx.try_emplace(block, impl->blocks.size()).first; 318 impl->blocks.emplace_back(std::make_unique<BlockDefinition>(block)); 319 } 320 321 BlockDefinition &def = *impl->blocks[it->second]; 322 for (SMLoc loc : locations) 323 def.definition.uses.push_back(convertIdLocToRange(loc)); 324 } 325 326 void AsmParserState::addUses(SymbolRefAttr refAttr, 327 ArrayRef<SMRange> locations) { 328 // Ignore this symbol if no scopes are active. 329 if (impl->symbolUseScopes.empty()) 330 return; 331 332 assert((refAttr.getNestedReferences().size() + 1) == locations.size() && 333 "expected the same number of references as provided locations"); 334 (*impl->symbolUseScopes.back())[refAttr].emplace_back(locations.begin(), 335 locations.end()); 336 } 337 338 void AsmParserState::refineDefinition(Value oldValue, Value newValue) { 339 auto it = impl->placeholderValueUses.find(oldValue); 340 assert(it != impl->placeholderValueUses.end() && 341 "expected `oldValue` to be a placeholder"); 342 addUses(newValue, it->second); 343 impl->placeholderValueUses.erase(oldValue); 344 } 345