1 //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains corner case tests for the SPIR-V serializer that are not
10 // covered by normal serialization and deserialization roundtripping.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/SPIRV/Serialization.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "gmock/gmock.h"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Test Fixture
34 //===----------------------------------------------------------------------===//
35 
36 class SerializationTest : public ::testing::Test {
37 protected:
SerializationTest()38   SerializationTest() {
39     context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
40     initModuleOp();
41   }
42 
43   /// Initializes an empty SPIR-V module op.
initModuleOp()44   void initModuleOp() {
45     OpBuilder builder(&context);
46     OperationState state(UnknownLoc::get(&context),
47                          spirv::ModuleOp::getOperationName());
48     state.addAttribute("addressing_model",
49                        builder.getI32IntegerAttr(static_cast<uint32_t>(
50                            spirv::AddressingModel::Logical)));
51     state.addAttribute("memory_model",
52                        builder.getI32IntegerAttr(
53                            static_cast<uint32_t>(spirv::MemoryModel::GLSL450)));
54     state.addAttribute("vce_triple",
55                        spirv::VerCapExtAttr::get(
56                            spirv::Version::V_1_0, ArrayRef<spirv::Capability>(),
57                            ArrayRef<spirv::Extension>(), &context));
58     spirv::ModuleOp::build(builder, state);
59     module = cast<spirv::ModuleOp>(Operation::create(state));
60   }
61 
62   /// Gets the `struct { float }` type.
getFloatStructType()63   spirv::StructType getFloatStructType() {
64     OpBuilder builder(module->getRegion());
65     llvm::SmallVector<Type, 1> elementTypes{builder.getF32Type()};
66     llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
67     return spirv::StructType::get(elementTypes, offsetInfo);
68   }
69 
70   /// Inserts a global variable of the given `type` and `name`.
addGlobalVar(Type type,llvm::StringRef name)71   spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) {
72     OpBuilder builder(module->getRegion());
73     auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
74     return builder.create<spirv::GlobalVariableOp>(
75         UnknownLoc::get(&context), TypeAttr::get(ptrType),
76         builder.getStringAttr(name), nullptr);
77   }
78 
79   /// Handles a SPIR-V instruction with the given `opcode` and `operand`.
80   /// Returns true to interrupt.
81   using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode,
82                                            ArrayRef<uint32_t> operands)>;
83 
84   /// Returns true if we can find a matching instruction in the SPIR-V blob.
scanInstruction(HandleFn handleFn)85   bool scanInstruction(HandleFn handleFn) {
86     auto binarySize = binary.size();
87     auto *begin = binary.begin();
88     auto currOffset = spirv::kHeaderWordCount;
89 
90     while (currOffset < binarySize) {
91       auto wordCount = binary[currOffset] >> 16;
92       if (!wordCount || (currOffset + wordCount > binarySize))
93         return false;
94 
95       spirv::Opcode opcode =
96           static_cast<spirv::Opcode>(binary[currOffset] & 0xffff);
97       llvm::ArrayRef<uint32_t> operands(begin + currOffset + 1,
98                                         begin + currOffset + wordCount);
99       if (handleFn(opcode, operands))
100         return true;
101 
102       currOffset += wordCount;
103     }
104     return false;
105   }
106 
107 protected:
108   MLIRContext context;
109   OwningOpRef<spirv::ModuleOp> module;
110   SmallVector<uint32_t, 0> binary;
111 };
112 
113 //===----------------------------------------------------------------------===//
114 // Block decoration
115 //===----------------------------------------------------------------------===//
116 
TEST_F(SerializationTest,ContainsBlockDecoration)117 TEST_F(SerializationTest, ContainsBlockDecoration) {
118   auto structType = getFloatStructType();
119   addGlobalVar(structType, "var0");
120 
121   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
122 
123   auto hasBlockDecoration = [](spirv::Opcode opcode,
124                                ArrayRef<uint32_t> operands) {
125     return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
126            operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
127   };
128   EXPECT_TRUE(scanInstruction(hasBlockDecoration));
129 }
130 
TEST_F(SerializationTest,ContainsNoDuplicatedBlockDecoration)131 TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) {
132   auto structType = getFloatStructType();
133   // Two global variables using the same type should not decorate the type with
134   // duplicated `Block` decorations.
135   addGlobalVar(structType, "var0");
136   addGlobalVar(structType, "var1");
137 
138   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
139 
140   unsigned count = 0;
141   auto countBlockDecoration = [&count](spirv::Opcode opcode,
142                                        ArrayRef<uint32_t> operands) {
143     if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
144         operands[1] == static_cast<uint32_t>(spirv::Decoration::Block))
145       ++count;
146     return false;
147   };
148   ASSERT_FALSE(scanInstruction(countBlockDecoration));
149   EXPECT_EQ(count, 1u);
150 }
151 
TEST_F(SerializationTest,ContainsSymbolName)152 TEST_F(SerializationTest, ContainsSymbolName) {
153   auto structType = getFloatStructType();
154   addGlobalVar(structType, "var0");
155 
156   spirv::SerializationOptions options;
157   options.emitSymbolName = true;
158   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
159 
160   auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
161     unsigned index = 1; // Skip the result <id>
162     return opcode == spirv::Opcode::OpName &&
163            spirv::decodeStringLiteral(operands, index) == "var0";
164   };
165   EXPECT_TRUE(scanInstruction(hasVarName));
166 }
167 
TEST_F(SerializationTest,DoesNotContainSymbolName)168 TEST_F(SerializationTest, DoesNotContainSymbolName) {
169   auto structType = getFloatStructType();
170   addGlobalVar(structType, "var0");
171 
172   spirv::SerializationOptions options;
173   options.emitSymbolName = false;
174   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
175 
176   auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
177     unsigned index = 1; // Skip the result <id>
178     return opcode == spirv::Opcode::OpName &&
179            spirv::decodeStringLiteral(operands, index) == "var0";
180   };
181   EXPECT_FALSE(scanInstruction(hasVarName));
182 }
183