1 //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation from SPIR-V binary module to MLIR SPIR-V 10 // ModuleOp. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/Dialect.h" 19 #include "mlir/IR/Verifier.h" 20 #include "mlir/Parser/Parser.h" 21 #include "mlir/Support/FileUtilities.h" 22 #include "mlir/Target/SPIRV/Deserialization.h" 23 #include "mlir/Target/SPIRV/Serialization.h" 24 #include "mlir/Tools/mlir-translate/Translation.h" 25 #include "llvm/ADT/StringRef.h" 26 #include "llvm/Support/MemoryBuffer.h" 27 #include "llvm/Support/SMLoc.h" 28 #include "llvm/Support/SourceMgr.h" 29 #include "llvm/Support/ToolOutputFile.h" 30 31 using namespace mlir; 32 33 //===----------------------------------------------------------------------===// 34 // Deserialization registration 35 //===----------------------------------------------------------------------===// 36 37 // Deserializes the SPIR-V binary module stored in the file named as 38 // `inputFilename` and returns a module containing the SPIR-V module. 39 static OwningOpRef<ModuleOp> deserializeModule(const llvm::MemoryBuffer *input, 40 MLIRContext *context) { 41 context->loadDialect<spirv::SPIRVDialect>(); 42 43 // Make sure the input stream can be treated as a stream of SPIR-V words 44 auto *start = input->getBufferStart(); 45 auto size = input->getBufferSize(); 46 if (size % sizeof(uint32_t) != 0) { 47 emitError(UnknownLoc::get(context)) 48 << "SPIR-V binary module must contain integral number of 32-bit words"; 49 return {}; 50 } 51 52 auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start), 53 size / sizeof(uint32_t)); 54 55 OwningOpRef<spirv::ModuleOp> spirvModule = 56 spirv::deserialize(binary, context); 57 if (!spirvModule) 58 return {}; 59 60 OwningOpRef<ModuleOp> module(ModuleOp::create(FileLineColLoc::get( 61 context, input->getBufferIdentifier(), /*line=*/0, /*column=*/0))); 62 module->getBody()->push_front(spirvModule.release()); 63 64 return module; 65 } 66 67 namespace mlir { 68 void registerFromSPIRVTranslation() { 69 TranslateToMLIRRegistration fromBinary( 70 "deserialize-spirv", 71 [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { 72 assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer"); 73 return deserializeModule( 74 sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context); 75 }); 76 } 77 } // namespace mlir 78 79 //===----------------------------------------------------------------------===// 80 // Serialization registration 81 //===----------------------------------------------------------------------===// 82 83 static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) { 84 if (!module) 85 return failure(); 86 87 SmallVector<uint32_t, 0> binary; 88 89 SmallVector<spirv::ModuleOp, 1> spirvModules; 90 module.walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); }); 91 92 if (spirvModules.empty()) 93 return module.emitError("found no 'spv.module' op"); 94 95 if (spirvModules.size() != 1) 96 return module.emitError("found more than one 'spv.module' op"); 97 98 if (failed(spirv::serialize(spirvModules[0], binary))) 99 return failure(); 100 101 output.write(reinterpret_cast<char *>(binary.data()), 102 binary.size() * sizeof(uint32_t)); 103 104 return mlir::success(); 105 } 106 107 namespace mlir { 108 void registerToSPIRVTranslation() { 109 TranslateFromMLIRRegistration toBinary( 110 "serialize-spirv", 111 [](ModuleOp module, raw_ostream &output) { 112 return serializeModule(module, output); 113 }, 114 [](DialectRegistry ®istry) { 115 registry.insert<spirv::SPIRVDialect>(); 116 }); 117 } 118 } // namespace mlir 119 120 //===----------------------------------------------------------------------===// 121 // Round-trip registration 122 //===----------------------------------------------------------------------===// 123 124 static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo, 125 raw_ostream &output) { 126 SmallVector<uint32_t, 0> binary; 127 MLIRContext *context = srcModule.getContext(); 128 auto spirvModules = srcModule.getOps<spirv::ModuleOp>(); 129 130 if (spirvModules.begin() == spirvModules.end()) 131 return srcModule.emitError("found no 'spv.module' op"); 132 133 if (std::next(spirvModules.begin()) != spirvModules.end()) 134 return srcModule.emitError("found more than one 'spv.module' op"); 135 136 spirv::SerializationOptions options; 137 options.emitDebugInfo = emitDebugInfo; 138 if (failed(spirv::serialize(*spirvModules.begin(), binary, options))) 139 return failure(); 140 141 MLIRContext deserializationContext(context->getDialectRegistry()); 142 // TODO: we should only load the required dialects instead of all dialects. 143 deserializationContext.loadAllAvailableDialects(); 144 // Then deserialize to get back a SPIR-V module. 145 OwningOpRef<spirv::ModuleOp> spirvModule = 146 spirv::deserialize(binary, &deserializationContext); 147 if (!spirvModule) 148 return failure(); 149 150 // Wrap around in a new MLIR module. 151 OwningOpRef<ModuleOp> dstModule(ModuleOp::create( 152 FileLineColLoc::get(&deserializationContext, 153 /*filename=*/"", /*line=*/0, /*column=*/0))); 154 dstModule->getBody()->push_front(spirvModule.release()); 155 if (failed(verify(*dstModule))) 156 return failure(); 157 dstModule->print(output); 158 159 return mlir::success(); 160 } 161 162 namespace mlir { 163 void registerTestRoundtripSPIRV() { 164 TranslateFromMLIRRegistration roundtrip( 165 "test-spirv-roundtrip", 166 [](ModuleOp module, raw_ostream &output) { 167 return roundTripModule(module, /*emitDebugInfo=*/false, output); 168 }, 169 [](DialectRegistry ®istry) { 170 registry.insert<spirv::SPIRVDialect>(); 171 }); 172 } 173 174 void registerTestRoundtripDebugSPIRV() { 175 TranslateFromMLIRRegistration roundtrip( 176 "test-spirv-roundtrip-debug", 177 [](ModuleOp module, raw_ostream &output) { 178 return roundTripModule(module, /*emitDebugInfo=*/true, output); 179 }, 180 [](DialectRegistry ®istry) { 181 registry.insert<spirv::SPIRVDialect>(); 182 }); 183 } 184 } // namespace mlir 185