1 //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains corner case tests for the SPIR-V serializer that are not 10 // covered by normal serialization and deserialization roundtripping. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/SPIRV/Serialization.h" 15 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" 17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 18 #include "mlir/Dialect/SPIRV/SPIRVModule.h" 19 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 20 #include "mlir/Dialect/SPIRV/SPIRVTypes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Location.h" 23 #include "mlir/IR/MLIRContext.h" 24 #include "llvm/ADT/DenseSet.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/ADT/Sequence.h" 27 #include "llvm/ADT/SmallVector.h" 28 #include "llvm/ADT/StringRef.h" 29 #include "gmock/gmock.h" 30 31 using namespace mlir; 32 33 //===----------------------------------------------------------------------===// 34 // Test Fixture 35 //===----------------------------------------------------------------------===// 36 37 class SerializationTest : public ::testing::Test { 38 protected: 39 SerializationTest() { createModuleOp(); } 40 41 void createModuleOp() { 42 OpBuilder builder(&context); 43 OperationState state(UnknownLoc::get(&context), 44 spirv::ModuleOp::getOperationName()); 45 state.addAttribute("addressing_model", 46 builder.getI32IntegerAttr(static_cast<uint32_t>( 47 spirv::AddressingModel::Logical))); 48 state.addAttribute("memory_model", 49 builder.getI32IntegerAttr( 50 static_cast<uint32_t>(spirv::MemoryModel::GLSL450))); 51 state.addAttribute("vce_triple", 52 spirv::VerCapExtAttr::get( 53 spirv::Version::V_1_0, ArrayRef<spirv::Capability>(), 54 ArrayRef<spirv::Extension>(), &context)); 55 spirv::ModuleOp::build(builder, state); 56 module = cast<spirv::ModuleOp>(Operation::create(state)); 57 } 58 59 Type getFloatStructType() { 60 OpBuilder opBuilder(module->body()); 61 llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()}; 62 llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0}; 63 auto structType = spirv::StructType::get(elementTypes, offsetInfo); 64 return structType; 65 } 66 67 void addGlobalVar(Type type, llvm::StringRef name) { 68 OpBuilder opBuilder(module->body()); 69 auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); 70 opBuilder.create<spirv::GlobalVariableOp>( 71 UnknownLoc::get(&context), TypeAttr::get(ptrType), 72 opBuilder.getStringAttr(name), nullptr); 73 } 74 75 bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode, 76 ArrayRef<uint32_t> operands)> 77 matchFn) { 78 auto binarySize = binary.size(); 79 auto begin = binary.begin(); 80 auto currOffset = spirv::kHeaderWordCount; 81 82 while (currOffset < binarySize) { 83 auto wordCount = binary[currOffset] >> 16; 84 if (!wordCount || (currOffset + wordCount > binarySize)) { 85 return false; 86 } 87 spirv::Opcode opcode = 88 static_cast<spirv::Opcode>(binary[currOffset] & 0xffff); 89 90 if (matchFn(opcode, 91 llvm::ArrayRef<uint32_t>(begin + currOffset + 1, 92 begin + currOffset + wordCount))) { 93 return true; 94 } 95 currOffset += wordCount; 96 } 97 return false; 98 } 99 100 protected: 101 MLIRContext context; 102 spirv::OwningSPIRVModuleRef module; 103 SmallVector<uint32_t, 0> binary; 104 }; 105 106 //===----------------------------------------------------------------------===// 107 // Block decoration 108 //===----------------------------------------------------------------------===// 109 110 TEST_F(SerializationTest, BlockDecorationTest) { 111 auto structType = getFloatStructType(); 112 addGlobalVar(structType, "var0"); 113 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); 114 auto hasBlockDecoration = [](spirv::Opcode opcode, 115 ArrayRef<uint32_t> operands) -> bool { 116 if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2) 117 return false; 118 return operands[1] == static_cast<uint32_t>(spirv::Decoration::Block); 119 }; 120 EXPECT_TRUE(findInstruction(hasBlockDecoration)); 121 } 122