1 //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation from SPIR-V binary module to MLIR SPIR-V 10 // ModuleOp. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/Dialect.h" 19 #include "mlir/Parser.h" 20 #include "mlir/Support/FileUtilities.h" 21 #include "mlir/Target/SPIRV/Deserialization.h" 22 #include "mlir/Target/SPIRV/Serialization.h" 23 #include "mlir/Translation.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/Support/MemoryBuffer.h" 26 #include "llvm/Support/SMLoc.h" 27 #include "llvm/Support/SourceMgr.h" 28 #include "llvm/Support/ToolOutputFile.h" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // Deserialization registration 34 //===----------------------------------------------------------------------===// 35 36 // Deserializes the SPIR-V binary module stored in the file named as 37 // `inputFilename` and returns a module containing the SPIR-V module. 38 static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input, 39 MLIRContext *context) { 40 context->loadDialect<spirv::SPIRVDialect>(); 41 42 // Make sure the input stream can be treated as a stream of SPIR-V words 43 auto *start = input->getBufferStart(); 44 auto size = input->getBufferSize(); 45 if (size % sizeof(uint32_t) != 0) { 46 emitError(UnknownLoc::get(context)) 47 << "SPIR-V binary module must contain integral number of 32-bit words"; 48 return {}; 49 } 50 51 auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start), 52 size / sizeof(uint32_t)); 53 54 OwningOpRef<spirv::ModuleOp> spirvModule = 55 spirv::deserialize(binary, context); 56 if (!spirvModule) 57 return {}; 58 59 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 60 context, input->getBufferIdentifier(), /*line=*/0, /*column=*/0))); 61 module->getBody()->push_front(spirvModule.release()); 62 63 return module; 64 } 65 66 namespace mlir { 67 void registerFromSPIRVTranslation() { 68 TranslateToMLIRRegistration fromBinary( 69 "deserialize-spirv", 70 [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { 71 assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer"); 72 return deserializeModule( 73 sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context); 74 }); 75 } 76 } // namespace mlir 77 78 //===----------------------------------------------------------------------===// 79 // Serialization registration 80 //===----------------------------------------------------------------------===// 81 82 static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) { 83 if (!module) 84 return failure(); 85 86 SmallVector<uint32_t, 0> binary; 87 88 SmallVector<spirv::ModuleOp, 1> spirvModules; 89 module.walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); }); 90 91 if (spirvModules.empty()) 92 return module.emitError("found no 'spv.module' op"); 93 94 if (spirvModules.size() != 1) 95 return module.emitError("found more than one 'spv.module' op"); 96 97 if (failed(spirv::serialize(spirvModules[0], binary))) 98 return failure(); 99 100 output.write(reinterpret_cast<char *>(binary.data()), 101 binary.size() * sizeof(uint32_t)); 102 103 return mlir::success(); 104 } 105 106 namespace mlir { 107 void registerToSPIRVTranslation() { 108 TranslateFromMLIRRegistration toBinary( 109 "serialize-spirv", 110 [](ModuleOp module, raw_ostream &output) { 111 return serializeModule(module, output); 112 }, 113 [](DialectRegistry ®istry) { 114 registry.insert<spirv::SPIRVDialect>(); 115 }); 116 } 117 } // namespace mlir 118 119 //===----------------------------------------------------------------------===// 120 // Round-trip registration 121 //===----------------------------------------------------------------------===// 122 123 static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo, 124 raw_ostream &output) { 125 SmallVector<uint32_t, 0> binary; 126 MLIRContext *context = srcModule.getContext(); 127 auto spirvModules = srcModule.getOps<spirv::ModuleOp>(); 128 129 if (spirvModules.begin() == spirvModules.end()) 130 return srcModule.emitError("found no 'spv.module' op"); 131 132 if (std::next(spirvModules.begin()) != spirvModules.end()) 133 return srcModule.emitError("found more than one 'spv.module' op"); 134 135 spirv::SerializationOptions options; 136 options.emitDebugInfo = emitDebugInfo; 137 if (failed(spirv::serialize(*spirvModules.begin(), binary, options))) 138 return failure(); 139 140 MLIRContext deserializationContext(context->getDialectRegistry()); 141 // TODO: we should only load the required dialects instead of all dialects. 142 deserializationContext.loadAllAvailableDialects(); 143 // Then deserialize to get back a SPIR-V module. 144 OwningOpRef<spirv::ModuleOp> spirvModule = 145 spirv::deserialize(binary, &deserializationContext); 146 if (!spirvModule) 147 return failure(); 148 149 // Wrap around in a new MLIR module. 150 OwningModuleRef dstModule(ModuleOp::create( 151 FileLineColLoc::get(&deserializationContext, 152 /*filename=*/"", /*line=*/0, /*column=*/0))); 153 dstModule->getBody()->push_front(spirvModule.release()); 154 dstModule->print(output); 155 156 return mlir::success(); 157 } 158 159 namespace mlir { 160 void registerTestRoundtripSPIRV() { 161 TranslateFromMLIRRegistration roundtrip( 162 "test-spirv-roundtrip", 163 [](ModuleOp module, raw_ostream &output) { 164 return roundTripModule(module, /*emitDebugInfo=*/false, output); 165 }, 166 [](DialectRegistry ®istry) { 167 registry.insert<spirv::SPIRVDialect>(); 168 }); 169 } 170 171 void registerTestRoundtripDebugSPIRV() { 172 TranslateFromMLIRRegistration roundtrip( 173 "test-spirv-roundtrip-debug", 174 [](ModuleOp module, raw_ostream &output) { 175 return roundTripModule(module, /*emitDebugInfo=*/true, output); 176 }, 177 [](DialectRegistry ®istry) { 178 registry.insert<spirv::SPIRVDialect>(); 179 }); 180 } 181 } // namespace mlir 182