1 //===- Diagnostics.cpp - MLIR Diagnostics ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Diagnostics.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Location.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/MapVector.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Mutex.h"
20 #include "llvm/Support/PrettyStackTrace.h"
21 #include "llvm/Support/Regex.h"
22 #include "llvm/Support/Signals.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include "llvm/Support/raw_ostream.h"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 // DiagnosticArgument
31 //===----------------------------------------------------------------------===//
32 
33 /// Construct from an Attribute.
34 DiagnosticArgument::DiagnosticArgument(Attribute attr)
35     : kind(DiagnosticArgumentKind::Attribute),
36       opaqueVal(reinterpret_cast<intptr_t>(attr.getAsOpaquePointer())) {}
37 
38 /// Construct from a Type.
39 DiagnosticArgument::DiagnosticArgument(Type val)
40     : kind(DiagnosticArgumentKind::Type),
41       opaqueVal(reinterpret_cast<intptr_t>(val.getAsOpaquePointer())) {}
42 
43 /// Returns this argument as an Attribute.
44 Attribute DiagnosticArgument::getAsAttribute() const {
45   assert(getKind() == DiagnosticArgumentKind::Attribute);
46   return Attribute::getFromOpaquePointer(
47       reinterpret_cast<const void *>(opaqueVal));
48 }
49 
50 /// Returns this argument as a Type.
51 Type DiagnosticArgument::getAsType() const {
52   assert(getKind() == DiagnosticArgumentKind::Type);
53   return Type::getFromOpaquePointer(reinterpret_cast<const void *>(opaqueVal));
54 }
55 
56 /// Outputs this argument to a stream.
57 void DiagnosticArgument::print(raw_ostream &os) const {
58   switch (kind) {
59   case DiagnosticArgumentKind::Attribute:
60     os << getAsAttribute();
61     break;
62   case DiagnosticArgumentKind::Double:
63     os << getAsDouble();
64     break;
65   case DiagnosticArgumentKind::Integer:
66     os << getAsInteger();
67     break;
68   case DiagnosticArgumentKind::String:
69     os << getAsString();
70     break;
71   case DiagnosticArgumentKind::Type:
72     os << '\'' << getAsType() << '\'';
73     break;
74   case DiagnosticArgumentKind::Unsigned:
75     os << getAsUnsigned();
76     break;
77   }
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // Diagnostic
82 //===----------------------------------------------------------------------===//
83 
84 /// Convert a Twine to a StringRef. Memory used for generating the StringRef is
85 /// stored in 'strings'.
86 static StringRef twineToStrRef(const Twine &val,
87                                std::vector<std::unique_ptr<char[]>> &strings) {
88   // Allocate memory to hold this string.
89   SmallString<64> data;
90   auto strRef = val.toStringRef(data);
91   if (strRef.empty())
92     return strRef;
93 
94   strings.push_back(std::unique_ptr<char[]>(new char[strRef.size()]));
95   memcpy(&strings.back()[0], strRef.data(), strRef.size());
96   // Return a reference to the new string.
97   return StringRef(&strings.back()[0], strRef.size());
98 }
99 
100 /// Stream in a Twine argument.
101 Diagnostic &Diagnostic::operator<<(char val) { return *this << Twine(val); }
102 Diagnostic &Diagnostic::operator<<(const Twine &val) {
103   arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
104   return *this;
105 }
106 Diagnostic &Diagnostic::operator<<(Twine &&val) {
107   arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
108   return *this;
109 }
110 
111 Diagnostic &Diagnostic::operator<<(StringAttr val) {
112   arguments.push_back(DiagnosticArgument(val));
113   return *this;
114 }
115 
116 /// Stream in an OperationName.
117 Diagnostic &Diagnostic::operator<<(OperationName val) {
118   // An OperationName is stored in the context, so we don't need to worry about
119   // the lifetime of its data.
120   arguments.push_back(DiagnosticArgument(val.getStringRef()));
121   return *this;
122 }
123 
124 /// Adjusts operation printing flags used in diagnostics for the given severity
125 /// level.
126 static OpPrintingFlags adjustPrintingFlags(OpPrintingFlags flags,
127                                            DiagnosticSeverity severity) {
128   flags.useLocalScope();
129   flags.elideLargeElementsAttrs();
130   if (severity == DiagnosticSeverity::Error)
131     flags.printGenericOpForm();
132   return flags;
133 }
134 
135 /// Stream in an Operation.
136 Diagnostic &Diagnostic::operator<<(Operation &val) {
137   return appendOp(val, OpPrintingFlags());
138 }
139 Diagnostic &Diagnostic::appendOp(Operation &val, const OpPrintingFlags &flags) {
140   std::string str;
141   llvm::raw_string_ostream os(str);
142   val.print(os, adjustPrintingFlags(flags, severity));
143   return *this << os.str();
144 }
145 
146 /// Stream in a Value.
147 Diagnostic &Diagnostic::operator<<(Value val) {
148   std::string str;
149   llvm::raw_string_ostream os(str);
150   val.print(os, adjustPrintingFlags(OpPrintingFlags(), severity));
151   return *this << os.str();
152 }
153 
154 /// Outputs this diagnostic to a stream.
155 void Diagnostic::print(raw_ostream &os) const {
156   for (auto &arg : getArguments())
157     arg.print(os);
158 }
159 
160 /// Convert the diagnostic to a string.
161 std::string Diagnostic::str() const {
162   std::string str;
163   llvm::raw_string_ostream os(str);
164   print(os);
165   return os.str();
166 }
167 
168 /// Attaches a note to this diagnostic. A new location may be optionally
169 /// provided, if not, then the location defaults to the one specified for this
170 /// diagnostic. Notes may not be attached to other notes.
171 Diagnostic &Diagnostic::attachNote(Optional<Location> noteLoc) {
172   // We don't allow attaching notes to notes.
173   assert(severity != DiagnosticSeverity::Note &&
174          "cannot attach a note to a note");
175 
176   // If a location wasn't provided then reuse our location.
177   if (!noteLoc)
178     noteLoc = loc;
179 
180   /// Append and return a new note.
181   notes.push_back(
182       std::make_unique<Diagnostic>(*noteLoc, DiagnosticSeverity::Note));
183   return *notes.back();
184 }
185 
186 /// Allow a diagnostic to be converted to 'failure'.
187 Diagnostic::operator LogicalResult() const { return failure(); }
188 
189 //===----------------------------------------------------------------------===//
190 // InFlightDiagnostic
191 //===----------------------------------------------------------------------===//
192 
193 /// Allow an inflight diagnostic to be converted to 'failure', otherwise
194 /// 'success' if this is an empty diagnostic.
195 InFlightDiagnostic::operator LogicalResult() const {
196   return failure(isActive());
197 }
198 
199 /// Reports the diagnostic to the engine.
200 void InFlightDiagnostic::report() {
201   // If this diagnostic is still inflight and it hasn't been abandoned, then
202   // report it.
203   if (isInFlight()) {
204     owner->emit(std::move(*impl));
205     owner = nullptr;
206   }
207   impl.reset();
208 }
209 
210 /// Abandons this diagnostic.
211 void InFlightDiagnostic::abandon() { owner = nullptr; }
212 
213 //===----------------------------------------------------------------------===//
214 // DiagnosticEngineImpl
215 //===----------------------------------------------------------------------===//
216 
217 namespace mlir {
218 namespace detail {
219 struct DiagnosticEngineImpl {
220   /// Emit a diagnostic using the registered issue handle if present, or with
221   /// the default behavior if not.
222   void emit(Diagnostic &&diag);
223 
224   /// A mutex to ensure that diagnostics emission is thread-safe.
225   llvm::sys::SmartMutex<true> mutex;
226 
227   /// These are the handlers used to report diagnostics.
228   llvm::SmallMapVector<DiagnosticEngine::HandlerID, DiagnosticEngine::HandlerTy,
229                        2>
230       handlers;
231 
232   /// This is a unique identifier counter for diagnostic handlers in the
233   /// context. This id starts at 1 to allow for 0 to be used as a sentinel.
234   DiagnosticEngine::HandlerID uniqueHandlerId = 1;
235 };
236 } // namespace detail
237 } // namespace mlir
238 
239 /// Emit a diagnostic using the registered issue handle if present, or with
240 /// the default behavior if not.
241 void DiagnosticEngineImpl::emit(Diagnostic &&diag) {
242   llvm::sys::SmartScopedLock<true> lock(mutex);
243 
244   // Try to process the given diagnostic on one of the registered handlers.
245   // Handlers are walked in reverse order, so that the most recent handler is
246   // processed first.
247   for (auto &handlerIt : llvm::reverse(handlers))
248     if (succeeded(handlerIt.second(diag)))
249       return;
250 
251   // Otherwise, if this is an error we emit it to stderr.
252   if (diag.getSeverity() != DiagnosticSeverity::Error)
253     return;
254 
255   auto &os = llvm::errs();
256   if (!diag.getLocation().isa<UnknownLoc>())
257     os << diag.getLocation() << ": ";
258   os << "error: ";
259 
260   // The default behavior for errors is to emit them to stderr.
261   os << diag << '\n';
262   os.flush();
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // DiagnosticEngine
267 //===----------------------------------------------------------------------===//
268 
269 DiagnosticEngine::DiagnosticEngine() : impl(new DiagnosticEngineImpl()) {}
270 DiagnosticEngine::~DiagnosticEngine() = default;
271 
272 /// Register a new handler for diagnostics to the engine. This function returns
273 /// a unique identifier for the registered handler, which can be used to
274 /// unregister this handler at a later time.
275 auto DiagnosticEngine::registerHandler(HandlerTy handler) -> HandlerID {
276   llvm::sys::SmartScopedLock<true> lock(impl->mutex);
277   auto uniqueID = impl->uniqueHandlerId++;
278   impl->handlers.insert({uniqueID, std::move(handler)});
279   return uniqueID;
280 }
281 
282 /// Erase the registered diagnostic handler with the given identifier.
283 void DiagnosticEngine::eraseHandler(HandlerID handlerID) {
284   llvm::sys::SmartScopedLock<true> lock(impl->mutex);
285   impl->handlers.erase(handlerID);
286 }
287 
288 /// Emit a diagnostic using the registered issue handler if present, or with
289 /// the default behavior if not.
290 void DiagnosticEngine::emit(Diagnostic &&diag) {
291   assert(diag.getSeverity() != DiagnosticSeverity::Note &&
292          "notes should not be emitted directly");
293   impl->emit(std::move(diag));
294 }
295 
296 /// Helper function used to emit a diagnostic with an optionally empty twine
297 /// message. If the message is empty, then it is not inserted into the
298 /// diagnostic.
299 static InFlightDiagnostic
300 emitDiag(Location location, DiagnosticSeverity severity, const Twine &message) {
301   MLIRContext *ctx = location->getContext();
302   auto &diagEngine = ctx->getDiagEngine();
303   auto diag = diagEngine.emit(location, severity);
304   if (!message.isTriviallyEmpty())
305     diag << message;
306 
307   // Add the stack trace as a note if necessary.
308   if (ctx->shouldPrintStackTraceOnDiagnostic()) {
309     std::string bt;
310     {
311       llvm::raw_string_ostream stream(bt);
312       llvm::sys::PrintStackTrace(stream);
313     }
314     if (!bt.empty())
315       diag.attachNote() << "diagnostic emitted with trace:\n" << bt;
316   }
317 
318   return diag;
319 }
320 
321 /// Emit an error message using this location.
322 InFlightDiagnostic mlir::emitError(Location loc) { return emitError(loc, {}); }
323 InFlightDiagnostic mlir::emitError(Location loc, const Twine &message) {
324   return emitDiag(loc, DiagnosticSeverity::Error, message);
325 }
326 
327 /// Emit a warning message using this location.
328 InFlightDiagnostic mlir::emitWarning(Location loc) {
329   return emitWarning(loc, {});
330 }
331 InFlightDiagnostic mlir::emitWarning(Location loc, const Twine &message) {
332   return emitDiag(loc, DiagnosticSeverity::Warning, message);
333 }
334 
335 /// Emit a remark message using this location.
336 InFlightDiagnostic mlir::emitRemark(Location loc) {
337   return emitRemark(loc, {});
338 }
339 InFlightDiagnostic mlir::emitRemark(Location loc, const Twine &message) {
340   return emitDiag(loc, DiagnosticSeverity::Remark, message);
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // ScopedDiagnosticHandler
345 //===----------------------------------------------------------------------===//
346 
347 ScopedDiagnosticHandler::~ScopedDiagnosticHandler() {
348   if (handlerID)
349     ctx->getDiagEngine().eraseHandler(handlerID);
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // SourceMgrDiagnosticHandler
354 //===----------------------------------------------------------------------===//
355 namespace mlir {
356 namespace detail {
357 struct SourceMgrDiagnosticHandlerImpl {
358   /// Return the SrcManager buffer id for the specified file, or zero if none
359   /// can be found.
360   unsigned getSourceMgrBufferIDForFile(llvm::SourceMgr &mgr,
361                                        StringRef filename) {
362     // Check for an existing mapping to the buffer id for this file.
363     auto bufferIt = filenameToBufId.find(filename);
364     if (bufferIt != filenameToBufId.end())
365       return bufferIt->second;
366 
367     // Look for a buffer in the manager that has this filename.
368     for (unsigned i = 1, e = mgr.getNumBuffers() + 1; i != e; ++i) {
369       auto *buf = mgr.getMemoryBuffer(i);
370       if (buf->getBufferIdentifier() == filename)
371         return filenameToBufId[filename] = i;
372     }
373 
374     // Otherwise, try to load the source file.
375     std::string ignored;
376     unsigned id =
377         mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
378     filenameToBufId[filename] = id;
379     return id;
380   }
381 
382   /// Mapping between file name and buffer ID's.
383   llvm::StringMap<unsigned> filenameToBufId;
384 };
385 } // namespace detail
386 } // namespace mlir
387 
388 /// Return a processable FileLineColLoc from the given location.
389 static Optional<FileLineColLoc> getFileLineColLoc(Location loc) {
390   Optional<FileLineColLoc> firstFileLoc;
391   loc->walk([&](Location loc) {
392     if (FileLineColLoc fileLoc = loc.dyn_cast<FileLineColLoc>()) {
393       firstFileLoc = fileLoc;
394       return WalkResult::interrupt();
395     }
396     return WalkResult::advance();
397   });
398   return firstFileLoc;
399 }
400 
401 /// Return a processable CallSiteLoc from the given location.
402 static Optional<CallSiteLoc> getCallSiteLoc(Location loc) {
403   if (auto nameLoc = loc.dyn_cast<NameLoc>())
404     return getCallSiteLoc(loc.cast<NameLoc>().getChildLoc());
405   if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
406     return callLoc;
407   if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
408     for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
409       if (auto callLoc = getCallSiteLoc(subLoc)) {
410         return callLoc;
411       }
412     }
413     return llvm::None;
414   }
415   return llvm::None;
416 }
417 
418 /// Given a diagnostic kind, returns the LLVM DiagKind.
419 static llvm::SourceMgr::DiagKind getDiagKind(DiagnosticSeverity kind) {
420   switch (kind) {
421   case DiagnosticSeverity::Note:
422     return llvm::SourceMgr::DK_Note;
423   case DiagnosticSeverity::Warning:
424     return llvm::SourceMgr::DK_Warning;
425   case DiagnosticSeverity::Error:
426     return llvm::SourceMgr::DK_Error;
427   case DiagnosticSeverity::Remark:
428     return llvm::SourceMgr::DK_Remark;
429   }
430   llvm_unreachable("Unknown DiagnosticSeverity");
431 }
432 
433 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
434     llvm::SourceMgr &mgr, MLIRContext *ctx, raw_ostream &os,
435     ShouldShowLocFn &&shouldShowLocFn)
436     : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os),
437       shouldShowLocFn(std::move(shouldShowLocFn)),
438       impl(new SourceMgrDiagnosticHandlerImpl()) {
439   setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); });
440 }
441 
442 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
443     llvm::SourceMgr &mgr, MLIRContext *ctx, ShouldShowLocFn &&shouldShowLocFn)
444     : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs(),
445                                  std::move(shouldShowLocFn)) {}
446 
447 SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() = default;
448 
449 void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
450                                                 DiagnosticSeverity kind,
451                                                 bool displaySourceLine) {
452   // Extract a file location from this loc.
453   auto fileLoc = getFileLineColLoc(loc);
454 
455   // If one doesn't exist, then print the raw message without a source location.
456   if (!fileLoc) {
457     std::string str;
458     llvm::raw_string_ostream strOS(str);
459     if (!loc.isa<UnknownLoc>())
460       strOS << loc << ": ";
461     strOS << message;
462     return mgr.PrintMessage(os, SMLoc(), getDiagKind(kind), strOS.str());
463   }
464 
465   // Otherwise if we are displaying the source line, try to convert the file
466   // location to an SMLoc.
467   if (displaySourceLine) {
468     auto smloc = convertLocToSMLoc(*fileLoc);
469     if (smloc.isValid())
470       return mgr.PrintMessage(os, smloc, getDiagKind(kind), message);
471   }
472 
473   // If the conversion was unsuccessful, create a diagnostic with the file
474   // information. We manually combine the line and column to avoid asserts in
475   // the constructor of SMDiagnostic that takes a location.
476   std::string locStr;
477   llvm::raw_string_ostream locOS(locStr);
478   locOS << fileLoc->getFilename().getValue() << ":" << fileLoc->getLine() << ":"
479         << fileLoc->getColumn();
480   llvm::SMDiagnostic diag(locOS.str(), getDiagKind(kind), message.str());
481   diag.print(nullptr, os);
482 }
483 
484 /// Emit the given diagnostic with the held source manager.
485 void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) {
486   SmallVector<std::pair<Location, StringRef>> locationStack;
487   auto addLocToStack = [&](Location loc, StringRef locContext) {
488     if (Optional<Location> showableLoc = findLocToShow(loc))
489       locationStack.emplace_back(*showableLoc, locContext);
490   };
491 
492   // Add locations to display for this diagnostic.
493   Location loc = diag.getLocation();
494   addLocToStack(loc, /*locContext=*/{});
495 
496   // If the diagnostic location was a call site location, add the call stack as
497   // well.
498   if (auto callLoc = getCallSiteLoc(loc)) {
499     // Print the call stack while valid, or until the limit is reached.
500     loc = callLoc->getCaller();
501     for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) {
502       addLocToStack(loc, "called from");
503       if ((callLoc = getCallSiteLoc(loc)))
504         loc = callLoc->getCaller();
505       else
506         break;
507     }
508   }
509 
510   // If the location stack is empty, use the initial location.
511   if (locationStack.empty()) {
512     emitDiagnostic(diag.getLocation(), diag.str(), diag.getSeverity());
513 
514     // Otherwise, use the location stack.
515   } else {
516     emitDiagnostic(locationStack.front().first, diag.str(), diag.getSeverity());
517     for (auto &it : llvm::drop_begin(locationStack))
518       emitDiagnostic(it.first, it.second, DiagnosticSeverity::Note);
519   }
520 
521   // Emit each of the notes. Only display the source code if the location is
522   // different from the previous location.
523   for (auto &note : diag.getNotes()) {
524     emitDiagnostic(note.getLocation(), note.str(), note.getSeverity(),
525                    /*displaySourceLine=*/loc != note.getLocation());
526     loc = note.getLocation();
527   }
528 }
529 
530 /// Get a memory buffer for the given file, or nullptr if one is not found.
531 const llvm::MemoryBuffer *
532 SourceMgrDiagnosticHandler::getBufferForFile(StringRef filename) {
533   if (unsigned id = impl->getSourceMgrBufferIDForFile(mgr, filename))
534     return mgr.getMemoryBuffer(id);
535   return nullptr;
536 }
537 
538 Optional<Location> SourceMgrDiagnosticHandler::findLocToShow(Location loc) {
539   if (!shouldShowLocFn)
540     return loc;
541   if (!shouldShowLocFn(loc))
542     return llvm::None;
543 
544   // Recurse into the child locations of some of location types.
545   return TypeSwitch<LocationAttr, Optional<Location>>(loc)
546       .Case([&](CallSiteLoc callLoc) -> Optional<Location> {
547         // We recurse into the callee of a call site, as the caller will be
548         // emitted in a different note on the main diagnostic.
549         return findLocToShow(callLoc.getCallee());
550       })
551       .Case([&](FileLineColLoc) -> Optional<Location> { return loc; })
552       .Case([&](FusedLoc fusedLoc) -> Optional<Location> {
553         // Fused location is unique in that we try to find a sub-location to
554         // show, rather than the top-level location itself.
555         for (Location childLoc : fusedLoc.getLocations())
556           if (Optional<Location> showableLoc = findLocToShow(childLoc))
557             return showableLoc;
558         return llvm::None;
559       })
560       .Case([&](NameLoc nameLoc) -> Optional<Location> {
561         return findLocToShow(nameLoc.getChildLoc());
562       })
563       .Case([&](OpaqueLoc opaqueLoc) -> Optional<Location> {
564         // OpaqueLoc always falls back to a different source location.
565         return findLocToShow(opaqueLoc.getFallbackLocation());
566       })
567       .Case([](UnknownLoc) -> Optional<Location> {
568         // Prefer not to show unknown locations.
569         return llvm::None;
570       });
571 }
572 
573 /// Get a memory buffer for the given file, or the main file of the source
574 /// manager if one doesn't exist. This always returns non-null.
575 SMLoc SourceMgrDiagnosticHandler::convertLocToSMLoc(FileLineColLoc loc) {
576   // The column and line may be zero to represent unknown column and/or unknown
577   /// line/column information.
578   if (loc.getLine() == 0 || loc.getColumn() == 0)
579     return SMLoc();
580 
581   unsigned bufferId = impl->getSourceMgrBufferIDForFile(mgr, loc.getFilename());
582   if (!bufferId)
583     return SMLoc();
584   return mgr.FindLocForLineAndColumn(bufferId, loc.getLine(), loc.getColumn());
585 }
586 
587 //===----------------------------------------------------------------------===//
588 // SourceMgrDiagnosticVerifierHandler
589 //===----------------------------------------------------------------------===//
590 
591 namespace mlir {
592 namespace detail {
593 /// This class represents an expected output diagnostic.
594 struct ExpectedDiag {
595   ExpectedDiag(DiagnosticSeverity kind, unsigned lineNo, SMLoc fileLoc,
596                StringRef substring)
597       : kind(kind), lineNo(lineNo), fileLoc(fileLoc), substring(substring) {}
598 
599   /// Emit an error at the location referenced by this diagnostic.
600   LogicalResult emitError(raw_ostream &os, llvm::SourceMgr &mgr,
601                           const Twine &msg) {
602     SMRange range(fileLoc, SMLoc::getFromPointer(fileLoc.getPointer() +
603                                                  substring.size()));
604     mgr.PrintMessage(os, fileLoc, llvm::SourceMgr::DK_Error, msg, range);
605     return failure();
606   }
607 
608   /// Returns true if this diagnostic matches the given string.
609   bool match(StringRef str) const {
610     // If this isn't a regex diagnostic, we simply check if the string was
611     // contained.
612     if (substringRegex)
613       return substringRegex->match(str);
614     return str.contains(substring);
615   }
616 
617   /// Compute the regex matcher for this diagnostic, using the provided stream
618   /// and manager to emit diagnostics as necessary.
619   LogicalResult computeRegex(raw_ostream &os, llvm::SourceMgr &mgr) {
620     std::string regexStr;
621     llvm::raw_string_ostream regexOS(regexStr);
622     StringRef strToProcess = substring;
623     while (!strToProcess.empty()) {
624       // Find the next regex block.
625       size_t regexIt = strToProcess.find("{{");
626       if (regexIt == StringRef::npos) {
627         regexOS << llvm::Regex::escape(strToProcess);
628         break;
629       }
630       regexOS << llvm::Regex::escape(strToProcess.take_front(regexIt));
631       strToProcess = strToProcess.drop_front(regexIt + 2);
632 
633       // Find the end of the regex block.
634       size_t regexEndIt = strToProcess.find("}}");
635       if (regexEndIt == StringRef::npos)
636         return emitError(os, mgr, "found start of regex with no end '}}'");
637       StringRef regexStr = strToProcess.take_front(regexEndIt);
638 
639       // Validate that the regex is actually valid.
640       std::string regexError;
641       if (!llvm::Regex(regexStr).isValid(regexError))
642         return emitError(os, mgr, "invalid regex: " + regexError);
643 
644       regexOS << '(' << regexStr << ')';
645       strToProcess = strToProcess.drop_front(regexEndIt + 2);
646     }
647     substringRegex = llvm::Regex(regexOS.str());
648     return success();
649   }
650 
651   /// The severity of the diagnosic expected.
652   DiagnosticSeverity kind;
653   /// The line number the expected diagnostic should be on.
654   unsigned lineNo;
655   /// The location of the expected diagnostic within the input file.
656   SMLoc fileLoc;
657   /// A flag indicating if the expected diagnostic has been matched yet.
658   bool matched = false;
659   /// The substring that is expected to be within the diagnostic.
660   StringRef substring;
661   /// An optional regex matcher, if the expected diagnostic sub-string was a
662   /// regex string.
663   Optional<llvm::Regex> substringRegex;
664 };
665 
666 struct SourceMgrDiagnosticVerifierHandlerImpl {
667   SourceMgrDiagnosticVerifierHandlerImpl() : status(success()) {}
668 
669   /// Returns the expected diagnostics for the given source file.
670   Optional<MutableArrayRef<ExpectedDiag>> getExpectedDiags(StringRef bufName);
671 
672   /// Computes the expected diagnostics for the given source buffer.
673   MutableArrayRef<ExpectedDiag>
674   computeExpectedDiags(raw_ostream &os, llvm::SourceMgr &mgr,
675                        const llvm::MemoryBuffer *buf);
676 
677   /// The current status of the verifier.
678   LogicalResult status;
679 
680   /// A list of expected diagnostics for each buffer of the source manager.
681   llvm::StringMap<SmallVector<ExpectedDiag, 2>> expectedDiagsPerFile;
682 
683   /// Regex to match the expected diagnostics format.
684   llvm::Regex expected =
685       llvm::Regex("expected-(error|note|remark|warning)(-re)? "
686                   "*(@([+-][0-9]+|above|below))? *{{(.*)}}$");
687 };
688 } // namespace detail
689 } // namespace mlir
690 
691 /// Given a diagnostic kind, return a human readable string for it.
692 static StringRef getDiagKindStr(DiagnosticSeverity kind) {
693   switch (kind) {
694   case DiagnosticSeverity::Note:
695     return "note";
696   case DiagnosticSeverity::Warning:
697     return "warning";
698   case DiagnosticSeverity::Error:
699     return "error";
700   case DiagnosticSeverity::Remark:
701     return "remark";
702   }
703   llvm_unreachable("Unknown DiagnosticSeverity");
704 }
705 
706 Optional<MutableArrayRef<ExpectedDiag>>
707 SourceMgrDiagnosticVerifierHandlerImpl::getExpectedDiags(StringRef bufName) {
708   auto expectedDiags = expectedDiagsPerFile.find(bufName);
709   if (expectedDiags != expectedDiagsPerFile.end())
710     return MutableArrayRef<ExpectedDiag>(expectedDiags->second);
711   return llvm::None;
712 }
713 
714 MutableArrayRef<ExpectedDiag>
715 SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
716     raw_ostream &os, llvm::SourceMgr &mgr, const llvm::MemoryBuffer *buf) {
717   // If the buffer is invalid, return an empty list.
718   if (!buf)
719     return llvm::None;
720   auto &expectedDiags = expectedDiagsPerFile[buf->getBufferIdentifier()];
721 
722   // The number of the last line that did not correlate to a designator.
723   unsigned lastNonDesignatorLine = 0;
724 
725   // The indices of designators that apply to the next non designator line.
726   SmallVector<unsigned, 1> designatorsForNextLine;
727 
728   // Scan the file for expected-* designators.
729   SmallVector<StringRef, 100> lines;
730   buf->getBuffer().split(lines, '\n');
731   for (unsigned lineNo = 0, e = lines.size(); lineNo < e; ++lineNo) {
732     SmallVector<StringRef, 4> matches;
733     if (!expected.match(lines[lineNo].rtrim(), &matches)) {
734       // Check for designators that apply to this line.
735       if (!designatorsForNextLine.empty()) {
736         for (unsigned diagIndex : designatorsForNextLine)
737           expectedDiags[diagIndex].lineNo = lineNo + 1;
738         designatorsForNextLine.clear();
739       }
740       lastNonDesignatorLine = lineNo;
741       continue;
742     }
743 
744     // Point to the start of expected-*.
745     SMLoc expectedStart = SMLoc::getFromPointer(matches[0].data());
746 
747     DiagnosticSeverity kind;
748     if (matches[1] == "error")
749       kind = DiagnosticSeverity::Error;
750     else if (matches[1] == "warning")
751       kind = DiagnosticSeverity::Warning;
752     else if (matches[1] == "remark")
753       kind = DiagnosticSeverity::Remark;
754     else {
755       assert(matches[1] == "note");
756       kind = DiagnosticSeverity::Note;
757     }
758     ExpectedDiag record(kind, lineNo + 1, expectedStart, matches[5]);
759 
760     // Check to see if this is a regex match, i.e. it includes the `-re`.
761     if (!matches[2].empty() && failed(record.computeRegex(os, mgr))) {
762       status = failure();
763       continue;
764     }
765 
766     StringRef offsetMatch = matches[3];
767     if (!offsetMatch.empty()) {
768       offsetMatch = offsetMatch.drop_front(1);
769 
770       // Get the integer value without the @ and +/- prefix.
771       if (offsetMatch[0] == '+' || offsetMatch[0] == '-') {
772         int offset;
773         offsetMatch.drop_front().getAsInteger(0, offset);
774 
775         if (offsetMatch.front() == '+')
776           record.lineNo += offset;
777         else
778           record.lineNo -= offset;
779       } else if (offsetMatch.consume_front("above")) {
780         // If the designator applies 'above' we add it to the last non
781         // designator line.
782         record.lineNo = lastNonDesignatorLine + 1;
783       } else {
784         // Otherwise, this is a 'below' designator and applies to the next
785         // non-designator line.
786         assert(offsetMatch.consume_front("below"));
787         designatorsForNextLine.push_back(expectedDiags.size());
788 
789         // Set the line number to the last in the case that this designator ends
790         // up dangling.
791         record.lineNo = e;
792       }
793     }
794     expectedDiags.emplace_back(std::move(record));
795   }
796   return expectedDiags;
797 }
798 
799 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
800     llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out)
801     : SourceMgrDiagnosticHandler(srcMgr, ctx, out),
802       impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
803   // Compute the expected diagnostics for each of the current files in the
804   // source manager.
805   for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
806     (void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1));
807 
808   // Register a handler to verify the diagnostics.
809   setHandler([&](Diagnostic &diag) {
810     // Process the main diagnostics.
811     process(diag);
812 
813     // Process each of the notes.
814     for (auto &note : diag.getNotes())
815       process(note);
816   });
817 }
818 
819 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
820     llvm::SourceMgr &srcMgr, MLIRContext *ctx)
821     : SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {}
822 
823 SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
824   // Ensure that all expected diagnostics were handled.
825   (void)verify();
826 }
827 
828 /// Returns the status of the verifier and verifies that all expected
829 /// diagnostics were emitted. This return success if all diagnostics were
830 /// verified correctly, failure otherwise.
831 LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
832   // Verify that all expected errors were seen.
833   for (auto &expectedDiagsPair : impl->expectedDiagsPerFile) {
834     for (auto &err : expectedDiagsPair.second) {
835       if (err.matched)
836         continue;
837       impl->status =
838           err.emitError(os, mgr,
839                         "expected " + getDiagKindStr(err.kind) + " \"" +
840                             err.substring + "\" was not produced");
841     }
842   }
843   impl->expectedDiagsPerFile.clear();
844   return impl->status;
845 }
846 
847 /// Process a single diagnostic.
848 void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
849   auto kind = diag.getSeverity();
850 
851   // Process a FileLineColLoc.
852   if (auto fileLoc = getFileLineColLoc(diag.getLocation()))
853     return process(*fileLoc, diag.str(), kind);
854 
855   emitDiagnostic(diag.getLocation(),
856                  "unexpected " + getDiagKindStr(kind) + ": " + diag.str(),
857                  DiagnosticSeverity::Error);
858   impl->status = failure();
859 }
860 
861 /// Process a FileLineColLoc diagnostic.
862 void SourceMgrDiagnosticVerifierHandler::process(FileLineColLoc loc,
863                                                  StringRef msg,
864                                                  DiagnosticSeverity kind) {
865   // Get the expected diagnostics for this file.
866   auto diags = impl->getExpectedDiags(loc.getFilename());
867   if (!diags) {
868     diags = impl->computeExpectedDiags(os, mgr,
869                                        getBufferForFile(loc.getFilename()));
870   }
871 
872   // Search for a matching expected diagnostic.
873   // If we find something that is close then emit a more specific error.
874   ExpectedDiag *nearMiss = nullptr;
875 
876   // If this was an expected error, remember that we saw it and return.
877   unsigned line = loc.getLine();
878   for (auto &e : *diags) {
879     if (line == e.lineNo && e.match(msg)) {
880       if (e.kind == kind) {
881         e.matched = true;
882         return;
883       }
884 
885       // If this only differs based on the diagnostic kind, then consider it
886       // to be a near miss.
887       nearMiss = &e;
888     }
889   }
890 
891   // Otherwise, emit an error for the near miss.
892   if (nearMiss)
893     mgr.PrintMessage(os, nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
894                      "'" + getDiagKindStr(kind) +
895                          "' diagnostic emitted when expecting a '" +
896                          getDiagKindStr(nearMiss->kind) + "'");
897   else
898     emitDiagnostic(loc, "unexpected " + getDiagKindStr(kind) + ": " + msg,
899                    DiagnosticSeverity::Error);
900   impl->status = failure();
901 }
902 
903 //===----------------------------------------------------------------------===//
904 // ParallelDiagnosticHandler
905 //===----------------------------------------------------------------------===//
906 
907 namespace mlir {
908 namespace detail {
909 struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
910   struct ThreadDiagnostic {
911     ThreadDiagnostic(size_t id, Diagnostic diag)
912         : id(id), diag(std::move(diag)) {}
913     bool operator<(const ThreadDiagnostic &rhs) const { return id < rhs.id; }
914 
915     /// The id for this diagnostic, this is used for ordering.
916     /// Note: This id corresponds to the ordered position of the current element
917     ///       being processed by a given thread.
918     size_t id;
919 
920     /// The diagnostic.
921     Diagnostic diag;
922   };
923 
924   ParallelDiagnosticHandlerImpl(MLIRContext *ctx) : context(ctx) {
925     handlerID = ctx->getDiagEngine().registerHandler([this](Diagnostic &diag) {
926       uint64_t tid = llvm::get_threadid();
927       llvm::sys::SmartScopedLock<true> lock(mutex);
928 
929       // If this thread is not tracked, then return failure to let another
930       // handler process this diagnostic.
931       if (!threadToOrderID.count(tid))
932         return failure();
933 
934       // Append a new diagnostic.
935       diagnostics.emplace_back(threadToOrderID[tid], std::move(diag));
936       return success();
937     });
938   }
939 
940   ~ParallelDiagnosticHandlerImpl() override {
941     // Erase this handler from the context.
942     context->getDiagEngine().eraseHandler(handlerID);
943 
944     // Early exit if there are no diagnostics, this is the common case.
945     if (diagnostics.empty())
946       return;
947 
948     // Emit the diagnostics back to the context.
949     emitDiagnostics([&](Diagnostic &diag) {
950       return context->getDiagEngine().emit(std::move(diag));
951     });
952   }
953 
954   /// Utility method to emit any held diagnostics.
955   void emitDiagnostics(llvm::function_ref<void(Diagnostic &)> emitFn) const {
956     // Stable sort all of the diagnostics that were emitted. This creates a
957     // deterministic ordering for the diagnostics based upon which order id they
958     // were emitted for.
959     std::stable_sort(diagnostics.begin(), diagnostics.end());
960 
961     // Emit each diagnostic to the context again.
962     for (ThreadDiagnostic &diag : diagnostics)
963       emitFn(diag.diag);
964   }
965 
966   /// Set the order id for the current thread.
967   void setOrderIDForThread(size_t orderID) {
968     uint64_t tid = llvm::get_threadid();
969     llvm::sys::SmartScopedLock<true> lock(mutex);
970     threadToOrderID[tid] = orderID;
971   }
972 
973   /// Remove the order id for the current thread.
974   void eraseOrderIDForThread() {
975     uint64_t tid = llvm::get_threadid();
976     llvm::sys::SmartScopedLock<true> lock(mutex);
977     threadToOrderID.erase(tid);
978   }
979 
980   /// Dump the current diagnostics that were inflight.
981   void print(raw_ostream &os) const override {
982     // Early exit if there are no diagnostics, this is the common case.
983     if (diagnostics.empty())
984       return;
985 
986     os << "In-Flight Diagnostics:\n";
987     emitDiagnostics([&](const Diagnostic &diag) {
988       os.indent(4);
989 
990       // Print each diagnostic with the format:
991       //   "<location>: <kind>: <msg>"
992       if (!diag.getLocation().isa<UnknownLoc>())
993         os << diag.getLocation() << ": ";
994       switch (diag.getSeverity()) {
995       case DiagnosticSeverity::Error:
996         os << "error: ";
997         break;
998       case DiagnosticSeverity::Warning:
999         os << "warning: ";
1000         break;
1001       case DiagnosticSeverity::Note:
1002         os << "note: ";
1003         break;
1004       case DiagnosticSeverity::Remark:
1005         os << "remark: ";
1006         break;
1007       }
1008       os << diag << '\n';
1009     });
1010   }
1011 
1012   /// A smart mutex to lock access to the internal state.
1013   llvm::sys::SmartMutex<true> mutex;
1014 
1015   /// A mapping between the thread id and the current order id.
1016   DenseMap<uint64_t, size_t> threadToOrderID;
1017 
1018   /// An unordered list of diagnostics that were emitted.
1019   mutable std::vector<ThreadDiagnostic> diagnostics;
1020 
1021   /// The unique id for the parallel handler.
1022   DiagnosticEngine::HandlerID handlerID = 0;
1023 
1024   /// The context to emit the diagnostics to.
1025   MLIRContext *context;
1026 };
1027 } // namespace detail
1028 } // namespace mlir
1029 
1030 ParallelDiagnosticHandler::ParallelDiagnosticHandler(MLIRContext *ctx)
1031     : impl(new ParallelDiagnosticHandlerImpl(ctx)) {}
1032 ParallelDiagnosticHandler::~ParallelDiagnosticHandler() = default;
1033 
1034 /// Set the order id for the current thread.
1035 void ParallelDiagnosticHandler::setOrderIDForThread(size_t orderID) {
1036   impl->setOrderIDForThread(orderID);
1037 }
1038 
1039 /// Remove the order id for the current thread. This removes the thread from
1040 /// diagnostics tracking.
1041 void ParallelDiagnosticHandler::eraseOrderIDForThread() {
1042   impl->eraseOrderIDForThread();
1043 }
1044