1 //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation from SPIR-V binary module to MLIR SPIR-V 10 // ModuleOp. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/Dialect.h" 19 #include "mlir/Parser.h" 20 #include "mlir/Support/FileUtilities.h" 21 #include "mlir/Target/SPIRV/Deserialization.h" 22 #include "mlir/Target/SPIRV/Serialization.h" 23 #include "mlir/Translation.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/Support/MemoryBuffer.h" 26 #include "llvm/Support/SMLoc.h" 27 #include "llvm/Support/SourceMgr.h" 28 #include "llvm/Support/ToolOutputFile.h" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // Deserialization registration 34 //===----------------------------------------------------------------------===// 35 36 // Deserializes the SPIR-V binary module stored in the file named as 37 // `inputFilename` and returns a module containing the SPIR-V module. 38 static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input, 39 MLIRContext *context) { 40 context->loadDialect<spirv::SPIRVDialect>(); 41 42 // Make sure the input stream can be treated as a stream of SPIR-V words 43 auto start = input->getBufferStart(); 44 auto size = input->getBufferSize(); 45 if (size % sizeof(uint32_t) != 0) { 46 emitError(UnknownLoc::get(context)) 47 << "SPIR-V binary module must contain integral number of 32-bit words"; 48 return {}; 49 } 50 51 auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start), 52 size / sizeof(uint32_t)); 53 54 OwningOpRef<spirv::ModuleOp> spirvModule = 55 spirv::deserialize(binary, context); 56 if (!spirvModule) 57 return {}; 58 59 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 60 context, input->getBufferIdentifier(), /*line=*/0, /*column=*/0))); 61 module->getBody()->push_front(spirvModule.release()); 62 63 return module; 64 } 65 66 namespace mlir { 67 void registerFromSPIRVTranslation() { 68 TranslateToMLIRRegistration fromBinary( 69 "deserialize-spirv", 70 [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { 71 assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer"); 72 return deserializeModule( 73 sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context); 74 }); 75 } 76 } // namespace mlir 77 78 //===----------------------------------------------------------------------===// 79 // Serialization registration 80 //===----------------------------------------------------------------------===// 81 82 static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) { 83 if (!module) 84 return failure(); 85 86 SmallVector<uint32_t, 0> binary; 87 88 SmallVector<spirv::ModuleOp, 1> spirvModules; 89 module.walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); }); 90 91 if (spirvModules.empty()) 92 return module.emitError("found no 'spv.module' op"); 93 94 if (spirvModules.size() != 1) 95 return module.emitError("found more than one 'spv.module' op"); 96 97 if (failed( 98 spirv::serialize(spirvModules[0], binary, /*emitDebuginfo=*/false))) 99 return failure(); 100 101 output.write(reinterpret_cast<char *>(binary.data()), 102 binary.size() * sizeof(uint32_t)); 103 104 return mlir::success(); 105 } 106 107 namespace mlir { 108 void registerToSPIRVTranslation() { 109 TranslateFromMLIRRegistration toBinary( 110 "serialize-spirv", 111 [](ModuleOp module, raw_ostream &output) { 112 return serializeModule(module, output); 113 }, 114 [](DialectRegistry ®istry) { 115 registry.insert<spirv::SPIRVDialect>(); 116 }); 117 } 118 } // namespace mlir 119 120 //===----------------------------------------------------------------------===// 121 // Round-trip registration 122 //===----------------------------------------------------------------------===// 123 124 static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo, 125 raw_ostream &output) { 126 SmallVector<uint32_t, 0> binary; 127 MLIRContext *context = srcModule.getContext(); 128 auto spirvModules = srcModule.getOps<spirv::ModuleOp>(); 129 130 if (spirvModules.begin() == spirvModules.end()) 131 return srcModule.emitError("found no 'spv.module' op"); 132 133 if (std::next(spirvModules.begin()) != spirvModules.end()) 134 return srcModule.emitError("found more than one 'spv.module' op"); 135 136 if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo))) 137 return failure(); 138 139 MLIRContext deserializationContext(context->getDialectRegistry()); 140 // TODO: we should only load the required dialects instead of all dialects. 141 deserializationContext.loadAllAvailableDialects(); 142 // Then deserialize to get back a SPIR-V module. 143 OwningOpRef<spirv::ModuleOp> spirvModule = 144 spirv::deserialize(binary, &deserializationContext); 145 if (!spirvModule) 146 return failure(); 147 148 // Wrap around in a new MLIR module. 149 OwningModuleRef dstModule(ModuleOp::create( 150 FileLineColLoc::get(&deserializationContext, 151 /*filename=*/"", /*line=*/0, /*column=*/0))); 152 dstModule->getBody()->push_front(spirvModule.release()); 153 dstModule->print(output); 154 155 return mlir::success(); 156 } 157 158 namespace mlir { 159 void registerTestRoundtripSPIRV() { 160 TranslateFromMLIRRegistration roundtrip( 161 "test-spirv-roundtrip", 162 [](ModuleOp module, raw_ostream &output) { 163 return roundTripModule(module, /*emitDebugInfo=*/false, output); 164 }, 165 [](DialectRegistry ®istry) { 166 registry.insert<spirv::SPIRVDialect>(); 167 }); 168 } 169 170 void registerTestRoundtripDebugSPIRV() { 171 TranslateFromMLIRRegistration roundtrip( 172 "test-spirv-roundtrip-debug", 173 [](ModuleOp module, raw_ostream &output) { 174 return roundTripModule(module, /*emitDebugInfo=*/true, output); 175 }, 176 [](DialectRegistry ®istry) { 177 registry.insert<spirv::SPIRVDialect>(); 178 }); 179 } 180 } // namespace mlir 181