1 //===- llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DXILPointerType.h"
10 #include "PointerTypeAnalysis.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/IR/Instructions.h"
13 #include "llvm/IR/LLVMContext.h"
14 #include "llvm/IR/Type.h"
15 #include "llvm/Support/SourceMgr.h"
16 
17 #include "gmock/gmock.h"
18 #include "gtest/gtest.h"
19 
20 using ::testing::Contains;
21 using ::testing::Pair;
22 
23 using namespace llvm;
24 using namespace llvm::dxil;
25 
26 template <typename T> struct IsA {
operator ==(const Value * V,const IsA &)27   friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
28 };
29 
TEST(DXILPointerType,PrintTest)30 TEST(DXILPointerType, PrintTest) {
31   std::string Buffer;
32   LLVMContext Context;
33   raw_string_ostream OS(Buffer);
34 
35   Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
36   I8Ptr->print(OS);
37   EXPECT_TRUE(StringRef(Buffer).startswith("dxil-ptr ("));
38 }
39 
TEST(PointerTypeAnalysis,DigressToi8)40 TEST(PointerTypeAnalysis, DigressToi8) {
41   StringRef Assembly = R"(
42     define i64 @test(ptr %p) {
43       store i32 0, ptr %p
44       %v = load i64, ptr %p
45       ret i64 %v
46     }
47   )";
48 
49   LLVMContext Context;
50   SMDiagnostic Error;
51   auto M = parseAssemblyString(Assembly, Error, Context);
52   ASSERT_TRUE(M) << "Bad assembly?";
53 
54   PointerTypeMap Map = PointerTypeAnalysis::run(*M);
55   ASSERT_EQ(Map.size(), 2u);
56   Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
57   Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
58 
59   EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
60   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
61 }
62 
TEST(PointerTypeAnalysis,DiscoverStore)63 TEST(PointerTypeAnalysis, DiscoverStore) {
64   StringRef Assembly = R"(
65     define i32 @test(ptr %p) {
66       store i32 0, ptr %p
67       ret i32 0
68     }
69   )";
70 
71   LLVMContext Context;
72   SMDiagnostic Error;
73   auto M = parseAssemblyString(Assembly, Error, Context);
74   ASSERT_TRUE(M) << "Bad assembly?";
75 
76   PointerTypeMap Map = PointerTypeAnalysis::run(*M);
77   ASSERT_EQ(Map.size(), 2u);
78   Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
79   Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
80 
81   EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
82   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
83 }
84 
TEST(PointerTypeAnalysis,DiscoverLoad)85 TEST(PointerTypeAnalysis, DiscoverLoad) {
86   StringRef Assembly = R"(
87     define i32 @test(ptr %p) {
88       %v = load i32, ptr %p
89       ret i32 %v
90     }
91   )";
92 
93   LLVMContext Context;
94   SMDiagnostic Error;
95   auto M = parseAssemblyString(Assembly, Error, Context);
96   ASSERT_TRUE(M) << "Bad assembly?";
97 
98   PointerTypeMap Map = PointerTypeAnalysis::run(*M);
99   ASSERT_EQ(Map.size(), 2u);
100   Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
101   Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false);
102 
103   EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
104   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I32Ptr)));
105 }
106 
TEST(PointerTypeAnalysis,DiscoverGEP)107 TEST(PointerTypeAnalysis, DiscoverGEP) {
108   StringRef Assembly = R"(
109     define ptr @test(ptr %p) {
110       %p2 = getelementptr i64, ptr %p, i64 1
111       ret ptr %p2
112     }
113   )";
114 
115   LLVMContext Context;
116   SMDiagnostic Error;
117   auto M = parseAssemblyString(Assembly, Error, Context);
118   ASSERT_TRUE(M) << "Bad assembly?";
119 
120   PointerTypeMap Map = PointerTypeAnalysis::run(*M);
121   ASSERT_EQ(Map.size(), 3u);
122 
123   Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
124   Type *FnTy = FunctionType::get(I64Ptr, {I64Ptr}, false);
125 
126   EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
127   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64Ptr)));
128   EXPECT_THAT(Map, Contains(Pair(IsA<GetElementPtrInst>(), I64Ptr)));
129 }
130 
TEST(PointerTypeAnalysis,TraceIndirect)131 TEST(PointerTypeAnalysis, TraceIndirect) {
132   StringRef Assembly = R"(
133     define i64 @test(ptr %p) {
134       %p2 = load ptr, ptr %p
135       %v = load i64, ptr %p2
136       ret i64 %v
137     }
138   )";
139 
140   LLVMContext Context;
141   SMDiagnostic Error;
142   auto M = parseAssemblyString(Assembly, Error, Context);
143   ASSERT_TRUE(M) << "Bad assembly?";
144 
145   PointerTypeMap Map = PointerTypeAnalysis::run(*M);
146   ASSERT_EQ(Map.size(), 3u);
147 
148   Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
149   Type *I64PtrPtr = TypedPointerType::get(I64Ptr, 0);
150   Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I64PtrPtr}, false);
151 
152   EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
153   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I64PtrPtr)));
154   EXPECT_THAT(Map, Contains(Pair(IsA<LoadInst>(), I64Ptr)));
155 }
156 
TEST(PointerTypeAnalysis,WithNoOpCasts)157 TEST(PointerTypeAnalysis, WithNoOpCasts) {
158   StringRef Assembly = R"(
159     define i64 @test(ptr %p) {
160       %1 = bitcast ptr %p to ptr
161       %2 = bitcast ptr %p to ptr
162       store i32 0, ptr %1, align 4
163       %3 = load i64, ptr %2, align 8
164       ret i64 %3
165     }
166   )";
167 
168   LLVMContext Context;
169   SMDiagnostic Error;
170   auto M = parseAssemblyString(Assembly, Error, Context);
171   ASSERT_TRUE(M) << "Bad assembly?";
172 
173   PointerTypeMap Map = PointerTypeAnalysis::run(*M);
174   ASSERT_EQ(Map.size(), 4u);
175 
176   Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0);
177   Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0);
178   Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0);
179   Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false);
180 
181   EXPECT_THAT(Map, Contains(Pair(IsA<Function>(), FnTy)));
182   EXPECT_THAT(Map, Contains(Pair(IsA<Argument>(), I8Ptr)));
183   EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I64Ptr)));
184   EXPECT_THAT(Map, Contains(Pair(IsA<BitCastInst>(), I32Ptr)));
185 }
186