1 //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains corner case tests for the SPIR-V serializer that are not 10 // covered by normal serialization and deserialization roundtripping. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/SPIRV/Serialization.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/Location.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 24 #include "llvm/ADT/DenseSet.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/ADT/Sequence.h" 27 #include "llvm/ADT/SmallVector.h" 28 #include "llvm/ADT/StringRef.h" 29 #include "gmock/gmock.h" 30 31 using namespace mlir; 32 33 //===----------------------------------------------------------------------===// 34 // Test Fixture 35 //===----------------------------------------------------------------------===// 36 37 class SerializationTest : public ::testing::Test { 38 protected: 39 SerializationTest() { 40 context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); 41 createModuleOp(); 42 } 43 44 void createModuleOp() { 45 OpBuilder builder(&context); 46 OperationState state(UnknownLoc::get(&context), 47 spirv::ModuleOp::getOperationName()); 48 state.addAttribute("addressing_model", 49 builder.getI32IntegerAttr(static_cast<uint32_t>( 50 spirv::AddressingModel::Logical))); 51 state.addAttribute("memory_model", 52 builder.getI32IntegerAttr( 53 static_cast<uint32_t>(spirv::MemoryModel::GLSL450))); 54 state.addAttribute("vce_triple", 55 spirv::VerCapExtAttr::get( 56 spirv::Version::V_1_0, ArrayRef<spirv::Capability>(), 57 ArrayRef<spirv::Extension>(), &context)); 58 spirv::ModuleOp::build(builder, state); 59 module = cast<spirv::ModuleOp>(Operation::create(state)); 60 } 61 62 Type getFloatStructType() { 63 OpBuilder opBuilder(module->body()); 64 llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()}; 65 llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0}; 66 auto structType = spirv::StructType::get(elementTypes, offsetInfo); 67 return structType; 68 } 69 70 void addGlobalVar(Type type, llvm::StringRef name) { 71 OpBuilder opBuilder(module->body()); 72 auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); 73 opBuilder.create<spirv::GlobalVariableOp>( 74 UnknownLoc::get(&context), TypeAttr::get(ptrType), 75 opBuilder.getStringAttr(name), nullptr); 76 } 77 78 bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode, 79 ArrayRef<uint32_t> operands)> 80 matchFn) { 81 auto binarySize = binary.size(); 82 auto begin = binary.begin(); 83 auto currOffset = spirv::kHeaderWordCount; 84 85 while (currOffset < binarySize) { 86 auto wordCount = binary[currOffset] >> 16; 87 if (!wordCount || (currOffset + wordCount > binarySize)) { 88 return false; 89 } 90 spirv::Opcode opcode = 91 static_cast<spirv::Opcode>(binary[currOffset] & 0xffff); 92 93 if (matchFn(opcode, 94 llvm::ArrayRef<uint32_t>(begin + currOffset + 1, 95 begin + currOffset + wordCount))) { 96 return true; 97 } 98 currOffset += wordCount; 99 } 100 return false; 101 } 102 103 protected: 104 MLIRContext context; 105 spirv::OwningSPIRVModuleRef module; 106 SmallVector<uint32_t, 0> binary; 107 }; 108 109 //===----------------------------------------------------------------------===// 110 // Block decoration 111 //===----------------------------------------------------------------------===// 112 113 TEST_F(SerializationTest, BlockDecorationTest) { 114 auto structType = getFloatStructType(); 115 addGlobalVar(structType, "var0"); 116 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); 117 auto hasBlockDecoration = [](spirv::Opcode opcode, 118 ArrayRef<uint32_t> operands) -> bool { 119 if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2) 120 return false; 121 return operands[1] == static_cast<uint32_t>(spirv::Decoration::Block); 122 }; 123 EXPECT_TRUE(findInstruction(hasBlockDecoration)); 124 } 125