1 //===- Diagnostics.cpp - MLIR Diagnostics ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Diagnostics.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Location.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/MapVector.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Mutex.h"
20 #include "llvm/Support/PrettyStackTrace.h"
21 #include "llvm/Support/Regex.h"
22 #include "llvm/Support/Signals.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include "llvm/Support/raw_ostream.h"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 // DiagnosticArgument
31 //===----------------------------------------------------------------------===//
32 
33 /// Construct from an Attribute.
DiagnosticArgument(Attribute attr)34 DiagnosticArgument::DiagnosticArgument(Attribute attr)
35     : kind(DiagnosticArgumentKind::Attribute),
36       opaqueVal(reinterpret_cast<intptr_t>(attr.getAsOpaquePointer())) {}
37 
38 /// Construct from a Type.
DiagnosticArgument(Type val)39 DiagnosticArgument::DiagnosticArgument(Type val)
40     : kind(DiagnosticArgumentKind::Type),
41       opaqueVal(reinterpret_cast<intptr_t>(val.getAsOpaquePointer())) {}
42 
43 /// Returns this argument as an Attribute.
getAsAttribute() const44 Attribute DiagnosticArgument::getAsAttribute() const {
45   assert(getKind() == DiagnosticArgumentKind::Attribute);
46   return Attribute::getFromOpaquePointer(
47       reinterpret_cast<const void *>(opaqueVal));
48 }
49 
50 /// Returns this argument as a Type.
getAsType() const51 Type DiagnosticArgument::getAsType() const {
52   assert(getKind() == DiagnosticArgumentKind::Type);
53   return Type::getFromOpaquePointer(reinterpret_cast<const void *>(opaqueVal));
54 }
55 
56 /// Outputs this argument to a stream.
print(raw_ostream & os) const57 void DiagnosticArgument::print(raw_ostream &os) const {
58   switch (kind) {
59   case DiagnosticArgumentKind::Attribute:
60     os << getAsAttribute();
61     break;
62   case DiagnosticArgumentKind::Double:
63     os << getAsDouble();
64     break;
65   case DiagnosticArgumentKind::Integer:
66     os << getAsInteger();
67     break;
68   case DiagnosticArgumentKind::String:
69     os << getAsString();
70     break;
71   case DiagnosticArgumentKind::Type:
72     os << '\'' << getAsType() << '\'';
73     break;
74   case DiagnosticArgumentKind::Unsigned:
75     os << getAsUnsigned();
76     break;
77   }
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // Diagnostic
82 //===----------------------------------------------------------------------===//
83 
84 /// Convert a Twine to a StringRef. Memory used for generating the StringRef is
85 /// stored in 'strings'.
twineToStrRef(const Twine & val,std::vector<std::unique_ptr<char[]>> & strings)86 static StringRef twineToStrRef(const Twine &val,
87                                std::vector<std::unique_ptr<char[]>> &strings) {
88   // Allocate memory to hold this string.
89   SmallString<64> data;
90   auto strRef = val.toStringRef(data);
91   if (strRef.empty())
92     return strRef;
93 
94   strings.push_back(std::unique_ptr<char[]>(new char[strRef.size()]));
95   memcpy(&strings.back()[0], strRef.data(), strRef.size());
96   // Return a reference to the new string.
97   return StringRef(&strings.back()[0], strRef.size());
98 }
99 
100 /// Stream in a Twine argument.
operator <<(char val)101 Diagnostic &Diagnostic::operator<<(char val) { return *this << Twine(val); }
operator <<(const Twine & val)102 Diagnostic &Diagnostic::operator<<(const Twine &val) {
103   arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
104   return *this;
105 }
operator <<(Twine && val)106 Diagnostic &Diagnostic::operator<<(Twine &&val) {
107   arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
108   return *this;
109 }
110 
operator <<(StringAttr val)111 Diagnostic &Diagnostic::operator<<(StringAttr val) {
112   arguments.push_back(DiagnosticArgument(val));
113   return *this;
114 }
115 
116 /// Stream in an OperationName.
operator <<(OperationName val)117 Diagnostic &Diagnostic::operator<<(OperationName val) {
118   // An OperationName is stored in the context, so we don't need to worry about
119   // the lifetime of its data.
120   arguments.push_back(DiagnosticArgument(val.getStringRef()));
121   return *this;
122 }
123 
124 /// Adjusts operation printing flags used in diagnostics for the given severity
125 /// level.
adjustPrintingFlags(OpPrintingFlags flags,DiagnosticSeverity severity)126 static OpPrintingFlags adjustPrintingFlags(OpPrintingFlags flags,
127                                            DiagnosticSeverity severity) {
128   flags.useLocalScope();
129   flags.elideLargeElementsAttrs();
130   if (severity == DiagnosticSeverity::Error)
131     flags.printGenericOpForm();
132   return flags;
133 }
134 
135 /// Stream in an Operation.
operator <<(Operation & val)136 Diagnostic &Diagnostic::operator<<(Operation &val) {
137   return appendOp(val, OpPrintingFlags());
138 }
appendOp(Operation & val,const OpPrintingFlags & flags)139 Diagnostic &Diagnostic::appendOp(Operation &val, const OpPrintingFlags &flags) {
140   std::string str;
141   llvm::raw_string_ostream os(str);
142   val.print(os, adjustPrintingFlags(flags, severity));
143   return *this << os.str();
144 }
145 
146 /// Stream in a Value.
operator <<(Value val)147 Diagnostic &Diagnostic::operator<<(Value val) {
148   std::string str;
149   llvm::raw_string_ostream os(str);
150   val.print(os, adjustPrintingFlags(OpPrintingFlags(), severity));
151   return *this << os.str();
152 }
153 
154 /// Outputs this diagnostic to a stream.
print(raw_ostream & os) const155 void Diagnostic::print(raw_ostream &os) const {
156   for (auto &arg : getArguments())
157     arg.print(os);
158 }
159 
160 /// Convert the diagnostic to a string.
str() const161 std::string Diagnostic::str() const {
162   std::string str;
163   llvm::raw_string_ostream os(str);
164   print(os);
165   return os.str();
166 }
167 
168 /// Attaches a note to this diagnostic. A new location may be optionally
169 /// provided, if not, then the location defaults to the one specified for this
170 /// diagnostic. Notes may not be attached to other notes.
attachNote(Optional<Location> noteLoc)171 Diagnostic &Diagnostic::attachNote(Optional<Location> noteLoc) {
172   // We don't allow attaching notes to notes.
173   assert(severity != DiagnosticSeverity::Note &&
174          "cannot attach a note to a note");
175 
176   // If a location wasn't provided then reuse our location.
177   if (!noteLoc)
178     noteLoc = loc;
179 
180   /// Append and return a new note.
181   notes.push_back(
182       std::make_unique<Diagnostic>(*noteLoc, DiagnosticSeverity::Note));
183   return *notes.back();
184 }
185 
186 /// Allow a diagnostic to be converted to 'failure'.
operator LogicalResult() const187 Diagnostic::operator LogicalResult() const { return failure(); }
188 
189 //===----------------------------------------------------------------------===//
190 // InFlightDiagnostic
191 //===----------------------------------------------------------------------===//
192 
193 /// Allow an inflight diagnostic to be converted to 'failure', otherwise
194 /// 'success' if this is an empty diagnostic.
operator LogicalResult() const195 InFlightDiagnostic::operator LogicalResult() const {
196   return failure(isActive());
197 }
198 
199 /// Reports the diagnostic to the engine.
report()200 void InFlightDiagnostic::report() {
201   // If this diagnostic is still inflight and it hasn't been abandoned, then
202   // report it.
203   if (isInFlight()) {
204     owner->emit(std::move(*impl));
205     owner = nullptr;
206   }
207   impl.reset();
208 }
209 
210 /// Abandons this diagnostic.
abandon()211 void InFlightDiagnostic::abandon() { owner = nullptr; }
212 
213 //===----------------------------------------------------------------------===//
214 // DiagnosticEngineImpl
215 //===----------------------------------------------------------------------===//
216 
217 namespace mlir {
218 namespace detail {
219 struct DiagnosticEngineImpl {
220   /// Emit a diagnostic using the registered issue handle if present, or with
221   /// the default behavior if not.
222   void emit(Diagnostic &&diag);
223 
224   /// A mutex to ensure that diagnostics emission is thread-safe.
225   llvm::sys::SmartMutex<true> mutex;
226 
227   /// These are the handlers used to report diagnostics.
228   llvm::SmallMapVector<DiagnosticEngine::HandlerID, DiagnosticEngine::HandlerTy,
229                        2>
230       handlers;
231 
232   /// This is a unique identifier counter for diagnostic handlers in the
233   /// context. This id starts at 1 to allow for 0 to be used as a sentinel.
234   DiagnosticEngine::HandlerID uniqueHandlerId = 1;
235 };
236 } // namespace detail
237 } // namespace mlir
238 
239 /// Emit a diagnostic using the registered issue handle if present, or with
240 /// the default behavior if not.
emit(Diagnostic && diag)241 void DiagnosticEngineImpl::emit(Diagnostic &&diag) {
242   llvm::sys::SmartScopedLock<true> lock(mutex);
243 
244   // Try to process the given diagnostic on one of the registered handlers.
245   // Handlers are walked in reverse order, so that the most recent handler is
246   // processed first.
247   for (auto &handlerIt : llvm::reverse(handlers))
248     if (succeeded(handlerIt.second(diag)))
249       return;
250 
251   // Otherwise, if this is an error we emit it to stderr.
252   if (diag.getSeverity() != DiagnosticSeverity::Error)
253     return;
254 
255   auto &os = llvm::errs();
256   if (!diag.getLocation().isa<UnknownLoc>())
257     os << diag.getLocation() << ": ";
258   os << "error: ";
259 
260   // The default behavior for errors is to emit them to stderr.
261   os << diag << '\n';
262   os.flush();
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // DiagnosticEngine
267 //===----------------------------------------------------------------------===//
268 
DiagnosticEngine()269 DiagnosticEngine::DiagnosticEngine() : impl(new DiagnosticEngineImpl()) {}
270 DiagnosticEngine::~DiagnosticEngine() = default;
271 
272 /// Register a new handler for diagnostics to the engine. This function returns
273 /// a unique identifier for the registered handler, which can be used to
274 /// unregister this handler at a later time.
registerHandler(HandlerTy handler)275 auto DiagnosticEngine::registerHandler(HandlerTy handler) -> HandlerID {
276   llvm::sys::SmartScopedLock<true> lock(impl->mutex);
277   auto uniqueID = impl->uniqueHandlerId++;
278   impl->handlers.insert({uniqueID, std::move(handler)});
279   return uniqueID;
280 }
281 
282 /// Erase the registered diagnostic handler with the given identifier.
eraseHandler(HandlerID handlerID)283 void DiagnosticEngine::eraseHandler(HandlerID handlerID) {
284   llvm::sys::SmartScopedLock<true> lock(impl->mutex);
285   impl->handlers.erase(handlerID);
286 }
287 
288 /// Emit a diagnostic using the registered issue handler if present, or with
289 /// the default behavior if not.
emit(Diagnostic && diag)290 void DiagnosticEngine::emit(Diagnostic &&diag) {
291   assert(diag.getSeverity() != DiagnosticSeverity::Note &&
292          "notes should not be emitted directly");
293   impl->emit(std::move(diag));
294 }
295 
296 /// Helper function used to emit a diagnostic with an optionally empty twine
297 /// message. If the message is empty, then it is not inserted into the
298 /// diagnostic.
299 static InFlightDiagnostic
emitDiag(Location location,DiagnosticSeverity severity,const Twine & message)300 emitDiag(Location location, DiagnosticSeverity severity, const Twine &message) {
301   MLIRContext *ctx = location->getContext();
302   auto &diagEngine = ctx->getDiagEngine();
303   auto diag = diagEngine.emit(location, severity);
304   if (!message.isTriviallyEmpty())
305     diag << message;
306 
307   // Add the stack trace as a note if necessary.
308   if (ctx->shouldPrintStackTraceOnDiagnostic()) {
309     std::string bt;
310     {
311       llvm::raw_string_ostream stream(bt);
312       llvm::sys::PrintStackTrace(stream);
313     }
314     if (!bt.empty())
315       diag.attachNote() << "diagnostic emitted with trace:\n" << bt;
316   }
317 
318   return diag;
319 }
320 
321 /// Emit an error message using this location.
emitError(Location loc)322 InFlightDiagnostic mlir::emitError(Location loc) { return emitError(loc, {}); }
emitError(Location loc,const Twine & message)323 InFlightDiagnostic mlir::emitError(Location loc, const Twine &message) {
324   return emitDiag(loc, DiagnosticSeverity::Error, message);
325 }
326 
327 /// Emit a warning message using this location.
emitWarning(Location loc)328 InFlightDiagnostic mlir::emitWarning(Location loc) {
329   return emitWarning(loc, {});
330 }
emitWarning(Location loc,const Twine & message)331 InFlightDiagnostic mlir::emitWarning(Location loc, const Twine &message) {
332   return emitDiag(loc, DiagnosticSeverity::Warning, message);
333 }
334 
335 /// Emit a remark message using this location.
emitRemark(Location loc)336 InFlightDiagnostic mlir::emitRemark(Location loc) {
337   return emitRemark(loc, {});
338 }
emitRemark(Location loc,const Twine & message)339 InFlightDiagnostic mlir::emitRemark(Location loc, const Twine &message) {
340   return emitDiag(loc, DiagnosticSeverity::Remark, message);
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // ScopedDiagnosticHandler
345 //===----------------------------------------------------------------------===//
346 
~ScopedDiagnosticHandler()347 ScopedDiagnosticHandler::~ScopedDiagnosticHandler() {
348   if (handlerID)
349     ctx->getDiagEngine().eraseHandler(handlerID);
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // SourceMgrDiagnosticHandler
354 //===----------------------------------------------------------------------===//
355 namespace mlir {
356 namespace detail {
357 struct SourceMgrDiagnosticHandlerImpl {
358   /// Return the SrcManager buffer id for the specified file, or zero if none
359   /// can be found.
getSourceMgrBufferIDForFilemlir::detail::SourceMgrDiagnosticHandlerImpl360   unsigned getSourceMgrBufferIDForFile(llvm::SourceMgr &mgr,
361                                        StringRef filename) {
362     // Check for an existing mapping to the buffer id for this file.
363     auto bufferIt = filenameToBufId.find(filename);
364     if (bufferIt != filenameToBufId.end())
365       return bufferIt->second;
366 
367     // Look for a buffer in the manager that has this filename.
368     for (unsigned i = 1, e = mgr.getNumBuffers() + 1; i != e; ++i) {
369       auto *buf = mgr.getMemoryBuffer(i);
370       if (buf->getBufferIdentifier() == filename)
371         return filenameToBufId[filename] = i;
372     }
373 
374     // Otherwise, try to load the source file.
375     std::string ignored;
376     unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
377     filenameToBufId[filename] = id;
378     return id;
379   }
380 
381   /// Mapping between file name and buffer ID's.
382   llvm::StringMap<unsigned> filenameToBufId;
383 };
384 } // namespace detail
385 } // namespace mlir
386 
387 /// Return a processable FileLineColLoc from the given location.
getFileLineColLoc(Location loc)388 static Optional<FileLineColLoc> getFileLineColLoc(Location loc) {
389   Optional<FileLineColLoc> firstFileLoc;
390   loc->walk([&](Location loc) {
391     if (FileLineColLoc fileLoc = loc.dyn_cast<FileLineColLoc>()) {
392       firstFileLoc = fileLoc;
393       return WalkResult::interrupt();
394     }
395     return WalkResult::advance();
396   });
397   return firstFileLoc;
398 }
399 
400 /// Return a processable CallSiteLoc from the given location.
getCallSiteLoc(Location loc)401 static Optional<CallSiteLoc> getCallSiteLoc(Location loc) {
402   if (auto nameLoc = loc.dyn_cast<NameLoc>())
403     return getCallSiteLoc(loc.cast<NameLoc>().getChildLoc());
404   if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
405     return callLoc;
406   if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
407     for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
408       if (auto callLoc = getCallSiteLoc(subLoc)) {
409         return callLoc;
410       }
411     }
412     return llvm::None;
413   }
414   return llvm::None;
415 }
416 
417 /// Given a diagnostic kind, returns the LLVM DiagKind.
getDiagKind(DiagnosticSeverity kind)418 static llvm::SourceMgr::DiagKind getDiagKind(DiagnosticSeverity kind) {
419   switch (kind) {
420   case DiagnosticSeverity::Note:
421     return llvm::SourceMgr::DK_Note;
422   case DiagnosticSeverity::Warning:
423     return llvm::SourceMgr::DK_Warning;
424   case DiagnosticSeverity::Error:
425     return llvm::SourceMgr::DK_Error;
426   case DiagnosticSeverity::Remark:
427     return llvm::SourceMgr::DK_Remark;
428   }
429   llvm_unreachable("Unknown DiagnosticSeverity");
430 }
431 
SourceMgrDiagnosticHandler(llvm::SourceMgr & mgr,MLIRContext * ctx,raw_ostream & os,ShouldShowLocFn && shouldShowLocFn)432 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
433     llvm::SourceMgr &mgr, MLIRContext *ctx, raw_ostream &os,
434     ShouldShowLocFn &&shouldShowLocFn)
435     : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os),
436       shouldShowLocFn(std::move(shouldShowLocFn)),
437       impl(new SourceMgrDiagnosticHandlerImpl()) {
438   setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); });
439 }
440 
SourceMgrDiagnosticHandler(llvm::SourceMgr & mgr,MLIRContext * ctx,ShouldShowLocFn && shouldShowLocFn)441 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
442     llvm::SourceMgr &mgr, MLIRContext *ctx, ShouldShowLocFn &&shouldShowLocFn)
443     : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs(),
444                                  std::move(shouldShowLocFn)) {}
445 
446 SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() = default;
447 
emitDiagnostic(Location loc,Twine message,DiagnosticSeverity kind,bool displaySourceLine)448 void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
449                                                 DiagnosticSeverity kind,
450                                                 bool displaySourceLine) {
451   // Extract a file location from this loc.
452   auto fileLoc = getFileLineColLoc(loc);
453 
454   // If one doesn't exist, then print the raw message without a source location.
455   if (!fileLoc) {
456     std::string str;
457     llvm::raw_string_ostream strOS(str);
458     if (!loc.isa<UnknownLoc>())
459       strOS << loc << ": ";
460     strOS << message;
461     return mgr.PrintMessage(os, SMLoc(), getDiagKind(kind), strOS.str());
462   }
463 
464   // Otherwise if we are displaying the source line, try to convert the file
465   // location to an SMLoc.
466   if (displaySourceLine) {
467     auto smloc = convertLocToSMLoc(*fileLoc);
468     if (smloc.isValid())
469       return mgr.PrintMessage(os, smloc, getDiagKind(kind), message);
470   }
471 
472   // If the conversion was unsuccessful, create a diagnostic with the file
473   // information. We manually combine the line and column to avoid asserts in
474   // the constructor of SMDiagnostic that takes a location.
475   std::string locStr;
476   llvm::raw_string_ostream locOS(locStr);
477   locOS << fileLoc->getFilename().getValue() << ":" << fileLoc->getLine() << ":"
478         << fileLoc->getColumn();
479   llvm::SMDiagnostic diag(locOS.str(), getDiagKind(kind), message.str());
480   diag.print(nullptr, os);
481 }
482 
483 /// Emit the given diagnostic with the held source manager.
emitDiagnostic(Diagnostic & diag)484 void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) {
485   SmallVector<std::pair<Location, StringRef>> locationStack;
486   auto addLocToStack = [&](Location loc, StringRef locContext) {
487     if (Optional<Location> showableLoc = findLocToShow(loc))
488       locationStack.emplace_back(*showableLoc, locContext);
489   };
490 
491   // Add locations to display for this diagnostic.
492   Location loc = diag.getLocation();
493   addLocToStack(loc, /*locContext=*/{});
494 
495   // If the diagnostic location was a call site location, add the call stack as
496   // well.
497   if (auto callLoc = getCallSiteLoc(loc)) {
498     // Print the call stack while valid, or until the limit is reached.
499     loc = callLoc->getCaller();
500     for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) {
501       addLocToStack(loc, "called from");
502       if ((callLoc = getCallSiteLoc(loc)))
503         loc = callLoc->getCaller();
504       else
505         break;
506     }
507   }
508 
509   // If the location stack is empty, use the initial location.
510   if (locationStack.empty()) {
511     emitDiagnostic(diag.getLocation(), diag.str(), diag.getSeverity());
512 
513     // Otherwise, use the location stack.
514   } else {
515     emitDiagnostic(locationStack.front().first, diag.str(), diag.getSeverity());
516     for (auto &it : llvm::drop_begin(locationStack))
517       emitDiagnostic(it.first, it.second, DiagnosticSeverity::Note);
518   }
519 
520   // Emit each of the notes. Only display the source code if the location is
521   // different from the previous location.
522   for (auto &note : diag.getNotes()) {
523     emitDiagnostic(note.getLocation(), note.str(), note.getSeverity(),
524                    /*displaySourceLine=*/loc != note.getLocation());
525     loc = note.getLocation();
526   }
527 }
528 
529 /// Get a memory buffer for the given file, or nullptr if one is not found.
530 const llvm::MemoryBuffer *
getBufferForFile(StringRef filename)531 SourceMgrDiagnosticHandler::getBufferForFile(StringRef filename) {
532   if (unsigned id = impl->getSourceMgrBufferIDForFile(mgr, filename))
533     return mgr.getMemoryBuffer(id);
534   return nullptr;
535 }
536 
findLocToShow(Location loc)537 Optional<Location> SourceMgrDiagnosticHandler::findLocToShow(Location loc) {
538   if (!shouldShowLocFn)
539     return loc;
540   if (!shouldShowLocFn(loc))
541     return llvm::None;
542 
543   // Recurse into the child locations of some of location types.
544   return TypeSwitch<LocationAttr, Optional<Location>>(loc)
545       .Case([&](CallSiteLoc callLoc) -> Optional<Location> {
546         // We recurse into the callee of a call site, as the caller will be
547         // emitted in a different note on the main diagnostic.
548         return findLocToShow(callLoc.getCallee());
549       })
550       .Case([&](FileLineColLoc) -> Optional<Location> { return loc; })
551       .Case([&](FusedLoc fusedLoc) -> Optional<Location> {
552         // Fused location is unique in that we try to find a sub-location to
553         // show, rather than the top-level location itself.
554         for (Location childLoc : fusedLoc.getLocations())
555           if (Optional<Location> showableLoc = findLocToShow(childLoc))
556             return showableLoc;
557         return llvm::None;
558       })
559       .Case([&](NameLoc nameLoc) -> Optional<Location> {
560         return findLocToShow(nameLoc.getChildLoc());
561       })
562       .Case([&](OpaqueLoc opaqueLoc) -> Optional<Location> {
563         // OpaqueLoc always falls back to a different source location.
564         return findLocToShow(opaqueLoc.getFallbackLocation());
565       })
566       .Case([](UnknownLoc) -> Optional<Location> {
567         // Prefer not to show unknown locations.
568         return llvm::None;
569       });
570 }
571 
572 /// Get a memory buffer for the given file, or the main file of the source
573 /// manager if one doesn't exist. This always returns non-null.
convertLocToSMLoc(FileLineColLoc loc)574 SMLoc SourceMgrDiagnosticHandler::convertLocToSMLoc(FileLineColLoc loc) {
575   // The column and line may be zero to represent unknown column and/or unknown
576   /// line/column information.
577   if (loc.getLine() == 0 || loc.getColumn() == 0)
578     return SMLoc();
579 
580   unsigned bufferId = impl->getSourceMgrBufferIDForFile(mgr, loc.getFilename());
581   if (!bufferId)
582     return SMLoc();
583   return mgr.FindLocForLineAndColumn(bufferId, loc.getLine(), loc.getColumn());
584 }
585 
586 //===----------------------------------------------------------------------===//
587 // SourceMgrDiagnosticVerifierHandler
588 //===----------------------------------------------------------------------===//
589 
590 namespace mlir {
591 namespace detail {
592 /// This class represents an expected output diagnostic.
593 struct ExpectedDiag {
ExpectedDiagmlir::detail::ExpectedDiag594   ExpectedDiag(DiagnosticSeverity kind, unsigned lineNo, SMLoc fileLoc,
595                StringRef substring)
596       : kind(kind), lineNo(lineNo), fileLoc(fileLoc), substring(substring) {}
597 
598   /// Emit an error at the location referenced by this diagnostic.
emitErrormlir::detail::ExpectedDiag599   LogicalResult emitError(raw_ostream &os, llvm::SourceMgr &mgr,
600                           const Twine &msg) {
601     SMRange range(fileLoc, SMLoc::getFromPointer(fileLoc.getPointer() +
602                                                  substring.size()));
603     mgr.PrintMessage(os, fileLoc, llvm::SourceMgr::DK_Error, msg, range);
604     return failure();
605   }
606 
607   /// Returns true if this diagnostic matches the given string.
matchmlir::detail::ExpectedDiag608   bool match(StringRef str) const {
609     // If this isn't a regex diagnostic, we simply check if the string was
610     // contained.
611     if (substringRegex)
612       return substringRegex->match(str);
613     return str.contains(substring);
614   }
615 
616   /// Compute the regex matcher for this diagnostic, using the provided stream
617   /// and manager to emit diagnostics as necessary.
computeRegexmlir::detail::ExpectedDiag618   LogicalResult computeRegex(raw_ostream &os, llvm::SourceMgr &mgr) {
619     std::string regexStr;
620     llvm::raw_string_ostream regexOS(regexStr);
621     StringRef strToProcess = substring;
622     while (!strToProcess.empty()) {
623       // Find the next regex block.
624       size_t regexIt = strToProcess.find("{{");
625       if (regexIt == StringRef::npos) {
626         regexOS << llvm::Regex::escape(strToProcess);
627         break;
628       }
629       regexOS << llvm::Regex::escape(strToProcess.take_front(regexIt));
630       strToProcess = strToProcess.drop_front(regexIt + 2);
631 
632       // Find the end of the regex block.
633       size_t regexEndIt = strToProcess.find("}}");
634       if (regexEndIt == StringRef::npos)
635         return emitError(os, mgr, "found start of regex with no end '}}'");
636       StringRef regexStr = strToProcess.take_front(regexEndIt);
637 
638       // Validate that the regex is actually valid.
639       std::string regexError;
640       if (!llvm::Regex(regexStr).isValid(regexError))
641         return emitError(os, mgr, "invalid regex: " + regexError);
642 
643       regexOS << '(' << regexStr << ')';
644       strToProcess = strToProcess.drop_front(regexEndIt + 2);
645     }
646     substringRegex = llvm::Regex(regexOS.str());
647     return success();
648   }
649 
650   /// The severity of the diagnosic expected.
651   DiagnosticSeverity kind;
652   /// The line number the expected diagnostic should be on.
653   unsigned lineNo;
654   /// The location of the expected diagnostic within the input file.
655   SMLoc fileLoc;
656   /// A flag indicating if the expected diagnostic has been matched yet.
657   bool matched = false;
658   /// The substring that is expected to be within the diagnostic.
659   StringRef substring;
660   /// An optional regex matcher, if the expected diagnostic sub-string was a
661   /// regex string.
662   Optional<llvm::Regex> substringRegex;
663 };
664 
665 struct SourceMgrDiagnosticVerifierHandlerImpl {
SourceMgrDiagnosticVerifierHandlerImplmlir::detail::SourceMgrDiagnosticVerifierHandlerImpl666   SourceMgrDiagnosticVerifierHandlerImpl() : status(success()) {}
667 
668   /// Returns the expected diagnostics for the given source file.
669   Optional<MutableArrayRef<ExpectedDiag>> getExpectedDiags(StringRef bufName);
670 
671   /// Computes the expected diagnostics for the given source buffer.
672   MutableArrayRef<ExpectedDiag>
673   computeExpectedDiags(raw_ostream &os, llvm::SourceMgr &mgr,
674                        const llvm::MemoryBuffer *buf);
675 
676   /// The current status of the verifier.
677   LogicalResult status;
678 
679   /// A list of expected diagnostics for each buffer of the source manager.
680   llvm::StringMap<SmallVector<ExpectedDiag, 2>> expectedDiagsPerFile;
681 
682   /// Regex to match the expected diagnostics format.
683   llvm::Regex expected =
684       llvm::Regex("expected-(error|note|remark|warning)(-re)? "
685                   "*(@([+-][0-9]+|above|below))? *{{(.*)}}$");
686 };
687 } // namespace detail
688 } // namespace mlir
689 
690 /// Given a diagnostic kind, return a human readable string for it.
getDiagKindStr(DiagnosticSeverity kind)691 static StringRef getDiagKindStr(DiagnosticSeverity kind) {
692   switch (kind) {
693   case DiagnosticSeverity::Note:
694     return "note";
695   case DiagnosticSeverity::Warning:
696     return "warning";
697   case DiagnosticSeverity::Error:
698     return "error";
699   case DiagnosticSeverity::Remark:
700     return "remark";
701   }
702   llvm_unreachable("Unknown DiagnosticSeverity");
703 }
704 
705 Optional<MutableArrayRef<ExpectedDiag>>
getExpectedDiags(StringRef bufName)706 SourceMgrDiagnosticVerifierHandlerImpl::getExpectedDiags(StringRef bufName) {
707   auto expectedDiags = expectedDiagsPerFile.find(bufName);
708   if (expectedDiags != expectedDiagsPerFile.end())
709     return MutableArrayRef<ExpectedDiag>(expectedDiags->second);
710   return llvm::None;
711 }
712 
713 MutableArrayRef<ExpectedDiag>
computeExpectedDiags(raw_ostream & os,llvm::SourceMgr & mgr,const llvm::MemoryBuffer * buf)714 SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
715     raw_ostream &os, llvm::SourceMgr &mgr, const llvm::MemoryBuffer *buf) {
716   // If the buffer is invalid, return an empty list.
717   if (!buf)
718     return llvm::None;
719   auto &expectedDiags = expectedDiagsPerFile[buf->getBufferIdentifier()];
720 
721   // The number of the last line that did not correlate to a designator.
722   unsigned lastNonDesignatorLine = 0;
723 
724   // The indices of designators that apply to the next non designator line.
725   SmallVector<unsigned, 1> designatorsForNextLine;
726 
727   // Scan the file for expected-* designators.
728   SmallVector<StringRef, 100> lines;
729   buf->getBuffer().split(lines, '\n');
730   for (unsigned lineNo = 0, e = lines.size(); lineNo < e; ++lineNo) {
731     SmallVector<StringRef, 4> matches;
732     if (!expected.match(lines[lineNo].rtrim(), &matches)) {
733       // Check for designators that apply to this line.
734       if (!designatorsForNextLine.empty()) {
735         for (unsigned diagIndex : designatorsForNextLine)
736           expectedDiags[diagIndex].lineNo = lineNo + 1;
737         designatorsForNextLine.clear();
738       }
739       lastNonDesignatorLine = lineNo;
740       continue;
741     }
742 
743     // Point to the start of expected-*.
744     SMLoc expectedStart = SMLoc::getFromPointer(matches[0].data());
745 
746     DiagnosticSeverity kind;
747     if (matches[1] == "error")
748       kind = DiagnosticSeverity::Error;
749     else if (matches[1] == "warning")
750       kind = DiagnosticSeverity::Warning;
751     else if (matches[1] == "remark")
752       kind = DiagnosticSeverity::Remark;
753     else {
754       assert(matches[1] == "note");
755       kind = DiagnosticSeverity::Note;
756     }
757     ExpectedDiag record(kind, lineNo + 1, expectedStart, matches[5]);
758 
759     // Check to see if this is a regex match, i.e. it includes the `-re`.
760     if (!matches[2].empty() && failed(record.computeRegex(os, mgr))) {
761       status = failure();
762       continue;
763     }
764 
765     StringRef offsetMatch = matches[3];
766     if (!offsetMatch.empty()) {
767       offsetMatch = offsetMatch.drop_front(1);
768 
769       // Get the integer value without the @ and +/- prefix.
770       if (offsetMatch[0] == '+' || offsetMatch[0] == '-') {
771         int offset;
772         offsetMatch.drop_front().getAsInteger(0, offset);
773 
774         if (offsetMatch.front() == '+')
775           record.lineNo += offset;
776         else
777           record.lineNo -= offset;
778       } else if (offsetMatch.consume_front("above")) {
779         // If the designator applies 'above' we add it to the last non
780         // designator line.
781         record.lineNo = lastNonDesignatorLine + 1;
782       } else {
783         // Otherwise, this is a 'below' designator and applies to the next
784         // non-designator line.
785         assert(offsetMatch.consume_front("below"));
786         designatorsForNextLine.push_back(expectedDiags.size());
787 
788         // Set the line number to the last in the case that this designator ends
789         // up dangling.
790         record.lineNo = e;
791       }
792     }
793     expectedDiags.emplace_back(std::move(record));
794   }
795   return expectedDiags;
796 }
797 
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr & srcMgr,MLIRContext * ctx,raw_ostream & out)798 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
799     llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out)
800     : SourceMgrDiagnosticHandler(srcMgr, ctx, out),
801       impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
802   // Compute the expected diagnostics for each of the current files in the
803   // source manager.
804   for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
805     (void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1));
806 
807   // Register a handler to verify the diagnostics.
808   setHandler([&](Diagnostic &diag) {
809     // Process the main diagnostics.
810     process(diag);
811 
812     // Process each of the notes.
813     for (auto &note : diag.getNotes())
814       process(note);
815   });
816 }
817 
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr & srcMgr,MLIRContext * ctx)818 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
819     llvm::SourceMgr &srcMgr, MLIRContext *ctx)
820     : SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {}
821 
~SourceMgrDiagnosticVerifierHandler()822 SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
823   // Ensure that all expected diagnostics were handled.
824   (void)verify();
825 }
826 
827 /// Returns the status of the verifier and verifies that all expected
828 /// diagnostics were emitted. This return success if all diagnostics were
829 /// verified correctly, failure otherwise.
verify()830 LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
831   // Verify that all expected errors were seen.
832   for (auto &expectedDiagsPair : impl->expectedDiagsPerFile) {
833     for (auto &err : expectedDiagsPair.second) {
834       if (err.matched)
835         continue;
836       impl->status =
837           err.emitError(os, mgr,
838                         "expected " + getDiagKindStr(err.kind) + " \"" +
839                             err.substring + "\" was not produced");
840     }
841   }
842   impl->expectedDiagsPerFile.clear();
843   return impl->status;
844 }
845 
846 /// Process a single diagnostic.
process(Diagnostic & diag)847 void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
848   auto kind = diag.getSeverity();
849 
850   // Process a FileLineColLoc.
851   if (auto fileLoc = getFileLineColLoc(diag.getLocation()))
852     return process(*fileLoc, diag.str(), kind);
853 
854   emitDiagnostic(diag.getLocation(),
855                  "unexpected " + getDiagKindStr(kind) + ": " + diag.str(),
856                  DiagnosticSeverity::Error);
857   impl->status = failure();
858 }
859 
860 /// Process a FileLineColLoc diagnostic.
process(FileLineColLoc loc,StringRef msg,DiagnosticSeverity kind)861 void SourceMgrDiagnosticVerifierHandler::process(FileLineColLoc loc,
862                                                  StringRef msg,
863                                                  DiagnosticSeverity kind) {
864   // Get the expected diagnostics for this file.
865   auto diags = impl->getExpectedDiags(loc.getFilename());
866   if (!diags) {
867     diags = impl->computeExpectedDiags(os, mgr,
868                                        getBufferForFile(loc.getFilename()));
869   }
870 
871   // Search for a matching expected diagnostic.
872   // If we find something that is close then emit a more specific error.
873   ExpectedDiag *nearMiss = nullptr;
874 
875   // If this was an expected error, remember that we saw it and return.
876   unsigned line = loc.getLine();
877   for (auto &e : *diags) {
878     if (line == e.lineNo && e.match(msg)) {
879       if (e.kind == kind) {
880         e.matched = true;
881         return;
882       }
883 
884       // If this only differs based on the diagnostic kind, then consider it
885       // to be a near miss.
886       nearMiss = &e;
887     }
888   }
889 
890   // Otherwise, emit an error for the near miss.
891   if (nearMiss)
892     mgr.PrintMessage(os, nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
893                      "'" + getDiagKindStr(kind) +
894                          "' diagnostic emitted when expecting a '" +
895                          getDiagKindStr(nearMiss->kind) + "'");
896   else
897     emitDiagnostic(loc, "unexpected " + getDiagKindStr(kind) + ": " + msg,
898                    DiagnosticSeverity::Error);
899   impl->status = failure();
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // ParallelDiagnosticHandler
904 //===----------------------------------------------------------------------===//
905 
906 namespace mlir {
907 namespace detail {
908 struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
909   struct ThreadDiagnostic {
ThreadDiagnosticmlir::detail::ParallelDiagnosticHandlerImpl::ThreadDiagnostic910     ThreadDiagnostic(size_t id, Diagnostic diag)
911         : id(id), diag(std::move(diag)) {}
operator <mlir::detail::ParallelDiagnosticHandlerImpl::ThreadDiagnostic912     bool operator<(const ThreadDiagnostic &rhs) const { return id < rhs.id; }
913 
914     /// The id for this diagnostic, this is used for ordering.
915     /// Note: This id corresponds to the ordered position of the current element
916     ///       being processed by a given thread.
917     size_t id;
918 
919     /// The diagnostic.
920     Diagnostic diag;
921   };
922 
ParallelDiagnosticHandlerImplmlir::detail::ParallelDiagnosticHandlerImpl923   ParallelDiagnosticHandlerImpl(MLIRContext *ctx) : context(ctx) {
924     handlerID = ctx->getDiagEngine().registerHandler([this](Diagnostic &diag) {
925       uint64_t tid = llvm::get_threadid();
926       llvm::sys::SmartScopedLock<true> lock(mutex);
927 
928       // If this thread is not tracked, then return failure to let another
929       // handler process this diagnostic.
930       if (!threadToOrderID.count(tid))
931         return failure();
932 
933       // Append a new diagnostic.
934       diagnostics.emplace_back(threadToOrderID[tid], std::move(diag));
935       return success();
936     });
937   }
938 
~ParallelDiagnosticHandlerImplmlir::detail::ParallelDiagnosticHandlerImpl939   ~ParallelDiagnosticHandlerImpl() override {
940     // Erase this handler from the context.
941     context->getDiagEngine().eraseHandler(handlerID);
942 
943     // Early exit if there are no diagnostics, this is the common case.
944     if (diagnostics.empty())
945       return;
946 
947     // Emit the diagnostics back to the context.
948     emitDiagnostics([&](Diagnostic &diag) {
949       return context->getDiagEngine().emit(std::move(diag));
950     });
951   }
952 
953   /// Utility method to emit any held diagnostics.
emitDiagnosticsmlir::detail::ParallelDiagnosticHandlerImpl954   void emitDiagnostics(llvm::function_ref<void(Diagnostic &)> emitFn) const {
955     // Stable sort all of the diagnostics that were emitted. This creates a
956     // deterministic ordering for the diagnostics based upon which order id they
957     // were emitted for.
958     std::stable_sort(diagnostics.begin(), diagnostics.end());
959 
960     // Emit each diagnostic to the context again.
961     for (ThreadDiagnostic &diag : diagnostics)
962       emitFn(diag.diag);
963   }
964 
965   /// Set the order id for the current thread.
setOrderIDForThreadmlir::detail::ParallelDiagnosticHandlerImpl966   void setOrderIDForThread(size_t orderID) {
967     uint64_t tid = llvm::get_threadid();
968     llvm::sys::SmartScopedLock<true> lock(mutex);
969     threadToOrderID[tid] = orderID;
970   }
971 
972   /// Remove the order id for the current thread.
eraseOrderIDForThreadmlir::detail::ParallelDiagnosticHandlerImpl973   void eraseOrderIDForThread() {
974     uint64_t tid = llvm::get_threadid();
975     llvm::sys::SmartScopedLock<true> lock(mutex);
976     threadToOrderID.erase(tid);
977   }
978 
979   /// Dump the current diagnostics that were inflight.
printmlir::detail::ParallelDiagnosticHandlerImpl980   void print(raw_ostream &os) const override {
981     // Early exit if there are no diagnostics, this is the common case.
982     if (diagnostics.empty())
983       return;
984 
985     os << "In-Flight Diagnostics:\n";
986     emitDiagnostics([&](const Diagnostic &diag) {
987       os.indent(4);
988 
989       // Print each diagnostic with the format:
990       //   "<location>: <kind>: <msg>"
991       if (!diag.getLocation().isa<UnknownLoc>())
992         os << diag.getLocation() << ": ";
993       switch (diag.getSeverity()) {
994       case DiagnosticSeverity::Error:
995         os << "error: ";
996         break;
997       case DiagnosticSeverity::Warning:
998         os << "warning: ";
999         break;
1000       case DiagnosticSeverity::Note:
1001         os << "note: ";
1002         break;
1003       case DiagnosticSeverity::Remark:
1004         os << "remark: ";
1005         break;
1006       }
1007       os << diag << '\n';
1008     });
1009   }
1010 
1011   /// A smart mutex to lock access to the internal state.
1012   llvm::sys::SmartMutex<true> mutex;
1013 
1014   /// A mapping between the thread id and the current order id.
1015   DenseMap<uint64_t, size_t> threadToOrderID;
1016 
1017   /// An unordered list of diagnostics that were emitted.
1018   mutable std::vector<ThreadDiagnostic> diagnostics;
1019 
1020   /// The unique id for the parallel handler.
1021   DiagnosticEngine::HandlerID handlerID = 0;
1022 
1023   /// The context to emit the diagnostics to.
1024   MLIRContext *context;
1025 };
1026 } // namespace detail
1027 } // namespace mlir
1028 
ParallelDiagnosticHandler(MLIRContext * ctx)1029 ParallelDiagnosticHandler::ParallelDiagnosticHandler(MLIRContext *ctx)
1030     : impl(new ParallelDiagnosticHandlerImpl(ctx)) {}
1031 ParallelDiagnosticHandler::~ParallelDiagnosticHandler() = default;
1032 
1033 /// Set the order id for the current thread.
setOrderIDForThread(size_t orderID)1034 void ParallelDiagnosticHandler::setOrderIDForThread(size_t orderID) {
1035   impl->setOrderIDForThread(orderID);
1036 }
1037 
1038 /// Remove the order id for the current thread. This removes the thread from
1039 /// diagnostics tracking.
eraseOrderIDForThread()1040 void ParallelDiagnosticHandler::eraseOrderIDForThread() {
1041   impl->eraseOrderIDForThread();
1042 }
1043