1 //===- Translation.cpp - Translation registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Definitions of the translation registry.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Tools/mlir-translate/Translation.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/Parser/Parser.h"
19 #include "llvm/Support/SourceMgr.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Translation Registry
25 //===----------------------------------------------------------------------===//
26 
27 /// Get the mutable static map between registered file-to-file MLIR translations
28 /// and the TranslateFunctions that perform those translations.
getTranslationRegistry()29 static llvm::StringMap<TranslateFunction> &getTranslationRegistry() {
30   static llvm::StringMap<TranslateFunction> translationRegistry;
31   return translationRegistry;
32 }
33 
34 /// Register the given translation.
registerTranslation(StringRef name,const TranslateFunction & function)35 static void registerTranslation(StringRef name,
36                                 const TranslateFunction &function) {
37   auto &translationRegistry = getTranslationRegistry();
38   if (translationRegistry.find(name) != translationRegistry.end())
39     llvm::report_fatal_error(
40         "Attempting to overwrite an existing <file-to-file> function");
41   assert(function &&
42          "Attempting to register an empty translate <file-to-file> function");
43   translationRegistry[name] = function;
44 }
45 
TranslateRegistration(StringRef name,const TranslateFunction & function)46 TranslateRegistration::TranslateRegistration(
47     StringRef name, const TranslateFunction &function) {
48   registerTranslation(name, function);
49 }
50 
51 //===----------------------------------------------------------------------===//
52 // Translation to MLIR
53 //===----------------------------------------------------------------------===//
54 
55 // Puts `function` into the to-MLIR translation registry unless there is already
56 // a function registered for the same name.
registerTranslateToMLIRFunction(StringRef name,const TranslateSourceMgrToMLIRFunction & function)57 static void registerTranslateToMLIRFunction(
58     StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
59   auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
60                               MLIRContext *context) {
61     OwningOpRef<ModuleOp> module = function(sourceMgr, context);
62     if (!module || failed(verify(*module)))
63       return failure();
64     module->print(output);
65     return success();
66   };
67   registerTranslation(name, wrappedFn);
68 }
69 
TranslateToMLIRRegistration(StringRef name,const TranslateSourceMgrToMLIRFunction & function)70 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
71     StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
72   registerTranslateToMLIRFunction(name, function);
73 }
74 
75 /// Wraps `function` with a lambda that extracts a StringRef from a source
76 /// manager and registers the wrapper lambda as a to-MLIR conversion.
TranslateToMLIRRegistration(StringRef name,const TranslateStringRefToMLIRFunction & function)77 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
78     StringRef name, const TranslateStringRefToMLIRFunction &function) {
79   registerTranslateToMLIRFunction(
80       name, [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) {
81         const llvm::MemoryBuffer *buffer =
82             sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
83         return function(buffer->getBuffer(), ctx);
84       });
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // Translation from MLIR
89 //===----------------------------------------------------------------------===//
90 
TranslateFromMLIRRegistration(StringRef name,const TranslateFromMLIRFunction & function,const std::function<void (DialectRegistry &)> & dialectRegistration)91 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
92     StringRef name, const TranslateFromMLIRFunction &function,
93     const std::function<void(DialectRegistry &)> &dialectRegistration) {
94   registerTranslation(name, [function, dialectRegistration](
95                                 llvm::SourceMgr &sourceMgr, raw_ostream &output,
96                                 MLIRContext *context) {
97     DialectRegistry registry;
98     dialectRegistration(registry);
99     context->appendDialectRegistry(registry);
100     auto module = parseSourceFile<ModuleOp>(sourceMgr, context);
101     if (!module || failed(verify(*module)))
102       return failure();
103     return function(module.get(), output);
104   });
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // Translation Parser
109 //===----------------------------------------------------------------------===//
110 
TranslationParser(llvm::cl::Option & opt)111 TranslationParser::TranslationParser(llvm::cl::Option &opt)
112     : llvm::cl::parser<const TranslateFunction *>(opt) {
113   for (const auto &kv : getTranslationRegistry())
114     addLiteralOption(kv.first(), &kv.second, kv.first());
115 }
116 
printOptionInfo(const llvm::cl::Option & o,size_t globalWidth) const117 void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
118                                         size_t globalWidth) const {
119   TranslationParser *tp = const_cast<TranslationParser *>(this);
120   llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
121                        [](const TranslationParser::OptionInfo *lhs,
122                           const TranslationParser::OptionInfo *rhs) {
123                          return lhs->Name.compare(rhs->Name);
124                        });
125   llvm::cl::parser<const TranslateFunction *>::printOptionInfo(o, globalWidth);
126 }
127