1 //===----- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements IR expansion for vector predication intrinsics, allowing
10 // targets to enable vector predication until just before codegen.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/CodeGen/ExpandVectorPredication.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/TargetTransformInfo.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/Analysis/VectorUtils.h"
19 #include "llvm/CodeGen/Passes.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/InstIterator.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/IntrinsicInst.h"
26 #include "llvm/IR/Intrinsics.h"
27 #include "llvm/InitializePasses.h"
28 #include "llvm/Pass.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Compiler.h"
31 #include "llvm/Support/Debug.h"
32
33 using namespace llvm;
34
35 using VPLegalization = TargetTransformInfo::VPLegalization;
36 using VPTransform = TargetTransformInfo::VPLegalization::VPTransform;
37
38 // Keep this in sync with TargetTransformInfo::VPLegalization.
39 #define VPINTERNAL_VPLEGAL_CASES \
40 VPINTERNAL_CASE(Legal) \
41 VPINTERNAL_CASE(Discard) \
42 VPINTERNAL_CASE(Convert)
43
44 #define VPINTERNAL_CASE(X) "|" #X
45
46 // Override options.
47 static cl::opt<std::string> EVLTransformOverride(
48 "expandvp-override-evl-transform", cl::init(""), cl::Hidden,
49 cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES
50 ". If non-empty, ignore "
51 "TargetTransformInfo and "
52 "always use this transformation for the %evl parameter (Used in "
53 "testing)."));
54
55 static cl::opt<std::string> MaskTransformOverride(
56 "expandvp-override-mask-transform", cl::init(""), cl::Hidden,
57 cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES
58 ". If non-empty, Ignore "
59 "TargetTransformInfo and "
60 "always use this transformation for the %mask parameter (Used in "
61 "testing)."));
62
63 #undef VPINTERNAL_CASE
64 #define VPINTERNAL_CASE(X) .Case(#X, VPLegalization::X)
65
parseOverrideOption(const std::string & TextOpt)66 static VPTransform parseOverrideOption(const std::string &TextOpt) {
67 return StringSwitch<VPTransform>(TextOpt) VPINTERNAL_VPLEGAL_CASES;
68 }
69
70 #undef VPINTERNAL_VPLEGAL_CASES
71
72 // Whether any override options are set.
anyExpandVPOverridesSet()73 static bool anyExpandVPOverridesSet() {
74 return !EVLTransformOverride.empty() || !MaskTransformOverride.empty();
75 }
76
77 #define DEBUG_TYPE "expandvp"
78
79 STATISTIC(NumFoldedVL, "Number of folded vector length params");
80 STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations");
81
82 ///// Helpers {
83
84 /// \returns Whether the vector mask \p MaskVal has all lane bits set.
isAllTrueMask(Value * MaskVal)85 static bool isAllTrueMask(Value *MaskVal) {
86 if (Value *SplattedVal = getSplatValue(MaskVal))
87 if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
88 return ConstValue->isAllOnesValue();
89
90 return false;
91 }
92
93 /// \returns A non-excepting divisor constant for this type.
getSafeDivisor(Type * DivTy)94 static Constant *getSafeDivisor(Type *DivTy) {
95 assert(DivTy->isIntOrIntVectorTy() && "Unsupported divisor type");
96 return ConstantInt::get(DivTy, 1u, false);
97 }
98
99 /// Transfer operation properties from \p OldVPI to \p NewVal.
transferDecorations(Value & NewVal,VPIntrinsic & VPI)100 static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) {
101 auto *NewInst = dyn_cast<Instruction>(&NewVal);
102 if (!NewInst || !isa<FPMathOperator>(NewVal))
103 return;
104
105 auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI);
106 if (!OldFMOp)
107 return;
108
109 NewInst->setFastMathFlags(OldFMOp->getFastMathFlags());
110 }
111
112 /// Transfer all properties from \p OldOp to \p NewOp and replace all uses.
113 /// OldVP gets erased.
replaceOperation(Value & NewOp,VPIntrinsic & OldOp)114 static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) {
115 transferDecorations(NewOp, OldOp);
116 OldOp.replaceAllUsesWith(&NewOp);
117 OldOp.eraseFromParent();
118 }
119
maySpeculateLanes(VPIntrinsic & VPI)120 static bool maySpeculateLanes(VPIntrinsic &VPI) {
121 // The result of VP reductions depends on the mask and evl.
122 if (isa<VPReductionIntrinsic>(VPI))
123 return false;
124 // Fallback to whether the intrinsic is speculatable.
125 Optional<unsigned> OpcOpt = VPI.getFunctionalOpcode();
126 unsigned FunctionalOpc = OpcOpt.value_or((unsigned)Instruction::Call);
127 return isSafeToSpeculativelyExecuteWithOpcode(FunctionalOpc, &VPI);
128 }
129
130 //// } Helpers
131
132 namespace {
133
134 // Expansion pass state at function scope.
135 struct CachingVPExpander {
136 Function &F;
137 const TargetTransformInfo &TTI;
138
139 /// \returns A (fixed length) vector with ascending integer indices
140 /// (<0, 1, ..., NumElems-1>).
141 /// \p Builder
142 /// Used for instruction creation.
143 /// \p LaneTy
144 /// Integer element type of the result vector.
145 /// \p NumElems
146 /// Number of vector elements.
147 Value *createStepVector(IRBuilder<> &Builder, Type *LaneTy,
148 unsigned NumElems);
149
150 /// \returns A bitmask that is true where the lane position is less-than \p
151 /// EVLParam
152 ///
153 /// \p Builder
154 /// Used for instruction creation.
155 /// \p VLParam
156 /// The explicit vector length parameter to test against the lane
157 /// positions.
158 /// \p ElemCount
159 /// Static (potentially scalable) number of vector elements.
160 Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam,
161 ElementCount ElemCount);
162
163 Value *foldEVLIntoMask(VPIntrinsic &VPI);
164
165 /// "Remove" the %evl parameter of \p PI by setting it to the static vector
166 /// length of the operation.
167 void discardEVLParameter(VPIntrinsic &PI);
168
169 /// \brief Lower this VP binary operator to a unpredicated binary operator.
170 Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder,
171 VPIntrinsic &PI);
172
173 /// \brief Lower this VP reduction to a call to an unpredicated reduction
174 /// intrinsic.
175 Value *expandPredicationInReduction(IRBuilder<> &Builder,
176 VPReductionIntrinsic &PI);
177
178 /// \brief Lower this VP memory operation to a non-VP intrinsic.
179 Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder,
180 VPIntrinsic &VPI);
181
182 /// \brief Query TTI and expand the vector predication in \p P accordingly.
183 Value *expandPredication(VPIntrinsic &PI);
184
185 /// \brief Determine how and whether the VPIntrinsic \p VPI shall be
186 /// expanded. This overrides TTI with the cl::opts listed at the top of this
187 /// file.
188 VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const;
189 bool UsingTTIOverrides;
190
191 public:
CachingVPExpander__anona96194c90111::CachingVPExpander192 CachingVPExpander(Function &F, const TargetTransformInfo &TTI)
193 : F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {}
194
195 bool expandVectorPredication();
196 };
197
198 //// CachingVPExpander {
199
createStepVector(IRBuilder<> & Builder,Type * LaneTy,unsigned NumElems)200 Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, Type *LaneTy,
201 unsigned NumElems) {
202 // TODO add caching
203 SmallVector<Constant *, 16> ConstElems;
204
205 for (unsigned Idx = 0; Idx < NumElems; ++Idx)
206 ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false));
207
208 return ConstantVector::get(ConstElems);
209 }
210
convertEVLToMask(IRBuilder<> & Builder,Value * EVLParam,ElementCount ElemCount)211 Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder,
212 Value *EVLParam,
213 ElementCount ElemCount) {
214 // TODO add caching
215 // Scalable vector %evl conversion.
216 if (ElemCount.isScalable()) {
217 auto *M = Builder.GetInsertBlock()->getModule();
218 Type *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount);
219 Function *ActiveMaskFunc = Intrinsic::getDeclaration(
220 M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()});
221 // `get_active_lane_mask` performs an implicit less-than comparison.
222 Value *ConstZero = Builder.getInt32(0);
223 return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam});
224 }
225
226 // Fixed vector %evl conversion.
227 Type *LaneTy = EVLParam->getType();
228 unsigned NumElems = ElemCount.getFixedValue();
229 Value *VLSplat = Builder.CreateVectorSplat(NumElems, EVLParam);
230 Value *IdxVec = createStepVector(Builder, LaneTy, NumElems);
231 return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat);
232 }
233
234 Value *
expandPredicationInBinaryOperator(IRBuilder<> & Builder,VPIntrinsic & VPI)235 CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder,
236 VPIntrinsic &VPI) {
237 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&
238 "Implicitly dropping %evl in non-speculatable operator!");
239
240 auto OC = static_cast<Instruction::BinaryOps>(*VPI.getFunctionalOpcode());
241 assert(Instruction::isBinaryOp(OC));
242
243 Value *Op0 = VPI.getOperand(0);
244 Value *Op1 = VPI.getOperand(1);
245 Value *Mask = VPI.getMaskParam();
246
247 // Blend in safe operands.
248 if (Mask && !isAllTrueMask(Mask)) {
249 switch (OC) {
250 default:
251 // Can safely ignore the predicate.
252 break;
253
254 // Division operators need a safe divisor on masked-off lanes (1).
255 case Instruction::UDiv:
256 case Instruction::SDiv:
257 case Instruction::URem:
258 case Instruction::SRem:
259 // 2nd operand must not be zero.
260 Value *SafeDivisor = getSafeDivisor(VPI.getType());
261 Op1 = Builder.CreateSelect(Mask, Op1, SafeDivisor);
262 }
263 }
264
265 Value *NewBinOp = Builder.CreateBinOp(OC, Op0, Op1, VPI.getName());
266
267 replaceOperation(*NewBinOp, VPI);
268 return NewBinOp;
269 }
270
getNeutralReductionElement(const VPReductionIntrinsic & VPI,Type * EltTy)271 static Value *getNeutralReductionElement(const VPReductionIntrinsic &VPI,
272 Type *EltTy) {
273 bool Negative = false;
274 unsigned EltBits = EltTy->getScalarSizeInBits();
275 switch (VPI.getIntrinsicID()) {
276 default:
277 llvm_unreachable("Expecting a VP reduction intrinsic");
278 case Intrinsic::vp_reduce_add:
279 case Intrinsic::vp_reduce_or:
280 case Intrinsic::vp_reduce_xor:
281 case Intrinsic::vp_reduce_umax:
282 return Constant::getNullValue(EltTy);
283 case Intrinsic::vp_reduce_mul:
284 return ConstantInt::get(EltTy, 1, /*IsSigned*/ false);
285 case Intrinsic::vp_reduce_and:
286 case Intrinsic::vp_reduce_umin:
287 return ConstantInt::getAllOnesValue(EltTy);
288 case Intrinsic::vp_reduce_smin:
289 return ConstantInt::get(EltTy->getContext(),
290 APInt::getSignedMaxValue(EltBits));
291 case Intrinsic::vp_reduce_smax:
292 return ConstantInt::get(EltTy->getContext(),
293 APInt::getSignedMinValue(EltBits));
294 case Intrinsic::vp_reduce_fmax:
295 Negative = true;
296 LLVM_FALLTHROUGH;
297 case Intrinsic::vp_reduce_fmin: {
298 FastMathFlags Flags = VPI.getFastMathFlags();
299 const fltSemantics &Semantics = EltTy->getFltSemantics();
300 return !Flags.noNaNs() ? ConstantFP::getQNaN(EltTy, Negative)
301 : !Flags.noInfs()
302 ? ConstantFP::getInfinity(EltTy, Negative)
303 : ConstantFP::get(EltTy,
304 APFloat::getLargest(Semantics, Negative));
305 }
306 case Intrinsic::vp_reduce_fadd:
307 return ConstantFP::getNegativeZero(EltTy);
308 case Intrinsic::vp_reduce_fmul:
309 return ConstantFP::get(EltTy, 1.0);
310 }
311 }
312
313 Value *
expandPredicationInReduction(IRBuilder<> & Builder,VPReductionIntrinsic & VPI)314 CachingVPExpander::expandPredicationInReduction(IRBuilder<> &Builder,
315 VPReductionIntrinsic &VPI) {
316 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&
317 "Implicitly dropping %evl in non-speculatable operator!");
318
319 Value *Mask = VPI.getMaskParam();
320 Value *RedOp = VPI.getOperand(VPI.getVectorParamPos());
321
322 // Insert neutral element in masked-out positions
323 if (Mask && !isAllTrueMask(Mask)) {
324 auto *NeutralElt = getNeutralReductionElement(VPI, VPI.getType());
325 auto *NeutralVector = Builder.CreateVectorSplat(
326 cast<VectorType>(RedOp->getType())->getElementCount(), NeutralElt);
327 RedOp = Builder.CreateSelect(Mask, RedOp, NeutralVector);
328 }
329
330 Value *Reduction;
331 Value *Start = VPI.getOperand(VPI.getStartParamPos());
332
333 switch (VPI.getIntrinsicID()) {
334 default:
335 llvm_unreachable("Impossible reduction kind");
336 case Intrinsic::vp_reduce_add:
337 Reduction = Builder.CreateAddReduce(RedOp);
338 Reduction = Builder.CreateAdd(Reduction, Start);
339 break;
340 case Intrinsic::vp_reduce_mul:
341 Reduction = Builder.CreateMulReduce(RedOp);
342 Reduction = Builder.CreateMul(Reduction, Start);
343 break;
344 case Intrinsic::vp_reduce_and:
345 Reduction = Builder.CreateAndReduce(RedOp);
346 Reduction = Builder.CreateAnd(Reduction, Start);
347 break;
348 case Intrinsic::vp_reduce_or:
349 Reduction = Builder.CreateOrReduce(RedOp);
350 Reduction = Builder.CreateOr(Reduction, Start);
351 break;
352 case Intrinsic::vp_reduce_xor:
353 Reduction = Builder.CreateXorReduce(RedOp);
354 Reduction = Builder.CreateXor(Reduction, Start);
355 break;
356 case Intrinsic::vp_reduce_smax:
357 Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ true);
358 Reduction =
359 Builder.CreateBinaryIntrinsic(Intrinsic::smax, Reduction, Start);
360 break;
361 case Intrinsic::vp_reduce_smin:
362 Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ true);
363 Reduction =
364 Builder.CreateBinaryIntrinsic(Intrinsic::smin, Reduction, Start);
365 break;
366 case Intrinsic::vp_reduce_umax:
367 Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ false);
368 Reduction =
369 Builder.CreateBinaryIntrinsic(Intrinsic::umax, Reduction, Start);
370 break;
371 case Intrinsic::vp_reduce_umin:
372 Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ false);
373 Reduction =
374 Builder.CreateBinaryIntrinsic(Intrinsic::umin, Reduction, Start);
375 break;
376 case Intrinsic::vp_reduce_fmax:
377 Reduction = Builder.CreateFPMaxReduce(RedOp);
378 transferDecorations(*Reduction, VPI);
379 Reduction =
380 Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, Reduction, Start);
381 break;
382 case Intrinsic::vp_reduce_fmin:
383 Reduction = Builder.CreateFPMinReduce(RedOp);
384 transferDecorations(*Reduction, VPI);
385 Reduction =
386 Builder.CreateBinaryIntrinsic(Intrinsic::minnum, Reduction, Start);
387 break;
388 case Intrinsic::vp_reduce_fadd:
389 Reduction = Builder.CreateFAddReduce(Start, RedOp);
390 break;
391 case Intrinsic::vp_reduce_fmul:
392 Reduction = Builder.CreateFMulReduce(Start, RedOp);
393 break;
394 }
395
396 replaceOperation(*Reduction, VPI);
397 return Reduction;
398 }
399
400 Value *
expandPredicationInMemoryIntrinsic(IRBuilder<> & Builder,VPIntrinsic & VPI)401 CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder,
402 VPIntrinsic &VPI) {
403 assert(VPI.canIgnoreVectorLengthParam());
404
405 const auto &DL = F.getParent()->getDataLayout();
406
407 Value *MaskParam = VPI.getMaskParam();
408 Value *PtrParam = VPI.getMemoryPointerParam();
409 Value *DataParam = VPI.getMemoryDataParam();
410 bool IsUnmasked = isAllTrueMask(MaskParam);
411
412 MaybeAlign AlignOpt = VPI.getPointerAlignment();
413
414 Value *NewMemoryInst = nullptr;
415 switch (VPI.getIntrinsicID()) {
416 default:
417 llvm_unreachable("Not a VP memory intrinsic");
418 case Intrinsic::vp_store:
419 if (IsUnmasked) {
420 StoreInst *NewStore =
421 Builder.CreateStore(DataParam, PtrParam, /*IsVolatile*/ false);
422 if (AlignOpt.has_value())
423 NewStore->setAlignment(AlignOpt.value());
424 NewMemoryInst = NewStore;
425 } else
426 NewMemoryInst = Builder.CreateMaskedStore(
427 DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam);
428
429 break;
430 case Intrinsic::vp_load:
431 if (IsUnmasked) {
432 LoadInst *NewLoad =
433 Builder.CreateLoad(VPI.getType(), PtrParam, /*IsVolatile*/ false);
434 if (AlignOpt.has_value())
435 NewLoad->setAlignment(AlignOpt.value());
436 NewMemoryInst = NewLoad;
437 } else
438 NewMemoryInst = Builder.CreateMaskedLoad(
439 VPI.getType(), PtrParam, AlignOpt.valueOrOne(), MaskParam);
440
441 break;
442 case Intrinsic::vp_scatter: {
443 auto *ElementType =
444 cast<VectorType>(DataParam->getType())->getElementType();
445 NewMemoryInst = Builder.CreateMaskedScatter(
446 DataParam, PtrParam,
447 AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam);
448 break;
449 }
450 case Intrinsic::vp_gather: {
451 auto *ElementType = cast<VectorType>(VPI.getType())->getElementType();
452 NewMemoryInst = Builder.CreateMaskedGather(
453 VPI.getType(), PtrParam,
454 AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam, nullptr,
455 VPI.getName());
456 break;
457 }
458 }
459
460 assert(NewMemoryInst);
461 replaceOperation(*NewMemoryInst, VPI);
462 return NewMemoryInst;
463 }
464
discardEVLParameter(VPIntrinsic & VPI)465 void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {
466 LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n");
467
468 if (VPI.canIgnoreVectorLengthParam())
469 return;
470
471 Value *EVLParam = VPI.getVectorLengthParam();
472 if (!EVLParam)
473 return;
474
475 ElementCount StaticElemCount = VPI.getStaticVectorLength();
476 Value *MaxEVL = nullptr;
477 Type *Int32Ty = Type::getInt32Ty(VPI.getContext());
478 if (StaticElemCount.isScalable()) {
479 // TODO add caching
480 auto *M = VPI.getModule();
481 Function *VScaleFunc =
482 Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty);
483 IRBuilder<> Builder(VPI.getParent(), VPI.getIterator());
484 Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue());
485 Value *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale");
486 MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size",
487 /*NUW*/ true, /*NSW*/ false);
488 } else {
489 MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false);
490 }
491 VPI.setVectorLengthParam(MaxEVL);
492 }
493
foldEVLIntoMask(VPIntrinsic & VPI)494 Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) {
495 LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n');
496
497 IRBuilder<> Builder(&VPI);
498
499 // Ineffective %evl parameter and so nothing to do here.
500 if (VPI.canIgnoreVectorLengthParam())
501 return &VPI;
502
503 // Only VP intrinsics can have an %evl parameter.
504 Value *OldMaskParam = VPI.getMaskParam();
505 Value *OldEVLParam = VPI.getVectorLengthParam();
506 assert(OldMaskParam && "no mask param to fold the vl param into");
507 assert(OldEVLParam && "no EVL param to fold away");
508
509 LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n');
510 LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n');
511
512 // Convert the %evl predication into vector mask predication.
513 ElementCount ElemCount = VPI.getStaticVectorLength();
514 Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount);
515 Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam);
516 VPI.setMaskParam(NewMaskParam);
517
518 // Drop the %evl parameter.
519 discardEVLParameter(VPI);
520 assert(VPI.canIgnoreVectorLengthParam() &&
521 "transformation did not render the evl param ineffective!");
522
523 // Reassess the modified instruction.
524 return &VPI;
525 }
526
expandPredication(VPIntrinsic & VPI)527 Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) {
528 LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n');
529
530 IRBuilder<> Builder(&VPI);
531
532 // Try lowering to a LLVM instruction first.
533 auto OC = VPI.getFunctionalOpcode();
534
535 if (OC && Instruction::isBinaryOp(*OC))
536 return expandPredicationInBinaryOperator(Builder, VPI);
537
538 if (auto *VPRI = dyn_cast<VPReductionIntrinsic>(&VPI))
539 return expandPredicationInReduction(Builder, *VPRI);
540
541 switch (VPI.getIntrinsicID()) {
542 default:
543 break;
544 case Intrinsic::vp_load:
545 case Intrinsic::vp_store:
546 case Intrinsic::vp_gather:
547 case Intrinsic::vp_scatter:
548 return expandPredicationInMemoryIntrinsic(Builder, VPI);
549 }
550
551 return &VPI;
552 }
553
554 //// } CachingVPExpander
555
556 struct TransformJob {
557 VPIntrinsic *PI;
558 TargetTransformInfo::VPLegalization Strategy;
TransformJob__anona96194c90111::TransformJob559 TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat)
560 : PI(PI), Strategy(InitStrat) {}
561
isDone__anona96194c90111::TransformJob562 bool isDone() const { return Strategy.shouldDoNothing(); }
563 };
564
sanitizeStrategy(VPIntrinsic & VPI,VPLegalization & LegalizeStrat)565 void sanitizeStrategy(VPIntrinsic &VPI, VPLegalization &LegalizeStrat) {
566 // Operations with speculatable lanes do not strictly need predication.
567 if (maySpeculateLanes(VPI)) {
568 // Converting a speculatable VP intrinsic means dropping %mask and %evl.
569 // No need to expand %evl into the %mask only to ignore that code.
570 if (LegalizeStrat.OpStrategy == VPLegalization::Convert)
571 LegalizeStrat.EVLParamStrategy = VPLegalization::Discard;
572 return;
573 }
574
575 // We have to preserve the predicating effect of %evl for this
576 // non-speculatable VP intrinsic.
577 // 1) Never discard %evl.
578 // 2) If this VP intrinsic will be expanded to non-VP code, make sure that
579 // %evl gets folded into %mask.
580 if ((LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) ||
581 (LegalizeStrat.OpStrategy == VPLegalization::Convert)) {
582 LegalizeStrat.EVLParamStrategy = VPLegalization::Convert;
583 }
584 }
585
586 VPLegalization
getVPLegalizationStrategy(const VPIntrinsic & VPI) const587 CachingVPExpander::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {
588 auto VPStrat = TTI.getVPLegalizationStrategy(VPI);
589 if (LLVM_LIKELY(!UsingTTIOverrides)) {
590 // No overrides - we are in production.
591 return VPStrat;
592 }
593
594 // Overrides set - we are in testing, the following does not need to be
595 // efficient.
596 VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride);
597 VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride);
598 return VPStrat;
599 }
600
601 /// \brief Expand llvm.vp.* intrinsics as requested by \p TTI.
expandVectorPredication()602 bool CachingVPExpander::expandVectorPredication() {
603 SmallVector<TransformJob, 16> Worklist;
604
605 // Collect all VPIntrinsics that need expansion and determine their expansion
606 // strategy.
607 for (auto &I : instructions(F)) {
608 auto *VPI = dyn_cast<VPIntrinsic>(&I);
609 if (!VPI)
610 continue;
611 auto VPStrat = getVPLegalizationStrategy(*VPI);
612 sanitizeStrategy(*VPI, VPStrat);
613 if (!VPStrat.shouldDoNothing())
614 Worklist.emplace_back(VPI, VPStrat);
615 }
616 if (Worklist.empty())
617 return false;
618
619 // Transform all VPIntrinsics on the worklist.
620 LLVM_DEBUG(dbgs() << "\n:::: Transforming " << Worklist.size()
621 << " instructions ::::\n");
622 for (TransformJob Job : Worklist) {
623 // Transform the EVL parameter.
624 switch (Job.Strategy.EVLParamStrategy) {
625 case VPLegalization::Legal:
626 break;
627 case VPLegalization::Discard:
628 discardEVLParameter(*Job.PI);
629 break;
630 case VPLegalization::Convert:
631 if (foldEVLIntoMask(*Job.PI))
632 ++NumFoldedVL;
633 break;
634 }
635 Job.Strategy.EVLParamStrategy = VPLegalization::Legal;
636
637 // Replace with a non-predicated operation.
638 switch (Job.Strategy.OpStrategy) {
639 case VPLegalization::Legal:
640 break;
641 case VPLegalization::Discard:
642 llvm_unreachable("Invalid strategy for operators.");
643 case VPLegalization::Convert:
644 expandPredication(*Job.PI);
645 ++NumLoweredVPOps;
646 break;
647 }
648 Job.Strategy.OpStrategy = VPLegalization::Legal;
649
650 assert(Job.isDone() && "incomplete transformation");
651 }
652
653 return true;
654 }
655 class ExpandVectorPredication : public FunctionPass {
656 public:
657 static char ID;
ExpandVectorPredication()658 ExpandVectorPredication() : FunctionPass(ID) {
659 initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry());
660 }
661
runOnFunction(Function & F)662 bool runOnFunction(Function &F) override {
663 const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
664 CachingVPExpander VPExpander(F, *TTI);
665 return VPExpander.expandVectorPredication();
666 }
667
getAnalysisUsage(AnalysisUsage & AU) const668 void getAnalysisUsage(AnalysisUsage &AU) const override {
669 AU.addRequired<TargetTransformInfoWrapperPass>();
670 AU.setPreservesCFG();
671 }
672 };
673 } // namespace
674
675 char ExpandVectorPredication::ID;
676 INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expandvp",
677 "Expand vector predication intrinsics", false, false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)678 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
679 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
680 INITIALIZE_PASS_END(ExpandVectorPredication, "expandvp",
681 "Expand vector predication intrinsics", false, false)
682
683 FunctionPass *llvm::createExpandVectorPredicationPass() {
684 return new ExpandVectorPredication();
685 }
686
687 PreservedAnalyses
run(Function & F,FunctionAnalysisManager & AM)688 ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) {
689 const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
690 CachingVPExpander VPExpander(F, TTI);
691 if (!VPExpander.expandVectorPredication())
692 return PreservedAnalyses::all();
693 PreservedAnalyses PA;
694 PA.preserveSet<CFGAnalyses>();
695 return PA;
696 }
697