1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass modifies function signatures containing aggregate arguments
10 // and/or return value. Also it substitutes some llvm intrinsic calls by
11 // function calls, generating these functions as the translator does.
12 //
13 // NOTE: this pass is a module-level one due to the necessity to modify
14 // GVs/functions.
15 //
16 //===----------------------------------------------------------------------===//
17
18 #include "SPIRV.h"
19 #include "SPIRVTargetMachine.h"
20 #include "SPIRVUtils.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/Transforms/Utils/Cloning.h"
24 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
25
26 using namespace llvm;
27
28 namespace llvm {
29 void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
30 }
31
32 namespace {
33
34 class SPIRVPrepareFunctions : public ModulePass {
35 Function *processFunctionSignature(Function *F);
36
37 public:
38 static char ID;
SPIRVPrepareFunctions()39 SPIRVPrepareFunctions() : ModulePass(ID) {
40 initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
41 }
42
43 bool runOnModule(Module &M) override;
44
getPassName() const45 StringRef getPassName() const override { return "SPIRV prepare functions"; }
46
getAnalysisUsage(AnalysisUsage & AU) const47 void getAnalysisUsage(AnalysisUsage &AU) const override {
48 ModulePass::getAnalysisUsage(AU);
49 }
50 };
51
52 } // namespace
53
54 char SPIRVPrepareFunctions::ID = 0;
55
56 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
57 "SPIRV prepare functions", false, false)
58
processFunctionSignature(Function * F)59 Function *SPIRVPrepareFunctions::processFunctionSignature(Function *F) {
60 IRBuilder<> B(F->getContext());
61
62 bool IsRetAggr = F->getReturnType()->isAggregateType();
63 bool HasAggrArg =
64 std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
65 return Arg.getType()->isAggregateType();
66 });
67 bool DoClone = IsRetAggr || HasAggrArg;
68 if (!DoClone)
69 return F;
70 SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
71 Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
72 if (IsRetAggr)
73 ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
74 SmallVector<Type *, 4> ArgTypes;
75 for (const auto &Arg : F->args()) {
76 if (Arg.getType()->isAggregateType()) {
77 ArgTypes.push_back(B.getInt32Ty());
78 ChangedTypes.push_back(
79 std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
80 } else
81 ArgTypes.push_back(Arg.getType());
82 }
83 FunctionType *NewFTy =
84 FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
85 Function *NewF =
86 Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
87
88 ValueToValueMapTy VMap;
89 auto NewFArgIt = NewF->arg_begin();
90 for (auto &Arg : F->args()) {
91 StringRef ArgName = Arg.getName();
92 NewFArgIt->setName(ArgName);
93 VMap[&Arg] = &(*NewFArgIt++);
94 }
95 SmallVector<ReturnInst *, 8> Returns;
96
97 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
98 Returns);
99 NewF->takeName(F);
100
101 NamedMDNode *FuncMD =
102 F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
103 SmallVector<Metadata *, 2> MDArgs;
104 MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
105 for (auto &ChangedTyP : ChangedTypes)
106 MDArgs.push_back(MDNode::get(
107 B.getContext(),
108 {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
109 ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
110 MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
111 FuncMD->addOperand(ThisFuncMD);
112
113 for (auto *U : make_early_inc_range(F->users())) {
114 if (auto *CI = dyn_cast<CallInst>(U))
115 CI->mutateFunctionType(NewF->getFunctionType());
116 U->replaceUsesOfWith(F, NewF);
117 }
118 return NewF;
119 }
120
lowerLLVMIntrinsicName(IntrinsicInst * II)121 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
122 Function *IntrinsicFunc = II->getCalledFunction();
123 assert(IntrinsicFunc && "Missing function");
124 std::string FuncName = IntrinsicFunc->getName().str();
125 std::replace(FuncName.begin(), FuncName.end(), '.', '_');
126 FuncName = "spirv." + FuncName;
127 return FuncName;
128 }
129
getOrCreateFunction(Module * M,Type * RetTy,ArrayRef<Type * > ArgTypes,StringRef Name)130 static Function *getOrCreateFunction(Module *M, Type *RetTy,
131 ArrayRef<Type *> ArgTypes,
132 StringRef Name) {
133 FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
134 Function *F = M->getFunction(Name);
135 if (F && F->getFunctionType() == FT)
136 return F;
137 Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
138 if (F)
139 NewF->setDSOLocal(F->isDSOLocal());
140 NewF->setCallingConv(CallingConv::SPIR_FUNC);
141 return NewF;
142 }
143
lowerFunnelShifts(Module * M,IntrinsicInst * FSHIntrinsic)144 static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) {
145 // Get a separate function - otherwise, we'd have to rework the CFG of the
146 // current one. Then simply replace the intrinsic uses with a call to the new
147 // function.
148 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
149 FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
150 Type *FSHRetTy = FSHFuncTy->getReturnType();
151 const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
152 Function *FSHFunc =
153 getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
154
155 if (!FSHFunc->empty()) {
156 FSHIntrinsic->setCalledFunction(FSHFunc);
157 return;
158 }
159 BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
160 IRBuilder<> IRB(RotateBB);
161 Type *Ty = FSHFunc->getReturnType();
162 // Build the actual funnel shift rotate logic.
163 // In the comments, "int" is used interchangeably with "vector of int
164 // elements".
165 FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
166 Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
167 unsigned BitWidth = IntTy->getIntegerBitWidth();
168 ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
169 Value *BitWidthForInsts =
170 VectorTy
171 ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
172 : BitWidthConstant;
173 Value *RotateModVal =
174 IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
175 Value *FirstShift = nullptr, *SecShift = nullptr;
176 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
177 // Shift the less significant number right, the "rotate" number of bits
178 // will be 0-filled on the left as a result of this regular shift.
179 FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
180 } else {
181 // Shift the more significant number left, the "rotate" number of bits
182 // will be 0-filled on the right as a result of this regular shift.
183 FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
184 }
185 // We want the "rotate" number of the more significant int's LSBs (MSBs) to
186 // occupy the leftmost (rightmost) "0 space" left by the previous operation.
187 // Therefore, subtract the "rotate" number from the integer bitsize...
188 Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
189 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
190 // ...and left-shift the more significant int by this number, zero-filling
191 // the LSBs.
192 SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
193 } else {
194 // ...and right-shift the less significant int by this number, zero-filling
195 // the MSBs.
196 SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
197 }
198 // A simple binary addition of the shifted ints yields the final result.
199 IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
200
201 FSHIntrinsic->setCalledFunction(FSHFunc);
202 }
203
buildUMulWithOverflowFunc(Module * M,Function * UMulFunc)204 static void buildUMulWithOverflowFunc(Module *M, Function *UMulFunc) {
205 // The function body is already created.
206 if (!UMulFunc->empty())
207 return;
208
209 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", UMulFunc);
210 IRBuilder<> IRB(EntryBB);
211 // Build the actual unsigned multiplication logic with the overflow
212 // indication. Do unsigned multiplication Mul = A * B. Then check
213 // if unsigned division Div = Mul / A is not equal to B. If so,
214 // then overflow has happened.
215 Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
216 Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
217 Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
218
219 // umul.with.overflow intrinsic return a structure, where the first element
220 // is the multiplication result, and the second is an overflow bit.
221 Type *StructTy = UMulFunc->getReturnType();
222 Value *Agg = IRB.CreateInsertValue(UndefValue::get(StructTy), Mul, {0});
223 Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
224 IRB.CreateRet(Res);
225 }
226
lowerUMulWithOverflow(Module * M,IntrinsicInst * UMulIntrinsic)227 static void lowerUMulWithOverflow(Module *M, IntrinsicInst *UMulIntrinsic) {
228 // Get a separate function - otherwise, we'd have to rework the CFG of the
229 // current one. Then simply replace the intrinsic uses with a call to the new
230 // function.
231 FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
232 Type *FSHLRetTy = UMulFuncTy->getReturnType();
233 const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
234 Function *UMulFunc =
235 getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
236 buildUMulWithOverflowFunc(M, UMulFunc);
237 UMulIntrinsic->setCalledFunction(UMulFunc);
238 }
239
substituteIntrinsicCalls(Module * M,Function * F)240 static void substituteIntrinsicCalls(Module *M, Function *F) {
241 for (BasicBlock &BB : *F) {
242 for (Instruction &I : BB) {
243 auto Call = dyn_cast<CallInst>(&I);
244 if (!Call)
245 continue;
246 Call->setTailCall(false);
247 Function *CF = Call->getCalledFunction();
248 if (!CF || !CF->isIntrinsic())
249 continue;
250 auto *II = cast<IntrinsicInst>(Call);
251 if (II->getIntrinsicID() == Intrinsic::fshl ||
252 II->getIntrinsicID() == Intrinsic::fshr)
253 lowerFunnelShifts(M, II);
254 else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow)
255 lowerUMulWithOverflow(M, II);
256 }
257 }
258 }
259
runOnModule(Module & M)260 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
261 for (Function &F : M)
262 substituteIntrinsicCalls(&M, &F);
263
264 std::vector<Function *> FuncsWorklist;
265 bool Changed = false;
266 for (auto &F : M)
267 FuncsWorklist.push_back(&F);
268
269 for (auto *Func : FuncsWorklist) {
270 Function *F = processFunctionSignature(Func);
271
272 bool CreatedNewF = F != Func;
273
274 if (Func->isDeclaration()) {
275 Changed |= CreatedNewF;
276 continue;
277 }
278
279 if (CreatedNewF)
280 Func->eraseFromParent();
281 }
282
283 return Changed;
284 }
285
createSPIRVPrepareFunctionsPass()286 ModulePass *llvm::createSPIRVPrepareFunctionsPass() {
287 return new SPIRVPrepareFunctions();
288 }
289