1bdd1243dSDimitry Andric //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2bdd1243dSDimitry Andric //
3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdd1243dSDimitry Andric //
7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8bdd1243dSDimitry Andric //
9bdd1243dSDimitry Andric // Identification:
10bdd1243dSDimitry Andric // This step is responsible for finding the patterns that can be lowered to
11bdd1243dSDimitry Andric // complex instructions, and building a graph to represent the complex
12bdd1243dSDimitry Andric // structures. Starting from the "Converging Shuffle" (a shuffle that
13bdd1243dSDimitry Andric // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14bdd1243dSDimitry Andric // operands are evaluated and identified as "Composite Nodes" (collections of
15bdd1243dSDimitry Andric // instructions that can potentially be lowered to a single complex
16bdd1243dSDimitry Andric // instruction). This is performed by checking the real and imaginary components
17bdd1243dSDimitry Andric // and tracking the data flow for each component while following the operand
18bdd1243dSDimitry Andric // pairs. Validity of each node is expected to be done upon creation, and any
19bdd1243dSDimitry Andric // validation errors should halt traversal and prevent further graph
20bdd1243dSDimitry Andric // construction.
21fe013be4SDimitry Andric // Instead of relying on Shuffle operations, vector interleaving and
22fe013be4SDimitry Andric // deinterleaving can be represented by vector.interleave2 and
23fe013be4SDimitry Andric // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24fe013be4SDimitry Andric // these intrinsics, whereas, fixed-width vectors are recognized for both
25fe013be4SDimitry Andric // shufflevector instruction and intrinsics.
26bdd1243dSDimitry Andric //
27bdd1243dSDimitry Andric // Replacement:
28bdd1243dSDimitry Andric // This step traverses the graph built up by identification, delegating to the
29bdd1243dSDimitry Andric // target to validate and generate the correct intrinsics, and plumbs them
30bdd1243dSDimitry Andric // together connecting each end of the new intrinsics graph to the existing
31bdd1243dSDimitry Andric // use-def chain. This step is assumed to finish successfully, as all
32bdd1243dSDimitry Andric // information is expected to be correct by this point.
33bdd1243dSDimitry Andric //
34bdd1243dSDimitry Andric //
35bdd1243dSDimitry Andric // Internal data structure:
36bdd1243dSDimitry Andric // ComplexDeinterleavingGraph:
37bdd1243dSDimitry Andric // Keeps references to all the valid CompositeNodes formed as part of the
38bdd1243dSDimitry Andric // transformation, and every Instruction contained within said nodes. It also
39bdd1243dSDimitry Andric // holds onto a reference to the root Instruction, and the root node that should
40bdd1243dSDimitry Andric // replace it.
41bdd1243dSDimitry Andric //
42bdd1243dSDimitry Andric // ComplexDeinterleavingCompositeNode:
43bdd1243dSDimitry Andric // A CompositeNode represents a single transformation point; each node should
44bdd1243dSDimitry Andric // transform into a single complex instruction (ignoring vector splitting, which
45bdd1243dSDimitry Andric // would generate more instructions per node). They are identified in a
46bdd1243dSDimitry Andric // depth-first manner, traversing and identifying the operands of each
47bdd1243dSDimitry Andric // instruction in the order they appear in the IR.
48bdd1243dSDimitry Andric // Each node maintains a reference to its Real and Imaginary instructions,
49bdd1243dSDimitry Andric // as well as any additional instructions that make up the identified operation
50bdd1243dSDimitry Andric // (Internal instructions should only have uses within their containing node).
51bdd1243dSDimitry Andric // A Node also contains the rotation and operation type that it represents.
52bdd1243dSDimitry Andric // Operands contains pointers to other CompositeNodes, acting as the edges in
53bdd1243dSDimitry Andric // the graph. ReplacementValue is the transformed Value* that has been emitted
54bdd1243dSDimitry Andric // to the IR.
55bdd1243dSDimitry Andric //
56bdd1243dSDimitry Andric // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57bdd1243dSDimitry Andric // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58bdd1243dSDimitry Andric // should be pre-populated.
59bdd1243dSDimitry Andric //
60bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
61bdd1243dSDimitry Andric
62bdd1243dSDimitry Andric #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63*c9157d92SDimitry Andric #include "llvm/ADT/MapVector.h"
64bdd1243dSDimitry Andric #include "llvm/ADT/Statistic.h"
65bdd1243dSDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h"
66bdd1243dSDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
67bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetLowering.h"
68bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
69bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h"
70bdd1243dSDimitry Andric #include "llvm/IR/IRBuilder.h"
71fe013be4SDimitry Andric #include "llvm/IR/PatternMatch.h"
72bdd1243dSDimitry Andric #include "llvm/InitializePasses.h"
73bdd1243dSDimitry Andric #include "llvm/Target/TargetMachine.h"
74bdd1243dSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
75bdd1243dSDimitry Andric #include <algorithm>
76bdd1243dSDimitry Andric
77bdd1243dSDimitry Andric using namespace llvm;
78bdd1243dSDimitry Andric using namespace PatternMatch;
79bdd1243dSDimitry Andric
80bdd1243dSDimitry Andric #define DEBUG_TYPE "complex-deinterleaving"
81bdd1243dSDimitry Andric
82bdd1243dSDimitry Andric STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83bdd1243dSDimitry Andric
84bdd1243dSDimitry Andric static cl::opt<bool> ComplexDeinterleavingEnabled(
85bdd1243dSDimitry Andric "enable-complex-deinterleaving",
86bdd1243dSDimitry Andric cl::desc("Enable generation of complex instructions"), cl::init(true),
87bdd1243dSDimitry Andric cl::Hidden);
88bdd1243dSDimitry Andric
89bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is interleaving.
90bdd1243dSDimitry Andric ///
91bdd1243dSDimitry Andric /// To be interleaving, a mask must alternate between `i` and `i + (Length /
92bdd1243dSDimitry Andric /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
93bdd1243dSDimitry Andric /// 4x vector interleaving mask would be <0, 2, 1, 3>).
94bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask);
95bdd1243dSDimitry Andric
96bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is deinterleaving.
97bdd1243dSDimitry Andric ///
98bdd1243dSDimitry Andric /// To be deinterleaving, a mask must increment in steps of 2, and either start
99bdd1243dSDimitry Andric /// with 0 or 1.
100bdd1243dSDimitry Andric /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101bdd1243dSDimitry Andric /// <1, 3, 5, 7>).
102bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask);
103bdd1243dSDimitry Andric
104fe013be4SDimitry Andric /// Returns true if the operation is a negation of V, and it works for both
105fe013be4SDimitry Andric /// integers and floats.
106fe013be4SDimitry Andric static bool isNeg(Value *V);
107fe013be4SDimitry Andric
108fe013be4SDimitry Andric /// Returns the operand for negation operation.
109fe013be4SDimitry Andric static Value *getNegOperand(Value *V);
110fe013be4SDimitry Andric
111bdd1243dSDimitry Andric namespace {
112bdd1243dSDimitry Andric
113bdd1243dSDimitry Andric class ComplexDeinterleavingLegacyPass : public FunctionPass {
114bdd1243dSDimitry Andric public:
115bdd1243dSDimitry Andric static char ID;
116bdd1243dSDimitry Andric
ComplexDeinterleavingLegacyPass(const TargetMachine * TM=nullptr)117bdd1243dSDimitry Andric ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
118bdd1243dSDimitry Andric : FunctionPass(ID), TM(TM) {
119bdd1243dSDimitry Andric initializeComplexDeinterleavingLegacyPassPass(
120bdd1243dSDimitry Andric *PassRegistry::getPassRegistry());
121bdd1243dSDimitry Andric }
122bdd1243dSDimitry Andric
getPassName() const123bdd1243dSDimitry Andric StringRef getPassName() const override {
124bdd1243dSDimitry Andric return "Complex Deinterleaving Pass";
125bdd1243dSDimitry Andric }
126bdd1243dSDimitry Andric
127bdd1243dSDimitry Andric bool runOnFunction(Function &F) override;
getAnalysisUsage(AnalysisUsage & AU) const128bdd1243dSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override {
129bdd1243dSDimitry Andric AU.addRequired<TargetLibraryInfoWrapperPass>();
130bdd1243dSDimitry Andric AU.setPreservesCFG();
131bdd1243dSDimitry Andric }
132bdd1243dSDimitry Andric
133bdd1243dSDimitry Andric private:
134bdd1243dSDimitry Andric const TargetMachine *TM;
135bdd1243dSDimitry Andric };
136bdd1243dSDimitry Andric
137bdd1243dSDimitry Andric class ComplexDeinterleavingGraph;
138bdd1243dSDimitry Andric struct ComplexDeinterleavingCompositeNode {
139bdd1243dSDimitry Andric
ComplexDeinterleavingCompositeNode__anon4b5acc0a0111::ComplexDeinterleavingCompositeNode140bdd1243dSDimitry Andric ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
141fe013be4SDimitry Andric Value *R, Value *I)
142bdd1243dSDimitry Andric : Operation(Op), Real(R), Imag(I) {}
143bdd1243dSDimitry Andric
144bdd1243dSDimitry Andric private:
145bdd1243dSDimitry Andric friend class ComplexDeinterleavingGraph;
146bdd1243dSDimitry Andric using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147bdd1243dSDimitry Andric using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148bdd1243dSDimitry Andric
149bdd1243dSDimitry Andric public:
150bdd1243dSDimitry Andric ComplexDeinterleavingOperation Operation;
151fe013be4SDimitry Andric Value *Real;
152fe013be4SDimitry Andric Value *Imag;
153bdd1243dSDimitry Andric
154fe013be4SDimitry Andric // This two members are required exclusively for generating
155fe013be4SDimitry Andric // ComplexDeinterleavingOperation::Symmetric operations.
156fe013be4SDimitry Andric unsigned Opcode;
157fe013be4SDimitry Andric std::optional<FastMathFlags> Flags;
158fe013be4SDimitry Andric
159fe013be4SDimitry Andric ComplexDeinterleavingRotation Rotation =
160fe013be4SDimitry Andric ComplexDeinterleavingRotation::Rotation_0;
161bdd1243dSDimitry Andric SmallVector<RawNodePtr> Operands;
162bdd1243dSDimitry Andric Value *ReplacementNode = nullptr;
163bdd1243dSDimitry Andric
addOperand__anon4b5acc0a0111::ComplexDeinterleavingCompositeNode164bdd1243dSDimitry Andric void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
165bdd1243dSDimitry Andric
dump__anon4b5acc0a0111::ComplexDeinterleavingCompositeNode166bdd1243dSDimitry Andric void dump() { dump(dbgs()); }
dump__anon4b5acc0a0111::ComplexDeinterleavingCompositeNode167bdd1243dSDimitry Andric void dump(raw_ostream &OS) {
168bdd1243dSDimitry Andric auto PrintValue = [&](Value *V) {
169bdd1243dSDimitry Andric if (V) {
170bdd1243dSDimitry Andric OS << "\"";
171bdd1243dSDimitry Andric V->print(OS, true);
172bdd1243dSDimitry Andric OS << "\"\n";
173bdd1243dSDimitry Andric } else
174bdd1243dSDimitry Andric OS << "nullptr\n";
175bdd1243dSDimitry Andric };
176bdd1243dSDimitry Andric auto PrintNodeRef = [&](RawNodePtr Ptr) {
177bdd1243dSDimitry Andric if (Ptr)
178bdd1243dSDimitry Andric OS << Ptr << "\n";
179bdd1243dSDimitry Andric else
180bdd1243dSDimitry Andric OS << "nullptr\n";
181bdd1243dSDimitry Andric };
182bdd1243dSDimitry Andric
183bdd1243dSDimitry Andric OS << "- CompositeNode: " << this << "\n";
184bdd1243dSDimitry Andric OS << " Real: ";
185bdd1243dSDimitry Andric PrintValue(Real);
186bdd1243dSDimitry Andric OS << " Imag: ";
187bdd1243dSDimitry Andric PrintValue(Imag);
188bdd1243dSDimitry Andric OS << " ReplacementNode: ";
189bdd1243dSDimitry Andric PrintValue(ReplacementNode);
190bdd1243dSDimitry Andric OS << " Operation: " << (int)Operation << "\n";
191bdd1243dSDimitry Andric OS << " Rotation: " << ((int)Rotation * 90) << "\n";
192bdd1243dSDimitry Andric OS << " Operands: \n";
193bdd1243dSDimitry Andric for (const auto &Op : Operands) {
194bdd1243dSDimitry Andric OS << " - ";
195bdd1243dSDimitry Andric PrintNodeRef(Op);
196bdd1243dSDimitry Andric }
197bdd1243dSDimitry Andric }
198bdd1243dSDimitry Andric };
199bdd1243dSDimitry Andric
200bdd1243dSDimitry Andric class ComplexDeinterleavingGraph {
201bdd1243dSDimitry Andric public:
202fe013be4SDimitry Andric struct Product {
203fe013be4SDimitry Andric Value *Multiplier;
204fe013be4SDimitry Andric Value *Multiplicand;
205fe013be4SDimitry Andric bool IsPositive;
206fe013be4SDimitry Andric };
207fe013be4SDimitry Andric
208fe013be4SDimitry Andric using Addend = std::pair<Value *, bool>;
209bdd1243dSDimitry Andric using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
210bdd1243dSDimitry Andric using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
211fe013be4SDimitry Andric
212fe013be4SDimitry Andric // Helper struct for holding info about potential partial multiplication
213fe013be4SDimitry Andric // candidates
214fe013be4SDimitry Andric struct PartialMulCandidate {
215fe013be4SDimitry Andric Value *Common;
216fe013be4SDimitry Andric NodePtr Node;
217fe013be4SDimitry Andric unsigned RealIdx;
218fe013be4SDimitry Andric unsigned ImagIdx;
219fe013be4SDimitry Andric bool IsNodeInverted;
220fe013be4SDimitry Andric };
221fe013be4SDimitry Andric
ComplexDeinterleavingGraph(const TargetLowering * TL,const TargetLibraryInfo * TLI)222fe013be4SDimitry Andric explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
223fe013be4SDimitry Andric const TargetLibraryInfo *TLI)
224fe013be4SDimitry Andric : TL(TL), TLI(TLI) {}
225bdd1243dSDimitry Andric
226bdd1243dSDimitry Andric private:
227fe013be4SDimitry Andric const TargetLowering *TL = nullptr;
228fe013be4SDimitry Andric const TargetLibraryInfo *TLI = nullptr;
229bdd1243dSDimitry Andric SmallVector<NodePtr> CompositeNodes;
230271697daSDimitry Andric DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
231fe013be4SDimitry Andric
232fe013be4SDimitry Andric SmallPtrSet<Instruction *, 16> FinalInstructions;
233fe013be4SDimitry Andric
234fe013be4SDimitry Andric /// Root instructions are instructions from which complex computation starts
235fe013be4SDimitry Andric std::map<Instruction *, NodePtr> RootToNode;
236fe013be4SDimitry Andric
237fe013be4SDimitry Andric /// Topologically sorted root instructions
238fe013be4SDimitry Andric SmallVector<Instruction *, 1> OrderedRoots;
239fe013be4SDimitry Andric
240fe013be4SDimitry Andric /// When examining a basic block for complex deinterleaving, if it is a simple
241fe013be4SDimitry Andric /// one-block loop, then the only incoming block is 'Incoming' and the
242fe013be4SDimitry Andric /// 'BackEdge' block is the block itself."
243fe013be4SDimitry Andric BasicBlock *BackEdge = nullptr;
244fe013be4SDimitry Andric BasicBlock *Incoming = nullptr;
245fe013be4SDimitry Andric
246fe013be4SDimitry Andric /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
247fe013be4SDimitry Andric /// %OutsideUser as it is shown in the IR:
248fe013be4SDimitry Andric ///
249fe013be4SDimitry Andric /// vector.body:
250fe013be4SDimitry Andric /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
251fe013be4SDimitry Andric /// [ %ReductionOp, %vector.body ]
252fe013be4SDimitry Andric /// ...
253fe013be4SDimitry Andric /// %ReductionOp = fadd i64 ...
254fe013be4SDimitry Andric /// ...
255fe013be4SDimitry Andric /// br i1 %condition, label %vector.body, %middle.block
256fe013be4SDimitry Andric ///
257fe013be4SDimitry Andric /// middle.block:
258fe013be4SDimitry Andric /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
259fe013be4SDimitry Andric ///
260fe013be4SDimitry Andric /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
261fe013be4SDimitry Andric /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
262*c9157d92SDimitry Andric MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
263fe013be4SDimitry Andric
264fe013be4SDimitry Andric /// In the process of detecting a reduction, we consider a pair of
265fe013be4SDimitry Andric /// %ReductionOP, which we refer to as real and imag (or vice versa), and
266fe013be4SDimitry Andric /// traverse the use-tree to detect complex operations. As this is a reduction
267fe013be4SDimitry Andric /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
268fe013be4SDimitry Andric /// to the %ReductionOPs that we suspect to be complex.
269fe013be4SDimitry Andric /// RealPHI and ImagPHI are used by the identifyPHINode method.
270fe013be4SDimitry Andric PHINode *RealPHI = nullptr;
271fe013be4SDimitry Andric PHINode *ImagPHI = nullptr;
272fe013be4SDimitry Andric
273fe013be4SDimitry Andric /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
274fe013be4SDimitry Andric /// detection.
275fe013be4SDimitry Andric bool PHIsFound = false;
276fe013be4SDimitry Andric
277fe013be4SDimitry Andric /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
278fe013be4SDimitry Andric /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
279fe013be4SDimitry Andric /// This mapping is populated during
280fe013be4SDimitry Andric /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
281fe013be4SDimitry Andric /// used in the ComplexDeinterleavingOperation::ReductionOperation node
282fe013be4SDimitry Andric /// replacement process.
283fe013be4SDimitry Andric std::map<PHINode *, PHINode *> OldToNewPHI;
284bdd1243dSDimitry Andric
prepareCompositeNode(ComplexDeinterleavingOperation Operation,Value * R,Value * I)285bdd1243dSDimitry Andric NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
286fe013be4SDimitry Andric Value *R, Value *I) {
287fe013be4SDimitry Andric assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
288fe013be4SDimitry Andric Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
289fe013be4SDimitry Andric (R && I)) &&
290fe013be4SDimitry Andric "Reduction related nodes must have Real and Imaginary parts");
291bdd1243dSDimitry Andric return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
292bdd1243dSDimitry Andric I);
293bdd1243dSDimitry Andric }
294bdd1243dSDimitry Andric
submitCompositeNode(NodePtr Node)295bdd1243dSDimitry Andric NodePtr submitCompositeNode(NodePtr Node) {
296bdd1243dSDimitry Andric CompositeNodes.push_back(Node);
297271697daSDimitry Andric if (Node->Real && Node->Imag)
298271697daSDimitry Andric CachedResult[{Node->Real, Node->Imag}] = Node;
299bdd1243dSDimitry Andric return Node;
300bdd1243dSDimitry Andric }
301bdd1243dSDimitry Andric
302bdd1243dSDimitry Andric /// Identifies a complex partial multiply pattern and its rotation, based on
303bdd1243dSDimitry Andric /// the following patterns
304bdd1243dSDimitry Andric ///
305bdd1243dSDimitry Andric /// 0: r: cr + ar * br
306bdd1243dSDimitry Andric /// i: ci + ar * bi
307bdd1243dSDimitry Andric /// 90: r: cr - ai * bi
308bdd1243dSDimitry Andric /// i: ci + ai * br
309bdd1243dSDimitry Andric /// 180: r: cr - ar * br
310bdd1243dSDimitry Andric /// i: ci - ar * bi
311bdd1243dSDimitry Andric /// 270: r: cr + ai * bi
312bdd1243dSDimitry Andric /// i: ci - ai * br
313bdd1243dSDimitry Andric NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
314bdd1243dSDimitry Andric
315bdd1243dSDimitry Andric /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
316bdd1243dSDimitry Andric /// is partially known from identifyPartialMul, filling in the other half of
317bdd1243dSDimitry Andric /// the complex pair.
318fe013be4SDimitry Andric NodePtr
319fe013be4SDimitry Andric identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
320fe013be4SDimitry Andric std::pair<Value *, Value *> &CommonOperandI);
321bdd1243dSDimitry Andric
322bdd1243dSDimitry Andric /// Identifies a complex add pattern and its rotation, based on the following
323bdd1243dSDimitry Andric /// patterns.
324bdd1243dSDimitry Andric ///
325bdd1243dSDimitry Andric /// 90: r: ar - bi
326bdd1243dSDimitry Andric /// i: ai + br
327bdd1243dSDimitry Andric /// 270: r: ar + bi
328bdd1243dSDimitry Andric /// i: ai - br
329bdd1243dSDimitry Andric NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
330fe013be4SDimitry Andric NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
331bdd1243dSDimitry Andric
332fe013be4SDimitry Andric NodePtr identifyNode(Value *R, Value *I);
333bdd1243dSDimitry Andric
334fe013be4SDimitry Andric /// Determine if a sum of complex numbers can be formed from \p RealAddends
335fe013be4SDimitry Andric /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
336fe013be4SDimitry Andric /// Return nullptr if it is not possible to construct a complex number.
337fe013be4SDimitry Andric /// \p Flags are needed to generate symmetric Add and Sub operations.
338fe013be4SDimitry Andric NodePtr identifyAdditions(std::list<Addend> &RealAddends,
339fe013be4SDimitry Andric std::list<Addend> &ImagAddends,
340fe013be4SDimitry Andric std::optional<FastMathFlags> Flags,
341fe013be4SDimitry Andric NodePtr Accumulator);
342fe013be4SDimitry Andric
343fe013be4SDimitry Andric /// Extract one addend that have both real and imaginary parts positive.
344fe013be4SDimitry Andric NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
345fe013be4SDimitry Andric std::list<Addend> &ImagAddends);
346fe013be4SDimitry Andric
347fe013be4SDimitry Andric /// Determine if sum of multiplications of complex numbers can be formed from
348fe013be4SDimitry Andric /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
349fe013be4SDimitry Andric /// to it. Return nullptr if it is not possible to construct a complex number.
350fe013be4SDimitry Andric NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
351fe013be4SDimitry Andric std::vector<Product> &ImagMuls,
352fe013be4SDimitry Andric NodePtr Accumulator);
353fe013be4SDimitry Andric
354fe013be4SDimitry Andric /// Go through pairs of multiplication (one Real and one Imag) and find all
355fe013be4SDimitry Andric /// possible candidates for partial multiplication and put them into \p
356fe013be4SDimitry Andric /// Candidates. Returns true if all Product has pair with common operand
357fe013be4SDimitry Andric bool collectPartialMuls(const std::vector<Product> &RealMuls,
358fe013be4SDimitry Andric const std::vector<Product> &ImagMuls,
359fe013be4SDimitry Andric std::vector<PartialMulCandidate> &Candidates);
360fe013be4SDimitry Andric
361fe013be4SDimitry Andric /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
362fe013be4SDimitry Andric /// the order of complex computation operations may be significantly altered,
363fe013be4SDimitry Andric /// and the real and imaginary parts may not be executed in parallel. This
364fe013be4SDimitry Andric /// function takes this into consideration and employs a more general approach
365fe013be4SDimitry Andric /// to identify complex computations. Initially, it gathers all the addends
366fe013be4SDimitry Andric /// and multiplicands and then constructs a complex expression from them.
367fe013be4SDimitry Andric NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
368fe013be4SDimitry Andric
369fe013be4SDimitry Andric NodePtr identifyRoot(Instruction *I);
370fe013be4SDimitry Andric
371fe013be4SDimitry Andric /// Identifies the Deinterleave operation applied to a vector containing
372fe013be4SDimitry Andric /// complex numbers. There are two ways to represent the Deinterleave
373fe013be4SDimitry Andric /// operation:
374fe013be4SDimitry Andric /// * Using two shufflevectors with even indices for /pReal instruction and
375fe013be4SDimitry Andric /// odd indices for /pImag instructions (only for fixed-width vectors)
376fe013be4SDimitry Andric /// * Using two extractvalue instructions applied to `vector.deinterleave2`
377fe013be4SDimitry Andric /// intrinsic (for both fixed and scalable vectors)
378fe013be4SDimitry Andric NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
379fe013be4SDimitry Andric
380fe013be4SDimitry Andric /// identifying the operation that represents a complex number repeated in a
381fe013be4SDimitry Andric /// Splat vector. There are two possible types of splats: ConstantExpr with
382fe013be4SDimitry Andric /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
383fe013be4SDimitry Andric /// initialization mask with all values set to zero.
384fe013be4SDimitry Andric NodePtr identifySplat(Value *Real, Value *Imag);
385fe013be4SDimitry Andric
386fe013be4SDimitry Andric NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
387fe013be4SDimitry Andric
388fe013be4SDimitry Andric /// Identifies SelectInsts in a loop that has reduction with predication masks
389fe013be4SDimitry Andric /// and/or predicated tail folding
390fe013be4SDimitry Andric NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
391fe013be4SDimitry Andric
392fe013be4SDimitry Andric Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
393fe013be4SDimitry Andric
394fe013be4SDimitry Andric /// Complete IR modifications after producing new reduction operation:
395fe013be4SDimitry Andric /// * Populate the PHINode generated for
396fe013be4SDimitry Andric /// ComplexDeinterleavingOperation::ReductionPHI
397fe013be4SDimitry Andric /// * Deinterleave the final value outside of the loop and repurpose original
398fe013be4SDimitry Andric /// reduction users
399fe013be4SDimitry Andric void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
400bdd1243dSDimitry Andric
401bdd1243dSDimitry Andric public:
dump()402bdd1243dSDimitry Andric void dump() { dump(dbgs()); }
dump(raw_ostream & OS)403bdd1243dSDimitry Andric void dump(raw_ostream &OS) {
404bdd1243dSDimitry Andric for (const auto &Node : CompositeNodes)
405bdd1243dSDimitry Andric Node->dump(OS);
406bdd1243dSDimitry Andric }
407bdd1243dSDimitry Andric
408bdd1243dSDimitry Andric /// Returns false if the deinterleaving operation should be cancelled for the
409bdd1243dSDimitry Andric /// current graph.
410bdd1243dSDimitry Andric bool identifyNodes(Instruction *RootI);
411bdd1243dSDimitry Andric
412fe013be4SDimitry Andric /// In case \pB is one-block loop, this function seeks potential reductions
413fe013be4SDimitry Andric /// and populates ReductionInfo. Returns true if any reductions were
414fe013be4SDimitry Andric /// identified.
415fe013be4SDimitry Andric bool collectPotentialReductions(BasicBlock *B);
416fe013be4SDimitry Andric
417fe013be4SDimitry Andric void identifyReductionNodes();
418fe013be4SDimitry Andric
419fe013be4SDimitry Andric /// Check that every instruction, from the roots to the leaves, has internal
420fe013be4SDimitry Andric /// uses.
421fe013be4SDimitry Andric bool checkNodes();
422fe013be4SDimitry Andric
423bdd1243dSDimitry Andric /// Perform the actual replacement of the underlying instruction graph.
424bdd1243dSDimitry Andric void replaceNodes();
425bdd1243dSDimitry Andric };
426bdd1243dSDimitry Andric
427bdd1243dSDimitry Andric class ComplexDeinterleaving {
428bdd1243dSDimitry Andric public:
ComplexDeinterleaving(const TargetLowering * tl,const TargetLibraryInfo * tli)429bdd1243dSDimitry Andric ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
430bdd1243dSDimitry Andric : TL(tl), TLI(tli) {}
431bdd1243dSDimitry Andric bool runOnFunction(Function &F);
432bdd1243dSDimitry Andric
433bdd1243dSDimitry Andric private:
434bdd1243dSDimitry Andric bool evaluateBasicBlock(BasicBlock *B);
435bdd1243dSDimitry Andric
436bdd1243dSDimitry Andric const TargetLowering *TL = nullptr;
437bdd1243dSDimitry Andric const TargetLibraryInfo *TLI = nullptr;
438bdd1243dSDimitry Andric };
439bdd1243dSDimitry Andric
440bdd1243dSDimitry Andric } // namespace
441bdd1243dSDimitry Andric
442bdd1243dSDimitry Andric char ComplexDeinterleavingLegacyPass::ID = 0;
443bdd1243dSDimitry Andric
444bdd1243dSDimitry Andric INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
445bdd1243dSDimitry Andric "Complex Deinterleaving", false, false)
446bdd1243dSDimitry Andric INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
447bdd1243dSDimitry Andric "Complex Deinterleaving", false, false)
448bdd1243dSDimitry Andric
run(Function & F,FunctionAnalysisManager & AM)449bdd1243dSDimitry Andric PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
450bdd1243dSDimitry Andric FunctionAnalysisManager &AM) {
451bdd1243dSDimitry Andric const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
452bdd1243dSDimitry Andric auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
453bdd1243dSDimitry Andric if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
454bdd1243dSDimitry Andric return PreservedAnalyses::all();
455bdd1243dSDimitry Andric
456bdd1243dSDimitry Andric PreservedAnalyses PA;
457bdd1243dSDimitry Andric PA.preserve<FunctionAnalysisManagerModuleProxy>();
458bdd1243dSDimitry Andric return PA;
459bdd1243dSDimitry Andric }
460bdd1243dSDimitry Andric
createComplexDeinterleavingPass(const TargetMachine * TM)461bdd1243dSDimitry Andric FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
462bdd1243dSDimitry Andric return new ComplexDeinterleavingLegacyPass(TM);
463bdd1243dSDimitry Andric }
464bdd1243dSDimitry Andric
runOnFunction(Function & F)465bdd1243dSDimitry Andric bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
466bdd1243dSDimitry Andric const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
467bdd1243dSDimitry Andric auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
468bdd1243dSDimitry Andric return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
469bdd1243dSDimitry Andric }
470bdd1243dSDimitry Andric
runOnFunction(Function & F)471bdd1243dSDimitry Andric bool ComplexDeinterleaving::runOnFunction(Function &F) {
472bdd1243dSDimitry Andric if (!ComplexDeinterleavingEnabled) {
473bdd1243dSDimitry Andric LLVM_DEBUG(
474bdd1243dSDimitry Andric dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
475bdd1243dSDimitry Andric return false;
476bdd1243dSDimitry Andric }
477bdd1243dSDimitry Andric
478bdd1243dSDimitry Andric if (!TL->isComplexDeinterleavingSupported()) {
479bdd1243dSDimitry Andric LLVM_DEBUG(
480bdd1243dSDimitry Andric dbgs() << "Complex deinterleaving has been disabled, target does "
481bdd1243dSDimitry Andric "not support lowering of complex number operations.\n");
482bdd1243dSDimitry Andric return false;
483bdd1243dSDimitry Andric }
484bdd1243dSDimitry Andric
485bdd1243dSDimitry Andric bool Changed = false;
486bdd1243dSDimitry Andric for (auto &B : F)
487bdd1243dSDimitry Andric Changed |= evaluateBasicBlock(&B);
488bdd1243dSDimitry Andric
489bdd1243dSDimitry Andric return Changed;
490bdd1243dSDimitry Andric }
491bdd1243dSDimitry Andric
isInterleavingMask(ArrayRef<int> Mask)492bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask) {
493bdd1243dSDimitry Andric // If the size is not even, it's not an interleaving mask
494bdd1243dSDimitry Andric if ((Mask.size() & 1))
495bdd1243dSDimitry Andric return false;
496bdd1243dSDimitry Andric
497bdd1243dSDimitry Andric int HalfNumElements = Mask.size() / 2;
498bdd1243dSDimitry Andric for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
499bdd1243dSDimitry Andric int MaskIdx = Idx * 2;
500bdd1243dSDimitry Andric if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
501bdd1243dSDimitry Andric return false;
502bdd1243dSDimitry Andric }
503bdd1243dSDimitry Andric
504bdd1243dSDimitry Andric return true;
505bdd1243dSDimitry Andric }
506bdd1243dSDimitry Andric
isDeinterleavingMask(ArrayRef<int> Mask)507bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask) {
508bdd1243dSDimitry Andric int Offset = Mask[0];
509bdd1243dSDimitry Andric int HalfNumElements = Mask.size() / 2;
510bdd1243dSDimitry Andric
511bdd1243dSDimitry Andric for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
512bdd1243dSDimitry Andric if (Mask[Idx] != (Idx * 2) + Offset)
513bdd1243dSDimitry Andric return false;
514bdd1243dSDimitry Andric }
515bdd1243dSDimitry Andric
516bdd1243dSDimitry Andric return true;
517bdd1243dSDimitry Andric }
518bdd1243dSDimitry Andric
isNeg(Value * V)519fe013be4SDimitry Andric bool isNeg(Value *V) {
520fe013be4SDimitry Andric return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
521fe013be4SDimitry Andric }
522fe013be4SDimitry Andric
getNegOperand(Value * V)523fe013be4SDimitry Andric Value *getNegOperand(Value *V) {
524fe013be4SDimitry Andric assert(isNeg(V));
525fe013be4SDimitry Andric auto *I = cast<Instruction>(V);
526fe013be4SDimitry Andric if (I->getOpcode() == Instruction::FNeg)
527fe013be4SDimitry Andric return I->getOperand(0);
528fe013be4SDimitry Andric
529fe013be4SDimitry Andric return I->getOperand(1);
530fe013be4SDimitry Andric }
531fe013be4SDimitry Andric
evaluateBasicBlock(BasicBlock * B)532bdd1243dSDimitry Andric bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
533fe013be4SDimitry Andric ComplexDeinterleavingGraph Graph(TL, TLI);
534fe013be4SDimitry Andric if (Graph.collectPotentialReductions(B))
535fe013be4SDimitry Andric Graph.identifyReductionNodes();
536bdd1243dSDimitry Andric
537fe013be4SDimitry Andric for (auto &I : *B)
538fe013be4SDimitry Andric Graph.identifyNodes(&I);
539bdd1243dSDimitry Andric
540fe013be4SDimitry Andric if (Graph.checkNodes()) {
541bdd1243dSDimitry Andric Graph.replaceNodes();
542fe013be4SDimitry Andric return true;
543bdd1243dSDimitry Andric }
544bdd1243dSDimitry Andric
545fe013be4SDimitry Andric return false;
546bdd1243dSDimitry Andric }
547bdd1243dSDimitry Andric
548bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyNodeWithImplicitAdd(Instruction * Real,Instruction * Imag,std::pair<Value *,Value * > & PartialMatch)549bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
550bdd1243dSDimitry Andric Instruction *Real, Instruction *Imag,
551fe013be4SDimitry Andric std::pair<Value *, Value *> &PartialMatch) {
552bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
553bdd1243dSDimitry Andric << "\n");
554bdd1243dSDimitry Andric
555bdd1243dSDimitry Andric if (!Real->hasOneUse() || !Imag->hasOneUse()) {
556bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
557bdd1243dSDimitry Andric return nullptr;
558bdd1243dSDimitry Andric }
559bdd1243dSDimitry Andric
560fe013be4SDimitry Andric if ((Real->getOpcode() != Instruction::FMul &&
561fe013be4SDimitry Andric Real->getOpcode() != Instruction::Mul) ||
562fe013be4SDimitry Andric (Imag->getOpcode() != Instruction::FMul &&
563fe013be4SDimitry Andric Imag->getOpcode() != Instruction::Mul)) {
564fe013be4SDimitry Andric LLVM_DEBUG(
565fe013be4SDimitry Andric dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
566bdd1243dSDimitry Andric return nullptr;
567bdd1243dSDimitry Andric }
568bdd1243dSDimitry Andric
569fe013be4SDimitry Andric Value *R0 = Real->getOperand(0);
570fe013be4SDimitry Andric Value *R1 = Real->getOperand(1);
571fe013be4SDimitry Andric Value *I0 = Imag->getOperand(0);
572fe013be4SDimitry Andric Value *I1 = Imag->getOperand(1);
573bdd1243dSDimitry Andric
574bdd1243dSDimitry Andric // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
575bdd1243dSDimitry Andric // rotations and use the operand.
576bdd1243dSDimitry Andric unsigned Negs = 0;
577fe013be4SDimitry Andric Value *Op;
578fe013be4SDimitry Andric if (match(R0, m_Neg(m_Value(Op)))) {
579bdd1243dSDimitry Andric Negs |= 1;
580fe013be4SDimitry Andric R0 = Op;
581fe013be4SDimitry Andric } else if (match(R1, m_Neg(m_Value(Op)))) {
582fe013be4SDimitry Andric Negs |= 1;
583fe013be4SDimitry Andric R1 = Op;
584bdd1243dSDimitry Andric }
585fe013be4SDimitry Andric
586fe013be4SDimitry Andric if (isNeg(I0)) {
587bdd1243dSDimitry Andric Negs |= 2;
588bdd1243dSDimitry Andric Negs ^= 1;
589fe013be4SDimitry Andric I0 = Op;
590fe013be4SDimitry Andric } else if (match(I1, m_Neg(m_Value(Op)))) {
591fe013be4SDimitry Andric Negs |= 2;
592fe013be4SDimitry Andric Negs ^= 1;
593fe013be4SDimitry Andric I1 = Op;
594bdd1243dSDimitry Andric }
595bdd1243dSDimitry Andric
596bdd1243dSDimitry Andric ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
597bdd1243dSDimitry Andric
598fe013be4SDimitry Andric Value *CommonOperand;
599fe013be4SDimitry Andric Value *UncommonRealOp;
600fe013be4SDimitry Andric Value *UncommonImagOp;
601bdd1243dSDimitry Andric
602bdd1243dSDimitry Andric if (R0 == I0 || R0 == I1) {
603bdd1243dSDimitry Andric CommonOperand = R0;
604bdd1243dSDimitry Andric UncommonRealOp = R1;
605bdd1243dSDimitry Andric } else if (R1 == I0 || R1 == I1) {
606bdd1243dSDimitry Andric CommonOperand = R1;
607bdd1243dSDimitry Andric UncommonRealOp = R0;
608bdd1243dSDimitry Andric } else {
609bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No equal operand\n");
610bdd1243dSDimitry Andric return nullptr;
611bdd1243dSDimitry Andric }
612bdd1243dSDimitry Andric
613bdd1243dSDimitry Andric UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
614bdd1243dSDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
615bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_270)
616bdd1243dSDimitry Andric std::swap(UncommonRealOp, UncommonImagOp);
617bdd1243dSDimitry Andric
618bdd1243dSDimitry Andric // Between identifyPartialMul and here we need to have found a complete valid
619bdd1243dSDimitry Andric // pair from the CommonOperand of each part.
620bdd1243dSDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
621bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_180)
622bdd1243dSDimitry Andric PartialMatch.first = CommonOperand;
623bdd1243dSDimitry Andric else
624bdd1243dSDimitry Andric PartialMatch.second = CommonOperand;
625bdd1243dSDimitry Andric
626bdd1243dSDimitry Andric if (!PartialMatch.first || !PartialMatch.second) {
627bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
628bdd1243dSDimitry Andric return nullptr;
629bdd1243dSDimitry Andric }
630bdd1243dSDimitry Andric
631bdd1243dSDimitry Andric NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
632bdd1243dSDimitry Andric if (!CommonNode) {
633bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
634bdd1243dSDimitry Andric return nullptr;
635bdd1243dSDimitry Andric }
636bdd1243dSDimitry Andric
637bdd1243dSDimitry Andric NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
638bdd1243dSDimitry Andric if (!UncommonNode) {
639bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
640bdd1243dSDimitry Andric return nullptr;
641bdd1243dSDimitry Andric }
642bdd1243dSDimitry Andric
643bdd1243dSDimitry Andric NodePtr Node = prepareCompositeNode(
644bdd1243dSDimitry Andric ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
645bdd1243dSDimitry Andric Node->Rotation = Rotation;
646bdd1243dSDimitry Andric Node->addOperand(CommonNode);
647bdd1243dSDimitry Andric Node->addOperand(UncommonNode);
648bdd1243dSDimitry Andric return submitCompositeNode(Node);
649bdd1243dSDimitry Andric }
650bdd1243dSDimitry Andric
651bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyPartialMul(Instruction * Real,Instruction * Imag)652bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
653bdd1243dSDimitry Andric Instruction *Imag) {
654bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
655bdd1243dSDimitry Andric << "\n");
656bdd1243dSDimitry Andric // Determine rotation
657fe013be4SDimitry Andric auto IsAdd = [](unsigned Op) {
658fe013be4SDimitry Andric return Op == Instruction::FAdd || Op == Instruction::Add;
659fe013be4SDimitry Andric };
660fe013be4SDimitry Andric auto IsSub = [](unsigned Op) {
661fe013be4SDimitry Andric return Op == Instruction::FSub || Op == Instruction::Sub;
662fe013be4SDimitry Andric };
663bdd1243dSDimitry Andric ComplexDeinterleavingRotation Rotation;
664fe013be4SDimitry Andric if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
665bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_0;
666fe013be4SDimitry Andric else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
667bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_90;
668fe013be4SDimitry Andric else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
669bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_180;
670fe013be4SDimitry Andric else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
671bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_270;
672bdd1243dSDimitry Andric else {
673bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
674bdd1243dSDimitry Andric return nullptr;
675bdd1243dSDimitry Andric }
676bdd1243dSDimitry Andric
677fe013be4SDimitry Andric if (isa<FPMathOperator>(Real) &&
678fe013be4SDimitry Andric (!Real->getFastMathFlags().allowContract() ||
679fe013be4SDimitry Andric !Imag->getFastMathFlags().allowContract())) {
680bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
681bdd1243dSDimitry Andric return nullptr;
682bdd1243dSDimitry Andric }
683bdd1243dSDimitry Andric
684bdd1243dSDimitry Andric Value *CR = Real->getOperand(0);
685bdd1243dSDimitry Andric Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
686bdd1243dSDimitry Andric if (!RealMulI)
687bdd1243dSDimitry Andric return nullptr;
688bdd1243dSDimitry Andric Value *CI = Imag->getOperand(0);
689bdd1243dSDimitry Andric Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
690bdd1243dSDimitry Andric if (!ImagMulI)
691bdd1243dSDimitry Andric return nullptr;
692bdd1243dSDimitry Andric
693bdd1243dSDimitry Andric if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
694bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
695bdd1243dSDimitry Andric return nullptr;
696bdd1243dSDimitry Andric }
697bdd1243dSDimitry Andric
698fe013be4SDimitry Andric Value *R0 = RealMulI->getOperand(0);
699fe013be4SDimitry Andric Value *R1 = RealMulI->getOperand(1);
700fe013be4SDimitry Andric Value *I0 = ImagMulI->getOperand(0);
701fe013be4SDimitry Andric Value *I1 = ImagMulI->getOperand(1);
702bdd1243dSDimitry Andric
703fe013be4SDimitry Andric Value *CommonOperand;
704fe013be4SDimitry Andric Value *UncommonRealOp;
705fe013be4SDimitry Andric Value *UncommonImagOp;
706bdd1243dSDimitry Andric
707bdd1243dSDimitry Andric if (R0 == I0 || R0 == I1) {
708bdd1243dSDimitry Andric CommonOperand = R0;
709bdd1243dSDimitry Andric UncommonRealOp = R1;
710bdd1243dSDimitry Andric } else if (R1 == I0 || R1 == I1) {
711bdd1243dSDimitry Andric CommonOperand = R1;
712bdd1243dSDimitry Andric UncommonRealOp = R0;
713bdd1243dSDimitry Andric } else {
714bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No equal operand\n");
715bdd1243dSDimitry Andric return nullptr;
716bdd1243dSDimitry Andric }
717bdd1243dSDimitry Andric
718bdd1243dSDimitry Andric UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
719bdd1243dSDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
720bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_270)
721bdd1243dSDimitry Andric std::swap(UncommonRealOp, UncommonImagOp);
722bdd1243dSDimitry Andric
723fe013be4SDimitry Andric std::pair<Value *, Value *> PartialMatch(
724bdd1243dSDimitry Andric (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
725bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_180)
726bdd1243dSDimitry Andric ? CommonOperand
727bdd1243dSDimitry Andric : nullptr,
728bdd1243dSDimitry Andric (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
729bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_270)
730bdd1243dSDimitry Andric ? CommonOperand
731bdd1243dSDimitry Andric : nullptr);
732fe013be4SDimitry Andric
733fe013be4SDimitry Andric auto *CRInst = dyn_cast<Instruction>(CR);
734fe013be4SDimitry Andric auto *CIInst = dyn_cast<Instruction>(CI);
735fe013be4SDimitry Andric
736fe013be4SDimitry Andric if (!CRInst || !CIInst) {
737fe013be4SDimitry Andric LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
738fe013be4SDimitry Andric return nullptr;
739fe013be4SDimitry Andric }
740fe013be4SDimitry Andric
741fe013be4SDimitry Andric NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
742bdd1243dSDimitry Andric if (!CNode) {
743bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No cnode identified\n");
744bdd1243dSDimitry Andric return nullptr;
745bdd1243dSDimitry Andric }
746bdd1243dSDimitry Andric
747bdd1243dSDimitry Andric NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
748bdd1243dSDimitry Andric if (!UncommonRes) {
749bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
750bdd1243dSDimitry Andric return nullptr;
751bdd1243dSDimitry Andric }
752bdd1243dSDimitry Andric
753bdd1243dSDimitry Andric assert(PartialMatch.first && PartialMatch.second);
754bdd1243dSDimitry Andric NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
755bdd1243dSDimitry Andric if (!CommonRes) {
756bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
757bdd1243dSDimitry Andric return nullptr;
758bdd1243dSDimitry Andric }
759bdd1243dSDimitry Andric
760bdd1243dSDimitry Andric NodePtr Node = prepareCompositeNode(
761bdd1243dSDimitry Andric ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
762bdd1243dSDimitry Andric Node->Rotation = Rotation;
763bdd1243dSDimitry Andric Node->addOperand(CommonRes);
764bdd1243dSDimitry Andric Node->addOperand(UncommonRes);
765bdd1243dSDimitry Andric Node->addOperand(CNode);
766bdd1243dSDimitry Andric return submitCompositeNode(Node);
767bdd1243dSDimitry Andric }
768bdd1243dSDimitry Andric
769bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyAdd(Instruction * Real,Instruction * Imag)770bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
771bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
772bdd1243dSDimitry Andric
773bdd1243dSDimitry Andric // Determine rotation
774bdd1243dSDimitry Andric ComplexDeinterleavingRotation Rotation;
775bdd1243dSDimitry Andric if ((Real->getOpcode() == Instruction::FSub &&
776bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::FAdd) ||
777bdd1243dSDimitry Andric (Real->getOpcode() == Instruction::Sub &&
778bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::Add))
779bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_90;
780bdd1243dSDimitry Andric else if ((Real->getOpcode() == Instruction::FAdd &&
781bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::FSub) ||
782bdd1243dSDimitry Andric (Real->getOpcode() == Instruction::Add &&
783bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::Sub))
784bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_270;
785bdd1243dSDimitry Andric else {
786bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
787bdd1243dSDimitry Andric return nullptr;
788bdd1243dSDimitry Andric }
789bdd1243dSDimitry Andric
790bdd1243dSDimitry Andric auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
791bdd1243dSDimitry Andric auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
792bdd1243dSDimitry Andric auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
793bdd1243dSDimitry Andric auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
794bdd1243dSDimitry Andric
795bdd1243dSDimitry Andric if (!AR || !AI || !BR || !BI) {
796bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
797bdd1243dSDimitry Andric return nullptr;
798bdd1243dSDimitry Andric }
799bdd1243dSDimitry Andric
800bdd1243dSDimitry Andric NodePtr ResA = identifyNode(AR, AI);
801bdd1243dSDimitry Andric if (!ResA) {
802bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
803bdd1243dSDimitry Andric return nullptr;
804bdd1243dSDimitry Andric }
805bdd1243dSDimitry Andric NodePtr ResB = identifyNode(BR, BI);
806bdd1243dSDimitry Andric if (!ResB) {
807bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
808bdd1243dSDimitry Andric return nullptr;
809bdd1243dSDimitry Andric }
810bdd1243dSDimitry Andric
811bdd1243dSDimitry Andric NodePtr Node =
812bdd1243dSDimitry Andric prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
813bdd1243dSDimitry Andric Node->Rotation = Rotation;
814bdd1243dSDimitry Andric Node->addOperand(ResA);
815bdd1243dSDimitry Andric Node->addOperand(ResB);
816bdd1243dSDimitry Andric return submitCompositeNode(Node);
817bdd1243dSDimitry Andric }
818bdd1243dSDimitry Andric
isInstructionPairAdd(Instruction * A,Instruction * B)819bdd1243dSDimitry Andric static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
820bdd1243dSDimitry Andric unsigned OpcA = A->getOpcode();
821bdd1243dSDimitry Andric unsigned OpcB = B->getOpcode();
822bdd1243dSDimitry Andric
823bdd1243dSDimitry Andric return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
824bdd1243dSDimitry Andric (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
825bdd1243dSDimitry Andric (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
826bdd1243dSDimitry Andric (OpcA == Instruction::Add && OpcB == Instruction::Sub);
827bdd1243dSDimitry Andric }
828bdd1243dSDimitry Andric
isInstructionPairMul(Instruction * A,Instruction * B)829bdd1243dSDimitry Andric static bool isInstructionPairMul(Instruction *A, Instruction *B) {
830bdd1243dSDimitry Andric auto Pattern =
831bdd1243dSDimitry Andric m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
832bdd1243dSDimitry Andric
833bdd1243dSDimitry Andric return match(A, Pattern) && match(B, Pattern);
834bdd1243dSDimitry Andric }
835bdd1243dSDimitry Andric
isInstructionPotentiallySymmetric(Instruction * I)836fe013be4SDimitry Andric static bool isInstructionPotentiallySymmetric(Instruction *I) {
837fe013be4SDimitry Andric switch (I->getOpcode()) {
838fe013be4SDimitry Andric case Instruction::FAdd:
839fe013be4SDimitry Andric case Instruction::FSub:
840fe013be4SDimitry Andric case Instruction::FMul:
841fe013be4SDimitry Andric case Instruction::FNeg:
842fe013be4SDimitry Andric case Instruction::Add:
843fe013be4SDimitry Andric case Instruction::Sub:
844fe013be4SDimitry Andric case Instruction::Mul:
845fe013be4SDimitry Andric return true;
846fe013be4SDimitry Andric default:
847fe013be4SDimitry Andric return false;
848fe013be4SDimitry Andric }
849fe013be4SDimitry Andric }
850fe013be4SDimitry Andric
851bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifySymmetricOperation(Instruction * Real,Instruction * Imag)852fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
853fe013be4SDimitry Andric Instruction *Imag) {
854fe013be4SDimitry Andric if (Real->getOpcode() != Imag->getOpcode())
855fe013be4SDimitry Andric return nullptr;
856fe013be4SDimitry Andric
857fe013be4SDimitry Andric if (!isInstructionPotentiallySymmetric(Real) ||
858fe013be4SDimitry Andric !isInstructionPotentiallySymmetric(Imag))
859fe013be4SDimitry Andric return nullptr;
860fe013be4SDimitry Andric
861fe013be4SDimitry Andric auto *R0 = Real->getOperand(0);
862fe013be4SDimitry Andric auto *I0 = Imag->getOperand(0);
863fe013be4SDimitry Andric
864fe013be4SDimitry Andric NodePtr Op0 = identifyNode(R0, I0);
865fe013be4SDimitry Andric NodePtr Op1 = nullptr;
866fe013be4SDimitry Andric if (Op0 == nullptr)
867fe013be4SDimitry Andric return nullptr;
868fe013be4SDimitry Andric
869fe013be4SDimitry Andric if (Real->isBinaryOp()) {
870fe013be4SDimitry Andric auto *R1 = Real->getOperand(1);
871fe013be4SDimitry Andric auto *I1 = Imag->getOperand(1);
872fe013be4SDimitry Andric Op1 = identifyNode(R1, I1);
873fe013be4SDimitry Andric if (Op1 == nullptr)
874fe013be4SDimitry Andric return nullptr;
875fe013be4SDimitry Andric }
876fe013be4SDimitry Andric
877fe013be4SDimitry Andric if (isa<FPMathOperator>(Real) &&
878fe013be4SDimitry Andric Real->getFastMathFlags() != Imag->getFastMathFlags())
879fe013be4SDimitry Andric return nullptr;
880fe013be4SDimitry Andric
881fe013be4SDimitry Andric auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
882fe013be4SDimitry Andric Real, Imag);
883fe013be4SDimitry Andric Node->Opcode = Real->getOpcode();
884fe013be4SDimitry Andric if (isa<FPMathOperator>(Real))
885fe013be4SDimitry Andric Node->Flags = Real->getFastMathFlags();
886fe013be4SDimitry Andric
887fe013be4SDimitry Andric Node->addOperand(Op0);
888fe013be4SDimitry Andric if (Real->isBinaryOp())
889fe013be4SDimitry Andric Node->addOperand(Op1);
890fe013be4SDimitry Andric
891fe013be4SDimitry Andric return submitCompositeNode(Node);
892fe013be4SDimitry Andric }
893fe013be4SDimitry Andric
894fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyNode(Value * R,Value * I)895fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
896fe013be4SDimitry Andric LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
897fe013be4SDimitry Andric assert(R->getType() == I->getType() &&
898fe013be4SDimitry Andric "Real and imaginary parts should not have different types");
899271697daSDimitry Andric
900271697daSDimitry Andric auto It = CachedResult.find({R, I});
901271697daSDimitry Andric if (It != CachedResult.end()) {
902bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
903271697daSDimitry Andric return It->second;
904bdd1243dSDimitry Andric }
905bdd1243dSDimitry Andric
906fe013be4SDimitry Andric if (NodePtr CN = identifySplat(R, I))
907fe013be4SDimitry Andric return CN;
908fe013be4SDimitry Andric
909fe013be4SDimitry Andric auto *Real = dyn_cast<Instruction>(R);
910fe013be4SDimitry Andric auto *Imag = dyn_cast<Instruction>(I);
911fe013be4SDimitry Andric if (!Real || !Imag)
912fe013be4SDimitry Andric return nullptr;
913fe013be4SDimitry Andric
914fe013be4SDimitry Andric if (NodePtr CN = identifyDeinterleave(Real, Imag))
915fe013be4SDimitry Andric return CN;
916fe013be4SDimitry Andric
917fe013be4SDimitry Andric if (NodePtr CN = identifyPHINode(Real, Imag))
918fe013be4SDimitry Andric return CN;
919fe013be4SDimitry Andric
920fe013be4SDimitry Andric if (NodePtr CN = identifySelectNode(Real, Imag))
921fe013be4SDimitry Andric return CN;
922fe013be4SDimitry Andric
923fe013be4SDimitry Andric auto *VTy = cast<VectorType>(Real->getType());
924fe013be4SDimitry Andric auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
925fe013be4SDimitry Andric
926fe013be4SDimitry Andric bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
927fe013be4SDimitry Andric ComplexDeinterleavingOperation::CMulPartial, NewVTy);
928fe013be4SDimitry Andric bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
929fe013be4SDimitry Andric ComplexDeinterleavingOperation::CAdd, NewVTy);
930fe013be4SDimitry Andric
931fe013be4SDimitry Andric if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
932fe013be4SDimitry Andric if (NodePtr CN = identifyPartialMul(Real, Imag))
933fe013be4SDimitry Andric return CN;
934fe013be4SDimitry Andric }
935fe013be4SDimitry Andric
936fe013be4SDimitry Andric if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
937fe013be4SDimitry Andric if (NodePtr CN = identifyAdd(Real, Imag))
938fe013be4SDimitry Andric return CN;
939fe013be4SDimitry Andric }
940fe013be4SDimitry Andric
941fe013be4SDimitry Andric if (HasCMulSupport && HasCAddSupport) {
942fe013be4SDimitry Andric if (NodePtr CN = identifyReassocNodes(Real, Imag))
943fe013be4SDimitry Andric return CN;
944fe013be4SDimitry Andric }
945fe013be4SDimitry Andric
946fe013be4SDimitry Andric if (NodePtr CN = identifySymmetricOperation(Real, Imag))
947fe013be4SDimitry Andric return CN;
948fe013be4SDimitry Andric
949fe013be4SDimitry Andric LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
950271697daSDimitry Andric CachedResult[{R, I}] = nullptr;
951fe013be4SDimitry Andric return nullptr;
952fe013be4SDimitry Andric }
953fe013be4SDimitry Andric
954fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyReassocNodes(Instruction * Real,Instruction * Imag)955fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
956fe013be4SDimitry Andric Instruction *Imag) {
957fe013be4SDimitry Andric auto IsOperationSupported = [](unsigned Opcode) -> bool {
958fe013be4SDimitry Andric return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
959fe013be4SDimitry Andric Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
960fe013be4SDimitry Andric Opcode == Instruction::Sub;
961fe013be4SDimitry Andric };
962fe013be4SDimitry Andric
963fe013be4SDimitry Andric if (!IsOperationSupported(Real->getOpcode()) ||
964fe013be4SDimitry Andric !IsOperationSupported(Imag->getOpcode()))
965fe013be4SDimitry Andric return nullptr;
966fe013be4SDimitry Andric
967fe013be4SDimitry Andric std::optional<FastMathFlags> Flags;
968fe013be4SDimitry Andric if (isa<FPMathOperator>(Real)) {
969fe013be4SDimitry Andric if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
970fe013be4SDimitry Andric LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
971fe013be4SDimitry Andric "not identical\n");
972fe013be4SDimitry Andric return nullptr;
973fe013be4SDimitry Andric }
974fe013be4SDimitry Andric
975fe013be4SDimitry Andric Flags = Real->getFastMathFlags();
976fe013be4SDimitry Andric if (!Flags->allowReassoc()) {
977fe013be4SDimitry Andric LLVM_DEBUG(
978fe013be4SDimitry Andric dbgs()
979fe013be4SDimitry Andric << "the 'Reassoc' attribute is missing in the FastMath flags\n");
980fe013be4SDimitry Andric return nullptr;
981fe013be4SDimitry Andric }
982fe013be4SDimitry Andric }
983fe013be4SDimitry Andric
984fe013be4SDimitry Andric // Collect multiplications and addend instructions from the given instruction
985fe013be4SDimitry Andric // while traversing it operands. Additionally, verify that all instructions
986fe013be4SDimitry Andric // have the same fast math flags.
987fe013be4SDimitry Andric auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
988fe013be4SDimitry Andric std::list<Addend> &Addends) -> bool {
989fe013be4SDimitry Andric SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
990fe013be4SDimitry Andric SmallPtrSet<Value *, 8> Visited;
991fe013be4SDimitry Andric while (!Worklist.empty()) {
992fe013be4SDimitry Andric auto [V, IsPositive] = Worklist.back();
993fe013be4SDimitry Andric Worklist.pop_back();
994fe013be4SDimitry Andric if (!Visited.insert(V).second)
995fe013be4SDimitry Andric continue;
996fe013be4SDimitry Andric
997fe013be4SDimitry Andric Instruction *I = dyn_cast<Instruction>(V);
998fe013be4SDimitry Andric if (!I) {
999fe013be4SDimitry Andric Addends.emplace_back(V, IsPositive);
1000fe013be4SDimitry Andric continue;
1001fe013be4SDimitry Andric }
1002fe013be4SDimitry Andric
1003fe013be4SDimitry Andric // If an instruction has more than one user, it indicates that it either
1004fe013be4SDimitry Andric // has an external user, which will be later checked by the checkNodes
1005fe013be4SDimitry Andric // function, or it is a subexpression utilized by multiple expressions. In
1006fe013be4SDimitry Andric // the latter case, we will attempt to separately identify the complex
1007fe013be4SDimitry Andric // operation from here in order to create a shared
1008fe013be4SDimitry Andric // ComplexDeinterleavingCompositeNode.
1009fe013be4SDimitry Andric if (I != Insn && I->getNumUses() > 1) {
1010fe013be4SDimitry Andric LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1011fe013be4SDimitry Andric Addends.emplace_back(I, IsPositive);
1012fe013be4SDimitry Andric continue;
1013fe013be4SDimitry Andric }
1014fe013be4SDimitry Andric switch (I->getOpcode()) {
1015fe013be4SDimitry Andric case Instruction::FAdd:
1016fe013be4SDimitry Andric case Instruction::Add:
1017fe013be4SDimitry Andric Worklist.emplace_back(I->getOperand(1), IsPositive);
1018fe013be4SDimitry Andric Worklist.emplace_back(I->getOperand(0), IsPositive);
1019fe013be4SDimitry Andric break;
1020fe013be4SDimitry Andric case Instruction::FSub:
1021fe013be4SDimitry Andric Worklist.emplace_back(I->getOperand(1), !IsPositive);
1022fe013be4SDimitry Andric Worklist.emplace_back(I->getOperand(0), IsPositive);
1023fe013be4SDimitry Andric break;
1024fe013be4SDimitry Andric case Instruction::Sub:
1025fe013be4SDimitry Andric if (isNeg(I)) {
1026fe013be4SDimitry Andric Worklist.emplace_back(getNegOperand(I), !IsPositive);
1027fe013be4SDimitry Andric } else {
1028fe013be4SDimitry Andric Worklist.emplace_back(I->getOperand(1), !IsPositive);
1029fe013be4SDimitry Andric Worklist.emplace_back(I->getOperand(0), IsPositive);
1030fe013be4SDimitry Andric }
1031fe013be4SDimitry Andric break;
1032fe013be4SDimitry Andric case Instruction::FMul:
1033fe013be4SDimitry Andric case Instruction::Mul: {
1034fe013be4SDimitry Andric Value *A, *B;
1035fe013be4SDimitry Andric if (isNeg(I->getOperand(0))) {
1036fe013be4SDimitry Andric A = getNegOperand(I->getOperand(0));
1037fe013be4SDimitry Andric IsPositive = !IsPositive;
1038fe013be4SDimitry Andric } else {
1039fe013be4SDimitry Andric A = I->getOperand(0);
1040fe013be4SDimitry Andric }
1041fe013be4SDimitry Andric
1042fe013be4SDimitry Andric if (isNeg(I->getOperand(1))) {
1043fe013be4SDimitry Andric B = getNegOperand(I->getOperand(1));
1044fe013be4SDimitry Andric IsPositive = !IsPositive;
1045fe013be4SDimitry Andric } else {
1046fe013be4SDimitry Andric B = I->getOperand(1);
1047fe013be4SDimitry Andric }
1048fe013be4SDimitry Andric Muls.push_back(Product{A, B, IsPositive});
1049fe013be4SDimitry Andric break;
1050fe013be4SDimitry Andric }
1051fe013be4SDimitry Andric case Instruction::FNeg:
1052fe013be4SDimitry Andric Worklist.emplace_back(I->getOperand(0), !IsPositive);
1053fe013be4SDimitry Andric break;
1054fe013be4SDimitry Andric default:
1055fe013be4SDimitry Andric Addends.emplace_back(I, IsPositive);
1056fe013be4SDimitry Andric continue;
1057fe013be4SDimitry Andric }
1058fe013be4SDimitry Andric
1059fe013be4SDimitry Andric if (Flags && I->getFastMathFlags() != *Flags) {
1060fe013be4SDimitry Andric LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1061fe013be4SDimitry Andric "inconsistent with the root instructions' flags: "
1062fe013be4SDimitry Andric << *I << "\n");
1063fe013be4SDimitry Andric return false;
1064fe013be4SDimitry Andric }
1065fe013be4SDimitry Andric }
1066fe013be4SDimitry Andric return true;
1067fe013be4SDimitry Andric };
1068fe013be4SDimitry Andric
1069fe013be4SDimitry Andric std::vector<Product> RealMuls, ImagMuls;
1070fe013be4SDimitry Andric std::list<Addend> RealAddends, ImagAddends;
1071fe013be4SDimitry Andric if (!Collect(Real, RealMuls, RealAddends) ||
1072fe013be4SDimitry Andric !Collect(Imag, ImagMuls, ImagAddends))
1073fe013be4SDimitry Andric return nullptr;
1074fe013be4SDimitry Andric
1075fe013be4SDimitry Andric if (RealAddends.size() != ImagAddends.size())
1076fe013be4SDimitry Andric return nullptr;
1077fe013be4SDimitry Andric
1078fe013be4SDimitry Andric NodePtr FinalNode;
1079fe013be4SDimitry Andric if (!RealMuls.empty() || !ImagMuls.empty()) {
1080fe013be4SDimitry Andric // If there are multiplicands, extract positive addend and use it as an
1081fe013be4SDimitry Andric // accumulator
1082fe013be4SDimitry Andric FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1083fe013be4SDimitry Andric FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1084fe013be4SDimitry Andric if (!FinalNode)
1085fe013be4SDimitry Andric return nullptr;
1086fe013be4SDimitry Andric }
1087fe013be4SDimitry Andric
1088fe013be4SDimitry Andric // Identify and process remaining additions
1089fe013be4SDimitry Andric if (!RealAddends.empty() || !ImagAddends.empty()) {
1090fe013be4SDimitry Andric FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1091fe013be4SDimitry Andric if (!FinalNode)
1092fe013be4SDimitry Andric return nullptr;
1093fe013be4SDimitry Andric }
1094fe013be4SDimitry Andric assert(FinalNode && "FinalNode can not be nullptr here");
1095fe013be4SDimitry Andric // Set the Real and Imag fields of the final node and submit it
1096fe013be4SDimitry Andric FinalNode->Real = Real;
1097fe013be4SDimitry Andric FinalNode->Imag = Imag;
1098fe013be4SDimitry Andric submitCompositeNode(FinalNode);
1099fe013be4SDimitry Andric return FinalNode;
1100fe013be4SDimitry Andric }
1101fe013be4SDimitry Andric
collectPartialMuls(const std::vector<Product> & RealMuls,const std::vector<Product> & ImagMuls,std::vector<PartialMulCandidate> & PartialMulCandidates)1102fe013be4SDimitry Andric bool ComplexDeinterleavingGraph::collectPartialMuls(
1103fe013be4SDimitry Andric const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1104fe013be4SDimitry Andric std::vector<PartialMulCandidate> &PartialMulCandidates) {
1105fe013be4SDimitry Andric // Helper function to extract a common operand from two products
1106fe013be4SDimitry Andric auto FindCommonInstruction = [](const Product &Real,
1107fe013be4SDimitry Andric const Product &Imag) -> Value * {
1108fe013be4SDimitry Andric if (Real.Multiplicand == Imag.Multiplicand ||
1109fe013be4SDimitry Andric Real.Multiplicand == Imag.Multiplier)
1110fe013be4SDimitry Andric return Real.Multiplicand;
1111fe013be4SDimitry Andric
1112fe013be4SDimitry Andric if (Real.Multiplier == Imag.Multiplicand ||
1113fe013be4SDimitry Andric Real.Multiplier == Imag.Multiplier)
1114fe013be4SDimitry Andric return Real.Multiplier;
1115fe013be4SDimitry Andric
1116fe013be4SDimitry Andric return nullptr;
1117fe013be4SDimitry Andric };
1118fe013be4SDimitry Andric
1119fe013be4SDimitry Andric // Iterating over real and imaginary multiplications to find common operands
1120fe013be4SDimitry Andric // If a common operand is found, a partial multiplication candidate is created
1121fe013be4SDimitry Andric // and added to the candidates vector The function returns false if no common
1122fe013be4SDimitry Andric // operands are found for any product
1123fe013be4SDimitry Andric for (unsigned i = 0; i < RealMuls.size(); ++i) {
1124fe013be4SDimitry Andric bool FoundCommon = false;
1125fe013be4SDimitry Andric for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1126fe013be4SDimitry Andric auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1127fe013be4SDimitry Andric if (!Common)
1128fe013be4SDimitry Andric continue;
1129fe013be4SDimitry Andric
1130fe013be4SDimitry Andric auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1131fe013be4SDimitry Andric : RealMuls[i].Multiplicand;
1132fe013be4SDimitry Andric auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1133fe013be4SDimitry Andric : ImagMuls[j].Multiplicand;
1134fe013be4SDimitry Andric
1135fe013be4SDimitry Andric auto Node = identifyNode(A, B);
1136fe013be4SDimitry Andric if (Node) {
1137fe013be4SDimitry Andric FoundCommon = true;
1138fe013be4SDimitry Andric PartialMulCandidates.push_back({Common, Node, i, j, false});
1139fe013be4SDimitry Andric }
1140fe013be4SDimitry Andric
1141fe013be4SDimitry Andric Node = identifyNode(B, A);
1142fe013be4SDimitry Andric if (Node) {
1143fe013be4SDimitry Andric FoundCommon = true;
1144fe013be4SDimitry Andric PartialMulCandidates.push_back({Common, Node, i, j, true});
1145fe013be4SDimitry Andric }
1146fe013be4SDimitry Andric }
1147fe013be4SDimitry Andric if (!FoundCommon)
1148fe013be4SDimitry Andric return false;
1149fe013be4SDimitry Andric }
1150fe013be4SDimitry Andric return true;
1151fe013be4SDimitry Andric }
1152fe013be4SDimitry Andric
1153fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyMultiplications(std::vector<Product> & RealMuls,std::vector<Product> & ImagMuls,NodePtr Accumulator=nullptr)1154fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyMultiplications(
1155fe013be4SDimitry Andric std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1156fe013be4SDimitry Andric NodePtr Accumulator = nullptr) {
1157fe013be4SDimitry Andric if (RealMuls.size() != ImagMuls.size())
1158fe013be4SDimitry Andric return nullptr;
1159fe013be4SDimitry Andric
1160fe013be4SDimitry Andric std::vector<PartialMulCandidate> Info;
1161fe013be4SDimitry Andric if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1162fe013be4SDimitry Andric return nullptr;
1163fe013be4SDimitry Andric
1164fe013be4SDimitry Andric // Map to store common instruction to node pointers
1165fe013be4SDimitry Andric std::map<Value *, NodePtr> CommonToNode;
1166fe013be4SDimitry Andric std::vector<bool> Processed(Info.size(), false);
1167fe013be4SDimitry Andric for (unsigned I = 0; I < Info.size(); ++I) {
1168fe013be4SDimitry Andric if (Processed[I])
1169fe013be4SDimitry Andric continue;
1170fe013be4SDimitry Andric
1171fe013be4SDimitry Andric PartialMulCandidate &InfoA = Info[I];
1172fe013be4SDimitry Andric for (unsigned J = I + 1; J < Info.size(); ++J) {
1173fe013be4SDimitry Andric if (Processed[J])
1174fe013be4SDimitry Andric continue;
1175fe013be4SDimitry Andric
1176fe013be4SDimitry Andric PartialMulCandidate &InfoB = Info[J];
1177fe013be4SDimitry Andric auto *InfoReal = &InfoA;
1178fe013be4SDimitry Andric auto *InfoImag = &InfoB;
1179fe013be4SDimitry Andric
1180fe013be4SDimitry Andric auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1181fe013be4SDimitry Andric if (!NodeFromCommon) {
1182fe013be4SDimitry Andric std::swap(InfoReal, InfoImag);
1183fe013be4SDimitry Andric NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1184fe013be4SDimitry Andric }
1185fe013be4SDimitry Andric if (!NodeFromCommon)
1186fe013be4SDimitry Andric continue;
1187fe013be4SDimitry Andric
1188fe013be4SDimitry Andric CommonToNode[InfoReal->Common] = NodeFromCommon;
1189fe013be4SDimitry Andric CommonToNode[InfoImag->Common] = NodeFromCommon;
1190fe013be4SDimitry Andric Processed[I] = true;
1191fe013be4SDimitry Andric Processed[J] = true;
1192fe013be4SDimitry Andric }
1193fe013be4SDimitry Andric }
1194fe013be4SDimitry Andric
1195fe013be4SDimitry Andric std::vector<bool> ProcessedReal(RealMuls.size(), false);
1196fe013be4SDimitry Andric std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1197fe013be4SDimitry Andric NodePtr Result = Accumulator;
1198fe013be4SDimitry Andric for (auto &PMI : Info) {
1199fe013be4SDimitry Andric if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1200fe013be4SDimitry Andric continue;
1201fe013be4SDimitry Andric
1202fe013be4SDimitry Andric auto It = CommonToNode.find(PMI.Common);
1203fe013be4SDimitry Andric // TODO: Process independent complex multiplications. Cases like this:
1204fe013be4SDimitry Andric // A.real() * B where both A and B are complex numbers.
1205fe013be4SDimitry Andric if (It == CommonToNode.end()) {
1206fe013be4SDimitry Andric LLVM_DEBUG({
1207fe013be4SDimitry Andric dbgs() << "Unprocessed independent partial multiplication:\n";
1208fe013be4SDimitry Andric for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1209fe013be4SDimitry Andric dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1210fe013be4SDimitry Andric << " multiplied by " << *Mul->Multiplicand << "\n";
1211fe013be4SDimitry Andric });
1212fe013be4SDimitry Andric return nullptr;
1213fe013be4SDimitry Andric }
1214fe013be4SDimitry Andric
1215fe013be4SDimitry Andric auto &RealMul = RealMuls[PMI.RealIdx];
1216fe013be4SDimitry Andric auto &ImagMul = ImagMuls[PMI.ImagIdx];
1217fe013be4SDimitry Andric
1218fe013be4SDimitry Andric auto NodeA = It->second;
1219fe013be4SDimitry Andric auto NodeB = PMI.Node;
1220fe013be4SDimitry Andric auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1221fe013be4SDimitry Andric // The following table illustrates the relationship between multiplications
1222fe013be4SDimitry Andric // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1223fe013be4SDimitry Andric // can see:
1224fe013be4SDimitry Andric //
1225fe013be4SDimitry Andric // Rotation | Real | Imag |
1226fe013be4SDimitry Andric // ---------+--------+--------+
1227fe013be4SDimitry Andric // 0 | x * u | x * v |
1228fe013be4SDimitry Andric // 90 | -y * v | y * u |
1229fe013be4SDimitry Andric // 180 | -x * u | -x * v |
1230fe013be4SDimitry Andric // 270 | y * v | -y * u |
1231fe013be4SDimitry Andric //
1232fe013be4SDimitry Andric // Check if the candidate can indeed be represented by partial
1233fe013be4SDimitry Andric // multiplication
1234fe013be4SDimitry Andric // TODO: Add support for multiplication by complex one
1235fe013be4SDimitry Andric if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1236fe013be4SDimitry Andric (!IsMultiplicandReal && !PMI.IsNodeInverted))
1237fe013be4SDimitry Andric continue;
1238fe013be4SDimitry Andric
1239fe013be4SDimitry Andric // Determine the rotation based on the multiplications
1240fe013be4SDimitry Andric ComplexDeinterleavingRotation Rotation;
1241fe013be4SDimitry Andric if (IsMultiplicandReal) {
1242fe013be4SDimitry Andric // Detect 0 and 180 degrees rotation
1243fe013be4SDimitry Andric if (RealMul.IsPositive && ImagMul.IsPositive)
1244fe013be4SDimitry Andric Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1245fe013be4SDimitry Andric else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1246fe013be4SDimitry Andric Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1247fe013be4SDimitry Andric else
1248fe013be4SDimitry Andric continue;
1249fe013be4SDimitry Andric
1250fe013be4SDimitry Andric } else {
1251fe013be4SDimitry Andric // Detect 90 and 270 degrees rotation
1252fe013be4SDimitry Andric if (!RealMul.IsPositive && ImagMul.IsPositive)
1253fe013be4SDimitry Andric Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1254fe013be4SDimitry Andric else if (RealMul.IsPositive && !ImagMul.IsPositive)
1255fe013be4SDimitry Andric Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1256fe013be4SDimitry Andric else
1257fe013be4SDimitry Andric continue;
1258fe013be4SDimitry Andric }
1259fe013be4SDimitry Andric
1260fe013be4SDimitry Andric LLVM_DEBUG({
1261fe013be4SDimitry Andric dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1262fe013be4SDimitry Andric dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1263fe013be4SDimitry Andric dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1264fe013be4SDimitry Andric dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1265fe013be4SDimitry Andric dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1266fe013be4SDimitry Andric dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1267fe013be4SDimitry Andric });
1268fe013be4SDimitry Andric
1269fe013be4SDimitry Andric NodePtr NodeMul = prepareCompositeNode(
1270fe013be4SDimitry Andric ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1271fe013be4SDimitry Andric NodeMul->Rotation = Rotation;
1272fe013be4SDimitry Andric NodeMul->addOperand(NodeA);
1273fe013be4SDimitry Andric NodeMul->addOperand(NodeB);
1274fe013be4SDimitry Andric if (Result)
1275fe013be4SDimitry Andric NodeMul->addOperand(Result);
1276fe013be4SDimitry Andric submitCompositeNode(NodeMul);
1277fe013be4SDimitry Andric Result = NodeMul;
1278fe013be4SDimitry Andric ProcessedReal[PMI.RealIdx] = true;
1279fe013be4SDimitry Andric ProcessedImag[PMI.ImagIdx] = true;
1280fe013be4SDimitry Andric }
1281fe013be4SDimitry Andric
1282fe013be4SDimitry Andric // Ensure all products have been processed, if not return nullptr.
1283fe013be4SDimitry Andric if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1284fe013be4SDimitry Andric !all_of(ProcessedImag, [](bool V) { return V; })) {
1285fe013be4SDimitry Andric
1286fe013be4SDimitry Andric // Dump debug information about which partial multiplications are not
1287fe013be4SDimitry Andric // processed.
1288fe013be4SDimitry Andric LLVM_DEBUG({
1289fe013be4SDimitry Andric dbgs() << "Unprocessed products (Real):\n";
1290fe013be4SDimitry Andric for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1291fe013be4SDimitry Andric if (!ProcessedReal[i])
1292fe013be4SDimitry Andric dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1293fe013be4SDimitry Andric << *RealMuls[i].Multiplier << " multiplied by "
1294fe013be4SDimitry Andric << *RealMuls[i].Multiplicand << "\n";
1295fe013be4SDimitry Andric }
1296fe013be4SDimitry Andric dbgs() << "Unprocessed products (Imag):\n";
1297fe013be4SDimitry Andric for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1298fe013be4SDimitry Andric if (!ProcessedImag[i])
1299fe013be4SDimitry Andric dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1300fe013be4SDimitry Andric << *ImagMuls[i].Multiplier << " multiplied by "
1301fe013be4SDimitry Andric << *ImagMuls[i].Multiplicand << "\n";
1302fe013be4SDimitry Andric }
1303fe013be4SDimitry Andric });
1304fe013be4SDimitry Andric return nullptr;
1305fe013be4SDimitry Andric }
1306fe013be4SDimitry Andric
1307fe013be4SDimitry Andric return Result;
1308fe013be4SDimitry Andric }
1309fe013be4SDimitry Andric
1310fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyAdditions(std::list<Addend> & RealAddends,std::list<Addend> & ImagAddends,std::optional<FastMathFlags> Flags,NodePtr Accumulator=nullptr)1311fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyAdditions(
1312fe013be4SDimitry Andric std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1313fe013be4SDimitry Andric std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1314fe013be4SDimitry Andric if (RealAddends.size() != ImagAddends.size())
1315fe013be4SDimitry Andric return nullptr;
1316fe013be4SDimitry Andric
1317fe013be4SDimitry Andric NodePtr Result;
1318fe013be4SDimitry Andric // If we have accumulator use it as first addend
1319fe013be4SDimitry Andric if (Accumulator)
1320fe013be4SDimitry Andric Result = Accumulator;
1321fe013be4SDimitry Andric // Otherwise find an element with both positive real and imaginary parts.
1322fe013be4SDimitry Andric else
1323fe013be4SDimitry Andric Result = extractPositiveAddend(RealAddends, ImagAddends);
1324fe013be4SDimitry Andric
1325fe013be4SDimitry Andric if (!Result)
1326fe013be4SDimitry Andric return nullptr;
1327fe013be4SDimitry Andric
1328fe013be4SDimitry Andric while (!RealAddends.empty()) {
1329fe013be4SDimitry Andric auto ItR = RealAddends.begin();
1330fe013be4SDimitry Andric auto [R, IsPositiveR] = *ItR;
1331fe013be4SDimitry Andric
1332fe013be4SDimitry Andric bool FoundImag = false;
1333fe013be4SDimitry Andric for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1334fe013be4SDimitry Andric auto [I, IsPositiveI] = *ItI;
1335fe013be4SDimitry Andric ComplexDeinterleavingRotation Rotation;
1336fe013be4SDimitry Andric if (IsPositiveR && IsPositiveI)
1337fe013be4SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_0;
1338fe013be4SDimitry Andric else if (!IsPositiveR && IsPositiveI)
1339fe013be4SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_90;
1340fe013be4SDimitry Andric else if (!IsPositiveR && !IsPositiveI)
1341fe013be4SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_180;
1342fe013be4SDimitry Andric else
1343fe013be4SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_270;
1344fe013be4SDimitry Andric
1345fe013be4SDimitry Andric NodePtr AddNode;
1346fe013be4SDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1347fe013be4SDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1348fe013be4SDimitry Andric AddNode = identifyNode(R, I);
1349fe013be4SDimitry Andric } else {
1350fe013be4SDimitry Andric AddNode = identifyNode(I, R);
1351fe013be4SDimitry Andric }
1352fe013be4SDimitry Andric if (AddNode) {
1353fe013be4SDimitry Andric LLVM_DEBUG({
1354fe013be4SDimitry Andric dbgs() << "Identified addition:\n";
1355fe013be4SDimitry Andric dbgs().indent(4) << "X: " << *R << "\n";
1356fe013be4SDimitry Andric dbgs().indent(4) << "Y: " << *I << "\n";
1357fe013be4SDimitry Andric dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1358fe013be4SDimitry Andric });
1359fe013be4SDimitry Andric
1360fe013be4SDimitry Andric NodePtr TmpNode;
1361fe013be4SDimitry Andric if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1362fe013be4SDimitry Andric TmpNode = prepareCompositeNode(
1363fe013be4SDimitry Andric ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1364fe013be4SDimitry Andric if (Flags) {
1365fe013be4SDimitry Andric TmpNode->Opcode = Instruction::FAdd;
1366fe013be4SDimitry Andric TmpNode->Flags = *Flags;
1367fe013be4SDimitry Andric } else {
1368fe013be4SDimitry Andric TmpNode->Opcode = Instruction::Add;
1369fe013be4SDimitry Andric }
1370fe013be4SDimitry Andric } else if (Rotation ==
1371fe013be4SDimitry Andric llvm::ComplexDeinterleavingRotation::Rotation_180) {
1372fe013be4SDimitry Andric TmpNode = prepareCompositeNode(
1373fe013be4SDimitry Andric ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1374fe013be4SDimitry Andric if (Flags) {
1375fe013be4SDimitry Andric TmpNode->Opcode = Instruction::FSub;
1376fe013be4SDimitry Andric TmpNode->Flags = *Flags;
1377fe013be4SDimitry Andric } else {
1378fe013be4SDimitry Andric TmpNode->Opcode = Instruction::Sub;
1379fe013be4SDimitry Andric }
1380fe013be4SDimitry Andric } else {
1381fe013be4SDimitry Andric TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1382fe013be4SDimitry Andric nullptr, nullptr);
1383fe013be4SDimitry Andric TmpNode->Rotation = Rotation;
1384fe013be4SDimitry Andric }
1385fe013be4SDimitry Andric
1386fe013be4SDimitry Andric TmpNode->addOperand(Result);
1387fe013be4SDimitry Andric TmpNode->addOperand(AddNode);
1388fe013be4SDimitry Andric submitCompositeNode(TmpNode);
1389fe013be4SDimitry Andric Result = TmpNode;
1390fe013be4SDimitry Andric RealAddends.erase(ItR);
1391fe013be4SDimitry Andric ImagAddends.erase(ItI);
1392fe013be4SDimitry Andric FoundImag = true;
1393fe013be4SDimitry Andric break;
1394fe013be4SDimitry Andric }
1395fe013be4SDimitry Andric }
1396fe013be4SDimitry Andric if (!FoundImag)
1397fe013be4SDimitry Andric return nullptr;
1398fe013be4SDimitry Andric }
1399fe013be4SDimitry Andric return Result;
1400fe013be4SDimitry Andric }
1401fe013be4SDimitry Andric
1402fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
extractPositiveAddend(std::list<Addend> & RealAddends,std::list<Addend> & ImagAddends)1403fe013be4SDimitry Andric ComplexDeinterleavingGraph::extractPositiveAddend(
1404fe013be4SDimitry Andric std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1405fe013be4SDimitry Andric for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1406fe013be4SDimitry Andric for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1407fe013be4SDimitry Andric auto [R, IsPositiveR] = *ItR;
1408fe013be4SDimitry Andric auto [I, IsPositiveI] = *ItI;
1409fe013be4SDimitry Andric if (IsPositiveR && IsPositiveI) {
1410fe013be4SDimitry Andric auto Result = identifyNode(R, I);
1411fe013be4SDimitry Andric if (Result) {
1412fe013be4SDimitry Andric RealAddends.erase(ItR);
1413fe013be4SDimitry Andric ImagAddends.erase(ItI);
1414fe013be4SDimitry Andric return Result;
1415fe013be4SDimitry Andric }
1416fe013be4SDimitry Andric }
1417fe013be4SDimitry Andric }
1418fe013be4SDimitry Andric }
1419fe013be4SDimitry Andric return nullptr;
1420fe013be4SDimitry Andric }
1421fe013be4SDimitry Andric
identifyNodes(Instruction * RootI)1422fe013be4SDimitry Andric bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1423fe013be4SDimitry Andric // This potential root instruction might already have been recognized as
1424fe013be4SDimitry Andric // reduction. Because RootToNode maps both Real and Imaginary parts to
1425fe013be4SDimitry Andric // CompositeNode we should choose only one either Real or Imag instruction to
1426fe013be4SDimitry Andric // use as an anchor for generating complex instruction.
1427fe013be4SDimitry Andric auto It = RootToNode.find(RootI);
1428271697daSDimitry Andric if (It != RootToNode.end()) {
1429271697daSDimitry Andric auto RootNode = It->second;
1430271697daSDimitry Andric assert(RootNode->Operation ==
1431271697daSDimitry Andric ComplexDeinterleavingOperation::ReductionOperation);
1432271697daSDimitry Andric // Find out which part, Real or Imag, comes later, and only if we come to
1433271697daSDimitry Andric // the latest part, add it to OrderedRoots.
1434271697daSDimitry Andric auto *R = cast<Instruction>(RootNode->Real);
1435271697daSDimitry Andric auto *I = cast<Instruction>(RootNode->Imag);
1436271697daSDimitry Andric auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1437271697daSDimitry Andric if (ReplacementAnchor != RootI)
1438271697daSDimitry Andric return false;
1439fe013be4SDimitry Andric OrderedRoots.push_back(RootI);
1440fe013be4SDimitry Andric return true;
1441fe013be4SDimitry Andric }
1442fe013be4SDimitry Andric
1443fe013be4SDimitry Andric auto RootNode = identifyRoot(RootI);
1444fe013be4SDimitry Andric if (!RootNode)
1445fe013be4SDimitry Andric return false;
1446fe013be4SDimitry Andric
1447fe013be4SDimitry Andric LLVM_DEBUG({
1448fe013be4SDimitry Andric Function *F = RootI->getFunction();
1449fe013be4SDimitry Andric BasicBlock *B = RootI->getParent();
1450fe013be4SDimitry Andric dbgs() << "Complex deinterleaving graph for " << F->getName()
1451fe013be4SDimitry Andric << "::" << B->getName() << ".\n";
1452fe013be4SDimitry Andric dump(dbgs());
1453fe013be4SDimitry Andric dbgs() << "\n";
1454fe013be4SDimitry Andric });
1455fe013be4SDimitry Andric RootToNode[RootI] = RootNode;
1456fe013be4SDimitry Andric OrderedRoots.push_back(RootI);
1457fe013be4SDimitry Andric return true;
1458fe013be4SDimitry Andric }
1459fe013be4SDimitry Andric
collectPotentialReductions(BasicBlock * B)1460fe013be4SDimitry Andric bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1461fe013be4SDimitry Andric bool FoundPotentialReduction = false;
1462fe013be4SDimitry Andric
1463fe013be4SDimitry Andric auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1464fe013be4SDimitry Andric if (!Br || Br->getNumSuccessors() != 2)
1465fe013be4SDimitry Andric return false;
1466fe013be4SDimitry Andric
1467fe013be4SDimitry Andric // Identify simple one-block loop
1468fe013be4SDimitry Andric if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1469fe013be4SDimitry Andric return false;
1470fe013be4SDimitry Andric
1471fe013be4SDimitry Andric SmallVector<PHINode *> PHIs;
1472fe013be4SDimitry Andric for (auto &PHI : B->phis()) {
1473fe013be4SDimitry Andric if (PHI.getNumIncomingValues() != 2)
1474fe013be4SDimitry Andric continue;
1475fe013be4SDimitry Andric
1476fe013be4SDimitry Andric if (!PHI.getType()->isVectorTy())
1477fe013be4SDimitry Andric continue;
1478fe013be4SDimitry Andric
1479fe013be4SDimitry Andric auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1480fe013be4SDimitry Andric if (!ReductionOp)
1481fe013be4SDimitry Andric continue;
1482fe013be4SDimitry Andric
1483fe013be4SDimitry Andric // Check if final instruction is reduced outside of current block
1484fe013be4SDimitry Andric Instruction *FinalReduction = nullptr;
1485fe013be4SDimitry Andric auto NumUsers = 0u;
1486fe013be4SDimitry Andric for (auto *U : ReductionOp->users()) {
1487fe013be4SDimitry Andric ++NumUsers;
1488fe013be4SDimitry Andric if (U == &PHI)
1489fe013be4SDimitry Andric continue;
1490fe013be4SDimitry Andric FinalReduction = dyn_cast<Instruction>(U);
1491fe013be4SDimitry Andric }
1492fe013be4SDimitry Andric
1493fe013be4SDimitry Andric if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1494fe013be4SDimitry Andric isa<PHINode>(FinalReduction))
1495fe013be4SDimitry Andric continue;
1496fe013be4SDimitry Andric
1497fe013be4SDimitry Andric ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1498fe013be4SDimitry Andric BackEdge = B;
1499fe013be4SDimitry Andric auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1500fe013be4SDimitry Andric auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1501fe013be4SDimitry Andric Incoming = PHI.getIncomingBlock(IncomingIdx);
1502fe013be4SDimitry Andric FoundPotentialReduction = true;
1503fe013be4SDimitry Andric
1504fe013be4SDimitry Andric // If the initial value of PHINode is an Instruction, consider it a leaf
1505fe013be4SDimitry Andric // value of a complex deinterleaving graph.
1506fe013be4SDimitry Andric if (auto *InitPHI =
1507fe013be4SDimitry Andric dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1508fe013be4SDimitry Andric FinalInstructions.insert(InitPHI);
1509fe013be4SDimitry Andric }
1510fe013be4SDimitry Andric return FoundPotentialReduction;
1511fe013be4SDimitry Andric }
1512fe013be4SDimitry Andric
identifyReductionNodes()1513fe013be4SDimitry Andric void ComplexDeinterleavingGraph::identifyReductionNodes() {
1514fe013be4SDimitry Andric SmallVector<bool> Processed(ReductionInfo.size(), false);
1515fe013be4SDimitry Andric SmallVector<Instruction *> OperationInstruction;
1516fe013be4SDimitry Andric for (auto &P : ReductionInfo)
1517fe013be4SDimitry Andric OperationInstruction.push_back(P.first);
1518fe013be4SDimitry Andric
1519fe013be4SDimitry Andric // Identify a complex computation by evaluating two reduction operations that
1520fe013be4SDimitry Andric // potentially could be involved
1521fe013be4SDimitry Andric for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1522fe013be4SDimitry Andric if (Processed[i])
1523fe013be4SDimitry Andric continue;
1524fe013be4SDimitry Andric for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1525fe013be4SDimitry Andric if (Processed[j])
1526fe013be4SDimitry Andric continue;
1527fe013be4SDimitry Andric
1528fe013be4SDimitry Andric auto *Real = OperationInstruction[i];
1529fe013be4SDimitry Andric auto *Imag = OperationInstruction[j];
1530fe013be4SDimitry Andric if (Real->getType() != Imag->getType())
1531fe013be4SDimitry Andric continue;
1532fe013be4SDimitry Andric
1533fe013be4SDimitry Andric RealPHI = ReductionInfo[Real].first;
1534fe013be4SDimitry Andric ImagPHI = ReductionInfo[Imag].first;
1535fe013be4SDimitry Andric PHIsFound = false;
1536fe013be4SDimitry Andric auto Node = identifyNode(Real, Imag);
1537fe013be4SDimitry Andric if (!Node) {
1538fe013be4SDimitry Andric std::swap(Real, Imag);
1539fe013be4SDimitry Andric std::swap(RealPHI, ImagPHI);
1540fe013be4SDimitry Andric Node = identifyNode(Real, Imag);
1541fe013be4SDimitry Andric }
1542fe013be4SDimitry Andric
1543fe013be4SDimitry Andric // If a node is identified and reduction PHINode is used in the chain of
1544fe013be4SDimitry Andric // operations, mark its operation instructions as used to prevent
1545fe013be4SDimitry Andric // re-identification and attach the node to the real part
1546fe013be4SDimitry Andric if (Node && PHIsFound) {
1547fe013be4SDimitry Andric LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1548fe013be4SDimitry Andric << *Real << " / " << *Imag << "\n");
1549fe013be4SDimitry Andric Processed[i] = true;
1550fe013be4SDimitry Andric Processed[j] = true;
1551fe013be4SDimitry Andric auto RootNode = prepareCompositeNode(
1552fe013be4SDimitry Andric ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1553fe013be4SDimitry Andric RootNode->addOperand(Node);
1554fe013be4SDimitry Andric RootToNode[Real] = RootNode;
1555fe013be4SDimitry Andric RootToNode[Imag] = RootNode;
1556fe013be4SDimitry Andric submitCompositeNode(RootNode);
1557fe013be4SDimitry Andric break;
1558fe013be4SDimitry Andric }
1559fe013be4SDimitry Andric }
1560fe013be4SDimitry Andric }
1561fe013be4SDimitry Andric
1562fe013be4SDimitry Andric RealPHI = nullptr;
1563fe013be4SDimitry Andric ImagPHI = nullptr;
1564fe013be4SDimitry Andric }
1565fe013be4SDimitry Andric
checkNodes()1566fe013be4SDimitry Andric bool ComplexDeinterleavingGraph::checkNodes() {
1567fe013be4SDimitry Andric // Collect all instructions from roots to leaves
1568fe013be4SDimitry Andric SmallPtrSet<Instruction *, 16> AllInstructions;
1569fe013be4SDimitry Andric SmallVector<Instruction *, 8> Worklist;
1570fe013be4SDimitry Andric for (auto &Pair : RootToNode)
1571fe013be4SDimitry Andric Worklist.push_back(Pair.first);
1572fe013be4SDimitry Andric
1573fe013be4SDimitry Andric // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1574fe013be4SDimitry Andric // chains
1575fe013be4SDimitry Andric while (!Worklist.empty()) {
1576fe013be4SDimitry Andric auto *I = Worklist.back();
1577fe013be4SDimitry Andric Worklist.pop_back();
1578fe013be4SDimitry Andric
1579fe013be4SDimitry Andric if (!AllInstructions.insert(I).second)
1580fe013be4SDimitry Andric continue;
1581fe013be4SDimitry Andric
1582fe013be4SDimitry Andric for (Value *Op : I->operands()) {
1583fe013be4SDimitry Andric if (auto *OpI = dyn_cast<Instruction>(Op)) {
1584fe013be4SDimitry Andric if (!FinalInstructions.count(I))
1585fe013be4SDimitry Andric Worklist.emplace_back(OpI);
1586fe013be4SDimitry Andric }
1587fe013be4SDimitry Andric }
1588fe013be4SDimitry Andric }
1589fe013be4SDimitry Andric
1590fe013be4SDimitry Andric // Find instructions that have users outside of chain
1591fe013be4SDimitry Andric SmallVector<Instruction *, 2> OuterInstructions;
1592fe013be4SDimitry Andric for (auto *I : AllInstructions) {
1593fe013be4SDimitry Andric // Skip root nodes
1594fe013be4SDimitry Andric if (RootToNode.count(I))
1595fe013be4SDimitry Andric continue;
1596fe013be4SDimitry Andric
1597fe013be4SDimitry Andric for (User *U : I->users()) {
1598fe013be4SDimitry Andric if (AllInstructions.count(cast<Instruction>(U)))
1599fe013be4SDimitry Andric continue;
1600fe013be4SDimitry Andric
1601fe013be4SDimitry Andric // Found an instruction that is not used by XCMLA/XCADD chain
1602fe013be4SDimitry Andric Worklist.emplace_back(I);
1603fe013be4SDimitry Andric break;
1604fe013be4SDimitry Andric }
1605fe013be4SDimitry Andric }
1606fe013be4SDimitry Andric
1607fe013be4SDimitry Andric // If any instructions are found to be used outside, find and remove roots
1608fe013be4SDimitry Andric // that somehow connect to those instructions.
1609fe013be4SDimitry Andric SmallPtrSet<Instruction *, 16> Visited;
1610fe013be4SDimitry Andric while (!Worklist.empty()) {
1611fe013be4SDimitry Andric auto *I = Worklist.back();
1612fe013be4SDimitry Andric Worklist.pop_back();
1613fe013be4SDimitry Andric if (!Visited.insert(I).second)
1614fe013be4SDimitry Andric continue;
1615fe013be4SDimitry Andric
1616fe013be4SDimitry Andric // Found an impacted root node. Removing it from the nodes to be
1617fe013be4SDimitry Andric // deinterleaved
1618fe013be4SDimitry Andric if (RootToNode.count(I)) {
1619fe013be4SDimitry Andric LLVM_DEBUG(dbgs() << "Instruction " << *I
1620fe013be4SDimitry Andric << " could be deinterleaved but its chain of complex "
1621fe013be4SDimitry Andric "operations have an outside user\n");
1622fe013be4SDimitry Andric RootToNode.erase(I);
1623fe013be4SDimitry Andric }
1624fe013be4SDimitry Andric
1625fe013be4SDimitry Andric if (!AllInstructions.count(I) || FinalInstructions.count(I))
1626fe013be4SDimitry Andric continue;
1627fe013be4SDimitry Andric
1628fe013be4SDimitry Andric for (User *U : I->users())
1629fe013be4SDimitry Andric Worklist.emplace_back(cast<Instruction>(U));
1630fe013be4SDimitry Andric
1631fe013be4SDimitry Andric for (Value *Op : I->operands()) {
1632fe013be4SDimitry Andric if (auto *OpI = dyn_cast<Instruction>(Op))
1633fe013be4SDimitry Andric Worklist.emplace_back(OpI);
1634fe013be4SDimitry Andric }
1635fe013be4SDimitry Andric }
1636fe013be4SDimitry Andric return !RootToNode.empty();
1637fe013be4SDimitry Andric }
1638fe013be4SDimitry Andric
1639fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyRoot(Instruction * RootI)1640fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1641fe013be4SDimitry Andric if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1642fe013be4SDimitry Andric if (Intrinsic->getIntrinsicID() !=
1643fe013be4SDimitry Andric Intrinsic::experimental_vector_interleave2)
1644fe013be4SDimitry Andric return nullptr;
1645fe013be4SDimitry Andric
1646fe013be4SDimitry Andric auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1647fe013be4SDimitry Andric auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1648fe013be4SDimitry Andric if (!Real || !Imag)
1649fe013be4SDimitry Andric return nullptr;
1650fe013be4SDimitry Andric
1651fe013be4SDimitry Andric return identifyNode(Real, Imag);
1652fe013be4SDimitry Andric }
1653fe013be4SDimitry Andric
1654fe013be4SDimitry Andric auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1655fe013be4SDimitry Andric if (!SVI)
1656fe013be4SDimitry Andric return nullptr;
1657fe013be4SDimitry Andric
1658fe013be4SDimitry Andric // Look for a shufflevector that takes separate vectors of the real and
1659fe013be4SDimitry Andric // imaginary components and recombines them into a single vector.
1660fe013be4SDimitry Andric if (!isInterleavingMask(SVI->getShuffleMask()))
1661fe013be4SDimitry Andric return nullptr;
1662fe013be4SDimitry Andric
1663fe013be4SDimitry Andric Instruction *Real;
1664fe013be4SDimitry Andric Instruction *Imag;
1665fe013be4SDimitry Andric if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1666fe013be4SDimitry Andric return nullptr;
1667fe013be4SDimitry Andric
1668fe013be4SDimitry Andric return identifyNode(Real, Imag);
1669fe013be4SDimitry Andric }
1670fe013be4SDimitry Andric
1671fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyDeinterleave(Instruction * Real,Instruction * Imag)1672fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1673fe013be4SDimitry Andric Instruction *Imag) {
1674fe013be4SDimitry Andric Instruction *I = nullptr;
1675fe013be4SDimitry Andric Value *FinalValue = nullptr;
1676fe013be4SDimitry Andric if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1677fe013be4SDimitry Andric match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1678fe013be4SDimitry Andric match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1679fe013be4SDimitry Andric m_Value(FinalValue)))) {
1680fe013be4SDimitry Andric NodePtr PlaceholderNode = prepareCompositeNode(
1681fe013be4SDimitry Andric llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1682fe013be4SDimitry Andric PlaceholderNode->ReplacementNode = FinalValue;
1683fe013be4SDimitry Andric FinalInstructions.insert(Real);
1684fe013be4SDimitry Andric FinalInstructions.insert(Imag);
1685fe013be4SDimitry Andric return submitCompositeNode(PlaceholderNode);
1686fe013be4SDimitry Andric }
1687fe013be4SDimitry Andric
1688bdd1243dSDimitry Andric auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1689bdd1243dSDimitry Andric auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1690fe013be4SDimitry Andric if (!RealShuffle || !ImagShuffle) {
1691fe013be4SDimitry Andric if (RealShuffle || ImagShuffle)
1692fe013be4SDimitry Andric LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1693fe013be4SDimitry Andric return nullptr;
1694fe013be4SDimitry Andric }
1695fe013be4SDimitry Andric
1696bdd1243dSDimitry Andric Value *RealOp1 = RealShuffle->getOperand(1);
1697bdd1243dSDimitry Andric if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1698bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1699bdd1243dSDimitry Andric return nullptr;
1700bdd1243dSDimitry Andric }
1701bdd1243dSDimitry Andric Value *ImagOp1 = ImagShuffle->getOperand(1);
1702bdd1243dSDimitry Andric if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1703bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1704bdd1243dSDimitry Andric return nullptr;
1705bdd1243dSDimitry Andric }
1706bdd1243dSDimitry Andric
1707bdd1243dSDimitry Andric Value *RealOp0 = RealShuffle->getOperand(0);
1708bdd1243dSDimitry Andric Value *ImagOp0 = ImagShuffle->getOperand(0);
1709bdd1243dSDimitry Andric
1710bdd1243dSDimitry Andric if (RealOp0 != ImagOp0) {
1711bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1712bdd1243dSDimitry Andric return nullptr;
1713bdd1243dSDimitry Andric }
1714bdd1243dSDimitry Andric
1715bdd1243dSDimitry Andric ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1716bdd1243dSDimitry Andric ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1717bdd1243dSDimitry Andric if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1718bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1719bdd1243dSDimitry Andric return nullptr;
1720bdd1243dSDimitry Andric }
1721bdd1243dSDimitry Andric
1722bdd1243dSDimitry Andric if (RealMask[0] != 0 || ImagMask[0] != 1) {
1723bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1724bdd1243dSDimitry Andric return nullptr;
1725bdd1243dSDimitry Andric }
1726bdd1243dSDimitry Andric
1727bdd1243dSDimitry Andric // Type checking, the shuffle type should be a vector type of the same
1728bdd1243dSDimitry Andric // scalar type, but half the size
1729bdd1243dSDimitry Andric auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1730bdd1243dSDimitry Andric Value *Op = Shuffle->getOperand(0);
1731bdd1243dSDimitry Andric auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1732bdd1243dSDimitry Andric auto *OpTy = cast<FixedVectorType>(Op->getType());
1733bdd1243dSDimitry Andric
1734bdd1243dSDimitry Andric if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1735bdd1243dSDimitry Andric return false;
1736bdd1243dSDimitry Andric if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1737bdd1243dSDimitry Andric return false;
1738bdd1243dSDimitry Andric
1739bdd1243dSDimitry Andric return true;
1740bdd1243dSDimitry Andric };
1741bdd1243dSDimitry Andric
1742bdd1243dSDimitry Andric auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1743bdd1243dSDimitry Andric if (!CheckType(Shuffle))
1744bdd1243dSDimitry Andric return false;
1745bdd1243dSDimitry Andric
1746bdd1243dSDimitry Andric ArrayRef<int> Mask = Shuffle->getShuffleMask();
1747bdd1243dSDimitry Andric int Last = *Mask.rbegin();
1748bdd1243dSDimitry Andric
1749bdd1243dSDimitry Andric Value *Op = Shuffle->getOperand(0);
1750bdd1243dSDimitry Andric auto *OpTy = cast<FixedVectorType>(Op->getType());
1751bdd1243dSDimitry Andric int NumElements = OpTy->getNumElements();
1752bdd1243dSDimitry Andric
1753bdd1243dSDimitry Andric // Ensure that the deinterleaving shuffle only pulls from the first
1754bdd1243dSDimitry Andric // shuffle operand.
1755bdd1243dSDimitry Andric return Last < NumElements;
1756bdd1243dSDimitry Andric };
1757bdd1243dSDimitry Andric
1758bdd1243dSDimitry Andric if (RealShuffle->getType() != ImagShuffle->getType()) {
1759bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1760bdd1243dSDimitry Andric return nullptr;
1761bdd1243dSDimitry Andric }
1762bdd1243dSDimitry Andric if (!CheckDeinterleavingShuffle(RealShuffle)) {
1763bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1764bdd1243dSDimitry Andric return nullptr;
1765bdd1243dSDimitry Andric }
1766bdd1243dSDimitry Andric if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1767bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1768bdd1243dSDimitry Andric return nullptr;
1769bdd1243dSDimitry Andric }
1770bdd1243dSDimitry Andric
1771bdd1243dSDimitry Andric NodePtr PlaceholderNode =
1772fe013be4SDimitry Andric prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1773bdd1243dSDimitry Andric RealShuffle, ImagShuffle);
1774bdd1243dSDimitry Andric PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1775fe013be4SDimitry Andric FinalInstructions.insert(RealShuffle);
1776fe013be4SDimitry Andric FinalInstructions.insert(ImagShuffle);
1777bdd1243dSDimitry Andric return submitCompositeNode(PlaceholderNode);
1778bdd1243dSDimitry Andric }
1779bdd1243dSDimitry Andric
1780fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifySplat(Value * R,Value * I)1781fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1782fe013be4SDimitry Andric auto IsSplat = [](Value *V) -> bool {
1783fe013be4SDimitry Andric // Fixed-width vector with constants
1784fe013be4SDimitry Andric if (isa<ConstantDataVector>(V))
1785fe013be4SDimitry Andric return true;
1786bdd1243dSDimitry Andric
1787fe013be4SDimitry Andric VectorType *VTy;
1788fe013be4SDimitry Andric ArrayRef<int> Mask;
1789fe013be4SDimitry Andric // Splats are represented differently depending on whether the repeated
1790fe013be4SDimitry Andric // value is a constant or an Instruction
1791fe013be4SDimitry Andric if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1792fe013be4SDimitry Andric if (Const->getOpcode() != Instruction::ShuffleVector)
1793bdd1243dSDimitry Andric return false;
1794fe013be4SDimitry Andric VTy = cast<VectorType>(Const->getType());
1795fe013be4SDimitry Andric Mask = Const->getShuffleMask();
1796fe013be4SDimitry Andric } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1797fe013be4SDimitry Andric VTy = Shuf->getType();
1798fe013be4SDimitry Andric Mask = Shuf->getShuffleMask();
1799fe013be4SDimitry Andric } else {
1800bdd1243dSDimitry Andric return false;
1801bdd1243dSDimitry Andric }
1802fe013be4SDimitry Andric
1803fe013be4SDimitry Andric // When the data type is <1 x Type>, it's not possible to differentiate
1804fe013be4SDimitry Andric // between the ComplexDeinterleaving::Deinterleave and
1805fe013be4SDimitry Andric // ComplexDeinterleaving::Splat operations.
1806fe013be4SDimitry Andric if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1807fe013be4SDimitry Andric return false;
1808fe013be4SDimitry Andric
1809fe013be4SDimitry Andric return all_equal(Mask) && Mask[0] == 0;
1810fe013be4SDimitry Andric };
1811fe013be4SDimitry Andric
1812fe013be4SDimitry Andric if (!IsSplat(R) || !IsSplat(I))
1813fe013be4SDimitry Andric return nullptr;
1814fe013be4SDimitry Andric
1815fe013be4SDimitry Andric auto *Real = dyn_cast<Instruction>(R);
1816fe013be4SDimitry Andric auto *Imag = dyn_cast<Instruction>(I);
1817fe013be4SDimitry Andric if ((!Real && Imag) || (Real && !Imag))
1818fe013be4SDimitry Andric return nullptr;
1819fe013be4SDimitry Andric
1820fe013be4SDimitry Andric if (Real && Imag) {
1821fe013be4SDimitry Andric // Non-constant splats should be in the same basic block
1822fe013be4SDimitry Andric if (Real->getParent() != Imag->getParent())
1823fe013be4SDimitry Andric return nullptr;
1824fe013be4SDimitry Andric
1825fe013be4SDimitry Andric FinalInstructions.insert(Real);
1826fe013be4SDimitry Andric FinalInstructions.insert(Imag);
1827bdd1243dSDimitry Andric }
1828fe013be4SDimitry Andric NodePtr PlaceholderNode =
1829fe013be4SDimitry Andric prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1830fe013be4SDimitry Andric return submitCompositeNode(PlaceholderNode);
1831bdd1243dSDimitry Andric }
1832bdd1243dSDimitry Andric
1833fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyPHINode(Instruction * Real,Instruction * Imag)1834fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1835fe013be4SDimitry Andric Instruction *Imag) {
1836fe013be4SDimitry Andric if (Real != RealPHI || Imag != ImagPHI)
1837fe013be4SDimitry Andric return nullptr;
1838fe013be4SDimitry Andric
1839fe013be4SDimitry Andric PHIsFound = true;
1840fe013be4SDimitry Andric NodePtr PlaceholderNode = prepareCompositeNode(
1841fe013be4SDimitry Andric ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1842fe013be4SDimitry Andric return submitCompositeNode(PlaceholderNode);
1843fe013be4SDimitry Andric }
1844fe013be4SDimitry Andric
1845fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifySelectNode(Instruction * Real,Instruction * Imag)1846fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1847fe013be4SDimitry Andric Instruction *Imag) {
1848fe013be4SDimitry Andric auto *SelectReal = dyn_cast<SelectInst>(Real);
1849fe013be4SDimitry Andric auto *SelectImag = dyn_cast<SelectInst>(Imag);
1850fe013be4SDimitry Andric if (!SelectReal || !SelectImag)
1851fe013be4SDimitry Andric return nullptr;
1852fe013be4SDimitry Andric
1853fe013be4SDimitry Andric Instruction *MaskA, *MaskB;
1854fe013be4SDimitry Andric Instruction *AR, *AI, *RA, *BI;
1855fe013be4SDimitry Andric if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1856fe013be4SDimitry Andric m_Instruction(RA))) ||
1857fe013be4SDimitry Andric !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1858fe013be4SDimitry Andric m_Instruction(BI))))
1859fe013be4SDimitry Andric return nullptr;
1860fe013be4SDimitry Andric
1861fe013be4SDimitry Andric if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1862fe013be4SDimitry Andric return nullptr;
1863fe013be4SDimitry Andric
1864fe013be4SDimitry Andric if (!MaskA->getType()->isVectorTy())
1865fe013be4SDimitry Andric return nullptr;
1866fe013be4SDimitry Andric
1867fe013be4SDimitry Andric auto NodeA = identifyNode(AR, AI);
1868fe013be4SDimitry Andric if (!NodeA)
1869fe013be4SDimitry Andric return nullptr;
1870fe013be4SDimitry Andric
1871fe013be4SDimitry Andric auto NodeB = identifyNode(RA, BI);
1872fe013be4SDimitry Andric if (!NodeB)
1873fe013be4SDimitry Andric return nullptr;
1874fe013be4SDimitry Andric
1875fe013be4SDimitry Andric NodePtr PlaceholderNode = prepareCompositeNode(
1876fe013be4SDimitry Andric ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1877fe013be4SDimitry Andric PlaceholderNode->addOperand(NodeA);
1878fe013be4SDimitry Andric PlaceholderNode->addOperand(NodeB);
1879fe013be4SDimitry Andric FinalInstructions.insert(MaskA);
1880fe013be4SDimitry Andric FinalInstructions.insert(MaskB);
1881fe013be4SDimitry Andric return submitCompositeNode(PlaceholderNode);
1882fe013be4SDimitry Andric }
1883fe013be4SDimitry Andric
replaceSymmetricNode(IRBuilderBase & B,unsigned Opcode,std::optional<FastMathFlags> Flags,Value * InputA,Value * InputB)1884fe013be4SDimitry Andric static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1885fe013be4SDimitry Andric std::optional<FastMathFlags> Flags,
1886fe013be4SDimitry Andric Value *InputA, Value *InputB) {
1887fe013be4SDimitry Andric Value *I;
1888fe013be4SDimitry Andric switch (Opcode) {
1889fe013be4SDimitry Andric case Instruction::FNeg:
1890fe013be4SDimitry Andric I = B.CreateFNeg(InputA);
1891fe013be4SDimitry Andric break;
1892fe013be4SDimitry Andric case Instruction::FAdd:
1893fe013be4SDimitry Andric I = B.CreateFAdd(InputA, InputB);
1894fe013be4SDimitry Andric break;
1895fe013be4SDimitry Andric case Instruction::Add:
1896fe013be4SDimitry Andric I = B.CreateAdd(InputA, InputB);
1897fe013be4SDimitry Andric break;
1898fe013be4SDimitry Andric case Instruction::FSub:
1899fe013be4SDimitry Andric I = B.CreateFSub(InputA, InputB);
1900fe013be4SDimitry Andric break;
1901fe013be4SDimitry Andric case Instruction::Sub:
1902fe013be4SDimitry Andric I = B.CreateSub(InputA, InputB);
1903fe013be4SDimitry Andric break;
1904fe013be4SDimitry Andric case Instruction::FMul:
1905fe013be4SDimitry Andric I = B.CreateFMul(InputA, InputB);
1906fe013be4SDimitry Andric break;
1907fe013be4SDimitry Andric case Instruction::Mul:
1908fe013be4SDimitry Andric I = B.CreateMul(InputA, InputB);
1909fe013be4SDimitry Andric break;
1910fe013be4SDimitry Andric default:
1911fe013be4SDimitry Andric llvm_unreachable("Incorrect symmetric opcode");
1912fe013be4SDimitry Andric }
1913fe013be4SDimitry Andric if (Flags)
1914fe013be4SDimitry Andric cast<Instruction>(I)->setFastMathFlags(*Flags);
1915fe013be4SDimitry Andric return I;
1916fe013be4SDimitry Andric }
1917fe013be4SDimitry Andric
replaceNode(IRBuilderBase & Builder,RawNodePtr Node)1918fe013be4SDimitry Andric Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1919fe013be4SDimitry Andric RawNodePtr Node) {
1920bdd1243dSDimitry Andric if (Node->ReplacementNode)
1921bdd1243dSDimitry Andric return Node->ReplacementNode;
1922bdd1243dSDimitry Andric
1923fe013be4SDimitry Andric auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1924fe013be4SDimitry Andric return Node->Operands.size() > Idx
1925fe013be4SDimitry Andric ? replaceNode(Builder, Node->Operands[Idx])
1926fe013be4SDimitry Andric : nullptr;
1927fe013be4SDimitry Andric };
1928bdd1243dSDimitry Andric
1929fe013be4SDimitry Andric Value *ReplacementNode;
1930fe013be4SDimitry Andric switch (Node->Operation) {
1931fe013be4SDimitry Andric case ComplexDeinterleavingOperation::CAdd:
1932fe013be4SDimitry Andric case ComplexDeinterleavingOperation::CMulPartial:
1933fe013be4SDimitry Andric case ComplexDeinterleavingOperation::Symmetric: {
1934fe013be4SDimitry Andric Value *Input0 = ReplaceOperandIfExist(Node, 0);
1935fe013be4SDimitry Andric Value *Input1 = ReplaceOperandIfExist(Node, 1);
1936fe013be4SDimitry Andric Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1937fe013be4SDimitry Andric assert(!Input1 || (Input0->getType() == Input1->getType() &&
1938fe013be4SDimitry Andric "Node inputs need to be of the same type"));
1939fe013be4SDimitry Andric assert(!Accumulator ||
1940fe013be4SDimitry Andric (Input0->getType() == Accumulator->getType() &&
1941fe013be4SDimitry Andric "Accumulator and input need to be of the same type"));
1942fe013be4SDimitry Andric if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1943fe013be4SDimitry Andric ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1944fe013be4SDimitry Andric Input0, Input1);
1945fe013be4SDimitry Andric else
1946fe013be4SDimitry Andric ReplacementNode = TL->createComplexDeinterleavingIR(
1947fe013be4SDimitry Andric Builder, Node->Operation, Node->Rotation, Input0, Input1,
1948fe013be4SDimitry Andric Accumulator);
1949fe013be4SDimitry Andric break;
1950fe013be4SDimitry Andric }
1951fe013be4SDimitry Andric case ComplexDeinterleavingOperation::Deinterleave:
1952fe013be4SDimitry Andric llvm_unreachable("Deinterleave node should already have ReplacementNode");
1953fe013be4SDimitry Andric break;
1954fe013be4SDimitry Andric case ComplexDeinterleavingOperation::Splat: {
1955fe013be4SDimitry Andric auto *NewTy = VectorType::getDoubleElementsVectorType(
1956fe013be4SDimitry Andric cast<VectorType>(Node->Real->getType()));
1957fe013be4SDimitry Andric auto *R = dyn_cast<Instruction>(Node->Real);
1958fe013be4SDimitry Andric auto *I = dyn_cast<Instruction>(Node->Imag);
1959fe013be4SDimitry Andric if (R && I) {
1960fe013be4SDimitry Andric // Splats that are not constant are interleaved where they are located
1961fe013be4SDimitry Andric Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1962fe013be4SDimitry Andric IRBuilder<> IRB(InsertPoint);
1963fe013be4SDimitry Andric ReplacementNode =
1964fe013be4SDimitry Andric IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy,
1965fe013be4SDimitry Andric {Node->Real, Node->Imag});
1966fe013be4SDimitry Andric } else {
1967fe013be4SDimitry Andric ReplacementNode =
1968fe013be4SDimitry Andric Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1969fe013be4SDimitry Andric NewTy, {Node->Real, Node->Imag});
1970fe013be4SDimitry Andric }
1971fe013be4SDimitry Andric break;
1972fe013be4SDimitry Andric }
1973fe013be4SDimitry Andric case ComplexDeinterleavingOperation::ReductionPHI: {
1974fe013be4SDimitry Andric // If Operation is ReductionPHI, a new empty PHINode is created.
1975fe013be4SDimitry Andric // It is filled later when the ReductionOperation is processed.
1976fe013be4SDimitry Andric auto *VTy = cast<VectorType>(Node->Real->getType());
1977fe013be4SDimitry Andric auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1978fe013be4SDimitry Andric auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
1979fe013be4SDimitry Andric OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1980fe013be4SDimitry Andric ReplacementNode = NewPHI;
1981fe013be4SDimitry Andric break;
1982fe013be4SDimitry Andric }
1983fe013be4SDimitry Andric case ComplexDeinterleavingOperation::ReductionOperation:
1984fe013be4SDimitry Andric ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1985fe013be4SDimitry Andric processReductionOperation(ReplacementNode, Node);
1986fe013be4SDimitry Andric break;
1987fe013be4SDimitry Andric case ComplexDeinterleavingOperation::ReductionSelect: {
1988fe013be4SDimitry Andric auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1989fe013be4SDimitry Andric auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1990fe013be4SDimitry Andric auto *A = replaceNode(Builder, Node->Operands[0]);
1991fe013be4SDimitry Andric auto *B = replaceNode(Builder, Node->Operands[1]);
1992fe013be4SDimitry Andric auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1993fe013be4SDimitry Andric cast<VectorType>(MaskReal->getType()));
1994fe013be4SDimitry Andric auto *NewMask =
1995fe013be4SDimitry Andric Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1996fe013be4SDimitry Andric NewMaskTy, {MaskReal, MaskImag});
1997fe013be4SDimitry Andric ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1998fe013be4SDimitry Andric break;
1999fe013be4SDimitry Andric }
2000fe013be4SDimitry Andric }
2001bdd1243dSDimitry Andric
2002fe013be4SDimitry Andric assert(ReplacementNode && "Target failed to create Intrinsic call.");
2003bdd1243dSDimitry Andric NumComplexTransformations += 1;
2004fe013be4SDimitry Andric Node->ReplacementNode = ReplacementNode;
2005fe013be4SDimitry Andric return ReplacementNode;
2006fe013be4SDimitry Andric }
2007fe013be4SDimitry Andric
processReductionOperation(Value * OperationReplacement,RawNodePtr Node)2008fe013be4SDimitry Andric void ComplexDeinterleavingGraph::processReductionOperation(
2009fe013be4SDimitry Andric Value *OperationReplacement, RawNodePtr Node) {
2010fe013be4SDimitry Andric auto *Real = cast<Instruction>(Node->Real);
2011fe013be4SDimitry Andric auto *Imag = cast<Instruction>(Node->Imag);
2012fe013be4SDimitry Andric auto *OldPHIReal = ReductionInfo[Real].first;
2013fe013be4SDimitry Andric auto *OldPHIImag = ReductionInfo[Imag].first;
2014fe013be4SDimitry Andric auto *NewPHI = OldToNewPHI[OldPHIReal];
2015fe013be4SDimitry Andric
2016fe013be4SDimitry Andric auto *VTy = cast<VectorType>(Real->getType());
2017fe013be4SDimitry Andric auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2018fe013be4SDimitry Andric
2019fe013be4SDimitry Andric // We have to interleave initial origin values coming from IncomingBlock
2020fe013be4SDimitry Andric Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2021fe013be4SDimitry Andric Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2022fe013be4SDimitry Andric
2023fe013be4SDimitry Andric IRBuilder<> Builder(Incoming->getTerminator());
2024fe013be4SDimitry Andric auto *NewInit = Builder.CreateIntrinsic(
2025fe013be4SDimitry Andric Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
2026fe013be4SDimitry Andric
2027fe013be4SDimitry Andric NewPHI->addIncoming(NewInit, Incoming);
2028fe013be4SDimitry Andric NewPHI->addIncoming(OperationReplacement, BackEdge);
2029fe013be4SDimitry Andric
2030fe013be4SDimitry Andric // Deinterleave complex vector outside of loop so that it can be finally
2031fe013be4SDimitry Andric // reduced
2032fe013be4SDimitry Andric auto *FinalReductionReal = ReductionInfo[Real].second;
2033fe013be4SDimitry Andric auto *FinalReductionImag = ReductionInfo[Imag].second;
2034fe013be4SDimitry Andric
2035fe013be4SDimitry Andric Builder.SetInsertPoint(
2036fe013be4SDimitry Andric &*FinalReductionReal->getParent()->getFirstInsertionPt());
2037fe013be4SDimitry Andric auto *Deinterleave = Builder.CreateIntrinsic(
2038fe013be4SDimitry Andric Intrinsic::experimental_vector_deinterleave2,
2039fe013be4SDimitry Andric OperationReplacement->getType(), OperationReplacement);
2040fe013be4SDimitry Andric
2041fe013be4SDimitry Andric auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2042fe013be4SDimitry Andric FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2043fe013be4SDimitry Andric
2044fe013be4SDimitry Andric Builder.SetInsertPoint(FinalReductionImag);
2045fe013be4SDimitry Andric auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2046fe013be4SDimitry Andric FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2047bdd1243dSDimitry Andric }
2048bdd1243dSDimitry Andric
replaceNodes()2049bdd1243dSDimitry Andric void ComplexDeinterleavingGraph::replaceNodes() {
2050fe013be4SDimitry Andric SmallVector<Instruction *, 16> DeadInstrRoots;
2051fe013be4SDimitry Andric for (auto *RootInstruction : OrderedRoots) {
2052fe013be4SDimitry Andric // Check if this potential root went through check process and we can
2053fe013be4SDimitry Andric // deinterleave it
2054fe013be4SDimitry Andric if (!RootToNode.count(RootInstruction))
2055fe013be4SDimitry Andric continue;
2056fe013be4SDimitry Andric
2057fe013be4SDimitry Andric IRBuilder<> Builder(RootInstruction);
2058fe013be4SDimitry Andric auto RootNode = RootToNode[RootInstruction];
2059fe013be4SDimitry Andric Value *R = replaceNode(Builder, RootNode.get());
2060fe013be4SDimitry Andric
2061fe013be4SDimitry Andric if (RootNode->Operation ==
2062fe013be4SDimitry Andric ComplexDeinterleavingOperation::ReductionOperation) {
2063fe013be4SDimitry Andric auto *RootReal = cast<Instruction>(RootNode->Real);
2064fe013be4SDimitry Andric auto *RootImag = cast<Instruction>(RootNode->Imag);
2065fe013be4SDimitry Andric ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2066fe013be4SDimitry Andric ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2067fe013be4SDimitry Andric DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2068fe013be4SDimitry Andric DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2069fe013be4SDimitry Andric } else {
2070fe013be4SDimitry Andric assert(R && "Unable to find replacement for RootInstruction");
2071fe013be4SDimitry Andric DeadInstrRoots.push_back(RootInstruction);
2072fe013be4SDimitry Andric RootInstruction->replaceAllUsesWith(R);
2073fe013be4SDimitry Andric }
2074bdd1243dSDimitry Andric }
2075bdd1243dSDimitry Andric
2076fe013be4SDimitry Andric for (auto *I : DeadInstrRoots)
2077fe013be4SDimitry Andric RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2078bdd1243dSDimitry Andric }
2079