1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass custom lowers llvm.gather and llvm.scatter instructions to
10 // RISCV intrinsics.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "RISCV.h"
15 #include "RISCVTargetMachine.h"
16 #include "llvm/Analysis/LoopInfo.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/Analysis/VectorUtils.h"
19 #include "llvm/CodeGen/TargetPassConfig.h"
20 #include "llvm/IR/GetElementPtrTypeIterator.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/IntrinsicsRISCV.h"
24 #include "llvm/Transforms/Utils/Local.h"
25
26 using namespace llvm;
27
28 #define DEBUG_TYPE "riscv-gather-scatter-lowering"
29
30 namespace {
31
32 class RISCVGatherScatterLowering : public FunctionPass {
33 const RISCVSubtarget *ST = nullptr;
34 const RISCVTargetLowering *TLI = nullptr;
35 LoopInfo *LI = nullptr;
36 const DataLayout *DL = nullptr;
37
38 SmallVector<WeakTrackingVH> MaybeDeadPHIs;
39
40 // Cache of the BasePtr and Stride determined from this GEP. When a GEP is
41 // used by multiple gathers/scatters, this allow us to reuse the scalar
42 // instructions we created for the first gather/scatter for the others.
43 DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;
44
45 public:
46 static char ID; // Pass identification, replacement for typeid
47
RISCVGatherScatterLowering()48 RISCVGatherScatterLowering() : FunctionPass(ID) {}
49
50 bool runOnFunction(Function &F) override;
51
getAnalysisUsage(AnalysisUsage & AU) const52 void getAnalysisUsage(AnalysisUsage &AU) const override {
53 AU.setPreservesCFG();
54 AU.addRequired<TargetPassConfig>();
55 AU.addRequired<LoopInfoWrapperPass>();
56 }
57
getPassName() const58 StringRef getPassName() const override {
59 return "RISCV gather/scatter lowering";
60 }
61
62 private:
63 bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp);
64
65 bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
66 Value *AlignOp);
67
68 std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
69 IRBuilder<> &Builder);
70
71 bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
72 PHINode *&BasePtr, BinaryOperator *&Inc,
73 IRBuilder<> &Builder);
74 };
75
76 } // end anonymous namespace
77
78 char RISCVGatherScatterLowering::ID = 0;
79
80 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
81 "RISCV gather/scatter lowering pass", false, false)
82
createRISCVGatherScatterLoweringPass()83 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
84 return new RISCVGatherScatterLowering();
85 }
86
isLegalTypeAndAlignment(Type * DataType,Value * AlignOp)87 bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType,
88 Value *AlignOp) {
89 Type *ScalarType = DataType->getScalarType();
90 if (!TLI->isLegalElementTypeForRVV(ScalarType))
91 return false;
92
93 MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
94 if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedSize())
95 return false;
96
97 // FIXME: Let the backend type legalize by splitting/widening?
98 EVT DataVT = TLI->getValueType(*DL, DataType);
99 if (!TLI->isTypeLegal(DataVT))
100 return false;
101
102 return true;
103 }
104
105 // TODO: Should we consider the mask when looking for a stride?
matchStridedConstant(Constant * StartC)106 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
107 unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
108
109 // Check that the start value is a strided constant.
110 auto *StartVal =
111 dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
112 if (!StartVal)
113 return std::make_pair(nullptr, nullptr);
114 APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
115 ConstantInt *Prev = StartVal;
116 for (unsigned i = 1; i != NumElts; ++i) {
117 auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
118 if (!C)
119 return std::make_pair(nullptr, nullptr);
120
121 APInt LocalStride = C->getValue() - Prev->getValue();
122 if (i == 1)
123 StrideVal = LocalStride;
124 else if (StrideVal != LocalStride)
125 return std::make_pair(nullptr, nullptr);
126
127 Prev = C;
128 }
129
130 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
131
132 return std::make_pair(StartVal, Stride);
133 }
134
matchStridedStart(Value * Start,IRBuilder<> & Builder)135 static std::pair<Value *, Value *> matchStridedStart(Value *Start,
136 IRBuilder<> &Builder) {
137 // Base case, start is a strided constant.
138 auto *StartC = dyn_cast<Constant>(Start);
139 if (StartC)
140 return matchStridedConstant(StartC);
141
142 // Not a constant, maybe it's a strided constant with a splat added to it.
143 auto *BO = dyn_cast<BinaryOperator>(Start);
144 if (!BO || BO->getOpcode() != Instruction::Add)
145 return std::make_pair(nullptr, nullptr);
146
147 // Look for an operand that is splatted.
148 unsigned OtherIndex = 1;
149 Value *Splat = getSplatValue(BO->getOperand(0));
150 if (!Splat) {
151 Splat = getSplatValue(BO->getOperand(1));
152 OtherIndex = 0;
153 }
154 if (!Splat)
155 return std::make_pair(nullptr, nullptr);
156
157 Value *Stride;
158 std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
159 Builder);
160 if (!Start)
161 return std::make_pair(nullptr, nullptr);
162
163 // Add the splat value to the start.
164 Builder.SetInsertPoint(BO);
165 Builder.SetCurrentDebugLocation(DebugLoc());
166 Start = Builder.CreateAdd(Start, Splat);
167 return std::make_pair(Start, Stride);
168 }
169
170 // Recursively, walk about the use-def chain until we find a Phi with a strided
171 // start value. Build and update a scalar recurrence as we unwind the recursion.
172 // We also update the Stride as we unwind. Our goal is to move all of the
173 // arithmetic out of the loop.
matchStridedRecurrence(Value * Index,Loop * L,Value * & Stride,PHINode * & BasePtr,BinaryOperator * & Inc,IRBuilder<> & Builder)174 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
175 Value *&Stride,
176 PHINode *&BasePtr,
177 BinaryOperator *&Inc,
178 IRBuilder<> &Builder) {
179 // Our base case is a Phi.
180 if (auto *Phi = dyn_cast<PHINode>(Index)) {
181 // A phi node we want to perform this function on should be from the
182 // loop header.
183 if (Phi->getParent() != L->getHeader())
184 return false;
185
186 Value *Step, *Start;
187 if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
188 Inc->getOpcode() != Instruction::Add)
189 return false;
190 assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
191 unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
192 assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
193 "Expected one operand of phi to be Inc");
194
195 // Only proceed if the step is loop invariant.
196 if (!L->isLoopInvariant(Step))
197 return false;
198
199 // Step should be a splat.
200 Step = getSplatValue(Step);
201 if (!Step)
202 return false;
203
204 std::tie(Start, Stride) = matchStridedStart(Start, Builder);
205 if (!Start)
206 return false;
207 assert(Stride != nullptr);
208
209 // Build scalar phi and increment.
210 BasePtr =
211 PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi);
212 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
213 Inc);
214 BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
215 BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
216
217 // Note that this Phi might be eligible for removal.
218 MaybeDeadPHIs.push_back(Phi);
219 return true;
220 }
221
222 // Otherwise look for binary operator.
223 auto *BO = dyn_cast<BinaryOperator>(Index);
224 if (!BO)
225 return false;
226
227 if (BO->getOpcode() != Instruction::Add &&
228 BO->getOpcode() != Instruction::Or &&
229 BO->getOpcode() != Instruction::Mul &&
230 BO->getOpcode() != Instruction::Shl)
231 return false;
232
233 // Only support shift by constant.
234 if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1)))
235 return false;
236
237 // We need to be able to treat Or as Add.
238 if (BO->getOpcode() == Instruction::Or &&
239 !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
240 return false;
241
242 // We should have one operand in the loop and one splat.
243 Value *OtherOp;
244 if (isa<Instruction>(BO->getOperand(0)) &&
245 L->contains(cast<Instruction>(BO->getOperand(0)))) {
246 Index = cast<Instruction>(BO->getOperand(0));
247 OtherOp = BO->getOperand(1);
248 } else if (isa<Instruction>(BO->getOperand(1)) &&
249 L->contains(cast<Instruction>(BO->getOperand(1)))) {
250 Index = cast<Instruction>(BO->getOperand(1));
251 OtherOp = BO->getOperand(0);
252 } else {
253 return false;
254 }
255
256 // Make sure other op is loop invariant.
257 if (!L->isLoopInvariant(OtherOp))
258 return false;
259
260 // Make sure we have a splat.
261 Value *SplatOp = getSplatValue(OtherOp);
262 if (!SplatOp)
263 return false;
264
265 // Recurse up the use-def chain.
266 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
267 return false;
268
269 // Locate the Step and Start values from the recurrence.
270 unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
271 unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
272 Value *Step = Inc->getOperand(StepIndex);
273 Value *Start = BasePtr->getOperand(StartBlock);
274
275 // We need to adjust the start value in the preheader.
276 Builder.SetInsertPoint(
277 BasePtr->getIncomingBlock(StartBlock)->getTerminator());
278 Builder.SetCurrentDebugLocation(DebugLoc());
279
280 switch (BO->getOpcode()) {
281 default:
282 llvm_unreachable("Unexpected opcode!");
283 case Instruction::Add:
284 case Instruction::Or: {
285 // An add only affects the start value. It's ok to do this for Or because
286 // we already checked that there are no common set bits.
287
288 // If the start value is Zero, just take the SplatOp.
289 if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero())
290 Start = SplatOp;
291 else
292 Start = Builder.CreateAdd(Start, SplatOp, "start");
293 BasePtr->setIncomingValue(StartBlock, Start);
294 break;
295 }
296 case Instruction::Mul: {
297 // If the start is zero we don't need to multiply.
298 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
299 Start = Builder.CreateMul(Start, SplatOp, "start");
300
301 Step = Builder.CreateMul(Step, SplatOp, "step");
302
303 // If the Stride is 1 just take the SplatOpt.
304 if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne())
305 Stride = SplatOp;
306 else
307 Stride = Builder.CreateMul(Stride, SplatOp, "stride");
308 Inc->setOperand(StepIndex, Step);
309 BasePtr->setIncomingValue(StartBlock, Start);
310 break;
311 }
312 case Instruction::Shl: {
313 // If the start is zero we don't need to shift.
314 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
315 Start = Builder.CreateShl(Start, SplatOp, "start");
316 Step = Builder.CreateShl(Step, SplatOp, "step");
317 Stride = Builder.CreateShl(Stride, SplatOp, "stride");
318 Inc->setOperand(StepIndex, Step);
319 BasePtr->setIncomingValue(StartBlock, Start);
320 break;
321 }
322 }
323
324 return true;
325 }
326
327 std::pair<Value *, Value *>
determineBaseAndStride(GetElementPtrInst * GEP,IRBuilder<> & Builder)328 RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
329 IRBuilder<> &Builder) {
330
331 auto I = StridedAddrs.find(GEP);
332 if (I != StridedAddrs.end())
333 return I->second;
334
335 SmallVector<Value *, 2> Ops(GEP->operands());
336
337 // Base pointer needs to be a scalar.
338 if (Ops[0]->getType()->isVectorTy())
339 return std::make_pair(nullptr, nullptr);
340
341 // Make sure we're in a loop and that has a pre-header and a single latch.
342 Loop *L = LI->getLoopFor(GEP->getParent());
343 if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
344 return std::make_pair(nullptr, nullptr);
345
346 Optional<unsigned> VecOperand;
347 unsigned TypeScale = 0;
348
349 // Look for a vector operand and scale.
350 gep_type_iterator GTI = gep_type_begin(GEP);
351 for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
352 if (!Ops[i]->getType()->isVectorTy())
353 continue;
354
355 if (VecOperand)
356 return std::make_pair(nullptr, nullptr);
357
358 VecOperand = i;
359
360 TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType());
361 if (TS.isScalable())
362 return std::make_pair(nullptr, nullptr);
363
364 TypeScale = TS.getFixedSize();
365 }
366
367 // We need to find a vector index to simplify.
368 if (!VecOperand)
369 return std::make_pair(nullptr, nullptr);
370
371 // We can't extract the stride if the arithmetic is done at a different size
372 // than the pointer type. Adding the stride later may not wrap correctly.
373 // Technically we could handle wider indices, but I don't expect that in
374 // practice.
375 Value *VecIndex = Ops[*VecOperand];
376 Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
377 if (VecIndex->getType() != VecIntPtrTy)
378 return std::make_pair(nullptr, nullptr);
379
380 Value *Stride;
381 BinaryOperator *Inc;
382 PHINode *BasePhi;
383 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
384 return std::make_pair(nullptr, nullptr);
385
386 assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
387 unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
388 assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
389 "Expected one operand of phi to be Inc");
390
391 Builder.SetInsertPoint(GEP);
392
393 // Replace the vector index with the scalar phi and build a scalar GEP.
394 Ops[*VecOperand] = BasePhi;
395 Type *SourceTy = GEP->getSourceElementType();
396 Value *BasePtr =
397 Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front());
398
399 // Final adjustments to stride should go in the start block.
400 Builder.SetInsertPoint(
401 BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
402
403 // Convert stride to pointer size if needed.
404 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
405 assert(Stride->getType() == IntPtrTy && "Unexpected type");
406
407 // Scale the stride by the size of the indexed type.
408 if (TypeScale != 1)
409 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
410
411 auto P = std::make_pair(BasePtr, Stride);
412 StridedAddrs[GEP] = P;
413 return P;
414 }
415
tryCreateStridedLoadStore(IntrinsicInst * II,Type * DataType,Value * Ptr,Value * AlignOp)416 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
417 Type *DataType,
418 Value *Ptr,
419 Value *AlignOp) {
420 // Make sure the operation will be supported by the backend.
421 if (!isLegalTypeAndAlignment(DataType, AlignOp))
422 return false;
423
424 // Pointer should be a GEP.
425 auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
426 if (!GEP)
427 return false;
428
429 IRBuilder<> Builder(GEP);
430
431 Value *BasePtr, *Stride;
432 std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
433 if (!BasePtr)
434 return false;
435 assert(Stride != nullptr);
436
437 Builder.SetInsertPoint(II);
438
439 CallInst *Call;
440 if (II->getIntrinsicID() == Intrinsic::masked_gather)
441 Call = Builder.CreateIntrinsic(
442 Intrinsic::riscv_masked_strided_load,
443 {DataType, BasePtr->getType(), Stride->getType()},
444 {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
445 else
446 Call = Builder.CreateIntrinsic(
447 Intrinsic::riscv_masked_strided_store,
448 {DataType, BasePtr->getType(), Stride->getType()},
449 {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
450
451 Call->takeName(II);
452 II->replaceAllUsesWith(Call);
453 II->eraseFromParent();
454
455 if (GEP->use_empty())
456 RecursivelyDeleteTriviallyDeadInstructions(GEP);
457
458 return true;
459 }
460
runOnFunction(Function & F)461 bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
462 if (skipFunction(F))
463 return false;
464
465 auto &TPC = getAnalysis<TargetPassConfig>();
466 auto &TM = TPC.getTM<RISCVTargetMachine>();
467 ST = &TM.getSubtarget<RISCVSubtarget>(F);
468 if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
469 return false;
470
471 TLI = ST->getTargetLowering();
472 DL = &F.getParent()->getDataLayout();
473 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
474
475 StridedAddrs.clear();
476
477 SmallVector<IntrinsicInst *, 4> Gathers;
478 SmallVector<IntrinsicInst *, 4> Scatters;
479
480 bool Changed = false;
481
482 for (BasicBlock &BB : F) {
483 for (Instruction &I : BB) {
484 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
485 if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
486 isa<FixedVectorType>(II->getType())) {
487 Gathers.push_back(II);
488 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
489 isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
490 Scatters.push_back(II);
491 }
492 }
493 }
494
495 // Rewrite gather/scatter to form strided load/store if possible.
496 for (auto *II : Gathers)
497 Changed |= tryCreateStridedLoadStore(
498 II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
499 for (auto *II : Scatters)
500 Changed |=
501 tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
502 II->getArgOperand(1), II->getArgOperand(2));
503
504 // Remove any dead phis.
505 while (!MaybeDeadPHIs.empty()) {
506 if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
507 RecursivelyDeleteDeadPHINode(Phi);
508 }
509
510 return Changed;
511 }
512