1 //===- Diagnostics.cpp - MLIR Diagnostics ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Diagnostics.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Location.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/MapVector.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Mutex.h"
20 #include "llvm/Support/PrettyStackTrace.h"
21 #include "llvm/Support/Regex.h"
22 #include "llvm/Support/Signals.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include "llvm/Support/raw_ostream.h"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 // DiagnosticArgument
31 //===----------------------------------------------------------------------===//
32 
33 /// Construct from an Attribute.
34 DiagnosticArgument::DiagnosticArgument(Attribute attr)
35     : kind(DiagnosticArgumentKind::Attribute),
36       opaqueVal(reinterpret_cast<intptr_t>(attr.getAsOpaquePointer())) {}
37 
38 /// Construct from a Type.
39 DiagnosticArgument::DiagnosticArgument(Type val)
40     : kind(DiagnosticArgumentKind::Type),
41       opaqueVal(reinterpret_cast<intptr_t>(val.getAsOpaquePointer())) {}
42 
43 /// Returns this argument as an Attribute.
44 Attribute DiagnosticArgument::getAsAttribute() const {
45   assert(getKind() == DiagnosticArgumentKind::Attribute);
46   return Attribute::getFromOpaquePointer(
47       reinterpret_cast<const void *>(opaqueVal));
48 }
49 
50 /// Returns this argument as a Type.
51 Type DiagnosticArgument::getAsType() const {
52   assert(getKind() == DiagnosticArgumentKind::Type);
53   return Type::getFromOpaquePointer(reinterpret_cast<const void *>(opaqueVal));
54 }
55 
56 /// Outputs this argument to a stream.
57 void DiagnosticArgument::print(raw_ostream &os) const {
58   switch (kind) {
59   case DiagnosticArgumentKind::Attribute:
60     os << getAsAttribute();
61     break;
62   case DiagnosticArgumentKind::Double:
63     os << getAsDouble();
64     break;
65   case DiagnosticArgumentKind::Integer:
66     os << getAsInteger();
67     break;
68   case DiagnosticArgumentKind::String:
69     os << getAsString();
70     break;
71   case DiagnosticArgumentKind::Type:
72     os << '\'' << getAsType() << '\'';
73     break;
74   case DiagnosticArgumentKind::Unsigned:
75     os << getAsUnsigned();
76     break;
77   }
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // Diagnostic
82 //===----------------------------------------------------------------------===//
83 
84 /// Convert a Twine to a StringRef. Memory used for generating the StringRef is
85 /// stored in 'strings'.
86 static StringRef twineToStrRef(const Twine &val,
87                                std::vector<std::unique_ptr<char[]>> &strings) {
88   // Allocate memory to hold this string.
89   SmallString<64> data;
90   auto strRef = val.toStringRef(data);
91   if (strRef.empty())
92     return strRef;
93 
94   strings.push_back(std::unique_ptr<char[]>(new char[strRef.size()]));
95   memcpy(&strings.back()[0], strRef.data(), strRef.size());
96   // Return a reference to the new string.
97   return StringRef(&strings.back()[0], strRef.size());
98 }
99 
100 /// Stream in a Twine argument.
101 Diagnostic &Diagnostic::operator<<(char val) { return *this << Twine(val); }
102 Diagnostic &Diagnostic::operator<<(const Twine &val) {
103   arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
104   return *this;
105 }
106 Diagnostic &Diagnostic::operator<<(Twine &&val) {
107   arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
108   return *this;
109 }
110 
111 Diagnostic &Diagnostic::operator<<(StringAttr val) {
112   arguments.push_back(DiagnosticArgument(val));
113   return *this;
114 }
115 
116 /// Stream in an OperationName.
117 Diagnostic &Diagnostic::operator<<(OperationName val) {
118   // An OperationName is stored in the context, so we don't need to worry about
119   // the lifetime of its data.
120   arguments.push_back(DiagnosticArgument(val.getStringRef()));
121   return *this;
122 }
123 
124 /// Adjusts operation printing flags used in diagnostics for the given severity
125 /// level.
126 static OpPrintingFlags adjustPrintingFlags(OpPrintingFlags flags,
127                                            DiagnosticSeverity severity) {
128   flags.useLocalScope();
129   flags.elideLargeElementsAttrs();
130   if (severity == DiagnosticSeverity::Error)
131     flags.printGenericOpForm();
132   return flags;
133 }
134 
135 /// Stream in an Operation.
136 Diagnostic &Diagnostic::operator<<(Operation &val) {
137   return appendOp(val, OpPrintingFlags());
138 }
139 Diagnostic &Diagnostic::appendOp(Operation &val, const OpPrintingFlags &flags) {
140   std::string str;
141   llvm::raw_string_ostream os(str);
142   val.print(os, adjustPrintingFlags(flags, severity));
143   return *this << os.str();
144 }
145 
146 /// Stream in a Value.
147 Diagnostic &Diagnostic::operator<<(Value val) {
148   std::string str;
149   llvm::raw_string_ostream os(str);
150   val.print(os, adjustPrintingFlags(OpPrintingFlags(), severity));
151   return *this << os.str();
152 }
153 
154 /// Outputs this diagnostic to a stream.
155 void Diagnostic::print(raw_ostream &os) const {
156   for (auto &arg : getArguments())
157     arg.print(os);
158 }
159 
160 /// Convert the diagnostic to a string.
161 std::string Diagnostic::str() const {
162   std::string str;
163   llvm::raw_string_ostream os(str);
164   print(os);
165   return os.str();
166 }
167 
168 /// Attaches a note to this diagnostic. A new location may be optionally
169 /// provided, if not, then the location defaults to the one specified for this
170 /// diagnostic. Notes may not be attached to other notes.
171 Diagnostic &Diagnostic::attachNote(Optional<Location> noteLoc) {
172   // We don't allow attaching notes to notes.
173   assert(severity != DiagnosticSeverity::Note &&
174          "cannot attach a note to a note");
175 
176   // If a location wasn't provided then reuse our location.
177   if (!noteLoc)
178     noteLoc = loc;
179 
180   /// Append and return a new note.
181   notes.push_back(
182       std::make_unique<Diagnostic>(*noteLoc, DiagnosticSeverity::Note));
183   return *notes.back();
184 }
185 
186 /// Allow a diagnostic to be converted to 'failure'.
187 Diagnostic::operator LogicalResult() const { return failure(); }
188 
189 //===----------------------------------------------------------------------===//
190 // InFlightDiagnostic
191 //===----------------------------------------------------------------------===//
192 
193 /// Allow an inflight diagnostic to be converted to 'failure', otherwise
194 /// 'success' if this is an empty diagnostic.
195 InFlightDiagnostic::operator LogicalResult() const {
196   return failure(isActive());
197 }
198 
199 /// Reports the diagnostic to the engine.
200 void InFlightDiagnostic::report() {
201   // If this diagnostic is still inflight and it hasn't been abandoned, then
202   // report it.
203   if (isInFlight()) {
204     owner->emit(std::move(*impl));
205     owner = nullptr;
206   }
207   impl.reset();
208 }
209 
210 /// Abandons this diagnostic.
211 void InFlightDiagnostic::abandon() { owner = nullptr; }
212 
213 //===----------------------------------------------------------------------===//
214 // DiagnosticEngineImpl
215 //===----------------------------------------------------------------------===//
216 
217 namespace mlir {
218 namespace detail {
219 struct DiagnosticEngineImpl {
220   /// Emit a diagnostic using the registered issue handle if present, or with
221   /// the default behavior if not.
222   void emit(Diagnostic &&diag);
223 
224   /// A mutex to ensure that diagnostics emission is thread-safe.
225   llvm::sys::SmartMutex<true> mutex;
226 
227   /// These are the handlers used to report diagnostics.
228   llvm::SmallMapVector<DiagnosticEngine::HandlerID, DiagnosticEngine::HandlerTy,
229                        2>
230       handlers;
231 
232   /// This is a unique identifier counter for diagnostic handlers in the
233   /// context. This id starts at 1 to allow for 0 to be used as a sentinel.
234   DiagnosticEngine::HandlerID uniqueHandlerId = 1;
235 };
236 } // namespace detail
237 } // namespace mlir
238 
239 /// Emit a diagnostic using the registered issue handle if present, or with
240 /// the default behavior if not.
241 void DiagnosticEngineImpl::emit(Diagnostic &&diag) {
242   llvm::sys::SmartScopedLock<true> lock(mutex);
243 
244   // Try to process the given diagnostic on one of the registered handlers.
245   // Handlers are walked in reverse order, so that the most recent handler is
246   // processed first.
247   for (auto &handlerIt : llvm::reverse(handlers))
248     if (succeeded(handlerIt.second(diag)))
249       return;
250 
251   // Otherwise, if this is an error we emit it to stderr.
252   if (diag.getSeverity() != DiagnosticSeverity::Error)
253     return;
254 
255   auto &os = llvm::errs();
256   if (!diag.getLocation().isa<UnknownLoc>())
257     os << diag.getLocation() << ": ";
258   os << "error: ";
259 
260   // The default behavior for errors is to emit them to stderr.
261   os << diag << '\n';
262   os.flush();
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // DiagnosticEngine
267 //===----------------------------------------------------------------------===//
268 
269 DiagnosticEngine::DiagnosticEngine() : impl(new DiagnosticEngineImpl()) {}
270 DiagnosticEngine::~DiagnosticEngine() = default;
271 
272 /// Register a new handler for diagnostics to the engine. This function returns
273 /// a unique identifier for the registered handler, which can be used to
274 /// unregister this handler at a later time.
275 auto DiagnosticEngine::registerHandler(HandlerTy handler) -> HandlerID {
276   llvm::sys::SmartScopedLock<true> lock(impl->mutex);
277   auto uniqueID = impl->uniqueHandlerId++;
278   impl->handlers.insert({uniqueID, std::move(handler)});
279   return uniqueID;
280 }
281 
282 /// Erase the registered diagnostic handler with the given identifier.
283 void DiagnosticEngine::eraseHandler(HandlerID handlerID) {
284   llvm::sys::SmartScopedLock<true> lock(impl->mutex);
285   impl->handlers.erase(handlerID);
286 }
287 
288 /// Emit a diagnostic using the registered issue handler if present, or with
289 /// the default behavior if not.
290 void DiagnosticEngine::emit(Diagnostic &&diag) {
291   assert(diag.getSeverity() != DiagnosticSeverity::Note &&
292          "notes should not be emitted directly");
293   impl->emit(std::move(diag));
294 }
295 
296 /// Helper function used to emit a diagnostic with an optionally empty twine
297 /// message. If the message is empty, then it is not inserted into the
298 /// diagnostic.
299 static InFlightDiagnostic
300 emitDiag(Location location, DiagnosticSeverity severity, const Twine &message) {
301   MLIRContext *ctx = location->getContext();
302   auto &diagEngine = ctx->getDiagEngine();
303   auto diag = diagEngine.emit(location, severity);
304   if (!message.isTriviallyEmpty())
305     diag << message;
306 
307   // Add the stack trace as a note if necessary.
308   if (ctx->shouldPrintStackTraceOnDiagnostic()) {
309     std::string bt;
310     {
311       llvm::raw_string_ostream stream(bt);
312       llvm::sys::PrintStackTrace(stream);
313     }
314     if (!bt.empty())
315       diag.attachNote() << "diagnostic emitted with trace:\n" << bt;
316   }
317 
318   return diag;
319 }
320 
321 /// Emit an error message using this location.
322 InFlightDiagnostic mlir::emitError(Location loc) { return emitError(loc, {}); }
323 InFlightDiagnostic mlir::emitError(Location loc, const Twine &message) {
324   return emitDiag(loc, DiagnosticSeverity::Error, message);
325 }
326 
327 /// Emit a warning message using this location.
328 InFlightDiagnostic mlir::emitWarning(Location loc) {
329   return emitWarning(loc, {});
330 }
331 InFlightDiagnostic mlir::emitWarning(Location loc, const Twine &message) {
332   return emitDiag(loc, DiagnosticSeverity::Warning, message);
333 }
334 
335 /// Emit a remark message using this location.
336 InFlightDiagnostic mlir::emitRemark(Location loc) {
337   return emitRemark(loc, {});
338 }
339 InFlightDiagnostic mlir::emitRemark(Location loc, const Twine &message) {
340   return emitDiag(loc, DiagnosticSeverity::Remark, message);
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // ScopedDiagnosticHandler
345 //===----------------------------------------------------------------------===//
346 
347 ScopedDiagnosticHandler::~ScopedDiagnosticHandler() {
348   if (handlerID)
349     ctx->getDiagEngine().eraseHandler(handlerID);
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // SourceMgrDiagnosticHandler
354 //===----------------------------------------------------------------------===//
355 namespace mlir {
356 namespace detail {
357 struct SourceMgrDiagnosticHandlerImpl {
358   /// Return the SrcManager buffer id for the specified file, or zero if none
359   /// can be found.
360   unsigned getSourceMgrBufferIDForFile(llvm::SourceMgr &mgr,
361                                        StringRef filename) {
362     // Check for an existing mapping to the buffer id for this file.
363     auto bufferIt = filenameToBufId.find(filename);
364     if (bufferIt != filenameToBufId.end())
365       return bufferIt->second;
366 
367     // Look for a buffer in the manager that has this filename.
368     for (unsigned i = 1, e = mgr.getNumBuffers() + 1; i != e; ++i) {
369       auto *buf = mgr.getMemoryBuffer(i);
370       if (buf->getBufferIdentifier() == filename)
371         return filenameToBufId[filename] = i;
372     }
373 
374     // Otherwise, try to load the source file.
375     std::string ignored;
376     unsigned id =
377         mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
378     filenameToBufId[filename] = id;
379     return id;
380   }
381 
382   /// Mapping between file name and buffer ID's.
383   llvm::StringMap<unsigned> filenameToBufId;
384 };
385 } // namespace detail
386 } // namespace mlir
387 
388 /// Return a processable FileLineColLoc from the given location.
389 static Optional<FileLineColLoc> getFileLineColLoc(Location loc) {
390   Optional<FileLineColLoc> firstFileLoc;
391   loc->walk([&](Location loc) {
392     if (FileLineColLoc fileLoc = loc.dyn_cast<FileLineColLoc>()) {
393       firstFileLoc = fileLoc;
394       return WalkResult::interrupt();
395     }
396     return WalkResult::advance();
397   });
398   return firstFileLoc;
399 }
400 
401 /// Return a processable CallSiteLoc from the given location.
402 static Optional<CallSiteLoc> getCallSiteLoc(Location loc) {
403   if (auto nameLoc = loc.dyn_cast<NameLoc>())
404     return getCallSiteLoc(loc.cast<NameLoc>().getChildLoc());
405   if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
406     return callLoc;
407   if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
408     for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
409       if (auto callLoc = getCallSiteLoc(subLoc)) {
410         return callLoc;
411       }
412     }
413     return llvm::None;
414   }
415   return llvm::None;
416 }
417 
418 /// Given a diagnostic kind, returns the LLVM DiagKind.
419 static llvm::SourceMgr::DiagKind getDiagKind(DiagnosticSeverity kind) {
420   switch (kind) {
421   case DiagnosticSeverity::Note:
422     return llvm::SourceMgr::DK_Note;
423   case DiagnosticSeverity::Warning:
424     return llvm::SourceMgr::DK_Warning;
425   case DiagnosticSeverity::Error:
426     return llvm::SourceMgr::DK_Error;
427   case DiagnosticSeverity::Remark:
428     return llvm::SourceMgr::DK_Remark;
429   }
430   llvm_unreachable("Unknown DiagnosticSeverity");
431 }
432 
433 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
434     llvm::SourceMgr &mgr, MLIRContext *ctx, raw_ostream &os,
435     ShouldShowLocFn &&shouldShowLocFn)
436     : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os),
437       shouldShowLocFn(std::move(shouldShowLocFn)),
438       impl(new SourceMgrDiagnosticHandlerImpl()) {
439   setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); });
440 }
441 
442 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
443     llvm::SourceMgr &mgr, MLIRContext *ctx, ShouldShowLocFn &&shouldShowLocFn)
444     : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs(),
445                                  std::move(shouldShowLocFn)) {}
446 
447 SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() = default;
448 
449 void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
450                                                 DiagnosticSeverity kind,
451                                                 bool displaySourceLine) {
452   // Extract a file location from this loc.
453   auto fileLoc = getFileLineColLoc(loc);
454 
455   // If one doesn't exist, then print the raw message without a source location.
456   if (!fileLoc) {
457     std::string str;
458     llvm::raw_string_ostream strOS(str);
459     if (!loc.isa<UnknownLoc>())
460       strOS << loc << ": ";
461     strOS << message;
462     return mgr.PrintMessage(os, SMLoc(), getDiagKind(kind), strOS.str());
463   }
464 
465   // Otherwise if we are displaying the source line, try to convert the file
466   // location to an SMLoc.
467   if (displaySourceLine) {
468     auto smloc = convertLocToSMLoc(*fileLoc);
469     if (smloc.isValid())
470       return mgr.PrintMessage(os, smloc, getDiagKind(kind), message);
471   }
472 
473   // If the conversion was unsuccessful, create a diagnostic with the file
474   // information. We manually combine the line and column to avoid asserts in
475   // the constructor of SMDiagnostic that takes a location.
476   std::string locStr;
477   llvm::raw_string_ostream locOS(locStr);
478   locOS << fileLoc->getFilename().getValue() << ":" << fileLoc->getLine() << ":"
479         << fileLoc->getColumn();
480   llvm::SMDiagnostic diag(locOS.str(), getDiagKind(kind), message.str());
481   diag.print(nullptr, os);
482 }
483 
484 /// Emit the given diagnostic with the held source manager.
485 void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) {
486   SmallVector<std::pair<Location, StringRef>> locationStack;
487   auto addLocToStack = [&](Location loc, StringRef locContext) {
488     if (Optional<Location> showableLoc = findLocToShow(loc))
489       locationStack.emplace_back(*showableLoc, locContext);
490   };
491 
492   // Add locations to display for this diagnostic.
493   Location loc = diag.getLocation();
494   addLocToStack(loc, /*locContext=*/{});
495 
496   // If the diagnostic location was a call site location, add the call stack as
497   // well.
498   if (auto callLoc = getCallSiteLoc(loc)) {
499     // Print the call stack while valid, or until the limit is reached.
500     loc = callLoc->getCaller();
501     for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) {
502       addLocToStack(loc, "called from");
503       if ((callLoc = getCallSiteLoc(loc)))
504         loc = callLoc->getCaller();
505       else
506         break;
507     }
508   }
509 
510   // If the location stack is empty, use the initial location.
511   if (locationStack.empty()) {
512     emitDiagnostic(diag.getLocation(), diag.str(), diag.getSeverity());
513 
514     // Otherwise, use the location stack.
515   } else {
516     emitDiagnostic(locationStack.front().first, diag.str(), diag.getSeverity());
517     for (auto &it : llvm::drop_begin(locationStack))
518       emitDiagnostic(it.first, it.second, DiagnosticSeverity::Note);
519   }
520 
521   // Emit each of the notes. Only display the source code if the location is
522   // different from the previous location.
523   for (auto &note : diag.getNotes()) {
524     emitDiagnostic(note.getLocation(), note.str(), note.getSeverity(),
525                    /*displaySourceLine=*/loc != note.getLocation());
526     loc = note.getLocation();
527   }
528 }
529 
530 /// Get a memory buffer for the given file, or nullptr if one is not found.
531 const llvm::MemoryBuffer *
532 SourceMgrDiagnosticHandler::getBufferForFile(StringRef filename) {
533   if (unsigned id = impl->getSourceMgrBufferIDForFile(mgr, filename))
534     return mgr.getMemoryBuffer(id);
535   return nullptr;
536 }
537 
538 Optional<Location> SourceMgrDiagnosticHandler::findLocToShow(Location loc) {
539   if (!shouldShowLocFn)
540     return loc;
541   if (!shouldShowLocFn(loc))
542     return llvm::None;
543 
544   // Recurse into the child locations of some of location types.
545   return TypeSwitch<LocationAttr, Optional<Location>>(loc)
546       .Case([&](CallSiteLoc callLoc) -> Optional<Location> {
547         // We recurse into the callee of a call site, as the caller will be
548         // emitted in a different note on the main diagnostic.
549         return findLocToShow(callLoc.getCallee());
550       })
551       .Case([&](FileLineColLoc) -> Optional<Location> { return loc; })
552       .Case([&](FusedLoc fusedLoc) -> Optional<Location> {
553         // Fused location is unique in that we try to find a sub-location to
554         // show, rather than the top-level location itself.
555         for (Location childLoc : fusedLoc.getLocations())
556           if (Optional<Location> showableLoc = findLocToShow(childLoc))
557             return showableLoc;
558         return llvm::None;
559       })
560       .Case([&](NameLoc nameLoc) -> Optional<Location> {
561         return findLocToShow(nameLoc.getChildLoc());
562       })
563       .Case([&](OpaqueLoc opaqueLoc) -> Optional<Location> {
564         // OpaqueLoc always falls back to a different source location.
565         return findLocToShow(opaqueLoc.getFallbackLocation());
566       })
567       .Case([](UnknownLoc) -> Optional<Location> {
568         // Prefer not to show unknown locations.
569         return llvm::None;
570       });
571 }
572 
573 /// Get a memory buffer for the given file, or the main file of the source
574 /// manager if one doesn't exist. This always returns non-null.
575 SMLoc SourceMgrDiagnosticHandler::convertLocToSMLoc(FileLineColLoc loc) {
576   // The column and line may be zero to represent unknown column and/or unknown
577   /// line/column information.
578   if (loc.getLine() == 0 || loc.getColumn() == 0)
579     return SMLoc();
580 
581   unsigned bufferId = impl->getSourceMgrBufferIDForFile(mgr, loc.getFilename());
582   if (!bufferId)
583     return SMLoc();
584   return mgr.FindLocForLineAndColumn(bufferId, loc.getLine(), loc.getColumn());
585 }
586 
587 //===----------------------------------------------------------------------===//
588 // SourceMgrDiagnosticVerifierHandler
589 //===----------------------------------------------------------------------===//
590 
591 namespace mlir {
592 namespace detail {
593 // Record the expected diagnostic's position, substring and whether it was
594 // seen.
595 struct ExpectedDiag {
596   DiagnosticSeverity kind;
597   unsigned lineNo;
598   StringRef substring;
599   SMLoc fileLoc;
600   bool matched;
601 };
602 
603 struct SourceMgrDiagnosticVerifierHandlerImpl {
604   SourceMgrDiagnosticVerifierHandlerImpl() : status(success()) {}
605 
606   /// Returns the expected diagnostics for the given source file.
607   Optional<MutableArrayRef<ExpectedDiag>> getExpectedDiags(StringRef bufName);
608 
609   /// Computes the expected diagnostics for the given source buffer.
610   MutableArrayRef<ExpectedDiag>
611   computeExpectedDiags(const llvm::MemoryBuffer *buf);
612 
613   /// The current status of the verifier.
614   LogicalResult status;
615 
616   /// A list of expected diagnostics for each buffer of the source manager.
617   llvm::StringMap<SmallVector<ExpectedDiag, 2>> expectedDiagsPerFile;
618 
619   /// Regex to match the expected diagnostics format.
620   llvm::Regex expected = llvm::Regex("expected-(error|note|remark|warning) "
621                                      "*(@([+-][0-9]+|above|below))? *{{(.*)}}");
622 };
623 } // namespace detail
624 } // namespace mlir
625 
626 /// Given a diagnostic kind, return a human readable string for it.
627 static StringRef getDiagKindStr(DiagnosticSeverity kind) {
628   switch (kind) {
629   case DiagnosticSeverity::Note:
630     return "note";
631   case DiagnosticSeverity::Warning:
632     return "warning";
633   case DiagnosticSeverity::Error:
634     return "error";
635   case DiagnosticSeverity::Remark:
636     return "remark";
637   }
638   llvm_unreachable("Unknown DiagnosticSeverity");
639 }
640 
641 /// Returns the expected diagnostics for the given source file.
642 Optional<MutableArrayRef<ExpectedDiag>>
643 SourceMgrDiagnosticVerifierHandlerImpl::getExpectedDiags(StringRef bufName) {
644   auto expectedDiags = expectedDiagsPerFile.find(bufName);
645   if (expectedDiags != expectedDiagsPerFile.end())
646     return MutableArrayRef<ExpectedDiag>(expectedDiags->second);
647   return llvm::None;
648 }
649 
650 /// Computes the expected diagnostics for the given source buffer.
651 MutableArrayRef<ExpectedDiag>
652 SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
653     const llvm::MemoryBuffer *buf) {
654   // If the buffer is invalid, return an empty list.
655   if (!buf)
656     return llvm::None;
657   auto &expectedDiags = expectedDiagsPerFile[buf->getBufferIdentifier()];
658 
659   // The number of the last line that did not correlate to a designator.
660   unsigned lastNonDesignatorLine = 0;
661 
662   // The indices of designators that apply to the next non designator line.
663   SmallVector<unsigned, 1> designatorsForNextLine;
664 
665   // Scan the file for expected-* designators.
666   SmallVector<StringRef, 100> lines;
667   buf->getBuffer().split(lines, '\n');
668   for (unsigned lineNo = 0, e = lines.size(); lineNo < e; ++lineNo) {
669     SmallVector<StringRef, 4> matches;
670     if (!expected.match(lines[lineNo], &matches)) {
671       // Check for designators that apply to this line.
672       if (!designatorsForNextLine.empty()) {
673         for (unsigned diagIndex : designatorsForNextLine)
674           expectedDiags[diagIndex].lineNo = lineNo + 1;
675         designatorsForNextLine.clear();
676       }
677       lastNonDesignatorLine = lineNo;
678       continue;
679     }
680 
681     // Point to the start of expected-*.
682     auto expectedStart = SMLoc::getFromPointer(matches[0].data());
683 
684     DiagnosticSeverity kind;
685     if (matches[1] == "error")
686       kind = DiagnosticSeverity::Error;
687     else if (matches[1] == "warning")
688       kind = DiagnosticSeverity::Warning;
689     else if (matches[1] == "remark")
690       kind = DiagnosticSeverity::Remark;
691     else {
692       assert(matches[1] == "note");
693       kind = DiagnosticSeverity::Note;
694     }
695 
696     ExpectedDiag record{kind, lineNo + 1, matches[4], expectedStart, false};
697     auto offsetMatch = matches[2];
698     if (!offsetMatch.empty()) {
699       offsetMatch = offsetMatch.drop_front(1);
700 
701       // Get the integer value without the @ and +/- prefix.
702       if (offsetMatch[0] == '+' || offsetMatch[0] == '-') {
703         int offset;
704         offsetMatch.drop_front().getAsInteger(0, offset);
705 
706         if (offsetMatch.front() == '+')
707           record.lineNo += offset;
708         else
709           record.lineNo -= offset;
710       } else if (offsetMatch.consume_front("above")) {
711         // If the designator applies 'above' we add it to the last non
712         // designator line.
713         record.lineNo = lastNonDesignatorLine + 1;
714       } else {
715         // Otherwise, this is a 'below' designator and applies to the next
716         // non-designator line.
717         assert(offsetMatch.consume_front("below"));
718         designatorsForNextLine.push_back(expectedDiags.size());
719 
720         // Set the line number to the last in the case that this designator ends
721         // up dangling.
722         record.lineNo = e;
723       }
724     }
725     expectedDiags.push_back(record);
726   }
727   return expectedDiags;
728 }
729 
730 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
731     llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out)
732     : SourceMgrDiagnosticHandler(srcMgr, ctx, out),
733       impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
734   // Compute the expected diagnostics for each of the current files in the
735   // source manager.
736   for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
737     (void)impl->computeExpectedDiags(mgr.getMemoryBuffer(i + 1));
738 
739   // Register a handler to verify the diagnostics.
740   setHandler([&](Diagnostic &diag) {
741     // Process the main diagnostics.
742     process(diag);
743 
744     // Process each of the notes.
745     for (auto &note : diag.getNotes())
746       process(note);
747   });
748 }
749 
750 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
751     llvm::SourceMgr &srcMgr, MLIRContext *ctx)
752     : SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {}
753 
754 SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
755   // Ensure that all expected diagnostics were handled.
756   (void)verify();
757 }
758 
759 /// Returns the status of the verifier and verifies that all expected
760 /// diagnostics were emitted. This return success if all diagnostics were
761 /// verified correctly, failure otherwise.
762 LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
763   // Verify that all expected errors were seen.
764   for (auto &expectedDiagsPair : impl->expectedDiagsPerFile) {
765     for (auto &err : expectedDiagsPair.second) {
766       if (err.matched)
767         continue;
768       SMRange range(err.fileLoc,
769                           SMLoc::getFromPointer(err.fileLoc.getPointer() +
770                                                       err.substring.size()));
771       mgr.PrintMessage(os, err.fileLoc, llvm::SourceMgr::DK_Error,
772                        "expected " + getDiagKindStr(err.kind) + " \"" +
773                            err.substring + "\" was not produced",
774                        range);
775       impl->status = failure();
776     }
777   }
778   impl->expectedDiagsPerFile.clear();
779   return impl->status;
780 }
781 
782 /// Process a single diagnostic.
783 void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
784   auto kind = diag.getSeverity();
785 
786   // Process a FileLineColLoc.
787   if (auto fileLoc = getFileLineColLoc(diag.getLocation()))
788     return process(*fileLoc, diag.str(), kind);
789 
790   emitDiagnostic(diag.getLocation(),
791                  "unexpected " + getDiagKindStr(kind) + ": " + diag.str(),
792                  DiagnosticSeverity::Error);
793   impl->status = failure();
794 }
795 
796 /// Process a FileLineColLoc diagnostic.
797 void SourceMgrDiagnosticVerifierHandler::process(FileLineColLoc loc,
798                                                  StringRef msg,
799                                                  DiagnosticSeverity kind) {
800   // Get the expected diagnostics for this file.
801   auto diags = impl->getExpectedDiags(loc.getFilename());
802   if (!diags)
803     diags = impl->computeExpectedDiags(getBufferForFile(loc.getFilename()));
804 
805   // Search for a matching expected diagnostic.
806   // If we find something that is close then emit a more specific error.
807   ExpectedDiag *nearMiss = nullptr;
808 
809   // If this was an expected error, remember that we saw it and return.
810   unsigned line = loc.getLine();
811   for (auto &e : *diags) {
812     if (line == e.lineNo && msg.contains(e.substring)) {
813       if (e.kind == kind) {
814         e.matched = true;
815         return;
816       }
817 
818       // If this only differs based on the diagnostic kind, then consider it
819       // to be a near miss.
820       nearMiss = &e;
821     }
822   }
823 
824   // Otherwise, emit an error for the near miss.
825   if (nearMiss)
826     mgr.PrintMessage(os, nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
827                      "'" + getDiagKindStr(kind) +
828                          "' diagnostic emitted when expecting a '" +
829                          getDiagKindStr(nearMiss->kind) + "'");
830   else
831     emitDiagnostic(loc, "unexpected " + getDiagKindStr(kind) + ": " + msg,
832                    DiagnosticSeverity::Error);
833   impl->status = failure();
834 }
835 
836 //===----------------------------------------------------------------------===//
837 // ParallelDiagnosticHandler
838 //===----------------------------------------------------------------------===//
839 
840 namespace mlir {
841 namespace detail {
842 struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
843   struct ThreadDiagnostic {
844     ThreadDiagnostic(size_t id, Diagnostic diag)
845         : id(id), diag(std::move(diag)) {}
846     bool operator<(const ThreadDiagnostic &rhs) const { return id < rhs.id; }
847 
848     /// The id for this diagnostic, this is used for ordering.
849     /// Note: This id corresponds to the ordered position of the current element
850     ///       being processed by a given thread.
851     size_t id;
852 
853     /// The diagnostic.
854     Diagnostic diag;
855   };
856 
857   ParallelDiagnosticHandlerImpl(MLIRContext *ctx) : context(ctx) {
858     handlerID = ctx->getDiagEngine().registerHandler([this](Diagnostic &diag) {
859       uint64_t tid = llvm::get_threadid();
860       llvm::sys::SmartScopedLock<true> lock(mutex);
861 
862       // If this thread is not tracked, then return failure to let another
863       // handler process this diagnostic.
864       if (!threadToOrderID.count(tid))
865         return failure();
866 
867       // Append a new diagnostic.
868       diagnostics.emplace_back(threadToOrderID[tid], std::move(diag));
869       return success();
870     });
871   }
872 
873   ~ParallelDiagnosticHandlerImpl() override {
874     // Erase this handler from the context.
875     context->getDiagEngine().eraseHandler(handlerID);
876 
877     // Early exit if there are no diagnostics, this is the common case.
878     if (diagnostics.empty())
879       return;
880 
881     // Emit the diagnostics back to the context.
882     emitDiagnostics([&](Diagnostic &diag) {
883       return context->getDiagEngine().emit(std::move(diag));
884     });
885   }
886 
887   /// Utility method to emit any held diagnostics.
888   void emitDiagnostics(llvm::function_ref<void(Diagnostic &)> emitFn) const {
889     // Stable sort all of the diagnostics that were emitted. This creates a
890     // deterministic ordering for the diagnostics based upon which order id they
891     // were emitted for.
892     std::stable_sort(diagnostics.begin(), diagnostics.end());
893 
894     // Emit each diagnostic to the context again.
895     for (ThreadDiagnostic &diag : diagnostics)
896       emitFn(diag.diag);
897   }
898 
899   /// Set the order id for the current thread.
900   void setOrderIDForThread(size_t orderID) {
901     uint64_t tid = llvm::get_threadid();
902     llvm::sys::SmartScopedLock<true> lock(mutex);
903     threadToOrderID[tid] = orderID;
904   }
905 
906   /// Remove the order id for the current thread.
907   void eraseOrderIDForThread() {
908     uint64_t tid = llvm::get_threadid();
909     llvm::sys::SmartScopedLock<true> lock(mutex);
910     threadToOrderID.erase(tid);
911   }
912 
913   /// Dump the current diagnostics that were inflight.
914   void print(raw_ostream &os) const override {
915     // Early exit if there are no diagnostics, this is the common case.
916     if (diagnostics.empty())
917       return;
918 
919     os << "In-Flight Diagnostics:\n";
920     emitDiagnostics([&](const Diagnostic &diag) {
921       os.indent(4);
922 
923       // Print each diagnostic with the format:
924       //   "<location>: <kind>: <msg>"
925       if (!diag.getLocation().isa<UnknownLoc>())
926         os << diag.getLocation() << ": ";
927       switch (diag.getSeverity()) {
928       case DiagnosticSeverity::Error:
929         os << "error: ";
930         break;
931       case DiagnosticSeverity::Warning:
932         os << "warning: ";
933         break;
934       case DiagnosticSeverity::Note:
935         os << "note: ";
936         break;
937       case DiagnosticSeverity::Remark:
938         os << "remark: ";
939         break;
940       }
941       os << diag << '\n';
942     });
943   }
944 
945   /// A smart mutex to lock access to the internal state.
946   llvm::sys::SmartMutex<true> mutex;
947 
948   /// A mapping between the thread id and the current order id.
949   DenseMap<uint64_t, size_t> threadToOrderID;
950 
951   /// An unordered list of diagnostics that were emitted.
952   mutable std::vector<ThreadDiagnostic> diagnostics;
953 
954   /// The unique id for the parallel handler.
955   DiagnosticEngine::HandlerID handlerID = 0;
956 
957   /// The context to emit the diagnostics to.
958   MLIRContext *context;
959 };
960 } // namespace detail
961 } // namespace mlir
962 
963 ParallelDiagnosticHandler::ParallelDiagnosticHandler(MLIRContext *ctx)
964     : impl(new ParallelDiagnosticHandlerImpl(ctx)) {}
965 ParallelDiagnosticHandler::~ParallelDiagnosticHandler() = default;
966 
967 /// Set the order id for the current thread.
968 void ParallelDiagnosticHandler::setOrderIDForThread(size_t orderID) {
969   impl->setOrderIDForThread(orderID);
970 }
971 
972 /// Remove the order id for the current thread. This removes the thread from
973 /// diagnostics tracking.
974 void ParallelDiagnosticHandler::eraseOrderIDForThread() {
975   impl->eraseOrderIDForThread();
976 }
977