1836b0f48SAmara Emerson //===--- ExpandReductions.cpp - Expand experimental reduction intrinsics --===//
2836b0f48SAmara Emerson //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6836b0f48SAmara Emerson //
7836b0f48SAmara Emerson //===----------------------------------------------------------------------===//
8836b0f48SAmara Emerson //
9836b0f48SAmara Emerson // This pass implements IR expansion for reduction intrinsics, allowing targets
10322d0afdSAmara Emerson // to enable the intrinsics until just before codegen.
11836b0f48SAmara Emerson //
12836b0f48SAmara Emerson //===----------------------------------------------------------------------===//
13836b0f48SAmara Emerson 
14836b0f48SAmara Emerson #include "llvm/CodeGen/ExpandReductions.h"
156bda14b3SChandler Carruth #include "llvm/Analysis/TargetTransformInfo.h"
16836b0f48SAmara Emerson #include "llvm/CodeGen/Passes.h"
17836b0f48SAmara Emerson #include "llvm/IR/IRBuilder.h"
18836b0f48SAmara Emerson #include "llvm/IR/InstIterator.h"
19836b0f48SAmara Emerson #include "llvm/IR/IntrinsicInst.h"
206bda14b3SChandler Carruth #include "llvm/IR/Intrinsics.h"
2105da2fe5SReid Kleckner #include "llvm/InitializePasses.h"
22836b0f48SAmara Emerson #include "llvm/Pass.h"
236bda14b3SChandler Carruth #include "llvm/Transforms/Utils/LoopUtils.h"
24836b0f48SAmara Emerson 
25836b0f48SAmara Emerson using namespace llvm;
26836b0f48SAmara Emerson 
27836b0f48SAmara Emerson namespace {
28836b0f48SAmara Emerson 
getOpcode(Intrinsic::ID ID)29836b0f48SAmara Emerson unsigned getOpcode(Intrinsic::ID ID) {
30836b0f48SAmara Emerson   switch (ID) {
31322d0afdSAmara Emerson   case Intrinsic::vector_reduce_fadd:
32836b0f48SAmara Emerson     return Instruction::FAdd;
33322d0afdSAmara Emerson   case Intrinsic::vector_reduce_fmul:
34836b0f48SAmara Emerson     return Instruction::FMul;
35322d0afdSAmara Emerson   case Intrinsic::vector_reduce_add:
36836b0f48SAmara Emerson     return Instruction::Add;
37322d0afdSAmara Emerson   case Intrinsic::vector_reduce_mul:
38836b0f48SAmara Emerson     return Instruction::Mul;
39322d0afdSAmara Emerson   case Intrinsic::vector_reduce_and:
40836b0f48SAmara Emerson     return Instruction::And;
41322d0afdSAmara Emerson   case Intrinsic::vector_reduce_or:
42836b0f48SAmara Emerson     return Instruction::Or;
43322d0afdSAmara Emerson   case Intrinsic::vector_reduce_xor:
44836b0f48SAmara Emerson     return Instruction::Xor;
45322d0afdSAmara Emerson   case Intrinsic::vector_reduce_smax:
46322d0afdSAmara Emerson   case Intrinsic::vector_reduce_smin:
47322d0afdSAmara Emerson   case Intrinsic::vector_reduce_umax:
48322d0afdSAmara Emerson   case Intrinsic::vector_reduce_umin:
49836b0f48SAmara Emerson     return Instruction::ICmp;
50322d0afdSAmara Emerson   case Intrinsic::vector_reduce_fmax:
51322d0afdSAmara Emerson   case Intrinsic::vector_reduce_fmin:
52836b0f48SAmara Emerson     return Instruction::FCmp;
53836b0f48SAmara Emerson   default:
54836b0f48SAmara Emerson     llvm_unreachable("Unexpected ID");
55836b0f48SAmara Emerson   }
56836b0f48SAmara Emerson }
57836b0f48SAmara Emerson 
getRK(Intrinsic::ID ID)58c74e8539SSanjay Patel RecurKind getRK(Intrinsic::ID ID) {
59836b0f48SAmara Emerson   switch (ID) {
60322d0afdSAmara Emerson   case Intrinsic::vector_reduce_smax:
61c74e8539SSanjay Patel     return RecurKind::SMax;
62322d0afdSAmara Emerson   case Intrinsic::vector_reduce_smin:
63c74e8539SSanjay Patel     return RecurKind::SMin;
64322d0afdSAmara Emerson   case Intrinsic::vector_reduce_umax:
65c74e8539SSanjay Patel     return RecurKind::UMax;
66322d0afdSAmara Emerson   case Intrinsic::vector_reduce_umin:
67c74e8539SSanjay Patel     return RecurKind::UMin;
68322d0afdSAmara Emerson   case Intrinsic::vector_reduce_fmax:
69c74e8539SSanjay Patel     return RecurKind::FMax;
70322d0afdSAmara Emerson   case Intrinsic::vector_reduce_fmin:
71c74e8539SSanjay Patel     return RecurKind::FMin;
72836b0f48SAmara Emerson   default:
73c74e8539SSanjay Patel     return RecurKind::None;
74836b0f48SAmara Emerson   }
75836b0f48SAmara Emerson }
76836b0f48SAmara Emerson 
expandReductions(Function & F,const TargetTransformInfo * TTI)77836b0f48SAmara Emerson bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
78836b0f48SAmara Emerson   bool Changed = false;
79836b0f48SAmara Emerson   SmallVector<IntrinsicInst *, 4> Worklist;
8017bb2d7cSCraig Topper   for (auto &I : instructions(F)) {
8117bb2d7cSCraig Topper     if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
8217bb2d7cSCraig Topper       switch (II->getIntrinsicID()) {
8317bb2d7cSCraig Topper       default: break;
84322d0afdSAmara Emerson       case Intrinsic::vector_reduce_fadd:
85322d0afdSAmara Emerson       case Intrinsic::vector_reduce_fmul:
86322d0afdSAmara Emerson       case Intrinsic::vector_reduce_add:
87322d0afdSAmara Emerson       case Intrinsic::vector_reduce_mul:
88322d0afdSAmara Emerson       case Intrinsic::vector_reduce_and:
89322d0afdSAmara Emerson       case Intrinsic::vector_reduce_or:
90322d0afdSAmara Emerson       case Intrinsic::vector_reduce_xor:
91322d0afdSAmara Emerson       case Intrinsic::vector_reduce_smax:
92322d0afdSAmara Emerson       case Intrinsic::vector_reduce_smin:
93322d0afdSAmara Emerson       case Intrinsic::vector_reduce_umax:
94322d0afdSAmara Emerson       case Intrinsic::vector_reduce_umin:
95322d0afdSAmara Emerson       case Intrinsic::vector_reduce_fmax:
96322d0afdSAmara Emerson       case Intrinsic::vector_reduce_fmin:
9717bb2d7cSCraig Topper         if (TTI->shouldExpandReduction(II))
98836b0f48SAmara Emerson           Worklist.push_back(II);
99836b0f48SAmara Emerson 
10017bb2d7cSCraig Topper         break;
10117bb2d7cSCraig Topper       }
10217bb2d7cSCraig Topper     }
10317bb2d7cSCraig Topper   }
104cbeb563cSSander de Smalen 
10517bb2d7cSCraig Topper   for (auto *II : Worklist) {
106cbeb563cSSander de Smalen     FastMathFlags FMF =
107cbeb563cSSander de Smalen         isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
108cbeb563cSSander de Smalen     Intrinsic::ID ID = II->getIntrinsicID();
109c74e8539SSanjay Patel     RecurKind RK = getRK(ID);
110cbeb563cSSander de Smalen 
111cbeb563cSSander de Smalen     Value *Rdx = nullptr;
112836b0f48SAmara Emerson     IRBuilder<> Builder(II);
113cbeb563cSSander de Smalen     IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
114cbeb563cSSander de Smalen     Builder.setFastMathFlags(FMF);
115836b0f48SAmara Emerson     switch (ID) {
11617bb2d7cSCraig Topper     default: llvm_unreachable("Unexpected intrinsic!");
117322d0afdSAmara Emerson     case Intrinsic::vector_reduce_fadd:
118322d0afdSAmara Emerson     case Intrinsic::vector_reduce_fmul: {
119836b0f48SAmara Emerson       // FMFs must be attached to the call, otherwise it's an ordered reduction
12023c2182cSSimon Pilgrim       // and it can't be handled by generating a shuffle sequence.
121cbeb563cSSander de Smalen       Value *Acc = II->getArgOperand(0);
122cbeb563cSSander de Smalen       Value *Vec = II->getArgOperand(1);
123cbeb563cSSander de Smalen       if (!FMF.allowReassoc())
124c74e8539SSanjay Patel         Rdx = getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), RK);
125cbeb563cSSander de Smalen       else {
126ff5b9a7bSChristopher Tetreault         if (!isPowerOf2_32(
127ff5b9a7bSChristopher Tetreault                 cast<FixedVectorType>(Vec->getType())->getNumElements()))
1284e9778e3Sshkzhang           continue;
1294e9778e3Sshkzhang 
130c74e8539SSanjay Patel         Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
131cbeb563cSSander de Smalen         Rdx = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(ID),
132cbeb563cSSander de Smalen                                   Acc, Rdx, "bin.rdx");
133cbeb563cSSander de Smalen       }
13417bb2d7cSCraig Topper       break;
13517bb2d7cSCraig Topper     }
136322d0afdSAmara Emerson     case Intrinsic::vector_reduce_add:
137322d0afdSAmara Emerson     case Intrinsic::vector_reduce_mul:
138322d0afdSAmara Emerson     case Intrinsic::vector_reduce_and:
139322d0afdSAmara Emerson     case Intrinsic::vector_reduce_or:
140322d0afdSAmara Emerson     case Intrinsic::vector_reduce_xor:
141322d0afdSAmara Emerson     case Intrinsic::vector_reduce_smax:
142322d0afdSAmara Emerson     case Intrinsic::vector_reduce_smin:
143322d0afdSAmara Emerson     case Intrinsic::vector_reduce_umax:
144322d0afdSAmara Emerson     case Intrinsic::vector_reduce_umin: {
145cbeb563cSSander de Smalen       Value *Vec = II->getArgOperand(0);
146ff5b9a7bSChristopher Tetreault       if (!isPowerOf2_32(
147ff5b9a7bSChristopher Tetreault               cast<FixedVectorType>(Vec->getType())->getNumElements()))
1484e9778e3Sshkzhang         continue;
1494e9778e3Sshkzhang 
150c74e8539SSanjay Patel       Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
15117bb2d7cSCraig Topper       break;
15217bb2d7cSCraig Topper     }
153322d0afdSAmara Emerson     case Intrinsic::vector_reduce_fmax:
154322d0afdSAmara Emerson     case Intrinsic::vector_reduce_fmin: {
155*056d31ddSSanjay Patel       // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
156*056d31ddSSanjay Patel       // semantics of the reduction.
1573a8ea860SSanjay Patel       Value *Vec = II->getArgOperand(0);
1583a8ea860SSanjay Patel       if (!isPowerOf2_32(
1593a8ea860SSanjay Patel               cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
160*056d31ddSSanjay Patel           !FMF.noNaNs())
1613a8ea860SSanjay Patel         continue;
1623a8ea860SSanjay Patel 
163c74e8539SSanjay Patel       Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
1643a8ea860SSanjay Patel       break;
1653a8ea860SSanjay Patel     }
166836b0f48SAmara Emerson     }
167836b0f48SAmara Emerson     II->replaceAllUsesWith(Rdx);
168836b0f48SAmara Emerson     II->eraseFromParent();
169836b0f48SAmara Emerson     Changed = true;
170836b0f48SAmara Emerson   }
171836b0f48SAmara Emerson   return Changed;
172836b0f48SAmara Emerson }
173836b0f48SAmara Emerson 
174836b0f48SAmara Emerson class ExpandReductions : public FunctionPass {
175836b0f48SAmara Emerson public:
176836b0f48SAmara Emerson   static char ID;
ExpandReductions()177836b0f48SAmara Emerson   ExpandReductions() : FunctionPass(ID) {
178836b0f48SAmara Emerson     initializeExpandReductionsPass(*PassRegistry::getPassRegistry());
179836b0f48SAmara Emerson   }
180836b0f48SAmara Emerson 
runOnFunction(Function & F)181836b0f48SAmara Emerson   bool runOnFunction(Function &F) override {
182836b0f48SAmara Emerson     const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
183836b0f48SAmara Emerson     return expandReductions(F, TTI);
184836b0f48SAmara Emerson   }
185836b0f48SAmara Emerson 
getAnalysisUsage(AnalysisUsage & AU) const186836b0f48SAmara Emerson   void getAnalysisUsage(AnalysisUsage &AU) const override {
187836b0f48SAmara Emerson     AU.addRequired<TargetTransformInfoWrapperPass>();
188836b0f48SAmara Emerson     AU.setPreservesCFG();
189836b0f48SAmara Emerson   }
190836b0f48SAmara Emerson };
191836b0f48SAmara Emerson }
192836b0f48SAmara Emerson 
193836b0f48SAmara Emerson char ExpandReductions::ID;
194836b0f48SAmara Emerson INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
195836b0f48SAmara Emerson                       "Expand reduction intrinsics", false, false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)196836b0f48SAmara Emerson INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
197836b0f48SAmara Emerson INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
198836b0f48SAmara Emerson                     "Expand reduction intrinsics", false, false)
199836b0f48SAmara Emerson 
200836b0f48SAmara Emerson FunctionPass *llvm::createExpandReductionsPass() {
201836b0f48SAmara Emerson   return new ExpandReductions();
202836b0f48SAmara Emerson }
203836b0f48SAmara Emerson 
run(Function & F,FunctionAnalysisManager & AM)204836b0f48SAmara Emerson PreservedAnalyses ExpandReductionsPass::run(Function &F,
205836b0f48SAmara Emerson                                             FunctionAnalysisManager &AM) {
206836b0f48SAmara Emerson   const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
207836b0f48SAmara Emerson   if (!expandReductions(F, &TTI))
208836b0f48SAmara Emerson     return PreservedAnalyses::all();
209836b0f48SAmara Emerson   PreservedAnalyses PA;
210836b0f48SAmara Emerson   PA.preserveSet<CFGAnalyses>();
211836b0f48SAmara Emerson   return PA;
212836b0f48SAmara Emerson }
213