1 //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation from SPIR-V binary module to MLIR SPIR-V
10 // ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/Verifier.h"
20 #include "mlir/Parser/Parser.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Target/SPIRV/Deserialization.h"
23 #include "mlir/Target/SPIRV/Serialization.h"
24 #include "mlir/Tools/mlir-translate/Translation.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/MemoryBuffer.h"
27 #include "llvm/Support/SMLoc.h"
28 #include "llvm/Support/SourceMgr.h"
29 #include "llvm/Support/ToolOutputFile.h"
30
31 using namespace mlir;
32
33 //===----------------------------------------------------------------------===//
34 // Deserialization registration
35 //===----------------------------------------------------------------------===//
36
37 // Deserializes the SPIR-V binary module stored in the file named as
38 // `inputFilename` and returns a module containing the SPIR-V module.
deserializeModule(const llvm::MemoryBuffer * input,MLIRContext * context)39 static OwningOpRef<ModuleOp> deserializeModule(const llvm::MemoryBuffer *input,
40 MLIRContext *context) {
41 context->loadDialect<spirv::SPIRVDialect>();
42
43 // Make sure the input stream can be treated as a stream of SPIR-V words
44 auto *start = input->getBufferStart();
45 auto size = input->getBufferSize();
46 if (size % sizeof(uint32_t) != 0) {
47 emitError(UnknownLoc::get(context))
48 << "SPIR-V binary module must contain integral number of 32-bit words";
49 return {};
50 }
51
52 auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start),
53 size / sizeof(uint32_t));
54
55 OwningOpRef<spirv::ModuleOp> spirvModule =
56 spirv::deserialize(binary, context);
57 if (!spirvModule)
58 return {};
59
60 OwningOpRef<ModuleOp> module(ModuleOp::create(FileLineColLoc::get(
61 context, input->getBufferIdentifier(), /*line=*/0, /*column=*/0)));
62 module->getBody()->push_front(spirvModule.release());
63
64 return module;
65 }
66
67 namespace mlir {
registerFromSPIRVTranslation()68 void registerFromSPIRVTranslation() {
69 TranslateToMLIRRegistration fromBinary(
70 "deserialize-spirv",
71 [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
72 assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
73 return deserializeModule(
74 sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context);
75 });
76 }
77 } // namespace mlir
78
79 //===----------------------------------------------------------------------===//
80 // Serialization registration
81 //===----------------------------------------------------------------------===//
82
serializeModule(ModuleOp module,raw_ostream & output)83 static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
84 if (!module)
85 return failure();
86
87 SmallVector<uint32_t, 0> binary;
88
89 SmallVector<spirv::ModuleOp, 1> spirvModules;
90 module.walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); });
91
92 if (spirvModules.empty())
93 return module.emitError("found no 'spv.module' op");
94
95 if (spirvModules.size() != 1)
96 return module.emitError("found more than one 'spv.module' op");
97
98 if (failed(spirv::serialize(spirvModules[0], binary)))
99 return failure();
100
101 output.write(reinterpret_cast<char *>(binary.data()),
102 binary.size() * sizeof(uint32_t));
103
104 return mlir::success();
105 }
106
107 namespace mlir {
registerToSPIRVTranslation()108 void registerToSPIRVTranslation() {
109 TranslateFromMLIRRegistration toBinary(
110 "serialize-spirv",
111 [](ModuleOp module, raw_ostream &output) {
112 return serializeModule(module, output);
113 },
114 [](DialectRegistry ®istry) {
115 registry.insert<spirv::SPIRVDialect>();
116 });
117 }
118 } // namespace mlir
119
120 //===----------------------------------------------------------------------===//
121 // Round-trip registration
122 //===----------------------------------------------------------------------===//
123
roundTripModule(ModuleOp srcModule,bool emitDebugInfo,raw_ostream & output)124 static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
125 raw_ostream &output) {
126 SmallVector<uint32_t, 0> binary;
127 MLIRContext *context = srcModule.getContext();
128 auto spirvModules = srcModule.getOps<spirv::ModuleOp>();
129
130 if (spirvModules.begin() == spirvModules.end())
131 return srcModule.emitError("found no 'spv.module' op");
132
133 if (std::next(spirvModules.begin()) != spirvModules.end())
134 return srcModule.emitError("found more than one 'spv.module' op");
135
136 spirv::SerializationOptions options;
137 options.emitDebugInfo = emitDebugInfo;
138 if (failed(spirv::serialize(*spirvModules.begin(), binary, options)))
139 return failure();
140
141 MLIRContext deserializationContext(context->getDialectRegistry());
142 // TODO: we should only load the required dialects instead of all dialects.
143 deserializationContext.loadAllAvailableDialects();
144 // Then deserialize to get back a SPIR-V module.
145 OwningOpRef<spirv::ModuleOp> spirvModule =
146 spirv::deserialize(binary, &deserializationContext);
147 if (!spirvModule)
148 return failure();
149
150 // Wrap around in a new MLIR module.
151 OwningOpRef<ModuleOp> dstModule(ModuleOp::create(
152 FileLineColLoc::get(&deserializationContext,
153 /*filename=*/"", /*line=*/0, /*column=*/0)));
154 dstModule->getBody()->push_front(spirvModule.release());
155 if (failed(verify(*dstModule)))
156 return failure();
157 dstModule->print(output);
158
159 return mlir::success();
160 }
161
162 namespace mlir {
registerTestRoundtripSPIRV()163 void registerTestRoundtripSPIRV() {
164 TranslateFromMLIRRegistration roundtrip(
165 "test-spirv-roundtrip",
166 [](ModuleOp module, raw_ostream &output) {
167 return roundTripModule(module, /*emitDebugInfo=*/false, output);
168 },
169 [](DialectRegistry ®istry) {
170 registry.insert<spirv::SPIRVDialect>();
171 });
172 }
173
registerTestRoundtripDebugSPIRV()174 void registerTestRoundtripDebugSPIRV() {
175 TranslateFromMLIRRegistration roundtrip(
176 "test-spirv-roundtrip-debug",
177 [](ModuleOp module, raw_ostream &output) {
178 return roundTripModule(module, /*emitDebugInfo=*/true, output);
179 },
180 [](DialectRegistry ®istry) {
181 registry.insert<spirv::SPIRVDialect>();
182 });
183 }
184 } // namespace mlir
185