1 //===- Diagnostics.cpp - MLIR Diagnostics ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/Diagnostics.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Location.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/MapVector.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Mutex.h"
20 #include "llvm/Support/PrettyStackTrace.h"
21 #include "llvm/Support/Regex.h"
22 #include "llvm/Support/Signals.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include "llvm/Support/raw_ostream.h"
25
26 using namespace mlir;
27 using namespace mlir::detail;
28
29 //===----------------------------------------------------------------------===//
30 // DiagnosticArgument
31 //===----------------------------------------------------------------------===//
32
33 /// Construct from an Attribute.
DiagnosticArgument(Attribute attr)34 DiagnosticArgument::DiagnosticArgument(Attribute attr)
35 : kind(DiagnosticArgumentKind::Attribute),
36 opaqueVal(reinterpret_cast<intptr_t>(attr.getAsOpaquePointer())) {}
37
38 /// Construct from a Type.
DiagnosticArgument(Type val)39 DiagnosticArgument::DiagnosticArgument(Type val)
40 : kind(DiagnosticArgumentKind::Type),
41 opaqueVal(reinterpret_cast<intptr_t>(val.getAsOpaquePointer())) {}
42
43 /// Returns this argument as an Attribute.
getAsAttribute() const44 Attribute DiagnosticArgument::getAsAttribute() const {
45 assert(getKind() == DiagnosticArgumentKind::Attribute);
46 return Attribute::getFromOpaquePointer(
47 reinterpret_cast<const void *>(opaqueVal));
48 }
49
50 /// Returns this argument as a Type.
getAsType() const51 Type DiagnosticArgument::getAsType() const {
52 assert(getKind() == DiagnosticArgumentKind::Type);
53 return Type::getFromOpaquePointer(reinterpret_cast<const void *>(opaqueVal));
54 }
55
56 /// Outputs this argument to a stream.
print(raw_ostream & os) const57 void DiagnosticArgument::print(raw_ostream &os) const {
58 switch (kind) {
59 case DiagnosticArgumentKind::Attribute:
60 os << getAsAttribute();
61 break;
62 case DiagnosticArgumentKind::Double:
63 os << getAsDouble();
64 break;
65 case DiagnosticArgumentKind::Integer:
66 os << getAsInteger();
67 break;
68 case DiagnosticArgumentKind::String:
69 os << getAsString();
70 break;
71 case DiagnosticArgumentKind::Type:
72 os << '\'' << getAsType() << '\'';
73 break;
74 case DiagnosticArgumentKind::Unsigned:
75 os << getAsUnsigned();
76 break;
77 }
78 }
79
80 //===----------------------------------------------------------------------===//
81 // Diagnostic
82 //===----------------------------------------------------------------------===//
83
84 /// Convert a Twine to a StringRef. Memory used for generating the StringRef is
85 /// stored in 'strings'.
twineToStrRef(const Twine & val,std::vector<std::unique_ptr<char[]>> & strings)86 static StringRef twineToStrRef(const Twine &val,
87 std::vector<std::unique_ptr<char[]>> &strings) {
88 // Allocate memory to hold this string.
89 SmallString<64> data;
90 auto strRef = val.toStringRef(data);
91 if (strRef.empty())
92 return strRef;
93
94 strings.push_back(std::unique_ptr<char[]>(new char[strRef.size()]));
95 memcpy(&strings.back()[0], strRef.data(), strRef.size());
96 // Return a reference to the new string.
97 return StringRef(&strings.back()[0], strRef.size());
98 }
99
100 /// Stream in a Twine argument.
operator <<(char val)101 Diagnostic &Diagnostic::operator<<(char val) { return *this << Twine(val); }
operator <<(const Twine & val)102 Diagnostic &Diagnostic::operator<<(const Twine &val) {
103 arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
104 return *this;
105 }
operator <<(Twine && val)106 Diagnostic &Diagnostic::operator<<(Twine &&val) {
107 arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
108 return *this;
109 }
110
operator <<(StringAttr val)111 Diagnostic &Diagnostic::operator<<(StringAttr val) {
112 arguments.push_back(DiagnosticArgument(val));
113 return *this;
114 }
115
116 /// Stream in an OperationName.
operator <<(OperationName val)117 Diagnostic &Diagnostic::operator<<(OperationName val) {
118 // An OperationName is stored in the context, so we don't need to worry about
119 // the lifetime of its data.
120 arguments.push_back(DiagnosticArgument(val.getStringRef()));
121 return *this;
122 }
123
124 /// Adjusts operation printing flags used in diagnostics for the given severity
125 /// level.
adjustPrintingFlags(OpPrintingFlags flags,DiagnosticSeverity severity)126 static OpPrintingFlags adjustPrintingFlags(OpPrintingFlags flags,
127 DiagnosticSeverity severity) {
128 flags.useLocalScope();
129 flags.elideLargeElementsAttrs();
130 if (severity == DiagnosticSeverity::Error)
131 flags.printGenericOpForm();
132 return flags;
133 }
134
135 /// Stream in an Operation.
operator <<(Operation & val)136 Diagnostic &Diagnostic::operator<<(Operation &val) {
137 return appendOp(val, OpPrintingFlags());
138 }
appendOp(Operation & val,const OpPrintingFlags & flags)139 Diagnostic &Diagnostic::appendOp(Operation &val, const OpPrintingFlags &flags) {
140 std::string str;
141 llvm::raw_string_ostream os(str);
142 val.print(os, adjustPrintingFlags(flags, severity));
143 return *this << os.str();
144 }
145
146 /// Stream in a Value.
operator <<(Value val)147 Diagnostic &Diagnostic::operator<<(Value val) {
148 std::string str;
149 llvm::raw_string_ostream os(str);
150 val.print(os, adjustPrintingFlags(OpPrintingFlags(), severity));
151 return *this << os.str();
152 }
153
154 /// Outputs this diagnostic to a stream.
print(raw_ostream & os) const155 void Diagnostic::print(raw_ostream &os) const {
156 for (auto &arg : getArguments())
157 arg.print(os);
158 }
159
160 /// Convert the diagnostic to a string.
str() const161 std::string Diagnostic::str() const {
162 std::string str;
163 llvm::raw_string_ostream os(str);
164 print(os);
165 return os.str();
166 }
167
168 /// Attaches a note to this diagnostic. A new location may be optionally
169 /// provided, if not, then the location defaults to the one specified for this
170 /// diagnostic. Notes may not be attached to other notes.
attachNote(Optional<Location> noteLoc)171 Diagnostic &Diagnostic::attachNote(Optional<Location> noteLoc) {
172 // We don't allow attaching notes to notes.
173 assert(severity != DiagnosticSeverity::Note &&
174 "cannot attach a note to a note");
175
176 // If a location wasn't provided then reuse our location.
177 if (!noteLoc)
178 noteLoc = loc;
179
180 /// Append and return a new note.
181 notes.push_back(
182 std::make_unique<Diagnostic>(*noteLoc, DiagnosticSeverity::Note));
183 return *notes.back();
184 }
185
186 /// Allow a diagnostic to be converted to 'failure'.
operator LogicalResult() const187 Diagnostic::operator LogicalResult() const { return failure(); }
188
189 //===----------------------------------------------------------------------===//
190 // InFlightDiagnostic
191 //===----------------------------------------------------------------------===//
192
193 /// Allow an inflight diagnostic to be converted to 'failure', otherwise
194 /// 'success' if this is an empty diagnostic.
operator LogicalResult() const195 InFlightDiagnostic::operator LogicalResult() const {
196 return failure(isActive());
197 }
198
199 /// Reports the diagnostic to the engine.
report()200 void InFlightDiagnostic::report() {
201 // If this diagnostic is still inflight and it hasn't been abandoned, then
202 // report it.
203 if (isInFlight()) {
204 owner->emit(std::move(*impl));
205 owner = nullptr;
206 }
207 impl.reset();
208 }
209
210 /// Abandons this diagnostic.
abandon()211 void InFlightDiagnostic::abandon() { owner = nullptr; }
212
213 //===----------------------------------------------------------------------===//
214 // DiagnosticEngineImpl
215 //===----------------------------------------------------------------------===//
216
217 namespace mlir {
218 namespace detail {
219 struct DiagnosticEngineImpl {
220 /// Emit a diagnostic using the registered issue handle if present, or with
221 /// the default behavior if not.
222 void emit(Diagnostic &&diag);
223
224 /// A mutex to ensure that diagnostics emission is thread-safe.
225 llvm::sys::SmartMutex<true> mutex;
226
227 /// These are the handlers used to report diagnostics.
228 llvm::SmallMapVector<DiagnosticEngine::HandlerID, DiagnosticEngine::HandlerTy,
229 2>
230 handlers;
231
232 /// This is a unique identifier counter for diagnostic handlers in the
233 /// context. This id starts at 1 to allow for 0 to be used as a sentinel.
234 DiagnosticEngine::HandlerID uniqueHandlerId = 1;
235 };
236 } // namespace detail
237 } // namespace mlir
238
239 /// Emit a diagnostic using the registered issue handle if present, or with
240 /// the default behavior if not.
emit(Diagnostic && diag)241 void DiagnosticEngineImpl::emit(Diagnostic &&diag) {
242 llvm::sys::SmartScopedLock<true> lock(mutex);
243
244 // Try to process the given diagnostic on one of the registered handlers.
245 // Handlers are walked in reverse order, so that the most recent handler is
246 // processed first.
247 for (auto &handlerIt : llvm::reverse(handlers))
248 if (succeeded(handlerIt.second(diag)))
249 return;
250
251 // Otherwise, if this is an error we emit it to stderr.
252 if (diag.getSeverity() != DiagnosticSeverity::Error)
253 return;
254
255 auto &os = llvm::errs();
256 if (!diag.getLocation().isa<UnknownLoc>())
257 os << diag.getLocation() << ": ";
258 os << "error: ";
259
260 // The default behavior for errors is to emit them to stderr.
261 os << diag << '\n';
262 os.flush();
263 }
264
265 //===----------------------------------------------------------------------===//
266 // DiagnosticEngine
267 //===----------------------------------------------------------------------===//
268
DiagnosticEngine()269 DiagnosticEngine::DiagnosticEngine() : impl(new DiagnosticEngineImpl()) {}
270 DiagnosticEngine::~DiagnosticEngine() = default;
271
272 /// Register a new handler for diagnostics to the engine. This function returns
273 /// a unique identifier for the registered handler, which can be used to
274 /// unregister this handler at a later time.
registerHandler(HandlerTy handler)275 auto DiagnosticEngine::registerHandler(HandlerTy handler) -> HandlerID {
276 llvm::sys::SmartScopedLock<true> lock(impl->mutex);
277 auto uniqueID = impl->uniqueHandlerId++;
278 impl->handlers.insert({uniqueID, std::move(handler)});
279 return uniqueID;
280 }
281
282 /// Erase the registered diagnostic handler with the given identifier.
eraseHandler(HandlerID handlerID)283 void DiagnosticEngine::eraseHandler(HandlerID handlerID) {
284 llvm::sys::SmartScopedLock<true> lock(impl->mutex);
285 impl->handlers.erase(handlerID);
286 }
287
288 /// Emit a diagnostic using the registered issue handler if present, or with
289 /// the default behavior if not.
emit(Diagnostic && diag)290 void DiagnosticEngine::emit(Diagnostic &&diag) {
291 assert(diag.getSeverity() != DiagnosticSeverity::Note &&
292 "notes should not be emitted directly");
293 impl->emit(std::move(diag));
294 }
295
296 /// Helper function used to emit a diagnostic with an optionally empty twine
297 /// message. If the message is empty, then it is not inserted into the
298 /// diagnostic.
299 static InFlightDiagnostic
emitDiag(Location location,DiagnosticSeverity severity,const Twine & message)300 emitDiag(Location location, DiagnosticSeverity severity, const Twine &message) {
301 MLIRContext *ctx = location->getContext();
302 auto &diagEngine = ctx->getDiagEngine();
303 auto diag = diagEngine.emit(location, severity);
304 if (!message.isTriviallyEmpty())
305 diag << message;
306
307 // Add the stack trace as a note if necessary.
308 if (ctx->shouldPrintStackTraceOnDiagnostic()) {
309 std::string bt;
310 {
311 llvm::raw_string_ostream stream(bt);
312 llvm::sys::PrintStackTrace(stream);
313 }
314 if (!bt.empty())
315 diag.attachNote() << "diagnostic emitted with trace:\n" << bt;
316 }
317
318 return diag;
319 }
320
321 /// Emit an error message using this location.
emitError(Location loc)322 InFlightDiagnostic mlir::emitError(Location loc) { return emitError(loc, {}); }
emitError(Location loc,const Twine & message)323 InFlightDiagnostic mlir::emitError(Location loc, const Twine &message) {
324 return emitDiag(loc, DiagnosticSeverity::Error, message);
325 }
326
327 /// Emit a warning message using this location.
emitWarning(Location loc)328 InFlightDiagnostic mlir::emitWarning(Location loc) {
329 return emitWarning(loc, {});
330 }
emitWarning(Location loc,const Twine & message)331 InFlightDiagnostic mlir::emitWarning(Location loc, const Twine &message) {
332 return emitDiag(loc, DiagnosticSeverity::Warning, message);
333 }
334
335 /// Emit a remark message using this location.
emitRemark(Location loc)336 InFlightDiagnostic mlir::emitRemark(Location loc) {
337 return emitRemark(loc, {});
338 }
emitRemark(Location loc,const Twine & message)339 InFlightDiagnostic mlir::emitRemark(Location loc, const Twine &message) {
340 return emitDiag(loc, DiagnosticSeverity::Remark, message);
341 }
342
343 //===----------------------------------------------------------------------===//
344 // ScopedDiagnosticHandler
345 //===----------------------------------------------------------------------===//
346
~ScopedDiagnosticHandler()347 ScopedDiagnosticHandler::~ScopedDiagnosticHandler() {
348 if (handlerID)
349 ctx->getDiagEngine().eraseHandler(handlerID);
350 }
351
352 //===----------------------------------------------------------------------===//
353 // SourceMgrDiagnosticHandler
354 //===----------------------------------------------------------------------===//
355 namespace mlir {
356 namespace detail {
357 struct SourceMgrDiagnosticHandlerImpl {
358 /// Return the SrcManager buffer id for the specified file, or zero if none
359 /// can be found.
getSourceMgrBufferIDForFilemlir::detail::SourceMgrDiagnosticHandlerImpl360 unsigned getSourceMgrBufferIDForFile(llvm::SourceMgr &mgr,
361 StringRef filename) {
362 // Check for an existing mapping to the buffer id for this file.
363 auto bufferIt = filenameToBufId.find(filename);
364 if (bufferIt != filenameToBufId.end())
365 return bufferIt->second;
366
367 // Look for a buffer in the manager that has this filename.
368 for (unsigned i = 1, e = mgr.getNumBuffers() + 1; i != e; ++i) {
369 auto *buf = mgr.getMemoryBuffer(i);
370 if (buf->getBufferIdentifier() == filename)
371 return filenameToBufId[filename] = i;
372 }
373
374 // Otherwise, try to load the source file.
375 std::string ignored;
376 unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
377 filenameToBufId[filename] = id;
378 return id;
379 }
380
381 /// Mapping between file name and buffer ID's.
382 llvm::StringMap<unsigned> filenameToBufId;
383 };
384 } // namespace detail
385 } // namespace mlir
386
387 /// Return a processable FileLineColLoc from the given location.
getFileLineColLoc(Location loc)388 static Optional<FileLineColLoc> getFileLineColLoc(Location loc) {
389 Optional<FileLineColLoc> firstFileLoc;
390 loc->walk([&](Location loc) {
391 if (FileLineColLoc fileLoc = loc.dyn_cast<FileLineColLoc>()) {
392 firstFileLoc = fileLoc;
393 return WalkResult::interrupt();
394 }
395 return WalkResult::advance();
396 });
397 return firstFileLoc;
398 }
399
400 /// Return a processable CallSiteLoc from the given location.
getCallSiteLoc(Location loc)401 static Optional<CallSiteLoc> getCallSiteLoc(Location loc) {
402 if (auto nameLoc = loc.dyn_cast<NameLoc>())
403 return getCallSiteLoc(loc.cast<NameLoc>().getChildLoc());
404 if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
405 return callLoc;
406 if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
407 for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
408 if (auto callLoc = getCallSiteLoc(subLoc)) {
409 return callLoc;
410 }
411 }
412 return llvm::None;
413 }
414 return llvm::None;
415 }
416
417 /// Given a diagnostic kind, returns the LLVM DiagKind.
getDiagKind(DiagnosticSeverity kind)418 static llvm::SourceMgr::DiagKind getDiagKind(DiagnosticSeverity kind) {
419 switch (kind) {
420 case DiagnosticSeverity::Note:
421 return llvm::SourceMgr::DK_Note;
422 case DiagnosticSeverity::Warning:
423 return llvm::SourceMgr::DK_Warning;
424 case DiagnosticSeverity::Error:
425 return llvm::SourceMgr::DK_Error;
426 case DiagnosticSeverity::Remark:
427 return llvm::SourceMgr::DK_Remark;
428 }
429 llvm_unreachable("Unknown DiagnosticSeverity");
430 }
431
SourceMgrDiagnosticHandler(llvm::SourceMgr & mgr,MLIRContext * ctx,raw_ostream & os,ShouldShowLocFn && shouldShowLocFn)432 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
433 llvm::SourceMgr &mgr, MLIRContext *ctx, raw_ostream &os,
434 ShouldShowLocFn &&shouldShowLocFn)
435 : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os),
436 shouldShowLocFn(std::move(shouldShowLocFn)),
437 impl(new SourceMgrDiagnosticHandlerImpl()) {
438 setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); });
439 }
440
SourceMgrDiagnosticHandler(llvm::SourceMgr & mgr,MLIRContext * ctx,ShouldShowLocFn && shouldShowLocFn)441 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
442 llvm::SourceMgr &mgr, MLIRContext *ctx, ShouldShowLocFn &&shouldShowLocFn)
443 : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs(),
444 std::move(shouldShowLocFn)) {}
445
446 SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() = default;
447
emitDiagnostic(Location loc,Twine message,DiagnosticSeverity kind,bool displaySourceLine)448 void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
449 DiagnosticSeverity kind,
450 bool displaySourceLine) {
451 // Extract a file location from this loc.
452 auto fileLoc = getFileLineColLoc(loc);
453
454 // If one doesn't exist, then print the raw message without a source location.
455 if (!fileLoc) {
456 std::string str;
457 llvm::raw_string_ostream strOS(str);
458 if (!loc.isa<UnknownLoc>())
459 strOS << loc << ": ";
460 strOS << message;
461 return mgr.PrintMessage(os, SMLoc(), getDiagKind(kind), strOS.str());
462 }
463
464 // Otherwise if we are displaying the source line, try to convert the file
465 // location to an SMLoc.
466 if (displaySourceLine) {
467 auto smloc = convertLocToSMLoc(*fileLoc);
468 if (smloc.isValid())
469 return mgr.PrintMessage(os, smloc, getDiagKind(kind), message);
470 }
471
472 // If the conversion was unsuccessful, create a diagnostic with the file
473 // information. We manually combine the line and column to avoid asserts in
474 // the constructor of SMDiagnostic that takes a location.
475 std::string locStr;
476 llvm::raw_string_ostream locOS(locStr);
477 locOS << fileLoc->getFilename().getValue() << ":" << fileLoc->getLine() << ":"
478 << fileLoc->getColumn();
479 llvm::SMDiagnostic diag(locOS.str(), getDiagKind(kind), message.str());
480 diag.print(nullptr, os);
481 }
482
483 /// Emit the given diagnostic with the held source manager.
emitDiagnostic(Diagnostic & diag)484 void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) {
485 SmallVector<std::pair<Location, StringRef>> locationStack;
486 auto addLocToStack = [&](Location loc, StringRef locContext) {
487 if (Optional<Location> showableLoc = findLocToShow(loc))
488 locationStack.emplace_back(*showableLoc, locContext);
489 };
490
491 // Add locations to display for this diagnostic.
492 Location loc = diag.getLocation();
493 addLocToStack(loc, /*locContext=*/{});
494
495 // If the diagnostic location was a call site location, add the call stack as
496 // well.
497 if (auto callLoc = getCallSiteLoc(loc)) {
498 // Print the call stack while valid, or until the limit is reached.
499 loc = callLoc->getCaller();
500 for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) {
501 addLocToStack(loc, "called from");
502 if ((callLoc = getCallSiteLoc(loc)))
503 loc = callLoc->getCaller();
504 else
505 break;
506 }
507 }
508
509 // If the location stack is empty, use the initial location.
510 if (locationStack.empty()) {
511 emitDiagnostic(diag.getLocation(), diag.str(), diag.getSeverity());
512
513 // Otherwise, use the location stack.
514 } else {
515 emitDiagnostic(locationStack.front().first, diag.str(), diag.getSeverity());
516 for (auto &it : llvm::drop_begin(locationStack))
517 emitDiagnostic(it.first, it.second, DiagnosticSeverity::Note);
518 }
519
520 // Emit each of the notes. Only display the source code if the location is
521 // different from the previous location.
522 for (auto ¬e : diag.getNotes()) {
523 emitDiagnostic(note.getLocation(), note.str(), note.getSeverity(),
524 /*displaySourceLine=*/loc != note.getLocation());
525 loc = note.getLocation();
526 }
527 }
528
529 /// Get a memory buffer for the given file, or nullptr if one is not found.
530 const llvm::MemoryBuffer *
getBufferForFile(StringRef filename)531 SourceMgrDiagnosticHandler::getBufferForFile(StringRef filename) {
532 if (unsigned id = impl->getSourceMgrBufferIDForFile(mgr, filename))
533 return mgr.getMemoryBuffer(id);
534 return nullptr;
535 }
536
findLocToShow(Location loc)537 Optional<Location> SourceMgrDiagnosticHandler::findLocToShow(Location loc) {
538 if (!shouldShowLocFn)
539 return loc;
540 if (!shouldShowLocFn(loc))
541 return llvm::None;
542
543 // Recurse into the child locations of some of location types.
544 return TypeSwitch<LocationAttr, Optional<Location>>(loc)
545 .Case([&](CallSiteLoc callLoc) -> Optional<Location> {
546 // We recurse into the callee of a call site, as the caller will be
547 // emitted in a different note on the main diagnostic.
548 return findLocToShow(callLoc.getCallee());
549 })
550 .Case([&](FileLineColLoc) -> Optional<Location> { return loc; })
551 .Case([&](FusedLoc fusedLoc) -> Optional<Location> {
552 // Fused location is unique in that we try to find a sub-location to
553 // show, rather than the top-level location itself.
554 for (Location childLoc : fusedLoc.getLocations())
555 if (Optional<Location> showableLoc = findLocToShow(childLoc))
556 return showableLoc;
557 return llvm::None;
558 })
559 .Case([&](NameLoc nameLoc) -> Optional<Location> {
560 return findLocToShow(nameLoc.getChildLoc());
561 })
562 .Case([&](OpaqueLoc opaqueLoc) -> Optional<Location> {
563 // OpaqueLoc always falls back to a different source location.
564 return findLocToShow(opaqueLoc.getFallbackLocation());
565 })
566 .Case([](UnknownLoc) -> Optional<Location> {
567 // Prefer not to show unknown locations.
568 return llvm::None;
569 });
570 }
571
572 /// Get a memory buffer for the given file, or the main file of the source
573 /// manager if one doesn't exist. This always returns non-null.
convertLocToSMLoc(FileLineColLoc loc)574 SMLoc SourceMgrDiagnosticHandler::convertLocToSMLoc(FileLineColLoc loc) {
575 // The column and line may be zero to represent unknown column and/or unknown
576 /// line/column information.
577 if (loc.getLine() == 0 || loc.getColumn() == 0)
578 return SMLoc();
579
580 unsigned bufferId = impl->getSourceMgrBufferIDForFile(mgr, loc.getFilename());
581 if (!bufferId)
582 return SMLoc();
583 return mgr.FindLocForLineAndColumn(bufferId, loc.getLine(), loc.getColumn());
584 }
585
586 //===----------------------------------------------------------------------===//
587 // SourceMgrDiagnosticVerifierHandler
588 //===----------------------------------------------------------------------===//
589
590 namespace mlir {
591 namespace detail {
592 /// This class represents an expected output diagnostic.
593 struct ExpectedDiag {
ExpectedDiagmlir::detail::ExpectedDiag594 ExpectedDiag(DiagnosticSeverity kind, unsigned lineNo, SMLoc fileLoc,
595 StringRef substring)
596 : kind(kind), lineNo(lineNo), fileLoc(fileLoc), substring(substring) {}
597
598 /// Emit an error at the location referenced by this diagnostic.
emitErrormlir::detail::ExpectedDiag599 LogicalResult emitError(raw_ostream &os, llvm::SourceMgr &mgr,
600 const Twine &msg) {
601 SMRange range(fileLoc, SMLoc::getFromPointer(fileLoc.getPointer() +
602 substring.size()));
603 mgr.PrintMessage(os, fileLoc, llvm::SourceMgr::DK_Error, msg, range);
604 return failure();
605 }
606
607 /// Returns true if this diagnostic matches the given string.
matchmlir::detail::ExpectedDiag608 bool match(StringRef str) const {
609 // If this isn't a regex diagnostic, we simply check if the string was
610 // contained.
611 if (substringRegex)
612 return substringRegex->match(str);
613 return str.contains(substring);
614 }
615
616 /// Compute the regex matcher for this diagnostic, using the provided stream
617 /// and manager to emit diagnostics as necessary.
computeRegexmlir::detail::ExpectedDiag618 LogicalResult computeRegex(raw_ostream &os, llvm::SourceMgr &mgr) {
619 std::string regexStr;
620 llvm::raw_string_ostream regexOS(regexStr);
621 StringRef strToProcess = substring;
622 while (!strToProcess.empty()) {
623 // Find the next regex block.
624 size_t regexIt = strToProcess.find("{{");
625 if (regexIt == StringRef::npos) {
626 regexOS << llvm::Regex::escape(strToProcess);
627 break;
628 }
629 regexOS << llvm::Regex::escape(strToProcess.take_front(regexIt));
630 strToProcess = strToProcess.drop_front(regexIt + 2);
631
632 // Find the end of the regex block.
633 size_t regexEndIt = strToProcess.find("}}");
634 if (regexEndIt == StringRef::npos)
635 return emitError(os, mgr, "found start of regex with no end '}}'");
636 StringRef regexStr = strToProcess.take_front(regexEndIt);
637
638 // Validate that the regex is actually valid.
639 std::string regexError;
640 if (!llvm::Regex(regexStr).isValid(regexError))
641 return emitError(os, mgr, "invalid regex: " + regexError);
642
643 regexOS << '(' << regexStr << ')';
644 strToProcess = strToProcess.drop_front(regexEndIt + 2);
645 }
646 substringRegex = llvm::Regex(regexOS.str());
647 return success();
648 }
649
650 /// The severity of the diagnosic expected.
651 DiagnosticSeverity kind;
652 /// The line number the expected diagnostic should be on.
653 unsigned lineNo;
654 /// The location of the expected diagnostic within the input file.
655 SMLoc fileLoc;
656 /// A flag indicating if the expected diagnostic has been matched yet.
657 bool matched = false;
658 /// The substring that is expected to be within the diagnostic.
659 StringRef substring;
660 /// An optional regex matcher, if the expected diagnostic sub-string was a
661 /// regex string.
662 Optional<llvm::Regex> substringRegex;
663 };
664
665 struct SourceMgrDiagnosticVerifierHandlerImpl {
SourceMgrDiagnosticVerifierHandlerImplmlir::detail::SourceMgrDiagnosticVerifierHandlerImpl666 SourceMgrDiagnosticVerifierHandlerImpl() : status(success()) {}
667
668 /// Returns the expected diagnostics for the given source file.
669 Optional<MutableArrayRef<ExpectedDiag>> getExpectedDiags(StringRef bufName);
670
671 /// Computes the expected diagnostics for the given source buffer.
672 MutableArrayRef<ExpectedDiag>
673 computeExpectedDiags(raw_ostream &os, llvm::SourceMgr &mgr,
674 const llvm::MemoryBuffer *buf);
675
676 /// The current status of the verifier.
677 LogicalResult status;
678
679 /// A list of expected diagnostics for each buffer of the source manager.
680 llvm::StringMap<SmallVector<ExpectedDiag, 2>> expectedDiagsPerFile;
681
682 /// Regex to match the expected diagnostics format.
683 llvm::Regex expected =
684 llvm::Regex("expected-(error|note|remark|warning)(-re)? "
685 "*(@([+-][0-9]+|above|below))? *{{(.*)}}$");
686 };
687 } // namespace detail
688 } // namespace mlir
689
690 /// Given a diagnostic kind, return a human readable string for it.
getDiagKindStr(DiagnosticSeverity kind)691 static StringRef getDiagKindStr(DiagnosticSeverity kind) {
692 switch (kind) {
693 case DiagnosticSeverity::Note:
694 return "note";
695 case DiagnosticSeverity::Warning:
696 return "warning";
697 case DiagnosticSeverity::Error:
698 return "error";
699 case DiagnosticSeverity::Remark:
700 return "remark";
701 }
702 llvm_unreachable("Unknown DiagnosticSeverity");
703 }
704
705 Optional<MutableArrayRef<ExpectedDiag>>
getExpectedDiags(StringRef bufName)706 SourceMgrDiagnosticVerifierHandlerImpl::getExpectedDiags(StringRef bufName) {
707 auto expectedDiags = expectedDiagsPerFile.find(bufName);
708 if (expectedDiags != expectedDiagsPerFile.end())
709 return MutableArrayRef<ExpectedDiag>(expectedDiags->second);
710 return llvm::None;
711 }
712
713 MutableArrayRef<ExpectedDiag>
computeExpectedDiags(raw_ostream & os,llvm::SourceMgr & mgr,const llvm::MemoryBuffer * buf)714 SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
715 raw_ostream &os, llvm::SourceMgr &mgr, const llvm::MemoryBuffer *buf) {
716 // If the buffer is invalid, return an empty list.
717 if (!buf)
718 return llvm::None;
719 auto &expectedDiags = expectedDiagsPerFile[buf->getBufferIdentifier()];
720
721 // The number of the last line that did not correlate to a designator.
722 unsigned lastNonDesignatorLine = 0;
723
724 // The indices of designators that apply to the next non designator line.
725 SmallVector<unsigned, 1> designatorsForNextLine;
726
727 // Scan the file for expected-* designators.
728 SmallVector<StringRef, 100> lines;
729 buf->getBuffer().split(lines, '\n');
730 for (unsigned lineNo = 0, e = lines.size(); lineNo < e; ++lineNo) {
731 SmallVector<StringRef, 4> matches;
732 if (!expected.match(lines[lineNo].rtrim(), &matches)) {
733 // Check for designators that apply to this line.
734 if (!designatorsForNextLine.empty()) {
735 for (unsigned diagIndex : designatorsForNextLine)
736 expectedDiags[diagIndex].lineNo = lineNo + 1;
737 designatorsForNextLine.clear();
738 }
739 lastNonDesignatorLine = lineNo;
740 continue;
741 }
742
743 // Point to the start of expected-*.
744 SMLoc expectedStart = SMLoc::getFromPointer(matches[0].data());
745
746 DiagnosticSeverity kind;
747 if (matches[1] == "error")
748 kind = DiagnosticSeverity::Error;
749 else if (matches[1] == "warning")
750 kind = DiagnosticSeverity::Warning;
751 else if (matches[1] == "remark")
752 kind = DiagnosticSeverity::Remark;
753 else {
754 assert(matches[1] == "note");
755 kind = DiagnosticSeverity::Note;
756 }
757 ExpectedDiag record(kind, lineNo + 1, expectedStart, matches[5]);
758
759 // Check to see if this is a regex match, i.e. it includes the `-re`.
760 if (!matches[2].empty() && failed(record.computeRegex(os, mgr))) {
761 status = failure();
762 continue;
763 }
764
765 StringRef offsetMatch = matches[3];
766 if (!offsetMatch.empty()) {
767 offsetMatch = offsetMatch.drop_front(1);
768
769 // Get the integer value without the @ and +/- prefix.
770 if (offsetMatch[0] == '+' || offsetMatch[0] == '-') {
771 int offset;
772 offsetMatch.drop_front().getAsInteger(0, offset);
773
774 if (offsetMatch.front() == '+')
775 record.lineNo += offset;
776 else
777 record.lineNo -= offset;
778 } else if (offsetMatch.consume_front("above")) {
779 // If the designator applies 'above' we add it to the last non
780 // designator line.
781 record.lineNo = lastNonDesignatorLine + 1;
782 } else {
783 // Otherwise, this is a 'below' designator and applies to the next
784 // non-designator line.
785 assert(offsetMatch.consume_front("below"));
786 designatorsForNextLine.push_back(expectedDiags.size());
787
788 // Set the line number to the last in the case that this designator ends
789 // up dangling.
790 record.lineNo = e;
791 }
792 }
793 expectedDiags.emplace_back(std::move(record));
794 }
795 return expectedDiags;
796 }
797
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr & srcMgr,MLIRContext * ctx,raw_ostream & out)798 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
799 llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out)
800 : SourceMgrDiagnosticHandler(srcMgr, ctx, out),
801 impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
802 // Compute the expected diagnostics for each of the current files in the
803 // source manager.
804 for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
805 (void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1));
806
807 // Register a handler to verify the diagnostics.
808 setHandler([&](Diagnostic &diag) {
809 // Process the main diagnostics.
810 process(diag);
811
812 // Process each of the notes.
813 for (auto ¬e : diag.getNotes())
814 process(note);
815 });
816 }
817
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr & srcMgr,MLIRContext * ctx)818 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
819 llvm::SourceMgr &srcMgr, MLIRContext *ctx)
820 : SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {}
821
~SourceMgrDiagnosticVerifierHandler()822 SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
823 // Ensure that all expected diagnostics were handled.
824 (void)verify();
825 }
826
827 /// Returns the status of the verifier and verifies that all expected
828 /// diagnostics were emitted. This return success if all diagnostics were
829 /// verified correctly, failure otherwise.
verify()830 LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
831 // Verify that all expected errors were seen.
832 for (auto &expectedDiagsPair : impl->expectedDiagsPerFile) {
833 for (auto &err : expectedDiagsPair.second) {
834 if (err.matched)
835 continue;
836 impl->status =
837 err.emitError(os, mgr,
838 "expected " + getDiagKindStr(err.kind) + " \"" +
839 err.substring + "\" was not produced");
840 }
841 }
842 impl->expectedDiagsPerFile.clear();
843 return impl->status;
844 }
845
846 /// Process a single diagnostic.
process(Diagnostic & diag)847 void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
848 auto kind = diag.getSeverity();
849
850 // Process a FileLineColLoc.
851 if (auto fileLoc = getFileLineColLoc(diag.getLocation()))
852 return process(*fileLoc, diag.str(), kind);
853
854 emitDiagnostic(diag.getLocation(),
855 "unexpected " + getDiagKindStr(kind) + ": " + diag.str(),
856 DiagnosticSeverity::Error);
857 impl->status = failure();
858 }
859
860 /// Process a FileLineColLoc diagnostic.
process(FileLineColLoc loc,StringRef msg,DiagnosticSeverity kind)861 void SourceMgrDiagnosticVerifierHandler::process(FileLineColLoc loc,
862 StringRef msg,
863 DiagnosticSeverity kind) {
864 // Get the expected diagnostics for this file.
865 auto diags = impl->getExpectedDiags(loc.getFilename());
866 if (!diags) {
867 diags = impl->computeExpectedDiags(os, mgr,
868 getBufferForFile(loc.getFilename()));
869 }
870
871 // Search for a matching expected diagnostic.
872 // If we find something that is close then emit a more specific error.
873 ExpectedDiag *nearMiss = nullptr;
874
875 // If this was an expected error, remember that we saw it and return.
876 unsigned line = loc.getLine();
877 for (auto &e : *diags) {
878 if (line == e.lineNo && e.match(msg)) {
879 if (e.kind == kind) {
880 e.matched = true;
881 return;
882 }
883
884 // If this only differs based on the diagnostic kind, then consider it
885 // to be a near miss.
886 nearMiss = &e;
887 }
888 }
889
890 // Otherwise, emit an error for the near miss.
891 if (nearMiss)
892 mgr.PrintMessage(os, nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
893 "'" + getDiagKindStr(kind) +
894 "' diagnostic emitted when expecting a '" +
895 getDiagKindStr(nearMiss->kind) + "'");
896 else
897 emitDiagnostic(loc, "unexpected " + getDiagKindStr(kind) + ": " + msg,
898 DiagnosticSeverity::Error);
899 impl->status = failure();
900 }
901
902 //===----------------------------------------------------------------------===//
903 // ParallelDiagnosticHandler
904 //===----------------------------------------------------------------------===//
905
906 namespace mlir {
907 namespace detail {
908 struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
909 struct ThreadDiagnostic {
ThreadDiagnosticmlir::detail::ParallelDiagnosticHandlerImpl::ThreadDiagnostic910 ThreadDiagnostic(size_t id, Diagnostic diag)
911 : id(id), diag(std::move(diag)) {}
operator <mlir::detail::ParallelDiagnosticHandlerImpl::ThreadDiagnostic912 bool operator<(const ThreadDiagnostic &rhs) const { return id < rhs.id; }
913
914 /// The id for this diagnostic, this is used for ordering.
915 /// Note: This id corresponds to the ordered position of the current element
916 /// being processed by a given thread.
917 size_t id;
918
919 /// The diagnostic.
920 Diagnostic diag;
921 };
922
ParallelDiagnosticHandlerImplmlir::detail::ParallelDiagnosticHandlerImpl923 ParallelDiagnosticHandlerImpl(MLIRContext *ctx) : context(ctx) {
924 handlerID = ctx->getDiagEngine().registerHandler([this](Diagnostic &diag) {
925 uint64_t tid = llvm::get_threadid();
926 llvm::sys::SmartScopedLock<true> lock(mutex);
927
928 // If this thread is not tracked, then return failure to let another
929 // handler process this diagnostic.
930 if (!threadToOrderID.count(tid))
931 return failure();
932
933 // Append a new diagnostic.
934 diagnostics.emplace_back(threadToOrderID[tid], std::move(diag));
935 return success();
936 });
937 }
938
~ParallelDiagnosticHandlerImplmlir::detail::ParallelDiagnosticHandlerImpl939 ~ParallelDiagnosticHandlerImpl() override {
940 // Erase this handler from the context.
941 context->getDiagEngine().eraseHandler(handlerID);
942
943 // Early exit if there are no diagnostics, this is the common case.
944 if (diagnostics.empty())
945 return;
946
947 // Emit the diagnostics back to the context.
948 emitDiagnostics([&](Diagnostic &diag) {
949 return context->getDiagEngine().emit(std::move(diag));
950 });
951 }
952
953 /// Utility method to emit any held diagnostics.
emitDiagnosticsmlir::detail::ParallelDiagnosticHandlerImpl954 void emitDiagnostics(llvm::function_ref<void(Diagnostic &)> emitFn) const {
955 // Stable sort all of the diagnostics that were emitted. This creates a
956 // deterministic ordering for the diagnostics based upon which order id they
957 // were emitted for.
958 std::stable_sort(diagnostics.begin(), diagnostics.end());
959
960 // Emit each diagnostic to the context again.
961 for (ThreadDiagnostic &diag : diagnostics)
962 emitFn(diag.diag);
963 }
964
965 /// Set the order id for the current thread.
setOrderIDForThreadmlir::detail::ParallelDiagnosticHandlerImpl966 void setOrderIDForThread(size_t orderID) {
967 uint64_t tid = llvm::get_threadid();
968 llvm::sys::SmartScopedLock<true> lock(mutex);
969 threadToOrderID[tid] = orderID;
970 }
971
972 /// Remove the order id for the current thread.
eraseOrderIDForThreadmlir::detail::ParallelDiagnosticHandlerImpl973 void eraseOrderIDForThread() {
974 uint64_t tid = llvm::get_threadid();
975 llvm::sys::SmartScopedLock<true> lock(mutex);
976 threadToOrderID.erase(tid);
977 }
978
979 /// Dump the current diagnostics that were inflight.
printmlir::detail::ParallelDiagnosticHandlerImpl980 void print(raw_ostream &os) const override {
981 // Early exit if there are no diagnostics, this is the common case.
982 if (diagnostics.empty())
983 return;
984
985 os << "In-Flight Diagnostics:\n";
986 emitDiagnostics([&](const Diagnostic &diag) {
987 os.indent(4);
988
989 // Print each diagnostic with the format:
990 // "<location>: <kind>: <msg>"
991 if (!diag.getLocation().isa<UnknownLoc>())
992 os << diag.getLocation() << ": ";
993 switch (diag.getSeverity()) {
994 case DiagnosticSeverity::Error:
995 os << "error: ";
996 break;
997 case DiagnosticSeverity::Warning:
998 os << "warning: ";
999 break;
1000 case DiagnosticSeverity::Note:
1001 os << "note: ";
1002 break;
1003 case DiagnosticSeverity::Remark:
1004 os << "remark: ";
1005 break;
1006 }
1007 os << diag << '\n';
1008 });
1009 }
1010
1011 /// A smart mutex to lock access to the internal state.
1012 llvm::sys::SmartMutex<true> mutex;
1013
1014 /// A mapping between the thread id and the current order id.
1015 DenseMap<uint64_t, size_t> threadToOrderID;
1016
1017 /// An unordered list of diagnostics that were emitted.
1018 mutable std::vector<ThreadDiagnostic> diagnostics;
1019
1020 /// The unique id for the parallel handler.
1021 DiagnosticEngine::HandlerID handlerID = 0;
1022
1023 /// The context to emit the diagnostics to.
1024 MLIRContext *context;
1025 };
1026 } // namespace detail
1027 } // namespace mlir
1028
ParallelDiagnosticHandler(MLIRContext * ctx)1029 ParallelDiagnosticHandler::ParallelDiagnosticHandler(MLIRContext *ctx)
1030 : impl(new ParallelDiagnosticHandlerImpl(ctx)) {}
1031 ParallelDiagnosticHandler::~ParallelDiagnosticHandler() = default;
1032
1033 /// Set the order id for the current thread.
setOrderIDForThread(size_t orderID)1034 void ParallelDiagnosticHandler::setOrderIDForThread(size_t orderID) {
1035 impl->setOrderIDForThread(orderID);
1036 }
1037
1038 /// Remove the order id for the current thread. This removes the thread from
1039 /// diagnostics tracking.
eraseOrderIDForThread()1040 void ParallelDiagnosticHandler::eraseOrderIDForThread() {
1041 impl->eraseOrderIDForThread();
1042 }
1043