1 //===- AsmParserState.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/AsmParser/AsmParserState.h"
10 #include "mlir/IR/Operation.h"
11 #include "mlir/IR/SymbolTable.h"
12 #include "llvm/ADT/StringExtras.h"
13
14 using namespace mlir;
15
16 //===----------------------------------------------------------------------===//
17 // AsmParserState::Impl
18 //===----------------------------------------------------------------------===//
19
20 struct AsmParserState::Impl {
21 /// A map from a SymbolRefAttr to a range of uses.
22 using SymbolUseMap =
23 DenseMap<Attribute, SmallVector<SmallVector<SMRange>, 0>>;
24
25 struct PartialOpDef {
PartialOpDefAsmParserState::Impl::PartialOpDef26 explicit PartialOpDef(const OperationName &opName) {
27 if (opName.hasTrait<OpTrait::SymbolTable>())
28 symbolTable = std::make_unique<SymbolUseMap>();
29 }
30
31 /// Return if this operation is a symbol table.
isSymbolTableAsmParserState::Impl::PartialOpDef32 bool isSymbolTable() const { return symbolTable.get(); }
33
34 /// If this operation is a symbol table, the following contains symbol uses
35 /// within this operation.
36 std::unique_ptr<SymbolUseMap> symbolTable;
37 };
38
39 /// Resolve any symbol table uses in the IR.
40 void resolveSymbolUses();
41
42 /// A mapping from operations in the input source file to their parser state.
43 SmallVector<std::unique_ptr<OperationDefinition>> operations;
44 DenseMap<Operation *, unsigned> operationToIdx;
45
46 /// A mapping from blocks in the input source file to their parser state.
47 SmallVector<std::unique_ptr<BlockDefinition>> blocks;
48 DenseMap<Block *, unsigned> blocksToIdx;
49
50 /// A set of value definitions that are placeholders for forward references.
51 /// This map should be empty if the parser finishes successfully.
52 DenseMap<Value, SmallVector<SMLoc>> placeholderValueUses;
53
54 /// The symbol table operations within the IR.
55 SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>>
56 symbolTableOperations;
57
58 /// A stack of partial operation definitions that have been started but not
59 /// yet finalized.
60 SmallVector<PartialOpDef> partialOperations;
61
62 /// A stack of symbol use scopes. This is used when collecting symbol table
63 /// uses during parsing.
64 SmallVector<SymbolUseMap *> symbolUseScopes;
65
66 /// A symbol table containing all of the symbol table operations in the IR.
67 SymbolTableCollection symbolTable;
68 };
69
resolveSymbolUses()70 void AsmParserState::Impl::resolveSymbolUses() {
71 SmallVector<Operation *> symbolOps;
72 for (auto &opAndUseMapIt : symbolTableOperations) {
73 for (auto &it : *opAndUseMapIt.second) {
74 symbolOps.clear();
75 if (failed(symbolTable.lookupSymbolIn(
76 opAndUseMapIt.first, it.first.cast<SymbolRefAttr>(), symbolOps)))
77 continue;
78
79 for (ArrayRef<SMRange> useRange : it.second) {
80 for (const auto &symIt : llvm::zip(symbolOps, useRange)) {
81 auto opIt = operationToIdx.find(std::get<0>(symIt));
82 if (opIt != operationToIdx.end())
83 operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
84 }
85 }
86 }
87 }
88 }
89
90 //===----------------------------------------------------------------------===//
91 // AsmParserState
92 //===----------------------------------------------------------------------===//
93
AsmParserState()94 AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {}
95 AsmParserState::~AsmParserState() = default;
operator =(AsmParserState && other)96 AsmParserState &AsmParserState::operator=(AsmParserState &&other) {
97 impl = std::move(other.impl);
98 return *this;
99 }
100
101 //===----------------------------------------------------------------------===//
102 // Access State
103
getBlockDefs() const104 auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> {
105 return llvm::make_pointee_range(llvm::makeArrayRef(impl->blocks));
106 }
107
getBlockDef(Block * block) const108 auto AsmParserState::getBlockDef(Block *block) const
109 -> const BlockDefinition * {
110 auto it = impl->blocksToIdx.find(block);
111 return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second];
112 }
113
getOpDefs() const114 auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
115 return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations));
116 }
117
getOpDef(Operation * op) const118 auto AsmParserState::getOpDef(Operation *op) const
119 -> const OperationDefinition * {
120 auto it = impl->operationToIdx.find(op);
121 return it == impl->operationToIdx.end() ? nullptr
122 : &*impl->operations[it->second];
123 }
124
125 /// Lex a string token whose contents start at the given `curPtr`. Returns the
126 /// position at the end of the string, after a terminal or invalid character
127 /// (e.g. `"` or `\0`).
lexLocStringTok(const char * curPtr)128 static const char *lexLocStringTok(const char *curPtr) {
129 while (char c = *curPtr++) {
130 // Check for various terminal characters.
131 if (StringRef("\"\n\v\f").contains(c))
132 return curPtr;
133
134 // Check for escape sequences.
135 if (c == '\\') {
136 // Check a few known escapes and \xx hex digits.
137 if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
138 ++curPtr;
139 else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1]))
140 curPtr += 2;
141 else
142 return curPtr;
143 }
144 }
145
146 // If we hit this point, we've reached the end of the buffer. Update the end
147 // pointer to not point past the buffer.
148 return curPtr - 1;
149 }
150
convertIdLocToRange(SMLoc loc)151 SMRange AsmParserState::convertIdLocToRange(SMLoc loc) {
152 if (!loc.isValid())
153 return SMRange();
154 const char *curPtr = loc.getPointer();
155
156 // Check if this is a string token.
157 if (*curPtr == '"') {
158 curPtr = lexLocStringTok(curPtr + 1);
159
160 // Otherwise, default to handling an identifier.
161 } else {
162 // Return if the given character is a valid identifier character.
163 auto isIdentifierChar = [](char c) {
164 return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-';
165 };
166
167 while (*curPtr && isIdentifierChar(*(++curPtr)))
168 continue;
169 }
170
171 return SMRange(loc, SMLoc::getFromPointer(curPtr));
172 }
173
174 //===----------------------------------------------------------------------===//
175 // Populate State
176
initialize(Operation * topLevelOp)177 void AsmParserState::initialize(Operation *topLevelOp) {
178 startOperationDefinition(topLevelOp->getName());
179
180 // If the top-level operation is a symbol table, push a new symbol scope.
181 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
182 if (partialOpDef.isSymbolTable())
183 impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
184 }
185
finalize(Operation * topLevelOp)186 void AsmParserState::finalize(Operation *topLevelOp) {
187 assert(!impl->partialOperations.empty() &&
188 "expected valid partial operation definition");
189 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
190
191 // If this operation is a symbol table, resolve any symbol uses.
192 if (partialOpDef.isSymbolTable()) {
193 impl->symbolTableOperations.emplace_back(
194 topLevelOp, std::move(partialOpDef.symbolTable));
195 }
196 impl->resolveSymbolUses();
197 }
198
startOperationDefinition(const OperationName & opName)199 void AsmParserState::startOperationDefinition(const OperationName &opName) {
200 impl->partialOperations.emplace_back(opName);
201 }
202
finalizeOperationDefinition(Operation * op,SMRange nameLoc,SMLoc endLoc,ArrayRef<std::pair<unsigned,SMLoc>> resultGroups)203 void AsmParserState::finalizeOperationDefinition(
204 Operation *op, SMRange nameLoc, SMLoc endLoc,
205 ArrayRef<std::pair<unsigned, SMLoc>> resultGroups) {
206 assert(!impl->partialOperations.empty() &&
207 "expected valid partial operation definition");
208 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
209
210 // Build the full operation definition.
211 std::unique_ptr<OperationDefinition> def =
212 std::make_unique<OperationDefinition>(op, nameLoc, endLoc);
213 for (auto &resultGroup : resultGroups)
214 def->resultGroups.emplace_back(resultGroup.first,
215 convertIdLocToRange(resultGroup.second));
216 impl->operationToIdx.try_emplace(op, impl->operations.size());
217 impl->operations.emplace_back(std::move(def));
218
219 // If this operation is a symbol table, resolve any symbol uses.
220 if (partialOpDef.isSymbolTable()) {
221 impl->symbolTableOperations.emplace_back(
222 op, std::move(partialOpDef.symbolTable));
223 }
224 }
225
startRegionDefinition()226 void AsmParserState::startRegionDefinition() {
227 assert(!impl->partialOperations.empty() &&
228 "expected valid partial operation definition");
229
230 // If the parent operation of this region is a symbol table, we also push a
231 // new symbol scope.
232 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
233 if (partialOpDef.isSymbolTable())
234 impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
235 }
236
finalizeRegionDefinition()237 void AsmParserState::finalizeRegionDefinition() {
238 assert(!impl->partialOperations.empty() &&
239 "expected valid partial operation definition");
240
241 // If the parent operation of this region is a symbol table, pop the symbol
242 // scope for this region.
243 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
244 if (partialOpDef.isSymbolTable())
245 impl->symbolUseScopes.pop_back();
246 }
247
addDefinition(Block * block,SMLoc location)248 void AsmParserState::addDefinition(Block *block, SMLoc location) {
249 auto it = impl->blocksToIdx.find(block);
250 if (it == impl->blocksToIdx.end()) {
251 impl->blocksToIdx.try_emplace(block, impl->blocks.size());
252 impl->blocks.emplace_back(std::make_unique<BlockDefinition>(
253 block, convertIdLocToRange(location)));
254 return;
255 }
256
257 // If an entry already exists, this was a forward declaration that now has a
258 // proper definition.
259 impl->blocks[it->second]->definition.loc = convertIdLocToRange(location);
260 }
261
addDefinition(BlockArgument blockArg,SMLoc location)262 void AsmParserState::addDefinition(BlockArgument blockArg, SMLoc location) {
263 auto it = impl->blocksToIdx.find(blockArg.getOwner());
264 assert(it != impl->blocksToIdx.end() &&
265 "expected owner block to have an entry");
266 BlockDefinition &def = *impl->blocks[it->second];
267 unsigned argIdx = blockArg.getArgNumber();
268
269 if (def.arguments.size() <= argIdx)
270 def.arguments.resize(argIdx + 1);
271 def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location));
272 }
273
addUses(Value value,ArrayRef<SMLoc> locations)274 void AsmParserState::addUses(Value value, ArrayRef<SMLoc> locations) {
275 // Handle the case where the value is an operation result.
276 if (OpResult result = value.dyn_cast<OpResult>()) {
277 // Check to see if a definition for the parent operation has been recorded.
278 // If one hasn't, we treat the provided value as a placeholder value that
279 // will be refined further later.
280 Operation *parentOp = result.getOwner();
281 auto existingIt = impl->operationToIdx.find(parentOp);
282 if (existingIt == impl->operationToIdx.end()) {
283 impl->placeholderValueUses[value].append(locations.begin(),
284 locations.end());
285 return;
286 }
287
288 // If a definition does exist, locate the value's result group and add the
289 // use. The result groups are ordered by increasing start index, so we just
290 // need to find the last group that has a smaller/equal start index.
291 unsigned resultNo = result.getResultNumber();
292 OperationDefinition &def = *impl->operations[existingIt->second];
293 for (auto &resultGroup : llvm::reverse(def.resultGroups)) {
294 if (resultNo >= resultGroup.startIndex) {
295 for (SMLoc loc : locations)
296 resultGroup.definition.uses.push_back(convertIdLocToRange(loc));
297 return;
298 }
299 }
300 llvm_unreachable("expected valid result group for value use");
301 }
302
303 // Otherwise, this is a block argument.
304 BlockArgument arg = value.cast<BlockArgument>();
305 auto existingIt = impl->blocksToIdx.find(arg.getOwner());
306 assert(existingIt != impl->blocksToIdx.end() &&
307 "expected valid block definition for block argument");
308 BlockDefinition &blockDef = *impl->blocks[existingIt->second];
309 SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()];
310 for (SMLoc loc : locations)
311 argDef.uses.emplace_back(convertIdLocToRange(loc));
312 }
313
addUses(Block * block,ArrayRef<SMLoc> locations)314 void AsmParserState::addUses(Block *block, ArrayRef<SMLoc> locations) {
315 auto it = impl->blocksToIdx.find(block);
316 if (it == impl->blocksToIdx.end()) {
317 it = impl->blocksToIdx.try_emplace(block, impl->blocks.size()).first;
318 impl->blocks.emplace_back(std::make_unique<BlockDefinition>(block));
319 }
320
321 BlockDefinition &def = *impl->blocks[it->second];
322 for (SMLoc loc : locations)
323 def.definition.uses.push_back(convertIdLocToRange(loc));
324 }
325
addUses(SymbolRefAttr refAttr,ArrayRef<SMRange> locations)326 void AsmParserState::addUses(SymbolRefAttr refAttr,
327 ArrayRef<SMRange> locations) {
328 // Ignore this symbol if no scopes are active.
329 if (impl->symbolUseScopes.empty())
330 return;
331
332 assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
333 "expected the same number of references as provided locations");
334 (*impl->symbolUseScopes.back())[refAttr].emplace_back(locations.begin(),
335 locations.end());
336 }
337
refineDefinition(Value oldValue,Value newValue)338 void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
339 auto it = impl->placeholderValueUses.find(oldValue);
340 assert(it != impl->placeholderValueUses.end() &&
341 "expected `oldValue` to be a placeholder");
342 addUses(newValue, it->second);
343 impl->placeholderValueUses.erase(oldValue);
344 }
345