1 //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains corner case tests for the SPIR-V serializer that are not 10 // covered by normal serialization and deserialization roundtripping. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/SPIRV/Serialization.h" 15 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" 17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 18 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/SPIRVTypes.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/Location.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "llvm/ADT/DenseSet.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/Sequence.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/ADT/StringRef.h" 28 #include "gmock/gmock.h" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // Test Fixture 34 //===----------------------------------------------------------------------===// 35 36 class SerializationTest : public ::testing::Test { 37 protected: 38 SerializationTest() { createModuleOp(); } 39 40 void createModuleOp() { 41 OpBuilder builder(&context); 42 OperationState state(UnknownLoc::get(&context), 43 spirv::ModuleOp::getOperationName()); 44 state.addAttribute("addressing_model", 45 builder.getI32IntegerAttr(static_cast<uint32_t>( 46 spirv::AddressingModel::Logical))); 47 state.addAttribute("memory_model", 48 builder.getI32IntegerAttr( 49 static_cast<uint32_t>(spirv::MemoryModel::GLSL450))); 50 state.addAttribute("vce_triple", 51 spirv::VerCapExtAttr::get( 52 spirv::Version::V_1_0, ArrayRef<spirv::Capability>(), 53 ArrayRef<spirv::Extension>(), &context)); 54 spirv::ModuleOp::build(builder, state); 55 module = cast<spirv::ModuleOp>(Operation::create(state)); 56 } 57 58 Type getFloatStructType() { 59 OpBuilder opBuilder(module.body()); 60 llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()}; 61 llvm::SmallVector<spirv::StructType::LayoutInfo, 1> layoutInfo{0}; 62 auto structType = spirv::StructType::get(elementTypes, layoutInfo); 63 return structType; 64 } 65 66 void addGlobalVar(Type type, llvm::StringRef name) { 67 OpBuilder opBuilder(module.body()); 68 auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); 69 opBuilder.create<spirv::GlobalVariableOp>( 70 UnknownLoc::get(&context), TypeAttr::get(ptrType), 71 opBuilder.getStringAttr(name), nullptr); 72 } 73 74 bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode, 75 ArrayRef<uint32_t> operands)> 76 matchFn) { 77 auto binarySize = binary.size(); 78 auto begin = binary.begin(); 79 auto currOffset = spirv::kHeaderWordCount; 80 81 while (currOffset < binarySize) { 82 auto wordCount = binary[currOffset] >> 16; 83 if (!wordCount || (currOffset + wordCount > binarySize)) { 84 return false; 85 } 86 spirv::Opcode opcode = 87 static_cast<spirv::Opcode>(binary[currOffset] & 0xffff); 88 89 if (matchFn(opcode, 90 llvm::ArrayRef<uint32_t>(begin + currOffset + 1, 91 begin + currOffset + wordCount))) { 92 return true; 93 } 94 currOffset += wordCount; 95 } 96 return false; 97 } 98 99 protected: 100 MLIRContext context; 101 spirv::ModuleOp module; 102 SmallVector<uint32_t, 0> binary; 103 }; 104 105 //===----------------------------------------------------------------------===// 106 // Block decoration 107 //===----------------------------------------------------------------------===// 108 109 TEST_F(SerializationTest, BlockDecorationTest) { 110 auto structType = getFloatStructType(); 111 addGlobalVar(structType, "var0"); 112 ASSERT_TRUE(succeeded(spirv::serialize(module, binary))); 113 auto hasBlockDecoration = [](spirv::Opcode opcode, 114 ArrayRef<uint32_t> operands) -> bool { 115 if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2) 116 return false; 117 return operands[1] == static_cast<uint32_t>(spirv::Decoration::Block); 118 }; 119 EXPECT_TRUE(findInstruction(hasBlockDecoration)); 120 } 121