1 //===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Performs general IR level optimizations on SVE intrinsics.
10 //
11 // This pass performs the following optimizations:
12 //
13 // - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g:
14 // %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
15 // %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31)
16 // ; (%1 can be replaced with a reinterpret of %2)
17 //
18 // - optimizes ptest intrinsics where the operands are being needlessly
19 // converted to and from svbool_t.
20 //
21 //===----------------------------------------------------------------------===//
22
23 #include "AArch64.h"
24 #include "Utils/AArch64BaseInfo.h"
25 #include "llvm/ADT/PostOrderIterator.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/IntrinsicsAArch64.h"
33 #include "llvm/IR/LLVMContext.h"
34 #include "llvm/IR/PatternMatch.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Support/Debug.h"
37
38 using namespace llvm;
39 using namespace llvm::PatternMatch;
40
41 #define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
42
43 namespace {
44 struct SVEIntrinsicOpts : public ModulePass {
45 static char ID; // Pass identification, replacement for typeid
SVEIntrinsicOpts__anon1b05a4550111::SVEIntrinsicOpts46 SVEIntrinsicOpts() : ModulePass(ID) {
47 initializeSVEIntrinsicOptsPass(*PassRegistry::getPassRegistry());
48 }
49
50 bool runOnModule(Module &M) override;
51 void getAnalysisUsage(AnalysisUsage &AU) const override;
52
53 private:
54 bool coalescePTrueIntrinsicCalls(BasicBlock &BB,
55 SmallSetVector<IntrinsicInst *, 4> &PTrues);
56 bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
57 bool optimizePredicateStore(Instruction *I);
58 bool optimizePredicateLoad(Instruction *I);
59
60 bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions);
61
62 /// Operates at the function-scope. I.e., optimizations are applied local to
63 /// the functions themselves.
64 bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions);
65 };
66 } // end anonymous namespace
67
getAnalysisUsage(AnalysisUsage & AU) const68 void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const {
69 AU.addRequired<DominatorTreeWrapperPass>();
70 AU.setPreservesCFG();
71 }
72
73 char SVEIntrinsicOpts::ID = 0;
74 static const char *name = "SVE intrinsics optimizations";
75 INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
76 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
INITIALIZE_PASS_END(SVEIntrinsicOpts,DEBUG_TYPE,name,false,false)77 INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
78
79 ModulePass *llvm::createSVEIntrinsicOptsPass() {
80 return new SVEIntrinsicOpts();
81 }
82
83 /// Checks if a ptrue intrinsic call is promoted. The act of promoting a
84 /// ptrue will introduce zeroing. For example:
85 ///
86 /// %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
87 /// %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1)
88 /// %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
89 ///
90 /// %1 is promoted, because it is converted:
91 ///
92 /// <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1>
93 ///
94 /// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool.
isPTruePromoted(IntrinsicInst * PTrue)95 static bool isPTruePromoted(IntrinsicInst *PTrue) {
96 // Find all users of this intrinsic that are calls to convert-to-svbool
97 // reinterpret intrinsics.
98 SmallVector<IntrinsicInst *, 4> ConvertToUses;
99 for (User *User : PTrue->users()) {
100 if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
101 ConvertToUses.push_back(cast<IntrinsicInst>(User));
102 }
103 }
104
105 // If no such calls were found, this is ptrue is not promoted.
106 if (ConvertToUses.empty())
107 return false;
108
109 // Otherwise, try to find users of the convert-to-svbool intrinsics that are
110 // calls to the convert-from-svbool intrinsic, and would result in some lanes
111 // being zeroed.
112 const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->getType());
113 for (IntrinsicInst *ConvertToUse : ConvertToUses) {
114 for (User *User : ConvertToUse->users()) {
115 auto *IntrUser = dyn_cast<IntrinsicInst>(User);
116 if (IntrUser && IntrUser->getIntrinsicID() ==
117 Intrinsic::aarch64_sve_convert_from_svbool) {
118 const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
119
120 // Would some lanes become zeroed by the conversion?
121 if (IntrUserVTy->getElementCount().getKnownMinValue() >
122 PTrueVTy->getElementCount().getKnownMinValue())
123 // This is a promoted ptrue.
124 return true;
125 }
126 }
127 }
128
129 // If no matching calls were found, this is not a promoted ptrue.
130 return false;
131 }
132
133 /// Attempts to coalesce ptrues in a basic block.
coalescePTrueIntrinsicCalls(BasicBlock & BB,SmallSetVector<IntrinsicInst *,4> & PTrues)134 bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
135 BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) {
136 if (PTrues.size() <= 1)
137 return false;
138
139 // Find the ptrue with the most lanes.
140 auto *MostEncompassingPTrue = *std::max_element(
141 PTrues.begin(), PTrues.end(), [](auto *PTrue1, auto *PTrue2) {
142 auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
143 auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
144 return PTrue1VTy->getElementCount().getKnownMinValue() <
145 PTrue2VTy->getElementCount().getKnownMinValue();
146 });
147
148 // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving
149 // behind only the ptrues to be coalesced.
150 PTrues.remove(MostEncompassingPTrue);
151 PTrues.remove_if(isPTruePromoted);
152
153 // Hoist MostEncompassingPTrue to the start of the basic block. It is always
154 // safe to do this, since ptrue intrinsic calls are guaranteed to have no
155 // predecessors.
156 MostEncompassingPTrue->moveBefore(BB, BB.getFirstInsertionPt());
157
158 LLVMContext &Ctx = BB.getContext();
159 IRBuilder<> Builder(Ctx);
160 Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
161
162 auto *MostEncompassingPTrueVTy =
163 cast<VectorType>(MostEncompassingPTrue->getType());
164 auto *ConvertToSVBool = Builder.CreateIntrinsic(
165 Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
166 {MostEncompassingPTrue});
167
168 bool ConvertFromCreated = false;
169 for (auto *PTrue : PTrues) {
170 auto *PTrueVTy = cast<VectorType>(PTrue->getType());
171
172 // Only create the converts if the types are not already the same, otherwise
173 // just use the most encompassing ptrue.
174 if (MostEncompassingPTrueVTy != PTrueVTy) {
175 ConvertFromCreated = true;
176
177 Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
178 auto *ConvertFromSVBool =
179 Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
180 {PTrueVTy}, {ConvertToSVBool});
181 PTrue->replaceAllUsesWith(ConvertFromSVBool);
182 } else
183 PTrue->replaceAllUsesWith(MostEncompassingPTrue);
184
185 PTrue->eraseFromParent();
186 }
187
188 // We never used the ConvertTo so remove it
189 if (!ConvertFromCreated)
190 ConvertToSVBool->eraseFromParent();
191
192 return true;
193 }
194
195 /// The goal of this function is to remove redundant calls to the SVE ptrue
196 /// intrinsic in each basic block within the given functions.
197 ///
198 /// SVE ptrues have two representations in LLVM IR:
199 /// - a logical representation -- an arbitrary-width scalable vector of i1s,
200 /// i.e. <vscale x N x i1>.
201 /// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element
202 /// scalable vector of i1s, i.e. <vscale x 16 x i1>.
203 ///
204 /// The SVE ptrue intrinsic is used to create a logical representation of an SVE
205 /// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If
206 /// P1 creates a logical SVE predicate that is at least as wide as the logical
207 /// SVE predicate created by P2, then all of the bits that are true in the
208 /// physical representation of P2 are necessarily also true in the physical
209 /// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to
210 /// P2 is redundant and can be replaced by an SVE reinterpret of P1 via
211 /// convert.{to,from}.svbool.
212 ///
213 /// Currently, this pass only coalesces calls to SVE ptrue intrinsics
214 /// if they match the following conditions:
215 ///
216 /// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns.
217 /// SV_ALL indicates that all bits of the predicate vector are to be set to
218 /// true. SV_POW2 indicates that all bits of the predicate vector up to the
219 /// largest power-of-two are to be set to true.
220 /// - the result of the call to the intrinsic is not promoted to a wider
221 /// predicate. In this case, keeping the extra ptrue leads to better codegen
222 /// -- coalescing here would create an irreducible chain of SVE reinterprets
223 /// via convert.{to,from}.svbool.
224 ///
225 /// EXAMPLE:
226 ///
227 /// %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL)
228 /// ; Logical: <1, 1, 1, 1, 1, 1, 1, 1>
229 /// ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0>
230 /// ...
231 ///
232 /// %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL)
233 /// ; Logical: <1, 1, 1, 1>
234 /// ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0>
235 /// ...
236 ///
237 /// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance:
238 ///
239 /// %1 = <vscale x 8 x i1> ptrue(i32 i31)
240 /// %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1)
241 /// %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2)
242 ///
optimizePTrueIntrinsicCalls(SmallSetVector<Function *,4> & Functions)243 bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
244 SmallSetVector<Function *, 4> &Functions) {
245 bool Changed = false;
246
247 for (auto *F : Functions) {
248 for (auto &BB : *F) {
249 SmallSetVector<IntrinsicInst *, 4> SVAllPTrues;
250 SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues;
251
252 // For each basic block, collect the used ptrues and try to coalesce them.
253 for (Instruction &I : BB) {
254 if (I.use_empty())
255 continue;
256
257 auto *IntrI = dyn_cast<IntrinsicInst>(&I);
258 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
259 continue;
260
261 const auto PTruePattern =
262 cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
263
264 if (PTruePattern == AArch64SVEPredPattern::all)
265 SVAllPTrues.insert(IntrI);
266 if (PTruePattern == AArch64SVEPredPattern::pow2)
267 SVPow2PTrues.insert(IntrI);
268 }
269
270 Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
271 Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
272 }
273 }
274
275 return Changed;
276 }
277
278 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
279 // scalable stores as late as possible
optimizePredicateStore(Instruction * I)280 bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) {
281 auto *F = I->getFunction();
282 auto Attr = F->getFnAttribute(Attribute::VScaleRange);
283 if (!Attr.isValid())
284 return false;
285
286 unsigned MinVScale = Attr.getVScaleRangeMin();
287 Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
288 // The transform needs to know the exact runtime length of scalable vectors
289 if (!MaxVScale || MinVScale != MaxVScale)
290 return false;
291
292 auto *PredType =
293 ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
294 auto *FixedPredType =
295 FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
296
297 // If we have a store..
298 auto *Store = dyn_cast<StoreInst>(I);
299 if (!Store || !Store->isSimple())
300 return false;
301
302 // ..that is storing a predicate vector sized worth of bits..
303 if (Store->getOperand(0)->getType() != FixedPredType)
304 return false;
305
306 // ..where the value stored comes from a vector extract..
307 auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand(0));
308 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract)
309 return false;
310
311 // ..that is extracting from index 0..
312 if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
313 return false;
314
315 // ..where the value being extract from comes from a bitcast
316 auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
317 if (!BitCast)
318 return false;
319
320 // ..and the bitcast is casting from predicate type
321 if (BitCast->getOperand(0)->getType() != PredType)
322 return false;
323
324 IRBuilder<> Builder(I->getContext());
325 Builder.SetInsertPoint(I);
326
327 auto *PtrBitCast = Builder.CreateBitCast(
328 Store->getPointerOperand(),
329 PredType->getPointerTo(Store->getPointerAddressSpace()));
330 Builder.CreateStore(BitCast->getOperand(0), PtrBitCast);
331
332 Store->eraseFromParent();
333 if (IntrI->getNumUses() == 0)
334 IntrI->eraseFromParent();
335 if (BitCast->getNumUses() == 0)
336 BitCast->eraseFromParent();
337
338 return true;
339 }
340
341 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
342 // scalable loads as late as possible
optimizePredicateLoad(Instruction * I)343 bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) {
344 auto *F = I->getFunction();
345 auto Attr = F->getFnAttribute(Attribute::VScaleRange);
346 if (!Attr.isValid())
347 return false;
348
349 unsigned MinVScale = Attr.getVScaleRangeMin();
350 Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
351 // The transform needs to know the exact runtime length of scalable vectors
352 if (!MaxVScale || MinVScale != MaxVScale)
353 return false;
354
355 auto *PredType =
356 ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
357 auto *FixedPredType =
358 FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
359
360 // If we have a bitcast..
361 auto *BitCast = dyn_cast<BitCastInst>(I);
362 if (!BitCast || BitCast->getType() != PredType)
363 return false;
364
365 // ..whose operand is a vector_insert..
366 auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
367 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert)
368 return false;
369
370 // ..that is inserting into index zero of an undef vector..
371 if (!isa<UndefValue>(IntrI->getOperand(0)) ||
372 !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
373 return false;
374
375 // ..where the value inserted comes from a load..
376 auto *Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
377 if (!Load || !Load->isSimple())
378 return false;
379
380 // ..that is loading a predicate vector sized worth of bits..
381 if (Load->getType() != FixedPredType)
382 return false;
383
384 IRBuilder<> Builder(I->getContext());
385 Builder.SetInsertPoint(Load);
386
387 auto *PtrBitCast = Builder.CreateBitCast(
388 Load->getPointerOperand(),
389 PredType->getPointerTo(Load->getPointerAddressSpace()));
390 auto *LoadPred = Builder.CreateLoad(PredType, PtrBitCast);
391
392 BitCast->replaceAllUsesWith(LoadPred);
393 BitCast->eraseFromParent();
394 if (IntrI->getNumUses() == 0)
395 IntrI->eraseFromParent();
396 if (Load->getNumUses() == 0)
397 Load->eraseFromParent();
398
399 return true;
400 }
401
optimizeInstructions(SmallSetVector<Function *,4> & Functions)402 bool SVEIntrinsicOpts::optimizeInstructions(
403 SmallSetVector<Function *, 4> &Functions) {
404 bool Changed = false;
405
406 for (auto *F : Functions) {
407 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
408
409 // Traverse the DT with an rpo walk so we see defs before uses, allowing
410 // simplification to be done incrementally.
411 BasicBlock *Root = DT->getRoot();
412 ReversePostOrderTraversal<BasicBlock *> RPOT(Root);
413 for (auto *BB : RPOT) {
414 for (Instruction &I : make_early_inc_range(*BB)) {
415 switch (I.getOpcode()) {
416 case Instruction::Store:
417 Changed |= optimizePredicateStore(&I);
418 break;
419 case Instruction::BitCast:
420 Changed |= optimizePredicateLoad(&I);
421 break;
422 }
423 }
424 }
425 }
426
427 return Changed;
428 }
429
optimizeFunctions(SmallSetVector<Function *,4> & Functions)430 bool SVEIntrinsicOpts::optimizeFunctions(
431 SmallSetVector<Function *, 4> &Functions) {
432 bool Changed = false;
433
434 Changed |= optimizePTrueIntrinsicCalls(Functions);
435 Changed |= optimizeInstructions(Functions);
436
437 return Changed;
438 }
439
runOnModule(Module & M)440 bool SVEIntrinsicOpts::runOnModule(Module &M) {
441 bool Changed = false;
442 SmallSetVector<Function *, 4> Functions;
443
444 // Check for SVE intrinsic declarations first so that we only iterate over
445 // relevant functions. Where an appropriate declaration is found, store the
446 // function(s) where it is used so we can target these only.
447 for (auto &F : M.getFunctionList()) {
448 if (!F.isDeclaration())
449 continue;
450
451 switch (F.getIntrinsicID()) {
452 case Intrinsic::vector_extract:
453 case Intrinsic::vector_insert:
454 case Intrinsic::aarch64_sve_ptrue:
455 for (User *U : F.users())
456 Functions.insert(cast<Instruction>(U)->getFunction());
457 break;
458 default:
459 break;
460 }
461 }
462
463 if (!Functions.empty())
464 Changed |= optimizeFunctions(Functions);
465
466 return Changed;
467 }
468