1 //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation from SPIR-V binary module to MLIR SPIR-V
10 // ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/Parser.h"
20 #include "mlir/Support/FileUtilities.h"
21 #include "mlir/Target/SPIRV/Deserialization.h"
22 #include "mlir/Target/SPIRV/Serialization.h"
23 #include "mlir/Translation.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/MemoryBuffer.h"
26 #include "llvm/Support/SMLoc.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include "llvm/Support/ToolOutputFile.h"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Deserialization registration
34 //===----------------------------------------------------------------------===//
35 
36 // Deserializes the SPIR-V binary module stored in the file named as
37 // `inputFilename` and returns a module containing the SPIR-V module.
38 static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
39                                          MLIRContext *context) {
40   context->loadDialect<spirv::SPIRVDialect>();
41 
42   // Make sure the input stream can be treated as a stream of SPIR-V words
43   auto *start = input->getBufferStart();
44   auto size = input->getBufferSize();
45   if (size % sizeof(uint32_t) != 0) {
46     emitError(UnknownLoc::get(context))
47         << "SPIR-V binary module must contain integral number of 32-bit words";
48     return {};
49   }
50 
51   auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start),
52                                    size / sizeof(uint32_t));
53 
54   OwningOpRef<spirv::ModuleOp> spirvModule =
55       spirv::deserialize(binary, context);
56   if (!spirvModule)
57     return {};
58 
59   OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
60       context, input->getBufferIdentifier(), /*line=*/0, /*column=*/0)));
61   module->getBody()->push_front(spirvModule.release());
62 
63   return module;
64 }
65 
66 namespace mlir {
67 void registerFromSPIRVTranslation() {
68   TranslateToMLIRRegistration fromBinary(
69       "deserialize-spirv",
70       [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
71         assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
72         return deserializeModule(
73             sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context);
74       });
75 }
76 } // namespace mlir
77 
78 //===----------------------------------------------------------------------===//
79 // Serialization registration
80 //===----------------------------------------------------------------------===//
81 
82 static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
83   if (!module)
84     return failure();
85 
86   SmallVector<uint32_t, 0> binary;
87 
88   SmallVector<spirv::ModuleOp, 1> spirvModules;
89   module.walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); });
90 
91   if (spirvModules.empty())
92     return module.emitError("found no 'spv.module' op");
93 
94   if (spirvModules.size() != 1)
95     return module.emitError("found more than one 'spv.module' op");
96 
97   if (failed(spirv::serialize(spirvModules[0], binary)))
98     return failure();
99 
100   output.write(reinterpret_cast<char *>(binary.data()),
101                binary.size() * sizeof(uint32_t));
102 
103   return mlir::success();
104 }
105 
106 namespace mlir {
107 void registerToSPIRVTranslation() {
108   TranslateFromMLIRRegistration toBinary(
109       "serialize-spirv",
110       [](ModuleOp module, raw_ostream &output) {
111         return serializeModule(module, output);
112       },
113       [](DialectRegistry &registry) {
114         registry.insert<spirv::SPIRVDialect>();
115       });
116 }
117 } // namespace mlir
118 
119 //===----------------------------------------------------------------------===//
120 // Round-trip registration
121 //===----------------------------------------------------------------------===//
122 
123 static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
124                                      raw_ostream &output) {
125   SmallVector<uint32_t, 0> binary;
126   MLIRContext *context = srcModule.getContext();
127   auto spirvModules = srcModule.getOps<spirv::ModuleOp>();
128 
129   if (spirvModules.begin() == spirvModules.end())
130     return srcModule.emitError("found no 'spv.module' op");
131 
132   if (std::next(spirvModules.begin()) != spirvModules.end())
133     return srcModule.emitError("found more than one 'spv.module' op");
134 
135   spirv::SerializationOptions options;
136   options.emitDebugInfo = emitDebugInfo;
137   if (failed(spirv::serialize(*spirvModules.begin(), binary, options)))
138     return failure();
139 
140   MLIRContext deserializationContext(context->getDialectRegistry());
141   // TODO: we should only load the required dialects instead of all dialects.
142   deserializationContext.loadAllAvailableDialects();
143   // Then deserialize to get back a SPIR-V module.
144   OwningOpRef<spirv::ModuleOp> spirvModule =
145       spirv::deserialize(binary, &deserializationContext);
146   if (!spirvModule)
147     return failure();
148 
149   // Wrap around in a new MLIR module.
150   OwningModuleRef dstModule(ModuleOp::create(
151       FileLineColLoc::get(&deserializationContext,
152                           /*filename=*/"", /*line=*/0, /*column=*/0)));
153   dstModule->getBody()->push_front(spirvModule.release());
154   dstModule->print(output);
155 
156   return mlir::success();
157 }
158 
159 namespace mlir {
160 void registerTestRoundtripSPIRV() {
161   TranslateFromMLIRRegistration roundtrip(
162       "test-spirv-roundtrip",
163       [](ModuleOp module, raw_ostream &output) {
164         return roundTripModule(module, /*emitDebugInfo=*/false, output);
165       },
166       [](DialectRegistry &registry) {
167         registry.insert<spirv::SPIRVDialect>();
168       });
169 }
170 
171 void registerTestRoundtripDebugSPIRV() {
172   TranslateFromMLIRRegistration roundtrip(
173       "test-spirv-roundtrip-debug",
174       [](ModuleOp module, raw_ostream &output) {
175         return roundTripModule(module, /*emitDebugInfo=*/true, output);
176       },
177       [](DialectRegistry &registry) {
178         registry.insert<spirv::SPIRVDialect>();
179       });
180 }
181 } // namespace mlir
182