1836b0f48SAmara Emerson //===--- ExpandReductions.cpp - Expand experimental reduction intrinsics --===//
2836b0f48SAmara Emerson //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6836b0f48SAmara Emerson //
7836b0f48SAmara Emerson //===----------------------------------------------------------------------===//
8836b0f48SAmara Emerson //
9836b0f48SAmara Emerson // This pass implements IR expansion for reduction intrinsics, allowing targets
10322d0afdSAmara Emerson // to enable the intrinsics until just before codegen.
11836b0f48SAmara Emerson //
12836b0f48SAmara Emerson //===----------------------------------------------------------------------===//
13836b0f48SAmara Emerson
14836b0f48SAmara Emerson #include "llvm/CodeGen/ExpandReductions.h"
156bda14b3SChandler Carruth #include "llvm/Analysis/TargetTransformInfo.h"
16836b0f48SAmara Emerson #include "llvm/CodeGen/Passes.h"
17836b0f48SAmara Emerson #include "llvm/IR/IRBuilder.h"
18836b0f48SAmara Emerson #include "llvm/IR/InstIterator.h"
19836b0f48SAmara Emerson #include "llvm/IR/IntrinsicInst.h"
206bda14b3SChandler Carruth #include "llvm/IR/Intrinsics.h"
2105da2fe5SReid Kleckner #include "llvm/InitializePasses.h"
22836b0f48SAmara Emerson #include "llvm/Pass.h"
236bda14b3SChandler Carruth #include "llvm/Transforms/Utils/LoopUtils.h"
24836b0f48SAmara Emerson
25836b0f48SAmara Emerson using namespace llvm;
26836b0f48SAmara Emerson
27836b0f48SAmara Emerson namespace {
28836b0f48SAmara Emerson
getOpcode(Intrinsic::ID ID)29836b0f48SAmara Emerson unsigned getOpcode(Intrinsic::ID ID) {
30836b0f48SAmara Emerson switch (ID) {
31322d0afdSAmara Emerson case Intrinsic::vector_reduce_fadd:
32836b0f48SAmara Emerson return Instruction::FAdd;
33322d0afdSAmara Emerson case Intrinsic::vector_reduce_fmul:
34836b0f48SAmara Emerson return Instruction::FMul;
35322d0afdSAmara Emerson case Intrinsic::vector_reduce_add:
36836b0f48SAmara Emerson return Instruction::Add;
37322d0afdSAmara Emerson case Intrinsic::vector_reduce_mul:
38836b0f48SAmara Emerson return Instruction::Mul;
39322d0afdSAmara Emerson case Intrinsic::vector_reduce_and:
40836b0f48SAmara Emerson return Instruction::And;
41322d0afdSAmara Emerson case Intrinsic::vector_reduce_or:
42836b0f48SAmara Emerson return Instruction::Or;
43322d0afdSAmara Emerson case Intrinsic::vector_reduce_xor:
44836b0f48SAmara Emerson return Instruction::Xor;
45322d0afdSAmara Emerson case Intrinsic::vector_reduce_smax:
46322d0afdSAmara Emerson case Intrinsic::vector_reduce_smin:
47322d0afdSAmara Emerson case Intrinsic::vector_reduce_umax:
48322d0afdSAmara Emerson case Intrinsic::vector_reduce_umin:
49836b0f48SAmara Emerson return Instruction::ICmp;
50322d0afdSAmara Emerson case Intrinsic::vector_reduce_fmax:
51322d0afdSAmara Emerson case Intrinsic::vector_reduce_fmin:
52836b0f48SAmara Emerson return Instruction::FCmp;
53836b0f48SAmara Emerson default:
54836b0f48SAmara Emerson llvm_unreachable("Unexpected ID");
55836b0f48SAmara Emerson }
56836b0f48SAmara Emerson }
57836b0f48SAmara Emerson
getRK(Intrinsic::ID ID)58c74e8539SSanjay Patel RecurKind getRK(Intrinsic::ID ID) {
59836b0f48SAmara Emerson switch (ID) {
60322d0afdSAmara Emerson case Intrinsic::vector_reduce_smax:
61c74e8539SSanjay Patel return RecurKind::SMax;
62322d0afdSAmara Emerson case Intrinsic::vector_reduce_smin:
63c74e8539SSanjay Patel return RecurKind::SMin;
64322d0afdSAmara Emerson case Intrinsic::vector_reduce_umax:
65c74e8539SSanjay Patel return RecurKind::UMax;
66322d0afdSAmara Emerson case Intrinsic::vector_reduce_umin:
67c74e8539SSanjay Patel return RecurKind::UMin;
68322d0afdSAmara Emerson case Intrinsic::vector_reduce_fmax:
69c74e8539SSanjay Patel return RecurKind::FMax;
70322d0afdSAmara Emerson case Intrinsic::vector_reduce_fmin:
71c74e8539SSanjay Patel return RecurKind::FMin;
72836b0f48SAmara Emerson default:
73c74e8539SSanjay Patel return RecurKind::None;
74836b0f48SAmara Emerson }
75836b0f48SAmara Emerson }
76836b0f48SAmara Emerson
expandReductions(Function & F,const TargetTransformInfo * TTI)77836b0f48SAmara Emerson bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
78836b0f48SAmara Emerson bool Changed = false;
79836b0f48SAmara Emerson SmallVector<IntrinsicInst *, 4> Worklist;
8017bb2d7cSCraig Topper for (auto &I : instructions(F)) {
8117bb2d7cSCraig Topper if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
8217bb2d7cSCraig Topper switch (II->getIntrinsicID()) {
8317bb2d7cSCraig Topper default: break;
84322d0afdSAmara Emerson case Intrinsic::vector_reduce_fadd:
85322d0afdSAmara Emerson case Intrinsic::vector_reduce_fmul:
86322d0afdSAmara Emerson case Intrinsic::vector_reduce_add:
87322d0afdSAmara Emerson case Intrinsic::vector_reduce_mul:
88322d0afdSAmara Emerson case Intrinsic::vector_reduce_and:
89322d0afdSAmara Emerson case Intrinsic::vector_reduce_or:
90322d0afdSAmara Emerson case Intrinsic::vector_reduce_xor:
91322d0afdSAmara Emerson case Intrinsic::vector_reduce_smax:
92322d0afdSAmara Emerson case Intrinsic::vector_reduce_smin:
93322d0afdSAmara Emerson case Intrinsic::vector_reduce_umax:
94322d0afdSAmara Emerson case Intrinsic::vector_reduce_umin:
95322d0afdSAmara Emerson case Intrinsic::vector_reduce_fmax:
96322d0afdSAmara Emerson case Intrinsic::vector_reduce_fmin:
9717bb2d7cSCraig Topper if (TTI->shouldExpandReduction(II))
98836b0f48SAmara Emerson Worklist.push_back(II);
99836b0f48SAmara Emerson
10017bb2d7cSCraig Topper break;
10117bb2d7cSCraig Topper }
10217bb2d7cSCraig Topper }
10317bb2d7cSCraig Topper }
104cbeb563cSSander de Smalen
10517bb2d7cSCraig Topper for (auto *II : Worklist) {
106cbeb563cSSander de Smalen FastMathFlags FMF =
107cbeb563cSSander de Smalen isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
108cbeb563cSSander de Smalen Intrinsic::ID ID = II->getIntrinsicID();
109c74e8539SSanjay Patel RecurKind RK = getRK(ID);
110cbeb563cSSander de Smalen
111cbeb563cSSander de Smalen Value *Rdx = nullptr;
112836b0f48SAmara Emerson IRBuilder<> Builder(II);
113cbeb563cSSander de Smalen IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
114cbeb563cSSander de Smalen Builder.setFastMathFlags(FMF);
115836b0f48SAmara Emerson switch (ID) {
11617bb2d7cSCraig Topper default: llvm_unreachable("Unexpected intrinsic!");
117322d0afdSAmara Emerson case Intrinsic::vector_reduce_fadd:
118322d0afdSAmara Emerson case Intrinsic::vector_reduce_fmul: {
119836b0f48SAmara Emerson // FMFs must be attached to the call, otherwise it's an ordered reduction
12023c2182cSSimon Pilgrim // and it can't be handled by generating a shuffle sequence.
121cbeb563cSSander de Smalen Value *Acc = II->getArgOperand(0);
122cbeb563cSSander de Smalen Value *Vec = II->getArgOperand(1);
123cbeb563cSSander de Smalen if (!FMF.allowReassoc())
124c74e8539SSanjay Patel Rdx = getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), RK);
125cbeb563cSSander de Smalen else {
126ff5b9a7bSChristopher Tetreault if (!isPowerOf2_32(
127ff5b9a7bSChristopher Tetreault cast<FixedVectorType>(Vec->getType())->getNumElements()))
1284e9778e3Sshkzhang continue;
1294e9778e3Sshkzhang
130c74e8539SSanjay Patel Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
131cbeb563cSSander de Smalen Rdx = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(ID),
132cbeb563cSSander de Smalen Acc, Rdx, "bin.rdx");
133cbeb563cSSander de Smalen }
13417bb2d7cSCraig Topper break;
13517bb2d7cSCraig Topper }
136322d0afdSAmara Emerson case Intrinsic::vector_reduce_add:
137322d0afdSAmara Emerson case Intrinsic::vector_reduce_mul:
138322d0afdSAmara Emerson case Intrinsic::vector_reduce_and:
139322d0afdSAmara Emerson case Intrinsic::vector_reduce_or:
140322d0afdSAmara Emerson case Intrinsic::vector_reduce_xor:
141322d0afdSAmara Emerson case Intrinsic::vector_reduce_smax:
142322d0afdSAmara Emerson case Intrinsic::vector_reduce_smin:
143322d0afdSAmara Emerson case Intrinsic::vector_reduce_umax:
144322d0afdSAmara Emerson case Intrinsic::vector_reduce_umin: {
145cbeb563cSSander de Smalen Value *Vec = II->getArgOperand(0);
146ff5b9a7bSChristopher Tetreault if (!isPowerOf2_32(
147ff5b9a7bSChristopher Tetreault cast<FixedVectorType>(Vec->getType())->getNumElements()))
1484e9778e3Sshkzhang continue;
1494e9778e3Sshkzhang
150c74e8539SSanjay Patel Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
15117bb2d7cSCraig Topper break;
15217bb2d7cSCraig Topper }
153322d0afdSAmara Emerson case Intrinsic::vector_reduce_fmax:
154322d0afdSAmara Emerson case Intrinsic::vector_reduce_fmin: {
155*056d31ddSSanjay Patel // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
156*056d31ddSSanjay Patel // semantics of the reduction.
1573a8ea860SSanjay Patel Value *Vec = II->getArgOperand(0);
1583a8ea860SSanjay Patel if (!isPowerOf2_32(
1593a8ea860SSanjay Patel cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
160*056d31ddSSanjay Patel !FMF.noNaNs())
1613a8ea860SSanjay Patel continue;
1623a8ea860SSanjay Patel
163c74e8539SSanjay Patel Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
1643a8ea860SSanjay Patel break;
1653a8ea860SSanjay Patel }
166836b0f48SAmara Emerson }
167836b0f48SAmara Emerson II->replaceAllUsesWith(Rdx);
168836b0f48SAmara Emerson II->eraseFromParent();
169836b0f48SAmara Emerson Changed = true;
170836b0f48SAmara Emerson }
171836b0f48SAmara Emerson return Changed;
172836b0f48SAmara Emerson }
173836b0f48SAmara Emerson
174836b0f48SAmara Emerson class ExpandReductions : public FunctionPass {
175836b0f48SAmara Emerson public:
176836b0f48SAmara Emerson static char ID;
ExpandReductions()177836b0f48SAmara Emerson ExpandReductions() : FunctionPass(ID) {
178836b0f48SAmara Emerson initializeExpandReductionsPass(*PassRegistry::getPassRegistry());
179836b0f48SAmara Emerson }
180836b0f48SAmara Emerson
runOnFunction(Function & F)181836b0f48SAmara Emerson bool runOnFunction(Function &F) override {
182836b0f48SAmara Emerson const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
183836b0f48SAmara Emerson return expandReductions(F, TTI);
184836b0f48SAmara Emerson }
185836b0f48SAmara Emerson
getAnalysisUsage(AnalysisUsage & AU) const186836b0f48SAmara Emerson void getAnalysisUsage(AnalysisUsage &AU) const override {
187836b0f48SAmara Emerson AU.addRequired<TargetTransformInfoWrapperPass>();
188836b0f48SAmara Emerson AU.setPreservesCFG();
189836b0f48SAmara Emerson }
190836b0f48SAmara Emerson };
191836b0f48SAmara Emerson }
192836b0f48SAmara Emerson
193836b0f48SAmara Emerson char ExpandReductions::ID;
194836b0f48SAmara Emerson INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
195836b0f48SAmara Emerson "Expand reduction intrinsics", false, false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)196836b0f48SAmara Emerson INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
197836b0f48SAmara Emerson INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
198836b0f48SAmara Emerson "Expand reduction intrinsics", false, false)
199836b0f48SAmara Emerson
200836b0f48SAmara Emerson FunctionPass *llvm::createExpandReductionsPass() {
201836b0f48SAmara Emerson return new ExpandReductions();
202836b0f48SAmara Emerson }
203836b0f48SAmara Emerson
run(Function & F,FunctionAnalysisManager & AM)204836b0f48SAmara Emerson PreservedAnalyses ExpandReductionsPass::run(Function &F,
205836b0f48SAmara Emerson FunctionAnalysisManager &AM) {
206836b0f48SAmara Emerson const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
207836b0f48SAmara Emerson if (!expandReductions(F, &TTI))
208836b0f48SAmara Emerson return PreservedAnalyses::all();
209836b0f48SAmara Emerson PreservedAnalyses PA;
210836b0f48SAmara Emerson PA.preserveSet<CFGAnalyses>();
211836b0f48SAmara Emerson return PA;
212836b0f48SAmara Emerson }
213