1 //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains corner case tests for the SPIR-V serializer that are not 10 // covered by normal serialization and deserialization roundtripping. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/SPIRV/Serialization.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Location.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 23 #include "llvm/ADT/DenseSet.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/Sequence.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/ADT/StringRef.h" 28 #include "gmock/gmock.h" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // Test Fixture 34 //===----------------------------------------------------------------------===// 35 36 class SerializationTest : public ::testing::Test { 37 protected: 38 SerializationTest() { 39 context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); 40 createModuleOp(); 41 } 42 43 void createModuleOp() { 44 OpBuilder builder(&context); 45 OperationState state(UnknownLoc::get(&context), 46 spirv::ModuleOp::getOperationName()); 47 state.addAttribute("addressing_model", 48 builder.getI32IntegerAttr(static_cast<uint32_t>( 49 spirv::AddressingModel::Logical))); 50 state.addAttribute("memory_model", 51 builder.getI32IntegerAttr( 52 static_cast<uint32_t>(spirv::MemoryModel::GLSL450))); 53 state.addAttribute("vce_triple", 54 spirv::VerCapExtAttr::get( 55 spirv::Version::V_1_0, ArrayRef<spirv::Capability>(), 56 ArrayRef<spirv::Extension>(), &context)); 57 spirv::ModuleOp::build(builder, state); 58 module = cast<spirv::ModuleOp>(Operation::create(state)); 59 } 60 61 Type getFloatStructType() { 62 OpBuilder opBuilder(module->getRegion()); 63 llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()}; 64 llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0}; 65 auto structType = spirv::StructType::get(elementTypes, offsetInfo); 66 return structType; 67 } 68 69 void addGlobalVar(Type type, llvm::StringRef name) { 70 OpBuilder opBuilder(module->getRegion()); 71 auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); 72 opBuilder.create<spirv::GlobalVariableOp>( 73 UnknownLoc::get(&context), TypeAttr::get(ptrType), 74 opBuilder.getStringAttr(name), nullptr); 75 } 76 77 bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode, 78 ArrayRef<uint32_t> operands)> 79 matchFn) { 80 auto binarySize = binary.size(); 81 auto begin = binary.begin(); 82 auto currOffset = spirv::kHeaderWordCount; 83 84 while (currOffset < binarySize) { 85 auto wordCount = binary[currOffset] >> 16; 86 if (!wordCount || (currOffset + wordCount > binarySize)) { 87 return false; 88 } 89 spirv::Opcode opcode = 90 static_cast<spirv::Opcode>(binary[currOffset] & 0xffff); 91 92 if (matchFn(opcode, 93 llvm::ArrayRef<uint32_t>(begin + currOffset + 1, 94 begin + currOffset + wordCount))) { 95 return true; 96 } 97 currOffset += wordCount; 98 } 99 return false; 100 } 101 102 protected: 103 MLIRContext context; 104 OwningOpRef<spirv::ModuleOp> module; 105 SmallVector<uint32_t, 0> binary; 106 }; 107 108 //===----------------------------------------------------------------------===// 109 // Block decoration 110 //===----------------------------------------------------------------------===// 111 112 TEST_F(SerializationTest, BlockDecorationTest) { 113 auto structType = getFloatStructType(); 114 addGlobalVar(structType, "var0"); 115 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); 116 auto hasBlockDecoration = [](spirv::Opcode opcode, 117 ArrayRef<uint32_t> operands) -> bool { 118 if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2) 119 return false; 120 return operands[1] == static_cast<uint32_t>(spirv::Decoration::Block); 121 }; 122 EXPECT_TRUE(findInstruction(hasBlockDecoration)); 123 } 124