1 //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains corner case tests for the SPIR-V serializer that are not
10 // covered by normal serialization and deserialization roundtripping.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/SPIRV/Serialization.h"
15 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/SPIRVModule.h"
19 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
20 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "gmock/gmock.h"
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Test Fixture
35 //===----------------------------------------------------------------------===//
36 
37 class SerializationTest : public ::testing::Test {
38 protected:
39   SerializationTest() { createModuleOp(); }
40 
41   void createModuleOp() {
42     OpBuilder builder(&context);
43     OperationState state(UnknownLoc::get(&context),
44                          spirv::ModuleOp::getOperationName());
45     state.addAttribute("addressing_model",
46                        builder.getI32IntegerAttr(static_cast<uint32_t>(
47                            spirv::AddressingModel::Logical)));
48     state.addAttribute("memory_model",
49                        builder.getI32IntegerAttr(
50                            static_cast<uint32_t>(spirv::MemoryModel::GLSL450)));
51     state.addAttribute("vce_triple",
52                        spirv::VerCapExtAttr::get(
53                            spirv::Version::V_1_0, ArrayRef<spirv::Capability>(),
54                            ArrayRef<spirv::Extension>(), &context));
55     spirv::ModuleOp::build(builder, state);
56     module = cast<spirv::ModuleOp>(Operation::create(state));
57   }
58 
59   Type getFloatStructType() {
60     OpBuilder opBuilder(module->body());
61     llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
62     llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
63     auto structType = spirv::StructType::get(elementTypes, offsetInfo);
64     return structType;
65   }
66 
67   void addGlobalVar(Type type, llvm::StringRef name) {
68     OpBuilder opBuilder(module->body());
69     auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
70     opBuilder.create<spirv::GlobalVariableOp>(
71         UnknownLoc::get(&context), TypeAttr::get(ptrType),
72         opBuilder.getStringAttr(name), nullptr);
73   }
74 
75   bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode,
76                                                ArrayRef<uint32_t> operands)>
77                            matchFn) {
78     auto binarySize = binary.size();
79     auto begin = binary.begin();
80     auto currOffset = spirv::kHeaderWordCount;
81 
82     while (currOffset < binarySize) {
83       auto wordCount = binary[currOffset] >> 16;
84       if (!wordCount || (currOffset + wordCount > binarySize)) {
85         return false;
86       }
87       spirv::Opcode opcode =
88           static_cast<spirv::Opcode>(binary[currOffset] & 0xffff);
89 
90       if (matchFn(opcode,
91                   llvm::ArrayRef<uint32_t>(begin + currOffset + 1,
92                                            begin + currOffset + wordCount))) {
93         return true;
94       }
95       currOffset += wordCount;
96     }
97     return false;
98   }
99 
100 protected:
101   MLIRContext context;
102   spirv::OwningSPIRVModuleRef module;
103   SmallVector<uint32_t, 0> binary;
104 };
105 
106 //===----------------------------------------------------------------------===//
107 // Block decoration
108 //===----------------------------------------------------------------------===//
109 
110 TEST_F(SerializationTest, BlockDecorationTest) {
111   auto structType = getFloatStructType();
112   addGlobalVar(structType, "var0");
113   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
114   auto hasBlockDecoration = [](spirv::Opcode opcode,
115                                ArrayRef<uint32_t> operands) -> bool {
116     if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)
117       return false;
118     return operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
119   };
120   EXPECT_TRUE(findInstruction(hasBlockDecoration));
121 }
122