1 //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains corner case tests for the SPIR-V serializer that are not 10 // covered by normal serialization and deserialization roundtripping. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/SPIRV/Serialization.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Location.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 23 #include "llvm/ADT/DenseSet.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/Sequence.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/ADT/StringRef.h" 28 #include "gmock/gmock.h" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // Test Fixture 34 //===----------------------------------------------------------------------===// 35 36 class SerializationTest : public ::testing::Test { 37 protected: 38 SerializationTest() { 39 context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); 40 initModuleOp(); 41 } 42 43 /// Initializes an empty SPIR-V module op. 44 void initModuleOp() { 45 OpBuilder builder(&context); 46 OperationState state(UnknownLoc::get(&context), 47 spirv::ModuleOp::getOperationName()); 48 state.addAttribute("addressing_model", 49 builder.getI32IntegerAttr(static_cast<uint32_t>( 50 spirv::AddressingModel::Logical))); 51 state.addAttribute("memory_model", 52 builder.getI32IntegerAttr( 53 static_cast<uint32_t>(spirv::MemoryModel::GLSL450))); 54 state.addAttribute("vce_triple", 55 spirv::VerCapExtAttr::get( 56 spirv::Version::V_1_0, ArrayRef<spirv::Capability>(), 57 ArrayRef<spirv::Extension>(), &context)); 58 spirv::ModuleOp::build(builder, state); 59 module = cast<spirv::ModuleOp>(Operation::create(state)); 60 } 61 62 /// Gets the `struct { float }` type. 63 spirv::StructType getFloatStructType() { 64 OpBuilder builder(module->getRegion()); 65 llvm::SmallVector<Type, 1> elementTypes{builder.getF32Type()}; 66 llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0}; 67 return spirv::StructType::get(elementTypes, offsetInfo); 68 } 69 70 /// Inserts a global variable of the given `type` and `name`. 71 spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) { 72 OpBuilder builder(module->getRegion()); 73 auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); 74 return builder.create<spirv::GlobalVariableOp>( 75 UnknownLoc::get(&context), TypeAttr::get(ptrType), 76 builder.getStringAttr(name), nullptr); 77 } 78 79 /// Handles a SPIR-V instruction with the given `opcode` and `operand`. 80 /// Returns true to interrupt. 81 using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode, 82 ArrayRef<uint32_t> operands)>; 83 84 /// Returns true if we can find a matching instruction in the SPIR-V blob. 85 bool scanInstruction(HandleFn handleFn) { 86 auto binarySize = binary.size(); 87 auto *begin = binary.begin(); 88 auto currOffset = spirv::kHeaderWordCount; 89 90 while (currOffset < binarySize) { 91 auto wordCount = binary[currOffset] >> 16; 92 if (!wordCount || (currOffset + wordCount > binarySize)) 93 return false; 94 95 spirv::Opcode opcode = 96 static_cast<spirv::Opcode>(binary[currOffset] & 0xffff); 97 llvm::ArrayRef<uint32_t> operands(begin + currOffset + 1, 98 begin + currOffset + wordCount); 99 if (handleFn(opcode, operands)) 100 return true; 101 102 currOffset += wordCount; 103 } 104 return false; 105 } 106 107 protected: 108 MLIRContext context; 109 OwningOpRef<spirv::ModuleOp> module; 110 SmallVector<uint32_t, 0> binary; 111 }; 112 113 //===----------------------------------------------------------------------===// 114 // Block decoration 115 //===----------------------------------------------------------------------===// 116 117 TEST_F(SerializationTest, ContainsBlockDecoration) { 118 auto structType = getFloatStructType(); 119 addGlobalVar(structType, "var0"); 120 121 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); 122 123 auto hasBlockDecoration = [](spirv::Opcode opcode, 124 ArrayRef<uint32_t> operands) { 125 return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 && 126 operands[1] == static_cast<uint32_t>(spirv::Decoration::Block); 127 }; 128 EXPECT_TRUE(scanInstruction(hasBlockDecoration)); 129 } 130 131 TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) { 132 auto structType = getFloatStructType(); 133 // Two global variables using the same type should not decorate the type with 134 // duplicated `Block` decorations. 135 addGlobalVar(structType, "var0"); 136 addGlobalVar(structType, "var1"); 137 138 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); 139 140 unsigned count = 0; 141 auto countBlockDecoration = [&count](spirv::Opcode opcode, 142 ArrayRef<uint32_t> operands) { 143 if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 && 144 operands[1] == static_cast<uint32_t>(spirv::Decoration::Block)) 145 ++count; 146 return false; 147 }; 148 ASSERT_FALSE(scanInstruction(countBlockDecoration)); 149 EXPECT_EQ(count, 1u); 150 } 151 152 TEST_F(SerializationTest, ContainsSymbolName) { 153 auto structType = getFloatStructType(); 154 addGlobalVar(structType, "var0"); 155 156 spirv::SerializationOptions options; 157 options.emitSymbolName = true; 158 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options))); 159 160 auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { 161 unsigned index = 1; // Skip the result <id> 162 return opcode == spirv::Opcode::OpName && 163 spirv::decodeStringLiteral(operands, index) == "var0"; 164 }; 165 EXPECT_TRUE(scanInstruction(hasVarName)); 166 } 167 168 TEST_F(SerializationTest, DoesNotContainSymbolName) { 169 auto structType = getFloatStructType(); 170 addGlobalVar(structType, "var0"); 171 172 spirv::SerializationOptions options; 173 options.emitSymbolName = false; 174 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options))); 175 176 auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { 177 unsigned index = 1; // Skip the result <id> 178 return opcode == spirv::Opcode::OpName && 179 spirv::decodeStringLiteral(operands, index) == "var0"; 180 }; 181 EXPECT_FALSE(scanInstruction(hasVarName)); 182 } 183