1 //===- llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "DXILPointerType.h"
10 #include "PointerTypeAnalysis.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/IR/Instructions.h"
13 #include "llvm/IR/LLVMContext.h"
14 #include "llvm/IR/Type.h"
15 #include "llvm/Support/SourceMgr.h"
16
17 #include "gmock/gmock.h"
18 #include "gtest/gtest.h"
19
20 using ::testing::Contains;
21 using ::testing::Pair;
22
23 using namespace llvm;
24 using namespace llvm::dxil;
25
26 template <typename T> struct IsA {
operator ==(const Value * V,const IsA &)27 friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
28 };
29
TEST(DXILPointerType,PrintTest)30 TEST(DXILPointerType, PrintTest) {
31 std::string Buffer;
32 LLVMContext Context;
33 raw_string_ostream OS(Buffer);
34
35 Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
36 I8Ptr->print(OS);
37 EXPECT_TRUE(StringRef(Buffer).startswith("dxil-ptr ("));
38 }
39
TEST(PointerTypeAnalysis,DigressToi8)40 TEST(PointerTypeAnalysis, DigressToi8) {
41 StringRef Assembly = R"(
42 define i64 @test(ptr %p) {
43 store i32 0, ptr %p
44 %v = load i64, ptr %p
45 ret i64 %v
46 }
47 )";
48
49 LLVMContext Context;
50 SMDiagnostic Error;
51 auto M = parseAssemblyString(Assembly, Error, Context);
52 ASSERT_TRUE(M) << "Bad assembly?";
53
54 PointerTypeMap Map = PointerTypeAnalysis::run(*M);
55 ASSERT_EQ(Map.size(), 2u);
56 Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
57 Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
58
59 EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
60 EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
61 }
62
TEST(PointerTypeAnalysis,DiscoverStore)63 TEST(PointerTypeAnalysis, DiscoverStore) {
64 StringRef Assembly = R"(
65 define i32 @test(ptr %p) {
66 store i32 0, ptr %p
67 ret i32 0
68 }
69 )";
70
71 LLVMContext Context;
72 SMDiagnostic Error;
73 auto M = parseAssemblyString(Assembly, Error, Context);
74 ASSERT_TRUE(M) << "Bad assembly?";
75
76 PointerTypeMap Map = PointerTypeAnalysis::run(*M);
77 ASSERT_EQ(Map.size(), 2u);
78 Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
79 Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
80
81 EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
82 EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
83 }
84
TEST(PointerTypeAnalysis,DiscoverLoad)85 TEST(PointerTypeAnalysis, DiscoverLoad) {
86 StringRef Assembly = R"(
87 define i32 @test(ptr %p) {
88 %v = load i32, ptr %p
89 ret i32 %v
90 }
91 )";
92
93 LLVMContext Context;
94 SMDiagnostic Error;
95 auto M = parseAssemblyString(Assembly, Error, Context);
96 ASSERT_TRUE(M) << "Bad assembly?";
97
98 PointerTypeMap Map = PointerTypeAnalysis::run(*M);
99 ASSERT_EQ(Map.size(), 2u);
100 Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
101 Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
102
103 EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
104 EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
105 }
106
TEST(PointerTypeAnalysis,DiscoverGEP)107 TEST(PointerTypeAnalysis, DiscoverGEP) {
108 StringRef Assembly = R"(
109 define ptr @test(ptr %p) {
110 %p2 = getelementptr i64, ptr %p, i64 1
111 ret ptr %p2
112 }
113 )";
114
115 LLVMContext Context;
116 SMDiagnostic Error;
117 auto M = parseAssemblyString(Assembly, Error, Context);
118 ASSERT_TRUE(M) << "Bad assembly?";
119
120 PointerTypeMap Map = PointerTypeAnalysis::run(*M);
121 ASSERT_EQ(Map.size(), 3u);
122
123 Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
124 Type *FnTy = FunctionType::get(I64Ptr, {I64Ptr}, false);
125
126 EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
127 EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64Ptr)));
128 EXPECT_THAT(Map, Contains(Pair(IsA<GetElementPtrInst>(), I64Ptr)));
129 }
130
TEST(PointerTypeAnalysis,TraceIndirect)131 TEST(PointerTypeAnalysis, TraceIndirect) {
132 StringRef Assembly = R"(
133 define i64 @test(ptr %p) {
134 %p2 = load ptr, ptr %p
135 %v = load i64, ptr %p2
136 ret i64 %v
137 }
138 )";
139
140 LLVMContext Context;
141 SMDiagnostic Error;
142 auto M = parseAssemblyString(Assembly, Error, Context);
143 ASSERT_TRUE(M) << "Bad assembly?";
144
145 PointerTypeMap Map = PointerTypeAnalysis::run(*M);
146 ASSERT_EQ(Map.size(), 3u);
147
148 Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
149 Type *I64PtrPtr = TypedPointerType::get(I64Ptr, 0);
150 Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I64PtrPtr}, false);
151
152 EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
153 EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64PtrPtr)));
154 EXPECT_THAT(Map, Contains(Pair(IsA<LoadInst>(), I64Ptr)));
155 }
156
TEST(PointerTypeAnalysis,WithNoOpCasts)157 TEST(PointerTypeAnalysis, WithNoOpCasts) {
158 StringRef Assembly = R"(
159 define i64 @test(ptr %p) {
160 %1 = bitcast ptr %p to ptr
161 %2 = bitcast ptr %p to ptr
162 store i32 0, ptr %1, align 4
163 %3 = load i64, ptr %2, align 8
164 ret i64 %3
165 }
166 )";
167
168 LLVMContext Context;
169 SMDiagnostic Error;
170 auto M = parseAssemblyString(Assembly, Error, Context);
171 ASSERT_TRUE(M) << "Bad assembly?";
172
173 PointerTypeMap Map = PointerTypeAnalysis::run(*M);
174 ASSERT_EQ(Map.size(), 4u);
175
176 Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
177 Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
178 Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
179 Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
180
181 EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
182 EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
183 EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I64Ptr)));
184 EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I32Ptr)));
185 }
186