1 //===- AsmParserState.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/AsmParser/AsmParserState.h"
10 #include "mlir/IR/Operation.h"
11 #include "mlir/IR/SymbolTable.h"
12 #include "llvm/ADT/StringExtras.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // AsmParserState::Impl
18 //===----------------------------------------------------------------------===//
19 
20 struct AsmParserState::Impl {
21   /// A map from a SymbolRefAttr to a range of uses.
22   using SymbolUseMap =
23       DenseMap<Attribute, SmallVector<SmallVector<SMRange>, 0>>;
24 
25   struct PartialOpDef {
PartialOpDefAsmParserState::Impl::PartialOpDef26     explicit PartialOpDef(const OperationName &opName) {
27       if (opName.hasTrait<OpTrait::SymbolTable>())
28         symbolTable = std::make_unique<SymbolUseMap>();
29     }
30 
31     /// Return if this operation is a symbol table.
isSymbolTableAsmParserState::Impl::PartialOpDef32     bool isSymbolTable() const { return symbolTable.get(); }
33 
34     /// If this operation is a symbol table, the following contains symbol uses
35     /// within this operation.
36     std::unique_ptr<SymbolUseMap> symbolTable;
37   };
38 
39   /// Resolve any symbol table uses in the IR.
40   void resolveSymbolUses();
41 
42   /// A mapping from operations in the input source file to their parser state.
43   SmallVector<std::unique_ptr<OperationDefinition>> operations;
44   DenseMap<Operation *, unsigned> operationToIdx;
45 
46   /// A mapping from blocks in the input source file to their parser state.
47   SmallVector<std::unique_ptr<BlockDefinition>> blocks;
48   DenseMap<Block *, unsigned> blocksToIdx;
49 
50   /// A set of value definitions that are placeholders for forward references.
51   /// This map should be empty if the parser finishes successfully.
52   DenseMap<Value, SmallVector<SMLoc>> placeholderValueUses;
53 
54   /// The symbol table operations within the IR.
55   SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>>
56       symbolTableOperations;
57 
58   /// A stack of partial operation definitions that have been started but not
59   /// yet finalized.
60   SmallVector<PartialOpDef> partialOperations;
61 
62   /// A stack of symbol use scopes. This is used when collecting symbol table
63   /// uses during parsing.
64   SmallVector<SymbolUseMap *> symbolUseScopes;
65 
66   /// A symbol table containing all of the symbol table operations in the IR.
67   SymbolTableCollection symbolTable;
68 };
69 
resolveSymbolUses()70 void AsmParserState::Impl::resolveSymbolUses() {
71   SmallVector<Operation *> symbolOps;
72   for (auto &opAndUseMapIt : symbolTableOperations) {
73     for (auto &it : *opAndUseMapIt.second) {
74       symbolOps.clear();
75       if (failed(symbolTable.lookupSymbolIn(
76               opAndUseMapIt.first, it.first.cast<SymbolRefAttr>(), symbolOps)))
77         continue;
78 
79       for (ArrayRef<SMRange> useRange : it.second) {
80         for (const auto &symIt : llvm::zip(symbolOps, useRange)) {
81           auto opIt = operationToIdx.find(std::get<0>(symIt));
82           if (opIt != operationToIdx.end())
83             operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
84         }
85       }
86     }
87   }
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // AsmParserState
92 //===----------------------------------------------------------------------===//
93 
AsmParserState()94 AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {}
95 AsmParserState::~AsmParserState() = default;
operator =(AsmParserState && other)96 AsmParserState &AsmParserState::operator=(AsmParserState &&other) {
97   impl = std::move(other.impl);
98   return *this;
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // Access State
103 
getBlockDefs() const104 auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> {
105   return llvm::make_pointee_range(llvm::makeArrayRef(impl->blocks));
106 }
107 
getBlockDef(Block * block) const108 auto AsmParserState::getBlockDef(Block *block) const
109     -> const BlockDefinition * {
110   auto it = impl->blocksToIdx.find(block);
111   return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second];
112 }
113 
getOpDefs() const114 auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
115   return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations));
116 }
117 
getOpDef(Operation * op) const118 auto AsmParserState::getOpDef(Operation *op) const
119     -> const OperationDefinition * {
120   auto it = impl->operationToIdx.find(op);
121   return it == impl->operationToIdx.end() ? nullptr
122                                           : &*impl->operations[it->second];
123 }
124 
125 /// Lex a string token whose contents start at the given `curPtr`. Returns the
126 /// position at the end of the string, after a terminal or invalid character
127 /// (e.g. `"` or `\0`).
lexLocStringTok(const char * curPtr)128 static const char *lexLocStringTok(const char *curPtr) {
129   while (char c = *curPtr++) {
130     // Check for various terminal characters.
131     if (StringRef("\"\n\v\f").contains(c))
132       return curPtr;
133 
134     // Check for escape sequences.
135     if (c == '\\') {
136       // Check a few known escapes and \xx hex digits.
137       if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
138         ++curPtr;
139       else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1]))
140         curPtr += 2;
141       else
142         return curPtr;
143     }
144   }
145 
146   // If we hit this point, we've reached the end of the buffer. Update the end
147   // pointer to not point past the buffer.
148   return curPtr - 1;
149 }
150 
convertIdLocToRange(SMLoc loc)151 SMRange AsmParserState::convertIdLocToRange(SMLoc loc) {
152   if (!loc.isValid())
153     return SMRange();
154   const char *curPtr = loc.getPointer();
155 
156   // Check if this is a string token.
157   if (*curPtr == '"') {
158     curPtr = lexLocStringTok(curPtr + 1);
159 
160     // Otherwise, default to handling an identifier.
161   } else {
162     // Return if the given character is a valid identifier character.
163     auto isIdentifierChar = [](char c) {
164       return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-';
165     };
166 
167     while (*curPtr && isIdentifierChar(*(++curPtr)))
168       continue;
169   }
170 
171   return SMRange(loc, SMLoc::getFromPointer(curPtr));
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Populate State
176 
initialize(Operation * topLevelOp)177 void AsmParserState::initialize(Operation *topLevelOp) {
178   startOperationDefinition(topLevelOp->getName());
179 
180   // If the top-level operation is a symbol table, push a new symbol scope.
181   Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
182   if (partialOpDef.isSymbolTable())
183     impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
184 }
185 
finalize(Operation * topLevelOp)186 void AsmParserState::finalize(Operation *topLevelOp) {
187   assert(!impl->partialOperations.empty() &&
188          "expected valid partial operation definition");
189   Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
190 
191   // If this operation is a symbol table, resolve any symbol uses.
192   if (partialOpDef.isSymbolTable()) {
193     impl->symbolTableOperations.emplace_back(
194         topLevelOp, std::move(partialOpDef.symbolTable));
195   }
196   impl->resolveSymbolUses();
197 }
198 
startOperationDefinition(const OperationName & opName)199 void AsmParserState::startOperationDefinition(const OperationName &opName) {
200   impl->partialOperations.emplace_back(opName);
201 }
202 
finalizeOperationDefinition(Operation * op,SMRange nameLoc,SMLoc endLoc,ArrayRef<std::pair<unsigned,SMLoc>> resultGroups)203 void AsmParserState::finalizeOperationDefinition(
204     Operation *op, SMRange nameLoc, SMLoc endLoc,
205     ArrayRef<std::pair<unsigned, SMLoc>> resultGroups) {
206   assert(!impl->partialOperations.empty() &&
207          "expected valid partial operation definition");
208   Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
209 
210   // Build the full operation definition.
211   std::unique_ptr<OperationDefinition> def =
212       std::make_unique<OperationDefinition>(op, nameLoc, endLoc);
213   for (auto &resultGroup : resultGroups)
214     def->resultGroups.emplace_back(resultGroup.first,
215                                    convertIdLocToRange(resultGroup.second));
216   impl->operationToIdx.try_emplace(op, impl->operations.size());
217   impl->operations.emplace_back(std::move(def));
218 
219   // If this operation is a symbol table, resolve any symbol uses.
220   if (partialOpDef.isSymbolTable()) {
221     impl->symbolTableOperations.emplace_back(
222         op, std::move(partialOpDef.symbolTable));
223   }
224 }
225 
startRegionDefinition()226 void AsmParserState::startRegionDefinition() {
227   assert(!impl->partialOperations.empty() &&
228          "expected valid partial operation definition");
229 
230   // If the parent operation of this region is a symbol table, we also push a
231   // new symbol scope.
232   Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
233   if (partialOpDef.isSymbolTable())
234     impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
235 }
236 
finalizeRegionDefinition()237 void AsmParserState::finalizeRegionDefinition() {
238   assert(!impl->partialOperations.empty() &&
239          "expected valid partial operation definition");
240 
241   // If the parent operation of this region is a symbol table, pop the symbol
242   // scope for this region.
243   Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
244   if (partialOpDef.isSymbolTable())
245     impl->symbolUseScopes.pop_back();
246 }
247 
addDefinition(Block * block,SMLoc location)248 void AsmParserState::addDefinition(Block *block, SMLoc location) {
249   auto it = impl->blocksToIdx.find(block);
250   if (it == impl->blocksToIdx.end()) {
251     impl->blocksToIdx.try_emplace(block, impl->blocks.size());
252     impl->blocks.emplace_back(std::make_unique<BlockDefinition>(
253         block, convertIdLocToRange(location)));
254     return;
255   }
256 
257   // If an entry already exists, this was a forward declaration that now has a
258   // proper definition.
259   impl->blocks[it->second]->definition.loc = convertIdLocToRange(location);
260 }
261 
addDefinition(BlockArgument blockArg,SMLoc location)262 void AsmParserState::addDefinition(BlockArgument blockArg, SMLoc location) {
263   auto it = impl->blocksToIdx.find(blockArg.getOwner());
264   assert(it != impl->blocksToIdx.end() &&
265          "expected owner block to have an entry");
266   BlockDefinition &def = *impl->blocks[it->second];
267   unsigned argIdx = blockArg.getArgNumber();
268 
269   if (def.arguments.size() <= argIdx)
270     def.arguments.resize(argIdx + 1);
271   def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location));
272 }
273 
addUses(Value value,ArrayRef<SMLoc> locations)274 void AsmParserState::addUses(Value value, ArrayRef<SMLoc> locations) {
275   // Handle the case where the value is an operation result.
276   if (OpResult result = value.dyn_cast<OpResult>()) {
277     // Check to see if a definition for the parent operation has been recorded.
278     // If one hasn't, we treat the provided value as a placeholder value that
279     // will be refined further later.
280     Operation *parentOp = result.getOwner();
281     auto existingIt = impl->operationToIdx.find(parentOp);
282     if (existingIt == impl->operationToIdx.end()) {
283       impl->placeholderValueUses[value].append(locations.begin(),
284                                                locations.end());
285       return;
286     }
287 
288     // If a definition does exist, locate the value's result group and add the
289     // use. The result groups are ordered by increasing start index, so we just
290     // need to find the last group that has a smaller/equal start index.
291     unsigned resultNo = result.getResultNumber();
292     OperationDefinition &def = *impl->operations[existingIt->second];
293     for (auto &resultGroup : llvm::reverse(def.resultGroups)) {
294       if (resultNo >= resultGroup.startIndex) {
295         for (SMLoc loc : locations)
296           resultGroup.definition.uses.push_back(convertIdLocToRange(loc));
297         return;
298       }
299     }
300     llvm_unreachable("expected valid result group for value use");
301   }
302 
303   // Otherwise, this is a block argument.
304   BlockArgument arg = value.cast<BlockArgument>();
305   auto existingIt = impl->blocksToIdx.find(arg.getOwner());
306   assert(existingIt != impl->blocksToIdx.end() &&
307          "expected valid block definition for block argument");
308   BlockDefinition &blockDef = *impl->blocks[existingIt->second];
309   SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()];
310   for (SMLoc loc : locations)
311     argDef.uses.emplace_back(convertIdLocToRange(loc));
312 }
313 
addUses(Block * block,ArrayRef<SMLoc> locations)314 void AsmParserState::addUses(Block *block, ArrayRef<SMLoc> locations) {
315   auto it = impl->blocksToIdx.find(block);
316   if (it == impl->blocksToIdx.end()) {
317     it = impl->blocksToIdx.try_emplace(block, impl->blocks.size()).first;
318     impl->blocks.emplace_back(std::make_unique<BlockDefinition>(block));
319   }
320 
321   BlockDefinition &def = *impl->blocks[it->second];
322   for (SMLoc loc : locations)
323     def.definition.uses.push_back(convertIdLocToRange(loc));
324 }
325 
addUses(SymbolRefAttr refAttr,ArrayRef<SMRange> locations)326 void AsmParserState::addUses(SymbolRefAttr refAttr,
327                              ArrayRef<SMRange> locations) {
328   // Ignore this symbol if no scopes are active.
329   if (impl->symbolUseScopes.empty())
330     return;
331 
332   assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
333          "expected the same number of references as provided locations");
334   (*impl->symbolUseScopes.back())[refAttr].emplace_back(locations.begin(),
335                                                         locations.end());
336 }
337 
refineDefinition(Value oldValue,Value newValue)338 void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
339   auto it = impl->placeholderValueUses.find(oldValue);
340   assert(it != impl->placeholderValueUses.end() &&
341          "expected `oldValue` to be a placeholder");
342   addUses(newValue, it->second);
343   impl->placeholderValueUses.erase(oldValue);
344 }
345