1 //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation from SPIR-V binary module to MLIR SPIR-V 10 // ModuleOp. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/IR/Dialect.h" 20 #include "mlir/Parser.h" 21 #include "mlir/Support/FileUtilities.h" 22 #include "mlir/Target/SPIRV/Deserialization.h" 23 #include "mlir/Target/SPIRV/Serialization.h" 24 #include "mlir/Translation.h" 25 #include "llvm/ADT/StringRef.h" 26 #include "llvm/Support/MemoryBuffer.h" 27 #include "llvm/Support/SMLoc.h" 28 #include "llvm/Support/SourceMgr.h" 29 #include "llvm/Support/ToolOutputFile.h" 30 31 using namespace mlir; 32 33 //===----------------------------------------------------------------------===// 34 // Deserialization registration 35 //===----------------------------------------------------------------------===// 36 37 // Deserializes the SPIR-V binary module stored in the file named as 38 // `inputFilename` and returns a module containing the SPIR-V module. 39 static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input, 40 MLIRContext *context) { 41 context->loadDialect<spirv::SPIRVDialect>(); 42 43 // Make sure the input stream can be treated as a stream of SPIR-V words 44 auto start = input->getBufferStart(); 45 auto size = input->getBufferSize(); 46 if (size % sizeof(uint32_t) != 0) { 47 emitError(UnknownLoc::get(context)) 48 << "SPIR-V binary module must contain integral number of 32-bit words"; 49 return {}; 50 } 51 52 auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start), 53 size / sizeof(uint32_t)); 54 55 spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context); 56 if (!spirvModule) 57 return {}; 58 59 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( 60 input->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); 61 module->getBody()->push_front(spirvModule.release()); 62 63 return module; 64 } 65 66 namespace mlir { 67 void registerFromSPIRVTranslation() { 68 TranslateToMLIRRegistration fromBinary( 69 "deserialize-spirv", 70 [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { 71 assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer"); 72 return deserializeModule( 73 sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context); 74 }); 75 } 76 } // namespace mlir 77 78 //===----------------------------------------------------------------------===// 79 // Serialization registration 80 //===----------------------------------------------------------------------===// 81 82 static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) { 83 if (!module) 84 return failure(); 85 86 SmallVector<uint32_t, 0> binary; 87 88 SmallVector<spirv::ModuleOp, 1> spirvModules; 89 module.walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); }); 90 91 if (spirvModules.empty()) 92 return module.emitError("found no 'spv.module' op"); 93 94 if (spirvModules.size() != 1) 95 return module.emitError("found more than one 'spv.module' op"); 96 97 if (failed( 98 spirv::serialize(spirvModules[0], binary, /*emitDebuginfo=*/false))) 99 return failure(); 100 101 output.write(reinterpret_cast<char *>(binary.data()), 102 binary.size() * sizeof(uint32_t)); 103 104 return mlir::success(); 105 } 106 107 namespace mlir { 108 void registerToSPIRVTranslation() { 109 TranslateFromMLIRRegistration toBinary( 110 "serialize-spirv", 111 [](ModuleOp module, raw_ostream &output) { 112 return serializeModule(module, output); 113 }, 114 [](DialectRegistry ®istry) { 115 registry.insert<spirv::SPIRVDialect>(); 116 }); 117 } 118 } // namespace mlir 119 120 //===----------------------------------------------------------------------===// 121 // Round-trip registration 122 //===----------------------------------------------------------------------===// 123 124 static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo, 125 raw_ostream &output) { 126 SmallVector<uint32_t, 0> binary; 127 MLIRContext *context = srcModule.getContext(); 128 auto spirvModules = srcModule.getOps<spirv::ModuleOp>(); 129 130 if (spirvModules.begin() == spirvModules.end()) 131 return srcModule.emitError("found no 'spv.module' op"); 132 133 if (std::next(spirvModules.begin()) != spirvModules.end()) 134 return srcModule.emitError("found more than one 'spv.module' op"); 135 136 if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo))) 137 return failure(); 138 139 MLIRContext deserializationContext(context->getDialectRegistry()); 140 // TODO: we should only load the required dialects instead of all dialects. 141 deserializationContext.loadAllAvailableDialects(); 142 // Then deserialize to get back a SPIR-V module. 143 spirv::OwningSPIRVModuleRef spirvModule = 144 spirv::deserialize(binary, &deserializationContext); 145 if (!spirvModule) 146 return failure(); 147 148 // Wrap around in a new MLIR module. 149 OwningModuleRef dstModule(ModuleOp::create(FileLineColLoc::get( 150 /*filename=*/"", /*line=*/0, /*column=*/0, &deserializationContext))); 151 dstModule->getBody()->push_front(spirvModule.release()); 152 dstModule->print(output); 153 154 return mlir::success(); 155 } 156 157 namespace mlir { 158 void registerTestRoundtripSPIRV() { 159 TranslateFromMLIRRegistration roundtrip( 160 "test-spirv-roundtrip", 161 [](ModuleOp module, raw_ostream &output) { 162 return roundTripModule(module, /*emitDebugInfo=*/false, output); 163 }, 164 [](DialectRegistry ®istry) { 165 registry.insert<spirv::SPIRVDialect>(); 166 }); 167 } 168 169 void registerTestRoundtripDebugSPIRV() { 170 TranslateFromMLIRRegistration roundtrip( 171 "test-spirv-roundtrip-debug", 172 [](ModuleOp module, raw_ostream &output) { 173 return roundTripModule(module, /*emitDebugInfo=*/true, output); 174 }, 175 [](DialectRegistry ®istry) { 176 registry.insert<spirv::SPIRVDialect>(); 177 }); 178 } 179 } // namespace mlir 180