1bdd1243dSDimitry Andric //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2bdd1243dSDimitry Andric //
3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdd1243dSDimitry Andric //
7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8bdd1243dSDimitry Andric //
9bdd1243dSDimitry Andric // Identification:
10bdd1243dSDimitry Andric // This step is responsible for finding the patterns that can be lowered to
11bdd1243dSDimitry Andric // complex instructions, and building a graph to represent the complex
12bdd1243dSDimitry Andric // structures. Starting from the "Converging Shuffle" (a shuffle that
13bdd1243dSDimitry Andric // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14bdd1243dSDimitry Andric // operands are evaluated and identified as "Composite Nodes" (collections of
15bdd1243dSDimitry Andric // instructions that can potentially be lowered to a single complex
16bdd1243dSDimitry Andric // instruction). This is performed by checking the real and imaginary components
17bdd1243dSDimitry Andric // and tracking the data flow for each component while following the operand
18bdd1243dSDimitry Andric // pairs. Validity of each node is expected to be done upon creation, and any
19bdd1243dSDimitry Andric // validation errors should halt traversal and prevent further graph
20bdd1243dSDimitry Andric // construction.
21fe013be4SDimitry Andric // Instead of relying on Shuffle operations, vector interleaving and
22fe013be4SDimitry Andric // deinterleaving can be represented by vector.interleave2 and
23fe013be4SDimitry Andric // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24fe013be4SDimitry Andric // these intrinsics, whereas, fixed-width vectors are recognized for both
25fe013be4SDimitry Andric // shufflevector instruction and intrinsics.
26bdd1243dSDimitry Andric //
27bdd1243dSDimitry Andric // Replacement:
28bdd1243dSDimitry Andric // This step traverses the graph built up by identification, delegating to the
29bdd1243dSDimitry Andric // target to validate and generate the correct intrinsics, and plumbs them
30bdd1243dSDimitry Andric // together connecting each end of the new intrinsics graph to the existing
31bdd1243dSDimitry Andric // use-def chain. This step is assumed to finish successfully, as all
32bdd1243dSDimitry Andric // information is expected to be correct by this point.
33bdd1243dSDimitry Andric //
34bdd1243dSDimitry Andric //
35bdd1243dSDimitry Andric // Internal data structure:
36bdd1243dSDimitry Andric // ComplexDeinterleavingGraph:
37bdd1243dSDimitry Andric // Keeps references to all the valid CompositeNodes formed as part of the
38bdd1243dSDimitry Andric // transformation, and every Instruction contained within said nodes. It also
39bdd1243dSDimitry Andric // holds onto a reference to the root Instruction, and the root node that should
40bdd1243dSDimitry Andric // replace it.
41bdd1243dSDimitry Andric //
42bdd1243dSDimitry Andric // ComplexDeinterleavingCompositeNode:
43bdd1243dSDimitry Andric // A CompositeNode represents a single transformation point; each node should
44bdd1243dSDimitry Andric // transform into a single complex instruction (ignoring vector splitting, which
45bdd1243dSDimitry Andric // would generate more instructions per node). They are identified in a
46bdd1243dSDimitry Andric // depth-first manner, traversing and identifying the operands of each
47bdd1243dSDimitry Andric // instruction in the order they appear in the IR.
48bdd1243dSDimitry Andric // Each node maintains a reference  to its Real and Imaginary instructions,
49bdd1243dSDimitry Andric // as well as any additional instructions that make up the identified operation
50bdd1243dSDimitry Andric // (Internal instructions should only have uses within their containing node).
51bdd1243dSDimitry Andric // A Node also contains the rotation and operation type that it represents.
52bdd1243dSDimitry Andric // Operands contains pointers to other CompositeNodes, acting as the edges in
53bdd1243dSDimitry Andric // the graph. ReplacementValue is the transformed Value* that has been emitted
54bdd1243dSDimitry Andric // to the IR.
55bdd1243dSDimitry Andric //
56bdd1243dSDimitry Andric // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57bdd1243dSDimitry Andric // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58bdd1243dSDimitry Andric // should be pre-populated.
59bdd1243dSDimitry Andric //
60bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
61bdd1243dSDimitry Andric 
62bdd1243dSDimitry Andric #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63*c9157d92SDimitry Andric #include "llvm/ADT/MapVector.h"
64bdd1243dSDimitry Andric #include "llvm/ADT/Statistic.h"
65bdd1243dSDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h"
66bdd1243dSDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
67bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetLowering.h"
68bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
69bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h"
70bdd1243dSDimitry Andric #include "llvm/IR/IRBuilder.h"
71fe013be4SDimitry Andric #include "llvm/IR/PatternMatch.h"
72bdd1243dSDimitry Andric #include "llvm/InitializePasses.h"
73bdd1243dSDimitry Andric #include "llvm/Target/TargetMachine.h"
74bdd1243dSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
75bdd1243dSDimitry Andric #include <algorithm>
76bdd1243dSDimitry Andric 
77bdd1243dSDimitry Andric using namespace llvm;
78bdd1243dSDimitry Andric using namespace PatternMatch;
79bdd1243dSDimitry Andric 
80bdd1243dSDimitry Andric #define DEBUG_TYPE "complex-deinterleaving"
81bdd1243dSDimitry Andric 
82bdd1243dSDimitry Andric STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83bdd1243dSDimitry Andric 
84bdd1243dSDimitry Andric static cl::opt<bool> ComplexDeinterleavingEnabled(
85bdd1243dSDimitry Andric     "enable-complex-deinterleaving",
86bdd1243dSDimitry Andric     cl::desc("Enable generation of complex instructions"), cl::init(true),
87bdd1243dSDimitry Andric     cl::Hidden);
88bdd1243dSDimitry Andric 
89bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is interleaving.
90bdd1243dSDimitry Andric ///
91bdd1243dSDimitry Andric /// To be interleaving, a mask must alternate between `i` and `i + (Length /
92bdd1243dSDimitry Andric /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
93bdd1243dSDimitry Andric /// 4x vector interleaving mask would be <0, 2, 1, 3>).
94bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask);
95bdd1243dSDimitry Andric 
96bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is deinterleaving.
97bdd1243dSDimitry Andric ///
98bdd1243dSDimitry Andric /// To be deinterleaving, a mask must increment in steps of 2, and either start
99bdd1243dSDimitry Andric /// with 0 or 1.
100bdd1243dSDimitry Andric /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101bdd1243dSDimitry Andric /// <1, 3, 5, 7>).
102bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask);
103bdd1243dSDimitry Andric 
104fe013be4SDimitry Andric /// Returns true if the operation is a negation of V, and it works for both
105fe013be4SDimitry Andric /// integers and floats.
106fe013be4SDimitry Andric static bool isNeg(Value *V);
107fe013be4SDimitry Andric 
108fe013be4SDimitry Andric /// Returns the operand for negation operation.
109fe013be4SDimitry Andric static Value *getNegOperand(Value *V);
110fe013be4SDimitry Andric 
111bdd1243dSDimitry Andric namespace {
112bdd1243dSDimitry Andric 
113bdd1243dSDimitry Andric class ComplexDeinterleavingLegacyPass : public FunctionPass {
114bdd1243dSDimitry Andric public:
115bdd1243dSDimitry Andric   static char ID;
116bdd1243dSDimitry Andric 
ComplexDeinterleavingLegacyPass(const TargetMachine * TM=nullptr)117bdd1243dSDimitry Andric   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
118bdd1243dSDimitry Andric       : FunctionPass(ID), TM(TM) {
119bdd1243dSDimitry Andric     initializeComplexDeinterleavingLegacyPassPass(
120bdd1243dSDimitry Andric         *PassRegistry::getPassRegistry());
121bdd1243dSDimitry Andric   }
122bdd1243dSDimitry Andric 
getPassName() const123bdd1243dSDimitry Andric   StringRef getPassName() const override {
124bdd1243dSDimitry Andric     return "Complex Deinterleaving Pass";
125bdd1243dSDimitry Andric   }
126bdd1243dSDimitry Andric 
127bdd1243dSDimitry Andric   bool runOnFunction(Function &F) override;
getAnalysisUsage(AnalysisUsage & AU) const128bdd1243dSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
129bdd1243dSDimitry Andric     AU.addRequired<TargetLibraryInfoWrapperPass>();
130bdd1243dSDimitry Andric     AU.setPreservesCFG();
131bdd1243dSDimitry Andric   }
132bdd1243dSDimitry Andric 
133bdd1243dSDimitry Andric private:
134bdd1243dSDimitry Andric   const TargetMachine *TM;
135bdd1243dSDimitry Andric };
136bdd1243dSDimitry Andric 
137bdd1243dSDimitry Andric class ComplexDeinterleavingGraph;
138bdd1243dSDimitry Andric struct ComplexDeinterleavingCompositeNode {
139bdd1243dSDimitry Andric 
ComplexDeinterleavingCompositeNode__anon4b5acc0a0111::ComplexDeinterleavingCompositeNode140bdd1243dSDimitry Andric   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
141fe013be4SDimitry Andric                                      Value *R, Value *I)
142bdd1243dSDimitry Andric       : Operation(Op), Real(R), Imag(I) {}
143bdd1243dSDimitry Andric 
144bdd1243dSDimitry Andric private:
145bdd1243dSDimitry Andric   friend class ComplexDeinterleavingGraph;
146bdd1243dSDimitry Andric   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147bdd1243dSDimitry Andric   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148bdd1243dSDimitry Andric 
149bdd1243dSDimitry Andric public:
150bdd1243dSDimitry Andric   ComplexDeinterleavingOperation Operation;
151fe013be4SDimitry Andric   Value *Real;
152fe013be4SDimitry Andric   Value *Imag;
153bdd1243dSDimitry Andric 
154fe013be4SDimitry Andric   // This two members are required exclusively for generating
155fe013be4SDimitry Andric   // ComplexDeinterleavingOperation::Symmetric operations.
156fe013be4SDimitry Andric   unsigned Opcode;
157fe013be4SDimitry Andric   std::optional<FastMathFlags> Flags;
158fe013be4SDimitry Andric 
159fe013be4SDimitry Andric   ComplexDeinterleavingRotation Rotation =
160fe013be4SDimitry Andric       ComplexDeinterleavingRotation::Rotation_0;
161bdd1243dSDimitry Andric   SmallVector<RawNodePtr> Operands;
162bdd1243dSDimitry Andric   Value *ReplacementNode = nullptr;
163bdd1243dSDimitry Andric 
addOperand__anon4b5acc0a0111::ComplexDeinterleavingCompositeNode164bdd1243dSDimitry Andric   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
165bdd1243dSDimitry Andric 
dump__anon4b5acc0a0111::ComplexDeinterleavingCompositeNode166bdd1243dSDimitry Andric   void dump() { dump(dbgs()); }
dump__anon4b5acc0a0111::ComplexDeinterleavingCompositeNode167bdd1243dSDimitry Andric   void dump(raw_ostream &OS) {
168bdd1243dSDimitry Andric     auto PrintValue = [&](Value *V) {
169bdd1243dSDimitry Andric       if (V) {
170bdd1243dSDimitry Andric         OS << "\"";
171bdd1243dSDimitry Andric         V->print(OS, true);
172bdd1243dSDimitry Andric         OS << "\"\n";
173bdd1243dSDimitry Andric       } else
174bdd1243dSDimitry Andric         OS << "nullptr\n";
175bdd1243dSDimitry Andric     };
176bdd1243dSDimitry Andric     auto PrintNodeRef = [&](RawNodePtr Ptr) {
177bdd1243dSDimitry Andric       if (Ptr)
178bdd1243dSDimitry Andric         OS << Ptr << "\n";
179bdd1243dSDimitry Andric       else
180bdd1243dSDimitry Andric         OS << "nullptr\n";
181bdd1243dSDimitry Andric     };
182bdd1243dSDimitry Andric 
183bdd1243dSDimitry Andric     OS << "- CompositeNode: " << this << "\n";
184bdd1243dSDimitry Andric     OS << "  Real: ";
185bdd1243dSDimitry Andric     PrintValue(Real);
186bdd1243dSDimitry Andric     OS << "  Imag: ";
187bdd1243dSDimitry Andric     PrintValue(Imag);
188bdd1243dSDimitry Andric     OS << "  ReplacementNode: ";
189bdd1243dSDimitry Andric     PrintValue(ReplacementNode);
190bdd1243dSDimitry Andric     OS << "  Operation: " << (int)Operation << "\n";
191bdd1243dSDimitry Andric     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
192bdd1243dSDimitry Andric     OS << "  Operands: \n";
193bdd1243dSDimitry Andric     for (const auto &Op : Operands) {
194bdd1243dSDimitry Andric       OS << "    - ";
195bdd1243dSDimitry Andric       PrintNodeRef(Op);
196bdd1243dSDimitry Andric     }
197bdd1243dSDimitry Andric   }
198bdd1243dSDimitry Andric };
199bdd1243dSDimitry Andric 
200bdd1243dSDimitry Andric class ComplexDeinterleavingGraph {
201bdd1243dSDimitry Andric public:
202fe013be4SDimitry Andric   struct Product {
203fe013be4SDimitry Andric     Value *Multiplier;
204fe013be4SDimitry Andric     Value *Multiplicand;
205fe013be4SDimitry Andric     bool IsPositive;
206fe013be4SDimitry Andric   };
207fe013be4SDimitry Andric 
208fe013be4SDimitry Andric   using Addend = std::pair<Value *, bool>;
209bdd1243dSDimitry Andric   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
210bdd1243dSDimitry Andric   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
211fe013be4SDimitry Andric 
212fe013be4SDimitry Andric   // Helper struct for holding info about potential partial multiplication
213fe013be4SDimitry Andric   // candidates
214fe013be4SDimitry Andric   struct PartialMulCandidate {
215fe013be4SDimitry Andric     Value *Common;
216fe013be4SDimitry Andric     NodePtr Node;
217fe013be4SDimitry Andric     unsigned RealIdx;
218fe013be4SDimitry Andric     unsigned ImagIdx;
219fe013be4SDimitry Andric     bool IsNodeInverted;
220fe013be4SDimitry Andric   };
221fe013be4SDimitry Andric 
ComplexDeinterleavingGraph(const TargetLowering * TL,const TargetLibraryInfo * TLI)222fe013be4SDimitry Andric   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
223fe013be4SDimitry Andric                                       const TargetLibraryInfo *TLI)
224fe013be4SDimitry Andric       : TL(TL), TLI(TLI) {}
225bdd1243dSDimitry Andric 
226bdd1243dSDimitry Andric private:
227fe013be4SDimitry Andric   const TargetLowering *TL = nullptr;
228fe013be4SDimitry Andric   const TargetLibraryInfo *TLI = nullptr;
229bdd1243dSDimitry Andric   SmallVector<NodePtr> CompositeNodes;
230271697daSDimitry Andric   DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
231fe013be4SDimitry Andric 
232fe013be4SDimitry Andric   SmallPtrSet<Instruction *, 16> FinalInstructions;
233fe013be4SDimitry Andric 
234fe013be4SDimitry Andric   /// Root instructions are instructions from which complex computation starts
235fe013be4SDimitry Andric   std::map<Instruction *, NodePtr> RootToNode;
236fe013be4SDimitry Andric 
237fe013be4SDimitry Andric   /// Topologically sorted root instructions
238fe013be4SDimitry Andric   SmallVector<Instruction *, 1> OrderedRoots;
239fe013be4SDimitry Andric 
240fe013be4SDimitry Andric   /// When examining a basic block for complex deinterleaving, if it is a simple
241fe013be4SDimitry Andric   /// one-block loop, then the only incoming block is 'Incoming' and the
242fe013be4SDimitry Andric   /// 'BackEdge' block is the block itself."
243fe013be4SDimitry Andric   BasicBlock *BackEdge = nullptr;
244fe013be4SDimitry Andric   BasicBlock *Incoming = nullptr;
245fe013be4SDimitry Andric 
246fe013be4SDimitry Andric   /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
247fe013be4SDimitry Andric   /// %OutsideUser as it is shown in the IR:
248fe013be4SDimitry Andric   ///
249fe013be4SDimitry Andric   /// vector.body:
250fe013be4SDimitry Andric   ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
251fe013be4SDimitry Andric   ///                                [ %ReductionOp, %vector.body ]
252fe013be4SDimitry Andric   ///   ...
253fe013be4SDimitry Andric   ///   %ReductionOp = fadd i64 ...
254fe013be4SDimitry Andric   ///   ...
255fe013be4SDimitry Andric   ///   br i1 %condition, label %vector.body, %middle.block
256fe013be4SDimitry Andric   ///
257fe013be4SDimitry Andric   /// middle.block:
258fe013be4SDimitry Andric   ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
259fe013be4SDimitry Andric   ///
260fe013be4SDimitry Andric   /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
261fe013be4SDimitry Andric   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
262*c9157d92SDimitry Andric   MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
263fe013be4SDimitry Andric 
264fe013be4SDimitry Andric   /// In the process of detecting a reduction, we consider a pair of
265fe013be4SDimitry Andric   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
266fe013be4SDimitry Andric   /// traverse the use-tree to detect complex operations. As this is a reduction
267fe013be4SDimitry Andric   /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
268fe013be4SDimitry Andric   /// to the %ReductionOPs that we suspect to be complex.
269fe013be4SDimitry Andric   /// RealPHI and ImagPHI are used by the identifyPHINode method.
270fe013be4SDimitry Andric   PHINode *RealPHI = nullptr;
271fe013be4SDimitry Andric   PHINode *ImagPHI = nullptr;
272fe013be4SDimitry Andric 
273fe013be4SDimitry Andric   /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
274fe013be4SDimitry Andric   /// detection.
275fe013be4SDimitry Andric   bool PHIsFound = false;
276fe013be4SDimitry Andric 
277fe013be4SDimitry Andric   /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
278fe013be4SDimitry Andric   /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
279fe013be4SDimitry Andric   /// This mapping is populated during
280fe013be4SDimitry Andric   /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
281fe013be4SDimitry Andric   /// used in the ComplexDeinterleavingOperation::ReductionOperation node
282fe013be4SDimitry Andric   /// replacement process.
283fe013be4SDimitry Andric   std::map<PHINode *, PHINode *> OldToNewPHI;
284bdd1243dSDimitry Andric 
prepareCompositeNode(ComplexDeinterleavingOperation Operation,Value * R,Value * I)285bdd1243dSDimitry Andric   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
286fe013be4SDimitry Andric                                Value *R, Value *I) {
287fe013be4SDimitry Andric     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
288fe013be4SDimitry Andric              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
289fe013be4SDimitry Andric             (R && I)) &&
290fe013be4SDimitry Andric            "Reduction related nodes must have Real and Imaginary parts");
291bdd1243dSDimitry Andric     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
292bdd1243dSDimitry Andric                                                                 I);
293bdd1243dSDimitry Andric   }
294bdd1243dSDimitry Andric 
submitCompositeNode(NodePtr Node)295bdd1243dSDimitry Andric   NodePtr submitCompositeNode(NodePtr Node) {
296bdd1243dSDimitry Andric     CompositeNodes.push_back(Node);
297271697daSDimitry Andric     if (Node->Real && Node->Imag)
298271697daSDimitry Andric       CachedResult[{Node->Real, Node->Imag}] = Node;
299bdd1243dSDimitry Andric     return Node;
300bdd1243dSDimitry Andric   }
301bdd1243dSDimitry Andric 
302bdd1243dSDimitry Andric   /// Identifies a complex partial multiply pattern and its rotation, based on
303bdd1243dSDimitry Andric   /// the following patterns
304bdd1243dSDimitry Andric   ///
305bdd1243dSDimitry Andric   ///  0:  r: cr + ar * br
306bdd1243dSDimitry Andric   ///      i: ci + ar * bi
307bdd1243dSDimitry Andric   /// 90:  r: cr - ai * bi
308bdd1243dSDimitry Andric   ///      i: ci + ai * br
309bdd1243dSDimitry Andric   /// 180: r: cr - ar * br
310bdd1243dSDimitry Andric   ///      i: ci - ar * bi
311bdd1243dSDimitry Andric   /// 270: r: cr + ai * bi
312bdd1243dSDimitry Andric   ///      i: ci - ai * br
313bdd1243dSDimitry Andric   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
314bdd1243dSDimitry Andric 
315bdd1243dSDimitry Andric   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
316bdd1243dSDimitry Andric   /// is partially known from identifyPartialMul, filling in the other half of
317bdd1243dSDimitry Andric   /// the complex pair.
318fe013be4SDimitry Andric   NodePtr
319fe013be4SDimitry Andric   identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
320fe013be4SDimitry Andric                               std::pair<Value *, Value *> &CommonOperandI);
321bdd1243dSDimitry Andric 
322bdd1243dSDimitry Andric   /// Identifies a complex add pattern and its rotation, based on the following
323bdd1243dSDimitry Andric   /// patterns.
324bdd1243dSDimitry Andric   ///
325bdd1243dSDimitry Andric   /// 90:  r: ar - bi
326bdd1243dSDimitry Andric   ///      i: ai + br
327bdd1243dSDimitry Andric   /// 270: r: ar + bi
328bdd1243dSDimitry Andric   ///      i: ai - br
329bdd1243dSDimitry Andric   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
330fe013be4SDimitry Andric   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
331bdd1243dSDimitry Andric 
332fe013be4SDimitry Andric   NodePtr identifyNode(Value *R, Value *I);
333bdd1243dSDimitry Andric 
334fe013be4SDimitry Andric   /// Determine if a sum of complex numbers can be formed from \p RealAddends
335fe013be4SDimitry Andric   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
336fe013be4SDimitry Andric   /// Return nullptr if it is not possible to construct a complex number.
337fe013be4SDimitry Andric   /// \p Flags are needed to generate symmetric Add and Sub operations.
338fe013be4SDimitry Andric   NodePtr identifyAdditions(std::list<Addend> &RealAddends,
339fe013be4SDimitry Andric                             std::list<Addend> &ImagAddends,
340fe013be4SDimitry Andric                             std::optional<FastMathFlags> Flags,
341fe013be4SDimitry Andric                             NodePtr Accumulator);
342fe013be4SDimitry Andric 
343fe013be4SDimitry Andric   /// Extract one addend that have both real and imaginary parts positive.
344fe013be4SDimitry Andric   NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
345fe013be4SDimitry Andric                                 std::list<Addend> &ImagAddends);
346fe013be4SDimitry Andric 
347fe013be4SDimitry Andric   /// Determine if sum of multiplications of complex numbers can be formed from
348fe013be4SDimitry Andric   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
349fe013be4SDimitry Andric   /// to it. Return nullptr if it is not possible to construct a complex number.
350fe013be4SDimitry Andric   NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
351fe013be4SDimitry Andric                                   std::vector<Product> &ImagMuls,
352fe013be4SDimitry Andric                                   NodePtr Accumulator);
353fe013be4SDimitry Andric 
354fe013be4SDimitry Andric   /// Go through pairs of multiplication (one Real and one Imag) and find all
355fe013be4SDimitry Andric   /// possible candidates for partial multiplication and put them into \p
356fe013be4SDimitry Andric   /// Candidates. Returns true if all Product has pair with common operand
357fe013be4SDimitry Andric   bool collectPartialMuls(const std::vector<Product> &RealMuls,
358fe013be4SDimitry Andric                           const std::vector<Product> &ImagMuls,
359fe013be4SDimitry Andric                           std::vector<PartialMulCandidate> &Candidates);
360fe013be4SDimitry Andric 
361fe013be4SDimitry Andric   /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
362fe013be4SDimitry Andric   /// the order of complex computation operations may be significantly altered,
363fe013be4SDimitry Andric   /// and the real and imaginary parts may not be executed in parallel. This
364fe013be4SDimitry Andric   /// function takes this into consideration and employs a more general approach
365fe013be4SDimitry Andric   /// to identify complex computations. Initially, it gathers all the addends
366fe013be4SDimitry Andric   /// and multiplicands and then constructs a complex expression from them.
367fe013be4SDimitry Andric   NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
368fe013be4SDimitry Andric 
369fe013be4SDimitry Andric   NodePtr identifyRoot(Instruction *I);
370fe013be4SDimitry Andric 
371fe013be4SDimitry Andric   /// Identifies the Deinterleave operation applied to a vector containing
372fe013be4SDimitry Andric   /// complex numbers. There are two ways to represent the Deinterleave
373fe013be4SDimitry Andric   /// operation:
374fe013be4SDimitry Andric   /// * Using two shufflevectors with even indices for /pReal instruction and
375fe013be4SDimitry Andric   /// odd indices for /pImag instructions (only for fixed-width vectors)
376fe013be4SDimitry Andric   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
377fe013be4SDimitry Andric   /// intrinsic (for both fixed and scalable vectors)
378fe013be4SDimitry Andric   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
379fe013be4SDimitry Andric 
380fe013be4SDimitry Andric   /// identifying the operation that represents a complex number repeated in a
381fe013be4SDimitry Andric   /// Splat vector. There are two possible types of splats: ConstantExpr with
382fe013be4SDimitry Andric   /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
383fe013be4SDimitry Andric   /// initialization mask with all values set to zero.
384fe013be4SDimitry Andric   NodePtr identifySplat(Value *Real, Value *Imag);
385fe013be4SDimitry Andric 
386fe013be4SDimitry Andric   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
387fe013be4SDimitry Andric 
388fe013be4SDimitry Andric   /// Identifies SelectInsts in a loop that has reduction with predication masks
389fe013be4SDimitry Andric   /// and/or predicated tail folding
390fe013be4SDimitry Andric   NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
391fe013be4SDimitry Andric 
392fe013be4SDimitry Andric   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
393fe013be4SDimitry Andric 
394fe013be4SDimitry Andric   /// Complete IR modifications after producing new reduction operation:
395fe013be4SDimitry Andric   /// * Populate the PHINode generated for
396fe013be4SDimitry Andric   /// ComplexDeinterleavingOperation::ReductionPHI
397fe013be4SDimitry Andric   /// * Deinterleave the final value outside of the loop and repurpose original
398fe013be4SDimitry Andric   /// reduction users
399fe013be4SDimitry Andric   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
400bdd1243dSDimitry Andric 
401bdd1243dSDimitry Andric public:
dump()402bdd1243dSDimitry Andric   void dump() { dump(dbgs()); }
dump(raw_ostream & OS)403bdd1243dSDimitry Andric   void dump(raw_ostream &OS) {
404bdd1243dSDimitry Andric     for (const auto &Node : CompositeNodes)
405bdd1243dSDimitry Andric       Node->dump(OS);
406bdd1243dSDimitry Andric   }
407bdd1243dSDimitry Andric 
408bdd1243dSDimitry Andric   /// Returns false if the deinterleaving operation should be cancelled for the
409bdd1243dSDimitry Andric   /// current graph.
410bdd1243dSDimitry Andric   bool identifyNodes(Instruction *RootI);
411bdd1243dSDimitry Andric 
412fe013be4SDimitry Andric   /// In case \pB is one-block loop, this function seeks potential reductions
413fe013be4SDimitry Andric   /// and populates ReductionInfo. Returns true if any reductions were
414fe013be4SDimitry Andric   /// identified.
415fe013be4SDimitry Andric   bool collectPotentialReductions(BasicBlock *B);
416fe013be4SDimitry Andric 
417fe013be4SDimitry Andric   void identifyReductionNodes();
418fe013be4SDimitry Andric 
419fe013be4SDimitry Andric   /// Check that every instruction, from the roots to the leaves, has internal
420fe013be4SDimitry Andric   /// uses.
421fe013be4SDimitry Andric   bool checkNodes();
422fe013be4SDimitry Andric 
423bdd1243dSDimitry Andric   /// Perform the actual replacement of the underlying instruction graph.
424bdd1243dSDimitry Andric   void replaceNodes();
425bdd1243dSDimitry Andric };
426bdd1243dSDimitry Andric 
427bdd1243dSDimitry Andric class ComplexDeinterleaving {
428bdd1243dSDimitry Andric public:
ComplexDeinterleaving(const TargetLowering * tl,const TargetLibraryInfo * tli)429bdd1243dSDimitry Andric   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
430bdd1243dSDimitry Andric       : TL(tl), TLI(tli) {}
431bdd1243dSDimitry Andric   bool runOnFunction(Function &F);
432bdd1243dSDimitry Andric 
433bdd1243dSDimitry Andric private:
434bdd1243dSDimitry Andric   bool evaluateBasicBlock(BasicBlock *B);
435bdd1243dSDimitry Andric 
436bdd1243dSDimitry Andric   const TargetLowering *TL = nullptr;
437bdd1243dSDimitry Andric   const TargetLibraryInfo *TLI = nullptr;
438bdd1243dSDimitry Andric };
439bdd1243dSDimitry Andric 
440bdd1243dSDimitry Andric } // namespace
441bdd1243dSDimitry Andric 
442bdd1243dSDimitry Andric char ComplexDeinterleavingLegacyPass::ID = 0;
443bdd1243dSDimitry Andric 
444bdd1243dSDimitry Andric INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
445bdd1243dSDimitry Andric                       "Complex Deinterleaving", false, false)
446bdd1243dSDimitry Andric INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
447bdd1243dSDimitry Andric                     "Complex Deinterleaving", false, false)
448bdd1243dSDimitry Andric 
run(Function & F,FunctionAnalysisManager & AM)449bdd1243dSDimitry Andric PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
450bdd1243dSDimitry Andric                                                  FunctionAnalysisManager &AM) {
451bdd1243dSDimitry Andric   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
452bdd1243dSDimitry Andric   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
453bdd1243dSDimitry Andric   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
454bdd1243dSDimitry Andric     return PreservedAnalyses::all();
455bdd1243dSDimitry Andric 
456bdd1243dSDimitry Andric   PreservedAnalyses PA;
457bdd1243dSDimitry Andric   PA.preserve<FunctionAnalysisManagerModuleProxy>();
458bdd1243dSDimitry Andric   return PA;
459bdd1243dSDimitry Andric }
460bdd1243dSDimitry Andric 
createComplexDeinterleavingPass(const TargetMachine * TM)461bdd1243dSDimitry Andric FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
462bdd1243dSDimitry Andric   return new ComplexDeinterleavingLegacyPass(TM);
463bdd1243dSDimitry Andric }
464bdd1243dSDimitry Andric 
runOnFunction(Function & F)465bdd1243dSDimitry Andric bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
466bdd1243dSDimitry Andric   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
467bdd1243dSDimitry Andric   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
468bdd1243dSDimitry Andric   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
469bdd1243dSDimitry Andric }
470bdd1243dSDimitry Andric 
runOnFunction(Function & F)471bdd1243dSDimitry Andric bool ComplexDeinterleaving::runOnFunction(Function &F) {
472bdd1243dSDimitry Andric   if (!ComplexDeinterleavingEnabled) {
473bdd1243dSDimitry Andric     LLVM_DEBUG(
474bdd1243dSDimitry Andric         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
475bdd1243dSDimitry Andric     return false;
476bdd1243dSDimitry Andric   }
477bdd1243dSDimitry Andric 
478bdd1243dSDimitry Andric   if (!TL->isComplexDeinterleavingSupported()) {
479bdd1243dSDimitry Andric     LLVM_DEBUG(
480bdd1243dSDimitry Andric         dbgs() << "Complex deinterleaving has been disabled, target does "
481bdd1243dSDimitry Andric                   "not support lowering of complex number operations.\n");
482bdd1243dSDimitry Andric     return false;
483bdd1243dSDimitry Andric   }
484bdd1243dSDimitry Andric 
485bdd1243dSDimitry Andric   bool Changed = false;
486bdd1243dSDimitry Andric   for (auto &B : F)
487bdd1243dSDimitry Andric     Changed |= evaluateBasicBlock(&B);
488bdd1243dSDimitry Andric 
489bdd1243dSDimitry Andric   return Changed;
490bdd1243dSDimitry Andric }
491bdd1243dSDimitry Andric 
isInterleavingMask(ArrayRef<int> Mask)492bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask) {
493bdd1243dSDimitry Andric   // If the size is not even, it's not an interleaving mask
494bdd1243dSDimitry Andric   if ((Mask.size() & 1))
495bdd1243dSDimitry Andric     return false;
496bdd1243dSDimitry Andric 
497bdd1243dSDimitry Andric   int HalfNumElements = Mask.size() / 2;
498bdd1243dSDimitry Andric   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
499bdd1243dSDimitry Andric     int MaskIdx = Idx * 2;
500bdd1243dSDimitry Andric     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
501bdd1243dSDimitry Andric       return false;
502bdd1243dSDimitry Andric   }
503bdd1243dSDimitry Andric 
504bdd1243dSDimitry Andric   return true;
505bdd1243dSDimitry Andric }
506bdd1243dSDimitry Andric 
isDeinterleavingMask(ArrayRef<int> Mask)507bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask) {
508bdd1243dSDimitry Andric   int Offset = Mask[0];
509bdd1243dSDimitry Andric   int HalfNumElements = Mask.size() / 2;
510bdd1243dSDimitry Andric 
511bdd1243dSDimitry Andric   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
512bdd1243dSDimitry Andric     if (Mask[Idx] != (Idx * 2) + Offset)
513bdd1243dSDimitry Andric       return false;
514bdd1243dSDimitry Andric   }
515bdd1243dSDimitry Andric 
516bdd1243dSDimitry Andric   return true;
517bdd1243dSDimitry Andric }
518bdd1243dSDimitry Andric 
isNeg(Value * V)519fe013be4SDimitry Andric bool isNeg(Value *V) {
520fe013be4SDimitry Andric   return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
521fe013be4SDimitry Andric }
522fe013be4SDimitry Andric 
getNegOperand(Value * V)523fe013be4SDimitry Andric Value *getNegOperand(Value *V) {
524fe013be4SDimitry Andric   assert(isNeg(V));
525fe013be4SDimitry Andric   auto *I = cast<Instruction>(V);
526fe013be4SDimitry Andric   if (I->getOpcode() == Instruction::FNeg)
527fe013be4SDimitry Andric     return I->getOperand(0);
528fe013be4SDimitry Andric 
529fe013be4SDimitry Andric   return I->getOperand(1);
530fe013be4SDimitry Andric }
531fe013be4SDimitry Andric 
evaluateBasicBlock(BasicBlock * B)532bdd1243dSDimitry Andric bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
533fe013be4SDimitry Andric   ComplexDeinterleavingGraph Graph(TL, TLI);
534fe013be4SDimitry Andric   if (Graph.collectPotentialReductions(B))
535fe013be4SDimitry Andric     Graph.identifyReductionNodes();
536bdd1243dSDimitry Andric 
537fe013be4SDimitry Andric   for (auto &I : *B)
538fe013be4SDimitry Andric     Graph.identifyNodes(&I);
539bdd1243dSDimitry Andric 
540fe013be4SDimitry Andric   if (Graph.checkNodes()) {
541bdd1243dSDimitry Andric     Graph.replaceNodes();
542fe013be4SDimitry Andric     return true;
543bdd1243dSDimitry Andric   }
544bdd1243dSDimitry Andric 
545fe013be4SDimitry Andric   return false;
546bdd1243dSDimitry Andric }
547bdd1243dSDimitry Andric 
548bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyNodeWithImplicitAdd(Instruction * Real,Instruction * Imag,std::pair<Value *,Value * > & PartialMatch)549bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
550bdd1243dSDimitry Andric     Instruction *Real, Instruction *Imag,
551fe013be4SDimitry Andric     std::pair<Value *, Value *> &PartialMatch) {
552bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
553bdd1243dSDimitry Andric                     << "\n");
554bdd1243dSDimitry Andric 
555bdd1243dSDimitry Andric   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
556bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
557bdd1243dSDimitry Andric     return nullptr;
558bdd1243dSDimitry Andric   }
559bdd1243dSDimitry Andric 
560fe013be4SDimitry Andric   if ((Real->getOpcode() != Instruction::FMul &&
561fe013be4SDimitry Andric        Real->getOpcode() != Instruction::Mul) ||
562fe013be4SDimitry Andric       (Imag->getOpcode() != Instruction::FMul &&
563fe013be4SDimitry Andric        Imag->getOpcode() != Instruction::Mul)) {
564fe013be4SDimitry Andric     LLVM_DEBUG(
565fe013be4SDimitry Andric         dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
566bdd1243dSDimitry Andric     return nullptr;
567bdd1243dSDimitry Andric   }
568bdd1243dSDimitry Andric 
569fe013be4SDimitry Andric   Value *R0 = Real->getOperand(0);
570fe013be4SDimitry Andric   Value *R1 = Real->getOperand(1);
571fe013be4SDimitry Andric   Value *I0 = Imag->getOperand(0);
572fe013be4SDimitry Andric   Value *I1 = Imag->getOperand(1);
573bdd1243dSDimitry Andric 
574bdd1243dSDimitry Andric   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
575bdd1243dSDimitry Andric   // rotations and use the operand.
576bdd1243dSDimitry Andric   unsigned Negs = 0;
577fe013be4SDimitry Andric   Value *Op;
578fe013be4SDimitry Andric   if (match(R0, m_Neg(m_Value(Op)))) {
579bdd1243dSDimitry Andric     Negs |= 1;
580fe013be4SDimitry Andric     R0 = Op;
581fe013be4SDimitry Andric   } else if (match(R1, m_Neg(m_Value(Op)))) {
582fe013be4SDimitry Andric     Negs |= 1;
583fe013be4SDimitry Andric     R1 = Op;
584bdd1243dSDimitry Andric   }
585fe013be4SDimitry Andric 
586fe013be4SDimitry Andric   if (isNeg(I0)) {
587bdd1243dSDimitry Andric     Negs |= 2;
588bdd1243dSDimitry Andric     Negs ^= 1;
589fe013be4SDimitry Andric     I0 = Op;
590fe013be4SDimitry Andric   } else if (match(I1, m_Neg(m_Value(Op)))) {
591fe013be4SDimitry Andric     Negs |= 2;
592fe013be4SDimitry Andric     Negs ^= 1;
593fe013be4SDimitry Andric     I1 = Op;
594bdd1243dSDimitry Andric   }
595bdd1243dSDimitry Andric 
596bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
597bdd1243dSDimitry Andric 
598fe013be4SDimitry Andric   Value *CommonOperand;
599fe013be4SDimitry Andric   Value *UncommonRealOp;
600fe013be4SDimitry Andric   Value *UncommonImagOp;
601bdd1243dSDimitry Andric 
602bdd1243dSDimitry Andric   if (R0 == I0 || R0 == I1) {
603bdd1243dSDimitry Andric     CommonOperand = R0;
604bdd1243dSDimitry Andric     UncommonRealOp = R1;
605bdd1243dSDimitry Andric   } else if (R1 == I0 || R1 == I1) {
606bdd1243dSDimitry Andric     CommonOperand = R1;
607bdd1243dSDimitry Andric     UncommonRealOp = R0;
608bdd1243dSDimitry Andric   } else {
609bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
610bdd1243dSDimitry Andric     return nullptr;
611bdd1243dSDimitry Andric   }
612bdd1243dSDimitry Andric 
613bdd1243dSDimitry Andric   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
614bdd1243dSDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
615bdd1243dSDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_270)
616bdd1243dSDimitry Andric     std::swap(UncommonRealOp, UncommonImagOp);
617bdd1243dSDimitry Andric 
618bdd1243dSDimitry Andric   // Between identifyPartialMul and here we need to have found a complete valid
619bdd1243dSDimitry Andric   // pair from the CommonOperand of each part.
620bdd1243dSDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
621bdd1243dSDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_180)
622bdd1243dSDimitry Andric     PartialMatch.first = CommonOperand;
623bdd1243dSDimitry Andric   else
624bdd1243dSDimitry Andric     PartialMatch.second = CommonOperand;
625bdd1243dSDimitry Andric 
626bdd1243dSDimitry Andric   if (!PartialMatch.first || !PartialMatch.second) {
627bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
628bdd1243dSDimitry Andric     return nullptr;
629bdd1243dSDimitry Andric   }
630bdd1243dSDimitry Andric 
631bdd1243dSDimitry Andric   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
632bdd1243dSDimitry Andric   if (!CommonNode) {
633bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
634bdd1243dSDimitry Andric     return nullptr;
635bdd1243dSDimitry Andric   }
636bdd1243dSDimitry Andric 
637bdd1243dSDimitry Andric   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
638bdd1243dSDimitry Andric   if (!UncommonNode) {
639bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
640bdd1243dSDimitry Andric     return nullptr;
641bdd1243dSDimitry Andric   }
642bdd1243dSDimitry Andric 
643bdd1243dSDimitry Andric   NodePtr Node = prepareCompositeNode(
644bdd1243dSDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
645bdd1243dSDimitry Andric   Node->Rotation = Rotation;
646bdd1243dSDimitry Andric   Node->addOperand(CommonNode);
647bdd1243dSDimitry Andric   Node->addOperand(UncommonNode);
648bdd1243dSDimitry Andric   return submitCompositeNode(Node);
649bdd1243dSDimitry Andric }
650bdd1243dSDimitry Andric 
651bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyPartialMul(Instruction * Real,Instruction * Imag)652bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
653bdd1243dSDimitry Andric                                                Instruction *Imag) {
654bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
655bdd1243dSDimitry Andric                     << "\n");
656bdd1243dSDimitry Andric   // Determine rotation
657fe013be4SDimitry Andric   auto IsAdd = [](unsigned Op) {
658fe013be4SDimitry Andric     return Op == Instruction::FAdd || Op == Instruction::Add;
659fe013be4SDimitry Andric   };
660fe013be4SDimitry Andric   auto IsSub = [](unsigned Op) {
661fe013be4SDimitry Andric     return Op == Instruction::FSub || Op == Instruction::Sub;
662fe013be4SDimitry Andric   };
663bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation;
664fe013be4SDimitry Andric   if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
665bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_0;
666fe013be4SDimitry Andric   else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
667bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_90;
668fe013be4SDimitry Andric   else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
669bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_180;
670fe013be4SDimitry Andric   else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
671bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_270;
672bdd1243dSDimitry Andric   else {
673bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
674bdd1243dSDimitry Andric     return nullptr;
675bdd1243dSDimitry Andric   }
676bdd1243dSDimitry Andric 
677fe013be4SDimitry Andric   if (isa<FPMathOperator>(Real) &&
678fe013be4SDimitry Andric       (!Real->getFastMathFlags().allowContract() ||
679fe013be4SDimitry Andric        !Imag->getFastMathFlags().allowContract())) {
680bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
681bdd1243dSDimitry Andric     return nullptr;
682bdd1243dSDimitry Andric   }
683bdd1243dSDimitry Andric 
684bdd1243dSDimitry Andric   Value *CR = Real->getOperand(0);
685bdd1243dSDimitry Andric   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
686bdd1243dSDimitry Andric   if (!RealMulI)
687bdd1243dSDimitry Andric     return nullptr;
688bdd1243dSDimitry Andric   Value *CI = Imag->getOperand(0);
689bdd1243dSDimitry Andric   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
690bdd1243dSDimitry Andric   if (!ImagMulI)
691bdd1243dSDimitry Andric     return nullptr;
692bdd1243dSDimitry Andric 
693bdd1243dSDimitry Andric   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
694bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
695bdd1243dSDimitry Andric     return nullptr;
696bdd1243dSDimitry Andric   }
697bdd1243dSDimitry Andric 
698fe013be4SDimitry Andric   Value *R0 = RealMulI->getOperand(0);
699fe013be4SDimitry Andric   Value *R1 = RealMulI->getOperand(1);
700fe013be4SDimitry Andric   Value *I0 = ImagMulI->getOperand(0);
701fe013be4SDimitry Andric   Value *I1 = ImagMulI->getOperand(1);
702bdd1243dSDimitry Andric 
703fe013be4SDimitry Andric   Value *CommonOperand;
704fe013be4SDimitry Andric   Value *UncommonRealOp;
705fe013be4SDimitry Andric   Value *UncommonImagOp;
706bdd1243dSDimitry Andric 
707bdd1243dSDimitry Andric   if (R0 == I0 || R0 == I1) {
708bdd1243dSDimitry Andric     CommonOperand = R0;
709bdd1243dSDimitry Andric     UncommonRealOp = R1;
710bdd1243dSDimitry Andric   } else if (R1 == I0 || R1 == I1) {
711bdd1243dSDimitry Andric     CommonOperand = R1;
712bdd1243dSDimitry Andric     UncommonRealOp = R0;
713bdd1243dSDimitry Andric   } else {
714bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
715bdd1243dSDimitry Andric     return nullptr;
716bdd1243dSDimitry Andric   }
717bdd1243dSDimitry Andric 
718bdd1243dSDimitry Andric   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
719bdd1243dSDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
720bdd1243dSDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_270)
721bdd1243dSDimitry Andric     std::swap(UncommonRealOp, UncommonImagOp);
722bdd1243dSDimitry Andric 
723fe013be4SDimitry Andric   std::pair<Value *, Value *> PartialMatch(
724bdd1243dSDimitry Andric       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
725bdd1243dSDimitry Andric        Rotation == ComplexDeinterleavingRotation::Rotation_180)
726bdd1243dSDimitry Andric           ? CommonOperand
727bdd1243dSDimitry Andric           : nullptr,
728bdd1243dSDimitry Andric       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
729bdd1243dSDimitry Andric        Rotation == ComplexDeinterleavingRotation::Rotation_270)
730bdd1243dSDimitry Andric           ? CommonOperand
731bdd1243dSDimitry Andric           : nullptr);
732fe013be4SDimitry Andric 
733fe013be4SDimitry Andric   auto *CRInst = dyn_cast<Instruction>(CR);
734fe013be4SDimitry Andric   auto *CIInst = dyn_cast<Instruction>(CI);
735fe013be4SDimitry Andric 
736fe013be4SDimitry Andric   if (!CRInst || !CIInst) {
737fe013be4SDimitry Andric     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
738fe013be4SDimitry Andric     return nullptr;
739fe013be4SDimitry Andric   }
740fe013be4SDimitry Andric 
741fe013be4SDimitry Andric   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
742bdd1243dSDimitry Andric   if (!CNode) {
743bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
744bdd1243dSDimitry Andric     return nullptr;
745bdd1243dSDimitry Andric   }
746bdd1243dSDimitry Andric 
747bdd1243dSDimitry Andric   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
748bdd1243dSDimitry Andric   if (!UncommonRes) {
749bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
750bdd1243dSDimitry Andric     return nullptr;
751bdd1243dSDimitry Andric   }
752bdd1243dSDimitry Andric 
753bdd1243dSDimitry Andric   assert(PartialMatch.first && PartialMatch.second);
754bdd1243dSDimitry Andric   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
755bdd1243dSDimitry Andric   if (!CommonRes) {
756bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
757bdd1243dSDimitry Andric     return nullptr;
758bdd1243dSDimitry Andric   }
759bdd1243dSDimitry Andric 
760bdd1243dSDimitry Andric   NodePtr Node = prepareCompositeNode(
761bdd1243dSDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
762bdd1243dSDimitry Andric   Node->Rotation = Rotation;
763bdd1243dSDimitry Andric   Node->addOperand(CommonRes);
764bdd1243dSDimitry Andric   Node->addOperand(UncommonRes);
765bdd1243dSDimitry Andric   Node->addOperand(CNode);
766bdd1243dSDimitry Andric   return submitCompositeNode(Node);
767bdd1243dSDimitry Andric }
768bdd1243dSDimitry Andric 
769bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyAdd(Instruction * Real,Instruction * Imag)770bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
771bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
772bdd1243dSDimitry Andric 
773bdd1243dSDimitry Andric   // Determine rotation
774bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation;
775bdd1243dSDimitry Andric   if ((Real->getOpcode() == Instruction::FSub &&
776bdd1243dSDimitry Andric        Imag->getOpcode() == Instruction::FAdd) ||
777bdd1243dSDimitry Andric       (Real->getOpcode() == Instruction::Sub &&
778bdd1243dSDimitry Andric        Imag->getOpcode() == Instruction::Add))
779bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_90;
780bdd1243dSDimitry Andric   else if ((Real->getOpcode() == Instruction::FAdd &&
781bdd1243dSDimitry Andric             Imag->getOpcode() == Instruction::FSub) ||
782bdd1243dSDimitry Andric            (Real->getOpcode() == Instruction::Add &&
783bdd1243dSDimitry Andric             Imag->getOpcode() == Instruction::Sub))
784bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_270;
785bdd1243dSDimitry Andric   else {
786bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
787bdd1243dSDimitry Andric     return nullptr;
788bdd1243dSDimitry Andric   }
789bdd1243dSDimitry Andric 
790bdd1243dSDimitry Andric   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
791bdd1243dSDimitry Andric   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
792bdd1243dSDimitry Andric   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
793bdd1243dSDimitry Andric   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
794bdd1243dSDimitry Andric 
795bdd1243dSDimitry Andric   if (!AR || !AI || !BR || !BI) {
796bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
797bdd1243dSDimitry Andric     return nullptr;
798bdd1243dSDimitry Andric   }
799bdd1243dSDimitry Andric 
800bdd1243dSDimitry Andric   NodePtr ResA = identifyNode(AR, AI);
801bdd1243dSDimitry Andric   if (!ResA) {
802bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
803bdd1243dSDimitry Andric     return nullptr;
804bdd1243dSDimitry Andric   }
805bdd1243dSDimitry Andric   NodePtr ResB = identifyNode(BR, BI);
806bdd1243dSDimitry Andric   if (!ResB) {
807bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
808bdd1243dSDimitry Andric     return nullptr;
809bdd1243dSDimitry Andric   }
810bdd1243dSDimitry Andric 
811bdd1243dSDimitry Andric   NodePtr Node =
812bdd1243dSDimitry Andric       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
813bdd1243dSDimitry Andric   Node->Rotation = Rotation;
814bdd1243dSDimitry Andric   Node->addOperand(ResA);
815bdd1243dSDimitry Andric   Node->addOperand(ResB);
816bdd1243dSDimitry Andric   return submitCompositeNode(Node);
817bdd1243dSDimitry Andric }
818bdd1243dSDimitry Andric 
isInstructionPairAdd(Instruction * A,Instruction * B)819bdd1243dSDimitry Andric static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
820bdd1243dSDimitry Andric   unsigned OpcA = A->getOpcode();
821bdd1243dSDimitry Andric   unsigned OpcB = B->getOpcode();
822bdd1243dSDimitry Andric 
823bdd1243dSDimitry Andric   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
824bdd1243dSDimitry Andric          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
825bdd1243dSDimitry Andric          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
826bdd1243dSDimitry Andric          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
827bdd1243dSDimitry Andric }
828bdd1243dSDimitry Andric 
isInstructionPairMul(Instruction * A,Instruction * B)829bdd1243dSDimitry Andric static bool isInstructionPairMul(Instruction *A, Instruction *B) {
830bdd1243dSDimitry Andric   auto Pattern =
831bdd1243dSDimitry Andric       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
832bdd1243dSDimitry Andric 
833bdd1243dSDimitry Andric   return match(A, Pattern) && match(B, Pattern);
834bdd1243dSDimitry Andric }
835bdd1243dSDimitry Andric 
isInstructionPotentiallySymmetric(Instruction * I)836fe013be4SDimitry Andric static bool isInstructionPotentiallySymmetric(Instruction *I) {
837fe013be4SDimitry Andric   switch (I->getOpcode()) {
838fe013be4SDimitry Andric   case Instruction::FAdd:
839fe013be4SDimitry Andric   case Instruction::FSub:
840fe013be4SDimitry Andric   case Instruction::FMul:
841fe013be4SDimitry Andric   case Instruction::FNeg:
842fe013be4SDimitry Andric   case Instruction::Add:
843fe013be4SDimitry Andric   case Instruction::Sub:
844fe013be4SDimitry Andric   case Instruction::Mul:
845fe013be4SDimitry Andric     return true;
846fe013be4SDimitry Andric   default:
847fe013be4SDimitry Andric     return false;
848fe013be4SDimitry Andric   }
849fe013be4SDimitry Andric }
850fe013be4SDimitry Andric 
851bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifySymmetricOperation(Instruction * Real,Instruction * Imag)852fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
853fe013be4SDimitry Andric                                                        Instruction *Imag) {
854fe013be4SDimitry Andric   if (Real->getOpcode() != Imag->getOpcode())
855fe013be4SDimitry Andric     return nullptr;
856fe013be4SDimitry Andric 
857fe013be4SDimitry Andric   if (!isInstructionPotentiallySymmetric(Real) ||
858fe013be4SDimitry Andric       !isInstructionPotentiallySymmetric(Imag))
859fe013be4SDimitry Andric     return nullptr;
860fe013be4SDimitry Andric 
861fe013be4SDimitry Andric   auto *R0 = Real->getOperand(0);
862fe013be4SDimitry Andric   auto *I0 = Imag->getOperand(0);
863fe013be4SDimitry Andric 
864fe013be4SDimitry Andric   NodePtr Op0 = identifyNode(R0, I0);
865fe013be4SDimitry Andric   NodePtr Op1 = nullptr;
866fe013be4SDimitry Andric   if (Op0 == nullptr)
867fe013be4SDimitry Andric     return nullptr;
868fe013be4SDimitry Andric 
869fe013be4SDimitry Andric   if (Real->isBinaryOp()) {
870fe013be4SDimitry Andric     auto *R1 = Real->getOperand(1);
871fe013be4SDimitry Andric     auto *I1 = Imag->getOperand(1);
872fe013be4SDimitry Andric     Op1 = identifyNode(R1, I1);
873fe013be4SDimitry Andric     if (Op1 == nullptr)
874fe013be4SDimitry Andric       return nullptr;
875fe013be4SDimitry Andric   }
876fe013be4SDimitry Andric 
877fe013be4SDimitry Andric   if (isa<FPMathOperator>(Real) &&
878fe013be4SDimitry Andric       Real->getFastMathFlags() != Imag->getFastMathFlags())
879fe013be4SDimitry Andric     return nullptr;
880fe013be4SDimitry Andric 
881fe013be4SDimitry Andric   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
882fe013be4SDimitry Andric                                    Real, Imag);
883fe013be4SDimitry Andric   Node->Opcode = Real->getOpcode();
884fe013be4SDimitry Andric   if (isa<FPMathOperator>(Real))
885fe013be4SDimitry Andric     Node->Flags = Real->getFastMathFlags();
886fe013be4SDimitry Andric 
887fe013be4SDimitry Andric   Node->addOperand(Op0);
888fe013be4SDimitry Andric   if (Real->isBinaryOp())
889fe013be4SDimitry Andric     Node->addOperand(Op1);
890fe013be4SDimitry Andric 
891fe013be4SDimitry Andric   return submitCompositeNode(Node);
892fe013be4SDimitry Andric }
893fe013be4SDimitry Andric 
894fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyNode(Value * R,Value * I)895fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
896fe013be4SDimitry Andric   LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
897fe013be4SDimitry Andric   assert(R->getType() == I->getType() &&
898fe013be4SDimitry Andric          "Real and imaginary parts should not have different types");
899271697daSDimitry Andric 
900271697daSDimitry Andric   auto It = CachedResult.find({R, I});
901271697daSDimitry Andric   if (It != CachedResult.end()) {
902bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
903271697daSDimitry Andric     return It->second;
904bdd1243dSDimitry Andric   }
905bdd1243dSDimitry Andric 
906fe013be4SDimitry Andric   if (NodePtr CN = identifySplat(R, I))
907fe013be4SDimitry Andric     return CN;
908fe013be4SDimitry Andric 
909fe013be4SDimitry Andric   auto *Real = dyn_cast<Instruction>(R);
910fe013be4SDimitry Andric   auto *Imag = dyn_cast<Instruction>(I);
911fe013be4SDimitry Andric   if (!Real || !Imag)
912fe013be4SDimitry Andric     return nullptr;
913fe013be4SDimitry Andric 
914fe013be4SDimitry Andric   if (NodePtr CN = identifyDeinterleave(Real, Imag))
915fe013be4SDimitry Andric     return CN;
916fe013be4SDimitry Andric 
917fe013be4SDimitry Andric   if (NodePtr CN = identifyPHINode(Real, Imag))
918fe013be4SDimitry Andric     return CN;
919fe013be4SDimitry Andric 
920fe013be4SDimitry Andric   if (NodePtr CN = identifySelectNode(Real, Imag))
921fe013be4SDimitry Andric     return CN;
922fe013be4SDimitry Andric 
923fe013be4SDimitry Andric   auto *VTy = cast<VectorType>(Real->getType());
924fe013be4SDimitry Andric   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
925fe013be4SDimitry Andric 
926fe013be4SDimitry Andric   bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
927fe013be4SDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, NewVTy);
928fe013be4SDimitry Andric   bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
929fe013be4SDimitry Andric       ComplexDeinterleavingOperation::CAdd, NewVTy);
930fe013be4SDimitry Andric 
931fe013be4SDimitry Andric   if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
932fe013be4SDimitry Andric     if (NodePtr CN = identifyPartialMul(Real, Imag))
933fe013be4SDimitry Andric       return CN;
934fe013be4SDimitry Andric   }
935fe013be4SDimitry Andric 
936fe013be4SDimitry Andric   if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
937fe013be4SDimitry Andric     if (NodePtr CN = identifyAdd(Real, Imag))
938fe013be4SDimitry Andric       return CN;
939fe013be4SDimitry Andric   }
940fe013be4SDimitry Andric 
941fe013be4SDimitry Andric   if (HasCMulSupport && HasCAddSupport) {
942fe013be4SDimitry Andric     if (NodePtr CN = identifyReassocNodes(Real, Imag))
943fe013be4SDimitry Andric       return CN;
944fe013be4SDimitry Andric   }
945fe013be4SDimitry Andric 
946fe013be4SDimitry Andric   if (NodePtr CN = identifySymmetricOperation(Real, Imag))
947fe013be4SDimitry Andric     return CN;
948fe013be4SDimitry Andric 
949fe013be4SDimitry Andric   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
950271697daSDimitry Andric   CachedResult[{R, I}] = nullptr;
951fe013be4SDimitry Andric   return nullptr;
952fe013be4SDimitry Andric }
953fe013be4SDimitry Andric 
954fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyReassocNodes(Instruction * Real,Instruction * Imag)955fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
956fe013be4SDimitry Andric                                                  Instruction *Imag) {
957fe013be4SDimitry Andric   auto IsOperationSupported = [](unsigned Opcode) -> bool {
958fe013be4SDimitry Andric     return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
959fe013be4SDimitry Andric            Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
960fe013be4SDimitry Andric            Opcode == Instruction::Sub;
961fe013be4SDimitry Andric   };
962fe013be4SDimitry Andric 
963fe013be4SDimitry Andric   if (!IsOperationSupported(Real->getOpcode()) ||
964fe013be4SDimitry Andric       !IsOperationSupported(Imag->getOpcode()))
965fe013be4SDimitry Andric     return nullptr;
966fe013be4SDimitry Andric 
967fe013be4SDimitry Andric   std::optional<FastMathFlags> Flags;
968fe013be4SDimitry Andric   if (isa<FPMathOperator>(Real)) {
969fe013be4SDimitry Andric     if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
970fe013be4SDimitry Andric       LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
971fe013be4SDimitry Andric                            "not identical\n");
972fe013be4SDimitry Andric       return nullptr;
973fe013be4SDimitry Andric     }
974fe013be4SDimitry Andric 
975fe013be4SDimitry Andric     Flags = Real->getFastMathFlags();
976fe013be4SDimitry Andric     if (!Flags->allowReassoc()) {
977fe013be4SDimitry Andric       LLVM_DEBUG(
978fe013be4SDimitry Andric           dbgs()
979fe013be4SDimitry Andric           << "the 'Reassoc' attribute is missing in the FastMath flags\n");
980fe013be4SDimitry Andric       return nullptr;
981fe013be4SDimitry Andric     }
982fe013be4SDimitry Andric   }
983fe013be4SDimitry Andric 
984fe013be4SDimitry Andric   // Collect multiplications and addend instructions from the given instruction
985fe013be4SDimitry Andric   // while traversing it operands. Additionally, verify that all instructions
986fe013be4SDimitry Andric   // have the same fast math flags.
987fe013be4SDimitry Andric   auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
988fe013be4SDimitry Andric                           std::list<Addend> &Addends) -> bool {
989fe013be4SDimitry Andric     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
990fe013be4SDimitry Andric     SmallPtrSet<Value *, 8> Visited;
991fe013be4SDimitry Andric     while (!Worklist.empty()) {
992fe013be4SDimitry Andric       auto [V, IsPositive] = Worklist.back();
993fe013be4SDimitry Andric       Worklist.pop_back();
994fe013be4SDimitry Andric       if (!Visited.insert(V).second)
995fe013be4SDimitry Andric         continue;
996fe013be4SDimitry Andric 
997fe013be4SDimitry Andric       Instruction *I = dyn_cast<Instruction>(V);
998fe013be4SDimitry Andric       if (!I) {
999fe013be4SDimitry Andric         Addends.emplace_back(V, IsPositive);
1000fe013be4SDimitry Andric         continue;
1001fe013be4SDimitry Andric       }
1002fe013be4SDimitry Andric 
1003fe013be4SDimitry Andric       // If an instruction has more than one user, it indicates that it either
1004fe013be4SDimitry Andric       // has an external user, which will be later checked by the checkNodes
1005fe013be4SDimitry Andric       // function, or it is a subexpression utilized by multiple expressions. In
1006fe013be4SDimitry Andric       // the latter case, we will attempt to separately identify the complex
1007fe013be4SDimitry Andric       // operation from here in order to create a shared
1008fe013be4SDimitry Andric       // ComplexDeinterleavingCompositeNode.
1009fe013be4SDimitry Andric       if (I != Insn && I->getNumUses() > 1) {
1010fe013be4SDimitry Andric         LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1011fe013be4SDimitry Andric         Addends.emplace_back(I, IsPositive);
1012fe013be4SDimitry Andric         continue;
1013fe013be4SDimitry Andric       }
1014fe013be4SDimitry Andric       switch (I->getOpcode()) {
1015fe013be4SDimitry Andric       case Instruction::FAdd:
1016fe013be4SDimitry Andric       case Instruction::Add:
1017fe013be4SDimitry Andric         Worklist.emplace_back(I->getOperand(1), IsPositive);
1018fe013be4SDimitry Andric         Worklist.emplace_back(I->getOperand(0), IsPositive);
1019fe013be4SDimitry Andric         break;
1020fe013be4SDimitry Andric       case Instruction::FSub:
1021fe013be4SDimitry Andric         Worklist.emplace_back(I->getOperand(1), !IsPositive);
1022fe013be4SDimitry Andric         Worklist.emplace_back(I->getOperand(0), IsPositive);
1023fe013be4SDimitry Andric         break;
1024fe013be4SDimitry Andric       case Instruction::Sub:
1025fe013be4SDimitry Andric         if (isNeg(I)) {
1026fe013be4SDimitry Andric           Worklist.emplace_back(getNegOperand(I), !IsPositive);
1027fe013be4SDimitry Andric         } else {
1028fe013be4SDimitry Andric           Worklist.emplace_back(I->getOperand(1), !IsPositive);
1029fe013be4SDimitry Andric           Worklist.emplace_back(I->getOperand(0), IsPositive);
1030fe013be4SDimitry Andric         }
1031fe013be4SDimitry Andric         break;
1032fe013be4SDimitry Andric       case Instruction::FMul:
1033fe013be4SDimitry Andric       case Instruction::Mul: {
1034fe013be4SDimitry Andric         Value *A, *B;
1035fe013be4SDimitry Andric         if (isNeg(I->getOperand(0))) {
1036fe013be4SDimitry Andric           A = getNegOperand(I->getOperand(0));
1037fe013be4SDimitry Andric           IsPositive = !IsPositive;
1038fe013be4SDimitry Andric         } else {
1039fe013be4SDimitry Andric           A = I->getOperand(0);
1040fe013be4SDimitry Andric         }
1041fe013be4SDimitry Andric 
1042fe013be4SDimitry Andric         if (isNeg(I->getOperand(1))) {
1043fe013be4SDimitry Andric           B = getNegOperand(I->getOperand(1));
1044fe013be4SDimitry Andric           IsPositive = !IsPositive;
1045fe013be4SDimitry Andric         } else {
1046fe013be4SDimitry Andric           B = I->getOperand(1);
1047fe013be4SDimitry Andric         }
1048fe013be4SDimitry Andric         Muls.push_back(Product{A, B, IsPositive});
1049fe013be4SDimitry Andric         break;
1050fe013be4SDimitry Andric       }
1051fe013be4SDimitry Andric       case Instruction::FNeg:
1052fe013be4SDimitry Andric         Worklist.emplace_back(I->getOperand(0), !IsPositive);
1053fe013be4SDimitry Andric         break;
1054fe013be4SDimitry Andric       default:
1055fe013be4SDimitry Andric         Addends.emplace_back(I, IsPositive);
1056fe013be4SDimitry Andric         continue;
1057fe013be4SDimitry Andric       }
1058fe013be4SDimitry Andric 
1059fe013be4SDimitry Andric       if (Flags && I->getFastMathFlags() != *Flags) {
1060fe013be4SDimitry Andric         LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1061fe013be4SDimitry Andric                              "inconsistent with the root instructions' flags: "
1062fe013be4SDimitry Andric                           << *I << "\n");
1063fe013be4SDimitry Andric         return false;
1064fe013be4SDimitry Andric       }
1065fe013be4SDimitry Andric     }
1066fe013be4SDimitry Andric     return true;
1067fe013be4SDimitry Andric   };
1068fe013be4SDimitry Andric 
1069fe013be4SDimitry Andric   std::vector<Product> RealMuls, ImagMuls;
1070fe013be4SDimitry Andric   std::list<Addend> RealAddends, ImagAddends;
1071fe013be4SDimitry Andric   if (!Collect(Real, RealMuls, RealAddends) ||
1072fe013be4SDimitry Andric       !Collect(Imag, ImagMuls, ImagAddends))
1073fe013be4SDimitry Andric     return nullptr;
1074fe013be4SDimitry Andric 
1075fe013be4SDimitry Andric   if (RealAddends.size() != ImagAddends.size())
1076fe013be4SDimitry Andric     return nullptr;
1077fe013be4SDimitry Andric 
1078fe013be4SDimitry Andric   NodePtr FinalNode;
1079fe013be4SDimitry Andric   if (!RealMuls.empty() || !ImagMuls.empty()) {
1080fe013be4SDimitry Andric     // If there are multiplicands, extract positive addend and use it as an
1081fe013be4SDimitry Andric     // accumulator
1082fe013be4SDimitry Andric     FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1083fe013be4SDimitry Andric     FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1084fe013be4SDimitry Andric     if (!FinalNode)
1085fe013be4SDimitry Andric       return nullptr;
1086fe013be4SDimitry Andric   }
1087fe013be4SDimitry Andric 
1088fe013be4SDimitry Andric   // Identify and process remaining additions
1089fe013be4SDimitry Andric   if (!RealAddends.empty() || !ImagAddends.empty()) {
1090fe013be4SDimitry Andric     FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1091fe013be4SDimitry Andric     if (!FinalNode)
1092fe013be4SDimitry Andric       return nullptr;
1093fe013be4SDimitry Andric   }
1094fe013be4SDimitry Andric   assert(FinalNode && "FinalNode can not be nullptr here");
1095fe013be4SDimitry Andric   // Set the Real and Imag fields of the final node and submit it
1096fe013be4SDimitry Andric   FinalNode->Real = Real;
1097fe013be4SDimitry Andric   FinalNode->Imag = Imag;
1098fe013be4SDimitry Andric   submitCompositeNode(FinalNode);
1099fe013be4SDimitry Andric   return FinalNode;
1100fe013be4SDimitry Andric }
1101fe013be4SDimitry Andric 
collectPartialMuls(const std::vector<Product> & RealMuls,const std::vector<Product> & ImagMuls,std::vector<PartialMulCandidate> & PartialMulCandidates)1102fe013be4SDimitry Andric bool ComplexDeinterleavingGraph::collectPartialMuls(
1103fe013be4SDimitry Andric     const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1104fe013be4SDimitry Andric     std::vector<PartialMulCandidate> &PartialMulCandidates) {
1105fe013be4SDimitry Andric   // Helper function to extract a common operand from two products
1106fe013be4SDimitry Andric   auto FindCommonInstruction = [](const Product &Real,
1107fe013be4SDimitry Andric                                   const Product &Imag) -> Value * {
1108fe013be4SDimitry Andric     if (Real.Multiplicand == Imag.Multiplicand ||
1109fe013be4SDimitry Andric         Real.Multiplicand == Imag.Multiplier)
1110fe013be4SDimitry Andric       return Real.Multiplicand;
1111fe013be4SDimitry Andric 
1112fe013be4SDimitry Andric     if (Real.Multiplier == Imag.Multiplicand ||
1113fe013be4SDimitry Andric         Real.Multiplier == Imag.Multiplier)
1114fe013be4SDimitry Andric       return Real.Multiplier;
1115fe013be4SDimitry Andric 
1116fe013be4SDimitry Andric     return nullptr;
1117fe013be4SDimitry Andric   };
1118fe013be4SDimitry Andric 
1119fe013be4SDimitry Andric   // Iterating over real and imaginary multiplications to find common operands
1120fe013be4SDimitry Andric   // If a common operand is found, a partial multiplication candidate is created
1121fe013be4SDimitry Andric   // and added to the candidates vector The function returns false if no common
1122fe013be4SDimitry Andric   // operands are found for any product
1123fe013be4SDimitry Andric   for (unsigned i = 0; i < RealMuls.size(); ++i) {
1124fe013be4SDimitry Andric     bool FoundCommon = false;
1125fe013be4SDimitry Andric     for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1126fe013be4SDimitry Andric       auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1127fe013be4SDimitry Andric       if (!Common)
1128fe013be4SDimitry Andric         continue;
1129fe013be4SDimitry Andric 
1130fe013be4SDimitry Andric       auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1131fe013be4SDimitry Andric                                                    : RealMuls[i].Multiplicand;
1132fe013be4SDimitry Andric       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1133fe013be4SDimitry Andric                                                    : ImagMuls[j].Multiplicand;
1134fe013be4SDimitry Andric 
1135fe013be4SDimitry Andric       auto Node = identifyNode(A, B);
1136fe013be4SDimitry Andric       if (Node) {
1137fe013be4SDimitry Andric         FoundCommon = true;
1138fe013be4SDimitry Andric         PartialMulCandidates.push_back({Common, Node, i, j, false});
1139fe013be4SDimitry Andric       }
1140fe013be4SDimitry Andric 
1141fe013be4SDimitry Andric       Node = identifyNode(B, A);
1142fe013be4SDimitry Andric       if (Node) {
1143fe013be4SDimitry Andric         FoundCommon = true;
1144fe013be4SDimitry Andric         PartialMulCandidates.push_back({Common, Node, i, j, true});
1145fe013be4SDimitry Andric       }
1146fe013be4SDimitry Andric     }
1147fe013be4SDimitry Andric     if (!FoundCommon)
1148fe013be4SDimitry Andric       return false;
1149fe013be4SDimitry Andric   }
1150fe013be4SDimitry Andric   return true;
1151fe013be4SDimitry Andric }
1152fe013be4SDimitry Andric 
1153fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyMultiplications(std::vector<Product> & RealMuls,std::vector<Product> & ImagMuls,NodePtr Accumulator=nullptr)1154fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyMultiplications(
1155fe013be4SDimitry Andric     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1156fe013be4SDimitry Andric     NodePtr Accumulator = nullptr) {
1157fe013be4SDimitry Andric   if (RealMuls.size() != ImagMuls.size())
1158fe013be4SDimitry Andric     return nullptr;
1159fe013be4SDimitry Andric 
1160fe013be4SDimitry Andric   std::vector<PartialMulCandidate> Info;
1161fe013be4SDimitry Andric   if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1162fe013be4SDimitry Andric     return nullptr;
1163fe013be4SDimitry Andric 
1164fe013be4SDimitry Andric   // Map to store common instruction to node pointers
1165fe013be4SDimitry Andric   std::map<Value *, NodePtr> CommonToNode;
1166fe013be4SDimitry Andric   std::vector<bool> Processed(Info.size(), false);
1167fe013be4SDimitry Andric   for (unsigned I = 0; I < Info.size(); ++I) {
1168fe013be4SDimitry Andric     if (Processed[I])
1169fe013be4SDimitry Andric       continue;
1170fe013be4SDimitry Andric 
1171fe013be4SDimitry Andric     PartialMulCandidate &InfoA = Info[I];
1172fe013be4SDimitry Andric     for (unsigned J = I + 1; J < Info.size(); ++J) {
1173fe013be4SDimitry Andric       if (Processed[J])
1174fe013be4SDimitry Andric         continue;
1175fe013be4SDimitry Andric 
1176fe013be4SDimitry Andric       PartialMulCandidate &InfoB = Info[J];
1177fe013be4SDimitry Andric       auto *InfoReal = &InfoA;
1178fe013be4SDimitry Andric       auto *InfoImag = &InfoB;
1179fe013be4SDimitry Andric 
1180fe013be4SDimitry Andric       auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1181fe013be4SDimitry Andric       if (!NodeFromCommon) {
1182fe013be4SDimitry Andric         std::swap(InfoReal, InfoImag);
1183fe013be4SDimitry Andric         NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1184fe013be4SDimitry Andric       }
1185fe013be4SDimitry Andric       if (!NodeFromCommon)
1186fe013be4SDimitry Andric         continue;
1187fe013be4SDimitry Andric 
1188fe013be4SDimitry Andric       CommonToNode[InfoReal->Common] = NodeFromCommon;
1189fe013be4SDimitry Andric       CommonToNode[InfoImag->Common] = NodeFromCommon;
1190fe013be4SDimitry Andric       Processed[I] = true;
1191fe013be4SDimitry Andric       Processed[J] = true;
1192fe013be4SDimitry Andric     }
1193fe013be4SDimitry Andric   }
1194fe013be4SDimitry Andric 
1195fe013be4SDimitry Andric   std::vector<bool> ProcessedReal(RealMuls.size(), false);
1196fe013be4SDimitry Andric   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1197fe013be4SDimitry Andric   NodePtr Result = Accumulator;
1198fe013be4SDimitry Andric   for (auto &PMI : Info) {
1199fe013be4SDimitry Andric     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1200fe013be4SDimitry Andric       continue;
1201fe013be4SDimitry Andric 
1202fe013be4SDimitry Andric     auto It = CommonToNode.find(PMI.Common);
1203fe013be4SDimitry Andric     // TODO: Process independent complex multiplications. Cases like this:
1204fe013be4SDimitry Andric     //  A.real() * B where both A and B are complex numbers.
1205fe013be4SDimitry Andric     if (It == CommonToNode.end()) {
1206fe013be4SDimitry Andric       LLVM_DEBUG({
1207fe013be4SDimitry Andric         dbgs() << "Unprocessed independent partial multiplication:\n";
1208fe013be4SDimitry Andric         for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1209fe013be4SDimitry Andric           dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1210fe013be4SDimitry Andric                            << " multiplied by " << *Mul->Multiplicand << "\n";
1211fe013be4SDimitry Andric       });
1212fe013be4SDimitry Andric       return nullptr;
1213fe013be4SDimitry Andric     }
1214fe013be4SDimitry Andric 
1215fe013be4SDimitry Andric     auto &RealMul = RealMuls[PMI.RealIdx];
1216fe013be4SDimitry Andric     auto &ImagMul = ImagMuls[PMI.ImagIdx];
1217fe013be4SDimitry Andric 
1218fe013be4SDimitry Andric     auto NodeA = It->second;
1219fe013be4SDimitry Andric     auto NodeB = PMI.Node;
1220fe013be4SDimitry Andric     auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1221fe013be4SDimitry Andric     // The following table illustrates the relationship between multiplications
1222fe013be4SDimitry Andric     // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1223fe013be4SDimitry Andric     // can see:
1224fe013be4SDimitry Andric     //
1225fe013be4SDimitry Andric     // Rotation |   Real |   Imag |
1226fe013be4SDimitry Andric     // ---------+--------+--------+
1227fe013be4SDimitry Andric     //        0 |  x * u |  x * v |
1228fe013be4SDimitry Andric     //       90 | -y * v |  y * u |
1229fe013be4SDimitry Andric     //      180 | -x * u | -x * v |
1230fe013be4SDimitry Andric     //      270 |  y * v | -y * u |
1231fe013be4SDimitry Andric     //
1232fe013be4SDimitry Andric     // Check if the candidate can indeed be represented by partial
1233fe013be4SDimitry Andric     // multiplication
1234fe013be4SDimitry Andric     // TODO: Add support for multiplication by complex one
1235fe013be4SDimitry Andric     if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1236fe013be4SDimitry Andric         (!IsMultiplicandReal && !PMI.IsNodeInverted))
1237fe013be4SDimitry Andric       continue;
1238fe013be4SDimitry Andric 
1239fe013be4SDimitry Andric     // Determine the rotation based on the multiplications
1240fe013be4SDimitry Andric     ComplexDeinterleavingRotation Rotation;
1241fe013be4SDimitry Andric     if (IsMultiplicandReal) {
1242fe013be4SDimitry Andric       // Detect 0 and 180 degrees rotation
1243fe013be4SDimitry Andric       if (RealMul.IsPositive && ImagMul.IsPositive)
1244fe013be4SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1245fe013be4SDimitry Andric       else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1246fe013be4SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1247fe013be4SDimitry Andric       else
1248fe013be4SDimitry Andric         continue;
1249fe013be4SDimitry Andric 
1250fe013be4SDimitry Andric     } else {
1251fe013be4SDimitry Andric       // Detect 90 and 270 degrees rotation
1252fe013be4SDimitry Andric       if (!RealMul.IsPositive && ImagMul.IsPositive)
1253fe013be4SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1254fe013be4SDimitry Andric       else if (RealMul.IsPositive && !ImagMul.IsPositive)
1255fe013be4SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1256fe013be4SDimitry Andric       else
1257fe013be4SDimitry Andric         continue;
1258fe013be4SDimitry Andric     }
1259fe013be4SDimitry Andric 
1260fe013be4SDimitry Andric     LLVM_DEBUG({
1261fe013be4SDimitry Andric       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1262fe013be4SDimitry Andric       dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1263fe013be4SDimitry Andric       dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1264fe013be4SDimitry Andric       dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1265fe013be4SDimitry Andric       dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1266fe013be4SDimitry Andric       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1267fe013be4SDimitry Andric     });
1268fe013be4SDimitry Andric 
1269fe013be4SDimitry Andric     NodePtr NodeMul = prepareCompositeNode(
1270fe013be4SDimitry Andric         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1271fe013be4SDimitry Andric     NodeMul->Rotation = Rotation;
1272fe013be4SDimitry Andric     NodeMul->addOperand(NodeA);
1273fe013be4SDimitry Andric     NodeMul->addOperand(NodeB);
1274fe013be4SDimitry Andric     if (Result)
1275fe013be4SDimitry Andric       NodeMul->addOperand(Result);
1276fe013be4SDimitry Andric     submitCompositeNode(NodeMul);
1277fe013be4SDimitry Andric     Result = NodeMul;
1278fe013be4SDimitry Andric     ProcessedReal[PMI.RealIdx] = true;
1279fe013be4SDimitry Andric     ProcessedImag[PMI.ImagIdx] = true;
1280fe013be4SDimitry Andric   }
1281fe013be4SDimitry Andric 
1282fe013be4SDimitry Andric   // Ensure all products have been processed, if not return nullptr.
1283fe013be4SDimitry Andric   if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1284fe013be4SDimitry Andric       !all_of(ProcessedImag, [](bool V) { return V; })) {
1285fe013be4SDimitry Andric 
1286fe013be4SDimitry Andric     // Dump debug information about which partial multiplications are not
1287fe013be4SDimitry Andric     // processed.
1288fe013be4SDimitry Andric     LLVM_DEBUG({
1289fe013be4SDimitry Andric       dbgs() << "Unprocessed products (Real):\n";
1290fe013be4SDimitry Andric       for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1291fe013be4SDimitry Andric         if (!ProcessedReal[i])
1292fe013be4SDimitry Andric           dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1293fe013be4SDimitry Andric                            << *RealMuls[i].Multiplier << " multiplied by "
1294fe013be4SDimitry Andric                            << *RealMuls[i].Multiplicand << "\n";
1295fe013be4SDimitry Andric       }
1296fe013be4SDimitry Andric       dbgs() << "Unprocessed products (Imag):\n";
1297fe013be4SDimitry Andric       for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1298fe013be4SDimitry Andric         if (!ProcessedImag[i])
1299fe013be4SDimitry Andric           dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1300fe013be4SDimitry Andric                            << *ImagMuls[i].Multiplier << " multiplied by "
1301fe013be4SDimitry Andric                            << *ImagMuls[i].Multiplicand << "\n";
1302fe013be4SDimitry Andric       }
1303fe013be4SDimitry Andric     });
1304fe013be4SDimitry Andric     return nullptr;
1305fe013be4SDimitry Andric   }
1306fe013be4SDimitry Andric 
1307fe013be4SDimitry Andric   return Result;
1308fe013be4SDimitry Andric }
1309fe013be4SDimitry Andric 
1310fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyAdditions(std::list<Addend> & RealAddends,std::list<Addend> & ImagAddends,std::optional<FastMathFlags> Flags,NodePtr Accumulator=nullptr)1311fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyAdditions(
1312fe013be4SDimitry Andric     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1313fe013be4SDimitry Andric     std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1314fe013be4SDimitry Andric   if (RealAddends.size() != ImagAddends.size())
1315fe013be4SDimitry Andric     return nullptr;
1316fe013be4SDimitry Andric 
1317fe013be4SDimitry Andric   NodePtr Result;
1318fe013be4SDimitry Andric   // If we have accumulator use it as first addend
1319fe013be4SDimitry Andric   if (Accumulator)
1320fe013be4SDimitry Andric     Result = Accumulator;
1321fe013be4SDimitry Andric   // Otherwise find an element with both positive real and imaginary parts.
1322fe013be4SDimitry Andric   else
1323fe013be4SDimitry Andric     Result = extractPositiveAddend(RealAddends, ImagAddends);
1324fe013be4SDimitry Andric 
1325fe013be4SDimitry Andric   if (!Result)
1326fe013be4SDimitry Andric     return nullptr;
1327fe013be4SDimitry Andric 
1328fe013be4SDimitry Andric   while (!RealAddends.empty()) {
1329fe013be4SDimitry Andric     auto ItR = RealAddends.begin();
1330fe013be4SDimitry Andric     auto [R, IsPositiveR] = *ItR;
1331fe013be4SDimitry Andric 
1332fe013be4SDimitry Andric     bool FoundImag = false;
1333fe013be4SDimitry Andric     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1334fe013be4SDimitry Andric       auto [I, IsPositiveI] = *ItI;
1335fe013be4SDimitry Andric       ComplexDeinterleavingRotation Rotation;
1336fe013be4SDimitry Andric       if (IsPositiveR && IsPositiveI)
1337fe013be4SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_0;
1338fe013be4SDimitry Andric       else if (!IsPositiveR && IsPositiveI)
1339fe013be4SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_90;
1340fe013be4SDimitry Andric       else if (!IsPositiveR && !IsPositiveI)
1341fe013be4SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_180;
1342fe013be4SDimitry Andric       else
1343fe013be4SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_270;
1344fe013be4SDimitry Andric 
1345fe013be4SDimitry Andric       NodePtr AddNode;
1346fe013be4SDimitry Andric       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1347fe013be4SDimitry Andric           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1348fe013be4SDimitry Andric         AddNode = identifyNode(R, I);
1349fe013be4SDimitry Andric       } else {
1350fe013be4SDimitry Andric         AddNode = identifyNode(I, R);
1351fe013be4SDimitry Andric       }
1352fe013be4SDimitry Andric       if (AddNode) {
1353fe013be4SDimitry Andric         LLVM_DEBUG({
1354fe013be4SDimitry Andric           dbgs() << "Identified addition:\n";
1355fe013be4SDimitry Andric           dbgs().indent(4) << "X: " << *R << "\n";
1356fe013be4SDimitry Andric           dbgs().indent(4) << "Y: " << *I << "\n";
1357fe013be4SDimitry Andric           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1358fe013be4SDimitry Andric         });
1359fe013be4SDimitry Andric 
1360fe013be4SDimitry Andric         NodePtr TmpNode;
1361fe013be4SDimitry Andric         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1362fe013be4SDimitry Andric           TmpNode = prepareCompositeNode(
1363fe013be4SDimitry Andric               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1364fe013be4SDimitry Andric           if (Flags) {
1365fe013be4SDimitry Andric             TmpNode->Opcode = Instruction::FAdd;
1366fe013be4SDimitry Andric             TmpNode->Flags = *Flags;
1367fe013be4SDimitry Andric           } else {
1368fe013be4SDimitry Andric             TmpNode->Opcode = Instruction::Add;
1369fe013be4SDimitry Andric           }
1370fe013be4SDimitry Andric         } else if (Rotation ==
1371fe013be4SDimitry Andric                    llvm::ComplexDeinterleavingRotation::Rotation_180) {
1372fe013be4SDimitry Andric           TmpNode = prepareCompositeNode(
1373fe013be4SDimitry Andric               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1374fe013be4SDimitry Andric           if (Flags) {
1375fe013be4SDimitry Andric             TmpNode->Opcode = Instruction::FSub;
1376fe013be4SDimitry Andric             TmpNode->Flags = *Flags;
1377fe013be4SDimitry Andric           } else {
1378fe013be4SDimitry Andric             TmpNode->Opcode = Instruction::Sub;
1379fe013be4SDimitry Andric           }
1380fe013be4SDimitry Andric         } else {
1381fe013be4SDimitry Andric           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1382fe013be4SDimitry Andric                                          nullptr, nullptr);
1383fe013be4SDimitry Andric           TmpNode->Rotation = Rotation;
1384fe013be4SDimitry Andric         }
1385fe013be4SDimitry Andric 
1386fe013be4SDimitry Andric         TmpNode->addOperand(Result);
1387fe013be4SDimitry Andric         TmpNode->addOperand(AddNode);
1388fe013be4SDimitry Andric         submitCompositeNode(TmpNode);
1389fe013be4SDimitry Andric         Result = TmpNode;
1390fe013be4SDimitry Andric         RealAddends.erase(ItR);
1391fe013be4SDimitry Andric         ImagAddends.erase(ItI);
1392fe013be4SDimitry Andric         FoundImag = true;
1393fe013be4SDimitry Andric         break;
1394fe013be4SDimitry Andric       }
1395fe013be4SDimitry Andric     }
1396fe013be4SDimitry Andric     if (!FoundImag)
1397fe013be4SDimitry Andric       return nullptr;
1398fe013be4SDimitry Andric   }
1399fe013be4SDimitry Andric   return Result;
1400fe013be4SDimitry Andric }
1401fe013be4SDimitry Andric 
1402fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
extractPositiveAddend(std::list<Addend> & RealAddends,std::list<Addend> & ImagAddends)1403fe013be4SDimitry Andric ComplexDeinterleavingGraph::extractPositiveAddend(
1404fe013be4SDimitry Andric     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1405fe013be4SDimitry Andric   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1406fe013be4SDimitry Andric     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1407fe013be4SDimitry Andric       auto [R, IsPositiveR] = *ItR;
1408fe013be4SDimitry Andric       auto [I, IsPositiveI] = *ItI;
1409fe013be4SDimitry Andric       if (IsPositiveR && IsPositiveI) {
1410fe013be4SDimitry Andric         auto Result = identifyNode(R, I);
1411fe013be4SDimitry Andric         if (Result) {
1412fe013be4SDimitry Andric           RealAddends.erase(ItR);
1413fe013be4SDimitry Andric           ImagAddends.erase(ItI);
1414fe013be4SDimitry Andric           return Result;
1415fe013be4SDimitry Andric         }
1416fe013be4SDimitry Andric       }
1417fe013be4SDimitry Andric     }
1418fe013be4SDimitry Andric   }
1419fe013be4SDimitry Andric   return nullptr;
1420fe013be4SDimitry Andric }
1421fe013be4SDimitry Andric 
identifyNodes(Instruction * RootI)1422fe013be4SDimitry Andric bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1423fe013be4SDimitry Andric   // This potential root instruction might already have been recognized as
1424fe013be4SDimitry Andric   // reduction. Because RootToNode maps both Real and Imaginary parts to
1425fe013be4SDimitry Andric   // CompositeNode we should choose only one either Real or Imag instruction to
1426fe013be4SDimitry Andric   // use as an anchor for generating complex instruction.
1427fe013be4SDimitry Andric   auto It = RootToNode.find(RootI);
1428271697daSDimitry Andric   if (It != RootToNode.end()) {
1429271697daSDimitry Andric     auto RootNode = It->second;
1430271697daSDimitry Andric     assert(RootNode->Operation ==
1431271697daSDimitry Andric            ComplexDeinterleavingOperation::ReductionOperation);
1432271697daSDimitry Andric     // Find out which part, Real or Imag, comes later, and only if we come to
1433271697daSDimitry Andric     // the latest part, add it to OrderedRoots.
1434271697daSDimitry Andric     auto *R = cast<Instruction>(RootNode->Real);
1435271697daSDimitry Andric     auto *I = cast<Instruction>(RootNode->Imag);
1436271697daSDimitry Andric     auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1437271697daSDimitry Andric     if (ReplacementAnchor != RootI)
1438271697daSDimitry Andric       return false;
1439fe013be4SDimitry Andric     OrderedRoots.push_back(RootI);
1440fe013be4SDimitry Andric     return true;
1441fe013be4SDimitry Andric   }
1442fe013be4SDimitry Andric 
1443fe013be4SDimitry Andric   auto RootNode = identifyRoot(RootI);
1444fe013be4SDimitry Andric   if (!RootNode)
1445fe013be4SDimitry Andric     return false;
1446fe013be4SDimitry Andric 
1447fe013be4SDimitry Andric   LLVM_DEBUG({
1448fe013be4SDimitry Andric     Function *F = RootI->getFunction();
1449fe013be4SDimitry Andric     BasicBlock *B = RootI->getParent();
1450fe013be4SDimitry Andric     dbgs() << "Complex deinterleaving graph for " << F->getName()
1451fe013be4SDimitry Andric            << "::" << B->getName() << ".\n";
1452fe013be4SDimitry Andric     dump(dbgs());
1453fe013be4SDimitry Andric     dbgs() << "\n";
1454fe013be4SDimitry Andric   });
1455fe013be4SDimitry Andric   RootToNode[RootI] = RootNode;
1456fe013be4SDimitry Andric   OrderedRoots.push_back(RootI);
1457fe013be4SDimitry Andric   return true;
1458fe013be4SDimitry Andric }
1459fe013be4SDimitry Andric 
collectPotentialReductions(BasicBlock * B)1460fe013be4SDimitry Andric bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1461fe013be4SDimitry Andric   bool FoundPotentialReduction = false;
1462fe013be4SDimitry Andric 
1463fe013be4SDimitry Andric   auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1464fe013be4SDimitry Andric   if (!Br || Br->getNumSuccessors() != 2)
1465fe013be4SDimitry Andric     return false;
1466fe013be4SDimitry Andric 
1467fe013be4SDimitry Andric   // Identify simple one-block loop
1468fe013be4SDimitry Andric   if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1469fe013be4SDimitry Andric     return false;
1470fe013be4SDimitry Andric 
1471fe013be4SDimitry Andric   SmallVector<PHINode *> PHIs;
1472fe013be4SDimitry Andric   for (auto &PHI : B->phis()) {
1473fe013be4SDimitry Andric     if (PHI.getNumIncomingValues() != 2)
1474fe013be4SDimitry Andric       continue;
1475fe013be4SDimitry Andric 
1476fe013be4SDimitry Andric     if (!PHI.getType()->isVectorTy())
1477fe013be4SDimitry Andric       continue;
1478fe013be4SDimitry Andric 
1479fe013be4SDimitry Andric     auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1480fe013be4SDimitry Andric     if (!ReductionOp)
1481fe013be4SDimitry Andric       continue;
1482fe013be4SDimitry Andric 
1483fe013be4SDimitry Andric     // Check if final instruction is reduced outside of current block
1484fe013be4SDimitry Andric     Instruction *FinalReduction = nullptr;
1485fe013be4SDimitry Andric     auto NumUsers = 0u;
1486fe013be4SDimitry Andric     for (auto *U : ReductionOp->users()) {
1487fe013be4SDimitry Andric       ++NumUsers;
1488fe013be4SDimitry Andric       if (U == &PHI)
1489fe013be4SDimitry Andric         continue;
1490fe013be4SDimitry Andric       FinalReduction = dyn_cast<Instruction>(U);
1491fe013be4SDimitry Andric     }
1492fe013be4SDimitry Andric 
1493fe013be4SDimitry Andric     if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1494fe013be4SDimitry Andric         isa<PHINode>(FinalReduction))
1495fe013be4SDimitry Andric       continue;
1496fe013be4SDimitry Andric 
1497fe013be4SDimitry Andric     ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1498fe013be4SDimitry Andric     BackEdge = B;
1499fe013be4SDimitry Andric     auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1500fe013be4SDimitry Andric     auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1501fe013be4SDimitry Andric     Incoming = PHI.getIncomingBlock(IncomingIdx);
1502fe013be4SDimitry Andric     FoundPotentialReduction = true;
1503fe013be4SDimitry Andric 
1504fe013be4SDimitry Andric     // If the initial value of PHINode is an Instruction, consider it a leaf
1505fe013be4SDimitry Andric     // value of a complex deinterleaving graph.
1506fe013be4SDimitry Andric     if (auto *InitPHI =
1507fe013be4SDimitry Andric             dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1508fe013be4SDimitry Andric       FinalInstructions.insert(InitPHI);
1509fe013be4SDimitry Andric   }
1510fe013be4SDimitry Andric   return FoundPotentialReduction;
1511fe013be4SDimitry Andric }
1512fe013be4SDimitry Andric 
identifyReductionNodes()1513fe013be4SDimitry Andric void ComplexDeinterleavingGraph::identifyReductionNodes() {
1514fe013be4SDimitry Andric   SmallVector<bool> Processed(ReductionInfo.size(), false);
1515fe013be4SDimitry Andric   SmallVector<Instruction *> OperationInstruction;
1516fe013be4SDimitry Andric   for (auto &P : ReductionInfo)
1517fe013be4SDimitry Andric     OperationInstruction.push_back(P.first);
1518fe013be4SDimitry Andric 
1519fe013be4SDimitry Andric   // Identify a complex computation by evaluating two reduction operations that
1520fe013be4SDimitry Andric   // potentially could be involved
1521fe013be4SDimitry Andric   for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1522fe013be4SDimitry Andric     if (Processed[i])
1523fe013be4SDimitry Andric       continue;
1524fe013be4SDimitry Andric     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1525fe013be4SDimitry Andric       if (Processed[j])
1526fe013be4SDimitry Andric         continue;
1527fe013be4SDimitry Andric 
1528fe013be4SDimitry Andric       auto *Real = OperationInstruction[i];
1529fe013be4SDimitry Andric       auto *Imag = OperationInstruction[j];
1530fe013be4SDimitry Andric       if (Real->getType() != Imag->getType())
1531fe013be4SDimitry Andric         continue;
1532fe013be4SDimitry Andric 
1533fe013be4SDimitry Andric       RealPHI = ReductionInfo[Real].first;
1534fe013be4SDimitry Andric       ImagPHI = ReductionInfo[Imag].first;
1535fe013be4SDimitry Andric       PHIsFound = false;
1536fe013be4SDimitry Andric       auto Node = identifyNode(Real, Imag);
1537fe013be4SDimitry Andric       if (!Node) {
1538fe013be4SDimitry Andric         std::swap(Real, Imag);
1539fe013be4SDimitry Andric         std::swap(RealPHI, ImagPHI);
1540fe013be4SDimitry Andric         Node = identifyNode(Real, Imag);
1541fe013be4SDimitry Andric       }
1542fe013be4SDimitry Andric 
1543fe013be4SDimitry Andric       // If a node is identified and reduction PHINode is used in the chain of
1544fe013be4SDimitry Andric       // operations, mark its operation instructions as used to prevent
1545fe013be4SDimitry Andric       // re-identification and attach the node to the real part
1546fe013be4SDimitry Andric       if (Node && PHIsFound) {
1547fe013be4SDimitry Andric         LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1548fe013be4SDimitry Andric                           << *Real << " / " << *Imag << "\n");
1549fe013be4SDimitry Andric         Processed[i] = true;
1550fe013be4SDimitry Andric         Processed[j] = true;
1551fe013be4SDimitry Andric         auto RootNode = prepareCompositeNode(
1552fe013be4SDimitry Andric             ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1553fe013be4SDimitry Andric         RootNode->addOperand(Node);
1554fe013be4SDimitry Andric         RootToNode[Real] = RootNode;
1555fe013be4SDimitry Andric         RootToNode[Imag] = RootNode;
1556fe013be4SDimitry Andric         submitCompositeNode(RootNode);
1557fe013be4SDimitry Andric         break;
1558fe013be4SDimitry Andric       }
1559fe013be4SDimitry Andric     }
1560fe013be4SDimitry Andric   }
1561fe013be4SDimitry Andric 
1562fe013be4SDimitry Andric   RealPHI = nullptr;
1563fe013be4SDimitry Andric   ImagPHI = nullptr;
1564fe013be4SDimitry Andric }
1565fe013be4SDimitry Andric 
checkNodes()1566fe013be4SDimitry Andric bool ComplexDeinterleavingGraph::checkNodes() {
1567fe013be4SDimitry Andric   // Collect all instructions from roots to leaves
1568fe013be4SDimitry Andric   SmallPtrSet<Instruction *, 16> AllInstructions;
1569fe013be4SDimitry Andric   SmallVector<Instruction *, 8> Worklist;
1570fe013be4SDimitry Andric   for (auto &Pair : RootToNode)
1571fe013be4SDimitry Andric     Worklist.push_back(Pair.first);
1572fe013be4SDimitry Andric 
1573fe013be4SDimitry Andric   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1574fe013be4SDimitry Andric   // chains
1575fe013be4SDimitry Andric   while (!Worklist.empty()) {
1576fe013be4SDimitry Andric     auto *I = Worklist.back();
1577fe013be4SDimitry Andric     Worklist.pop_back();
1578fe013be4SDimitry Andric 
1579fe013be4SDimitry Andric     if (!AllInstructions.insert(I).second)
1580fe013be4SDimitry Andric       continue;
1581fe013be4SDimitry Andric 
1582fe013be4SDimitry Andric     for (Value *Op : I->operands()) {
1583fe013be4SDimitry Andric       if (auto *OpI = dyn_cast<Instruction>(Op)) {
1584fe013be4SDimitry Andric         if (!FinalInstructions.count(I))
1585fe013be4SDimitry Andric           Worklist.emplace_back(OpI);
1586fe013be4SDimitry Andric       }
1587fe013be4SDimitry Andric     }
1588fe013be4SDimitry Andric   }
1589fe013be4SDimitry Andric 
1590fe013be4SDimitry Andric   // Find instructions that have users outside of chain
1591fe013be4SDimitry Andric   SmallVector<Instruction *, 2> OuterInstructions;
1592fe013be4SDimitry Andric   for (auto *I : AllInstructions) {
1593fe013be4SDimitry Andric     // Skip root nodes
1594fe013be4SDimitry Andric     if (RootToNode.count(I))
1595fe013be4SDimitry Andric       continue;
1596fe013be4SDimitry Andric 
1597fe013be4SDimitry Andric     for (User *U : I->users()) {
1598fe013be4SDimitry Andric       if (AllInstructions.count(cast<Instruction>(U)))
1599fe013be4SDimitry Andric         continue;
1600fe013be4SDimitry Andric 
1601fe013be4SDimitry Andric       // Found an instruction that is not used by XCMLA/XCADD chain
1602fe013be4SDimitry Andric       Worklist.emplace_back(I);
1603fe013be4SDimitry Andric       break;
1604fe013be4SDimitry Andric     }
1605fe013be4SDimitry Andric   }
1606fe013be4SDimitry Andric 
1607fe013be4SDimitry Andric   // If any instructions are found to be used outside, find and remove roots
1608fe013be4SDimitry Andric   // that somehow connect to those instructions.
1609fe013be4SDimitry Andric   SmallPtrSet<Instruction *, 16> Visited;
1610fe013be4SDimitry Andric   while (!Worklist.empty()) {
1611fe013be4SDimitry Andric     auto *I = Worklist.back();
1612fe013be4SDimitry Andric     Worklist.pop_back();
1613fe013be4SDimitry Andric     if (!Visited.insert(I).second)
1614fe013be4SDimitry Andric       continue;
1615fe013be4SDimitry Andric 
1616fe013be4SDimitry Andric     // Found an impacted root node. Removing it from the nodes to be
1617fe013be4SDimitry Andric     // deinterleaved
1618fe013be4SDimitry Andric     if (RootToNode.count(I)) {
1619fe013be4SDimitry Andric       LLVM_DEBUG(dbgs() << "Instruction " << *I
1620fe013be4SDimitry Andric                         << " could be deinterleaved but its chain of complex "
1621fe013be4SDimitry Andric                            "operations have an outside user\n");
1622fe013be4SDimitry Andric       RootToNode.erase(I);
1623fe013be4SDimitry Andric     }
1624fe013be4SDimitry Andric 
1625fe013be4SDimitry Andric     if (!AllInstructions.count(I) || FinalInstructions.count(I))
1626fe013be4SDimitry Andric       continue;
1627fe013be4SDimitry Andric 
1628fe013be4SDimitry Andric     for (User *U : I->users())
1629fe013be4SDimitry Andric       Worklist.emplace_back(cast<Instruction>(U));
1630fe013be4SDimitry Andric 
1631fe013be4SDimitry Andric     for (Value *Op : I->operands()) {
1632fe013be4SDimitry Andric       if (auto *OpI = dyn_cast<Instruction>(Op))
1633fe013be4SDimitry Andric         Worklist.emplace_back(OpI);
1634fe013be4SDimitry Andric     }
1635fe013be4SDimitry Andric   }
1636fe013be4SDimitry Andric   return !RootToNode.empty();
1637fe013be4SDimitry Andric }
1638fe013be4SDimitry Andric 
1639fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyRoot(Instruction * RootI)1640fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1641fe013be4SDimitry Andric   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1642fe013be4SDimitry Andric     if (Intrinsic->getIntrinsicID() !=
1643fe013be4SDimitry Andric         Intrinsic::experimental_vector_interleave2)
1644fe013be4SDimitry Andric       return nullptr;
1645fe013be4SDimitry Andric 
1646fe013be4SDimitry Andric     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1647fe013be4SDimitry Andric     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1648fe013be4SDimitry Andric     if (!Real || !Imag)
1649fe013be4SDimitry Andric       return nullptr;
1650fe013be4SDimitry Andric 
1651fe013be4SDimitry Andric     return identifyNode(Real, Imag);
1652fe013be4SDimitry Andric   }
1653fe013be4SDimitry Andric 
1654fe013be4SDimitry Andric   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1655fe013be4SDimitry Andric   if (!SVI)
1656fe013be4SDimitry Andric     return nullptr;
1657fe013be4SDimitry Andric 
1658fe013be4SDimitry Andric   // Look for a shufflevector that takes separate vectors of the real and
1659fe013be4SDimitry Andric   // imaginary components and recombines them into a single vector.
1660fe013be4SDimitry Andric   if (!isInterleavingMask(SVI->getShuffleMask()))
1661fe013be4SDimitry Andric     return nullptr;
1662fe013be4SDimitry Andric 
1663fe013be4SDimitry Andric   Instruction *Real;
1664fe013be4SDimitry Andric   Instruction *Imag;
1665fe013be4SDimitry Andric   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1666fe013be4SDimitry Andric     return nullptr;
1667fe013be4SDimitry Andric 
1668fe013be4SDimitry Andric   return identifyNode(Real, Imag);
1669fe013be4SDimitry Andric }
1670fe013be4SDimitry Andric 
1671fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyDeinterleave(Instruction * Real,Instruction * Imag)1672fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1673fe013be4SDimitry Andric                                                  Instruction *Imag) {
1674fe013be4SDimitry Andric   Instruction *I = nullptr;
1675fe013be4SDimitry Andric   Value *FinalValue = nullptr;
1676fe013be4SDimitry Andric   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1677fe013be4SDimitry Andric       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1678fe013be4SDimitry Andric       match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1679fe013be4SDimitry Andric                    m_Value(FinalValue)))) {
1680fe013be4SDimitry Andric     NodePtr PlaceholderNode = prepareCompositeNode(
1681fe013be4SDimitry Andric         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1682fe013be4SDimitry Andric     PlaceholderNode->ReplacementNode = FinalValue;
1683fe013be4SDimitry Andric     FinalInstructions.insert(Real);
1684fe013be4SDimitry Andric     FinalInstructions.insert(Imag);
1685fe013be4SDimitry Andric     return submitCompositeNode(PlaceholderNode);
1686fe013be4SDimitry Andric   }
1687fe013be4SDimitry Andric 
1688bdd1243dSDimitry Andric   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1689bdd1243dSDimitry Andric   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1690fe013be4SDimitry Andric   if (!RealShuffle || !ImagShuffle) {
1691fe013be4SDimitry Andric     if (RealShuffle || ImagShuffle)
1692fe013be4SDimitry Andric       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1693fe013be4SDimitry Andric     return nullptr;
1694fe013be4SDimitry Andric   }
1695fe013be4SDimitry Andric 
1696bdd1243dSDimitry Andric   Value *RealOp1 = RealShuffle->getOperand(1);
1697bdd1243dSDimitry Andric   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1698bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1699bdd1243dSDimitry Andric     return nullptr;
1700bdd1243dSDimitry Andric   }
1701bdd1243dSDimitry Andric   Value *ImagOp1 = ImagShuffle->getOperand(1);
1702bdd1243dSDimitry Andric   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1703bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1704bdd1243dSDimitry Andric     return nullptr;
1705bdd1243dSDimitry Andric   }
1706bdd1243dSDimitry Andric 
1707bdd1243dSDimitry Andric   Value *RealOp0 = RealShuffle->getOperand(0);
1708bdd1243dSDimitry Andric   Value *ImagOp0 = ImagShuffle->getOperand(0);
1709bdd1243dSDimitry Andric 
1710bdd1243dSDimitry Andric   if (RealOp0 != ImagOp0) {
1711bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1712bdd1243dSDimitry Andric     return nullptr;
1713bdd1243dSDimitry Andric   }
1714bdd1243dSDimitry Andric 
1715bdd1243dSDimitry Andric   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1716bdd1243dSDimitry Andric   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1717bdd1243dSDimitry Andric   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1718bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1719bdd1243dSDimitry Andric     return nullptr;
1720bdd1243dSDimitry Andric   }
1721bdd1243dSDimitry Andric 
1722bdd1243dSDimitry Andric   if (RealMask[0] != 0 || ImagMask[0] != 1) {
1723bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1724bdd1243dSDimitry Andric     return nullptr;
1725bdd1243dSDimitry Andric   }
1726bdd1243dSDimitry Andric 
1727bdd1243dSDimitry Andric   // Type checking, the shuffle type should be a vector type of the same
1728bdd1243dSDimitry Andric   // scalar type, but half the size
1729bdd1243dSDimitry Andric   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1730bdd1243dSDimitry Andric     Value *Op = Shuffle->getOperand(0);
1731bdd1243dSDimitry Andric     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1732bdd1243dSDimitry Andric     auto *OpTy = cast<FixedVectorType>(Op->getType());
1733bdd1243dSDimitry Andric 
1734bdd1243dSDimitry Andric     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1735bdd1243dSDimitry Andric       return false;
1736bdd1243dSDimitry Andric     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1737bdd1243dSDimitry Andric       return false;
1738bdd1243dSDimitry Andric 
1739bdd1243dSDimitry Andric     return true;
1740bdd1243dSDimitry Andric   };
1741bdd1243dSDimitry Andric 
1742bdd1243dSDimitry Andric   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1743bdd1243dSDimitry Andric     if (!CheckType(Shuffle))
1744bdd1243dSDimitry Andric       return false;
1745bdd1243dSDimitry Andric 
1746bdd1243dSDimitry Andric     ArrayRef<int> Mask = Shuffle->getShuffleMask();
1747bdd1243dSDimitry Andric     int Last = *Mask.rbegin();
1748bdd1243dSDimitry Andric 
1749bdd1243dSDimitry Andric     Value *Op = Shuffle->getOperand(0);
1750bdd1243dSDimitry Andric     auto *OpTy = cast<FixedVectorType>(Op->getType());
1751bdd1243dSDimitry Andric     int NumElements = OpTy->getNumElements();
1752bdd1243dSDimitry Andric 
1753bdd1243dSDimitry Andric     // Ensure that the deinterleaving shuffle only pulls from the first
1754bdd1243dSDimitry Andric     // shuffle operand.
1755bdd1243dSDimitry Andric     return Last < NumElements;
1756bdd1243dSDimitry Andric   };
1757bdd1243dSDimitry Andric 
1758bdd1243dSDimitry Andric   if (RealShuffle->getType() != ImagShuffle->getType()) {
1759bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1760bdd1243dSDimitry Andric     return nullptr;
1761bdd1243dSDimitry Andric   }
1762bdd1243dSDimitry Andric   if (!CheckDeinterleavingShuffle(RealShuffle)) {
1763bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1764bdd1243dSDimitry Andric     return nullptr;
1765bdd1243dSDimitry Andric   }
1766bdd1243dSDimitry Andric   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1767bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1768bdd1243dSDimitry Andric     return nullptr;
1769bdd1243dSDimitry Andric   }
1770bdd1243dSDimitry Andric 
1771bdd1243dSDimitry Andric   NodePtr PlaceholderNode =
1772fe013be4SDimitry Andric       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1773bdd1243dSDimitry Andric                            RealShuffle, ImagShuffle);
1774bdd1243dSDimitry Andric   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1775fe013be4SDimitry Andric   FinalInstructions.insert(RealShuffle);
1776fe013be4SDimitry Andric   FinalInstructions.insert(ImagShuffle);
1777bdd1243dSDimitry Andric   return submitCompositeNode(PlaceholderNode);
1778bdd1243dSDimitry Andric }
1779bdd1243dSDimitry Andric 
1780fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifySplat(Value * R,Value * I)1781fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1782fe013be4SDimitry Andric   auto IsSplat = [](Value *V) -> bool {
1783fe013be4SDimitry Andric     // Fixed-width vector with constants
1784fe013be4SDimitry Andric     if (isa<ConstantDataVector>(V))
1785fe013be4SDimitry Andric       return true;
1786bdd1243dSDimitry Andric 
1787fe013be4SDimitry Andric     VectorType *VTy;
1788fe013be4SDimitry Andric     ArrayRef<int> Mask;
1789fe013be4SDimitry Andric     // Splats are represented differently depending on whether the repeated
1790fe013be4SDimitry Andric     // value is a constant or an Instruction
1791fe013be4SDimitry Andric     if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1792fe013be4SDimitry Andric       if (Const->getOpcode() != Instruction::ShuffleVector)
1793bdd1243dSDimitry Andric         return false;
1794fe013be4SDimitry Andric       VTy = cast<VectorType>(Const->getType());
1795fe013be4SDimitry Andric       Mask = Const->getShuffleMask();
1796fe013be4SDimitry Andric     } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1797fe013be4SDimitry Andric       VTy = Shuf->getType();
1798fe013be4SDimitry Andric       Mask = Shuf->getShuffleMask();
1799fe013be4SDimitry Andric     } else {
1800bdd1243dSDimitry Andric       return false;
1801bdd1243dSDimitry Andric     }
1802fe013be4SDimitry Andric 
1803fe013be4SDimitry Andric     // When the data type is <1 x Type>, it's not possible to differentiate
1804fe013be4SDimitry Andric     // between the ComplexDeinterleaving::Deinterleave and
1805fe013be4SDimitry Andric     // ComplexDeinterleaving::Splat operations.
1806fe013be4SDimitry Andric     if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1807fe013be4SDimitry Andric       return false;
1808fe013be4SDimitry Andric 
1809fe013be4SDimitry Andric     return all_equal(Mask) && Mask[0] == 0;
1810fe013be4SDimitry Andric   };
1811fe013be4SDimitry Andric 
1812fe013be4SDimitry Andric   if (!IsSplat(R) || !IsSplat(I))
1813fe013be4SDimitry Andric     return nullptr;
1814fe013be4SDimitry Andric 
1815fe013be4SDimitry Andric   auto *Real = dyn_cast<Instruction>(R);
1816fe013be4SDimitry Andric   auto *Imag = dyn_cast<Instruction>(I);
1817fe013be4SDimitry Andric   if ((!Real && Imag) || (Real && !Imag))
1818fe013be4SDimitry Andric     return nullptr;
1819fe013be4SDimitry Andric 
1820fe013be4SDimitry Andric   if (Real && Imag) {
1821fe013be4SDimitry Andric     // Non-constant splats should be in the same basic block
1822fe013be4SDimitry Andric     if (Real->getParent() != Imag->getParent())
1823fe013be4SDimitry Andric       return nullptr;
1824fe013be4SDimitry Andric 
1825fe013be4SDimitry Andric     FinalInstructions.insert(Real);
1826fe013be4SDimitry Andric     FinalInstructions.insert(Imag);
1827bdd1243dSDimitry Andric   }
1828fe013be4SDimitry Andric   NodePtr PlaceholderNode =
1829fe013be4SDimitry Andric       prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1830fe013be4SDimitry Andric   return submitCompositeNode(PlaceholderNode);
1831bdd1243dSDimitry Andric }
1832bdd1243dSDimitry Andric 
1833fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifyPHINode(Instruction * Real,Instruction * Imag)1834fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1835fe013be4SDimitry Andric                                             Instruction *Imag) {
1836fe013be4SDimitry Andric   if (Real != RealPHI || Imag != ImagPHI)
1837fe013be4SDimitry Andric     return nullptr;
1838fe013be4SDimitry Andric 
1839fe013be4SDimitry Andric   PHIsFound = true;
1840fe013be4SDimitry Andric   NodePtr PlaceholderNode = prepareCompositeNode(
1841fe013be4SDimitry Andric       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1842fe013be4SDimitry Andric   return submitCompositeNode(PlaceholderNode);
1843fe013be4SDimitry Andric }
1844fe013be4SDimitry Andric 
1845fe013be4SDimitry Andric ComplexDeinterleavingGraph::NodePtr
identifySelectNode(Instruction * Real,Instruction * Imag)1846fe013be4SDimitry Andric ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1847fe013be4SDimitry Andric                                                Instruction *Imag) {
1848fe013be4SDimitry Andric   auto *SelectReal = dyn_cast<SelectInst>(Real);
1849fe013be4SDimitry Andric   auto *SelectImag = dyn_cast<SelectInst>(Imag);
1850fe013be4SDimitry Andric   if (!SelectReal || !SelectImag)
1851fe013be4SDimitry Andric     return nullptr;
1852fe013be4SDimitry Andric 
1853fe013be4SDimitry Andric   Instruction *MaskA, *MaskB;
1854fe013be4SDimitry Andric   Instruction *AR, *AI, *RA, *BI;
1855fe013be4SDimitry Andric   if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1856fe013be4SDimitry Andric                             m_Instruction(RA))) ||
1857fe013be4SDimitry Andric       !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1858fe013be4SDimitry Andric                             m_Instruction(BI))))
1859fe013be4SDimitry Andric     return nullptr;
1860fe013be4SDimitry Andric 
1861fe013be4SDimitry Andric   if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1862fe013be4SDimitry Andric     return nullptr;
1863fe013be4SDimitry Andric 
1864fe013be4SDimitry Andric   if (!MaskA->getType()->isVectorTy())
1865fe013be4SDimitry Andric     return nullptr;
1866fe013be4SDimitry Andric 
1867fe013be4SDimitry Andric   auto NodeA = identifyNode(AR, AI);
1868fe013be4SDimitry Andric   if (!NodeA)
1869fe013be4SDimitry Andric     return nullptr;
1870fe013be4SDimitry Andric 
1871fe013be4SDimitry Andric   auto NodeB = identifyNode(RA, BI);
1872fe013be4SDimitry Andric   if (!NodeB)
1873fe013be4SDimitry Andric     return nullptr;
1874fe013be4SDimitry Andric 
1875fe013be4SDimitry Andric   NodePtr PlaceholderNode = prepareCompositeNode(
1876fe013be4SDimitry Andric       ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1877fe013be4SDimitry Andric   PlaceholderNode->addOperand(NodeA);
1878fe013be4SDimitry Andric   PlaceholderNode->addOperand(NodeB);
1879fe013be4SDimitry Andric   FinalInstructions.insert(MaskA);
1880fe013be4SDimitry Andric   FinalInstructions.insert(MaskB);
1881fe013be4SDimitry Andric   return submitCompositeNode(PlaceholderNode);
1882fe013be4SDimitry Andric }
1883fe013be4SDimitry Andric 
replaceSymmetricNode(IRBuilderBase & B,unsigned Opcode,std::optional<FastMathFlags> Flags,Value * InputA,Value * InputB)1884fe013be4SDimitry Andric static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1885fe013be4SDimitry Andric                                    std::optional<FastMathFlags> Flags,
1886fe013be4SDimitry Andric                                    Value *InputA, Value *InputB) {
1887fe013be4SDimitry Andric   Value *I;
1888fe013be4SDimitry Andric   switch (Opcode) {
1889fe013be4SDimitry Andric   case Instruction::FNeg:
1890fe013be4SDimitry Andric     I = B.CreateFNeg(InputA);
1891fe013be4SDimitry Andric     break;
1892fe013be4SDimitry Andric   case Instruction::FAdd:
1893fe013be4SDimitry Andric     I = B.CreateFAdd(InputA, InputB);
1894fe013be4SDimitry Andric     break;
1895fe013be4SDimitry Andric   case Instruction::Add:
1896fe013be4SDimitry Andric     I = B.CreateAdd(InputA, InputB);
1897fe013be4SDimitry Andric     break;
1898fe013be4SDimitry Andric   case Instruction::FSub:
1899fe013be4SDimitry Andric     I = B.CreateFSub(InputA, InputB);
1900fe013be4SDimitry Andric     break;
1901fe013be4SDimitry Andric   case Instruction::Sub:
1902fe013be4SDimitry Andric     I = B.CreateSub(InputA, InputB);
1903fe013be4SDimitry Andric     break;
1904fe013be4SDimitry Andric   case Instruction::FMul:
1905fe013be4SDimitry Andric     I = B.CreateFMul(InputA, InputB);
1906fe013be4SDimitry Andric     break;
1907fe013be4SDimitry Andric   case Instruction::Mul:
1908fe013be4SDimitry Andric     I = B.CreateMul(InputA, InputB);
1909fe013be4SDimitry Andric     break;
1910fe013be4SDimitry Andric   default:
1911fe013be4SDimitry Andric     llvm_unreachable("Incorrect symmetric opcode");
1912fe013be4SDimitry Andric   }
1913fe013be4SDimitry Andric   if (Flags)
1914fe013be4SDimitry Andric     cast<Instruction>(I)->setFastMathFlags(*Flags);
1915fe013be4SDimitry Andric   return I;
1916fe013be4SDimitry Andric }
1917fe013be4SDimitry Andric 
replaceNode(IRBuilderBase & Builder,RawNodePtr Node)1918fe013be4SDimitry Andric Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1919fe013be4SDimitry Andric                                                RawNodePtr Node) {
1920bdd1243dSDimitry Andric   if (Node->ReplacementNode)
1921bdd1243dSDimitry Andric     return Node->ReplacementNode;
1922bdd1243dSDimitry Andric 
1923fe013be4SDimitry Andric   auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1924fe013be4SDimitry Andric     return Node->Operands.size() > Idx
1925fe013be4SDimitry Andric                ? replaceNode(Builder, Node->Operands[Idx])
1926fe013be4SDimitry Andric                : nullptr;
1927fe013be4SDimitry Andric   };
1928bdd1243dSDimitry Andric 
1929fe013be4SDimitry Andric   Value *ReplacementNode;
1930fe013be4SDimitry Andric   switch (Node->Operation) {
1931fe013be4SDimitry Andric   case ComplexDeinterleavingOperation::CAdd:
1932fe013be4SDimitry Andric   case ComplexDeinterleavingOperation::CMulPartial:
1933fe013be4SDimitry Andric   case ComplexDeinterleavingOperation::Symmetric: {
1934fe013be4SDimitry Andric     Value *Input0 = ReplaceOperandIfExist(Node, 0);
1935fe013be4SDimitry Andric     Value *Input1 = ReplaceOperandIfExist(Node, 1);
1936fe013be4SDimitry Andric     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1937fe013be4SDimitry Andric     assert(!Input1 || (Input0->getType() == Input1->getType() &&
1938fe013be4SDimitry Andric                        "Node inputs need to be of the same type"));
1939fe013be4SDimitry Andric     assert(!Accumulator ||
1940fe013be4SDimitry Andric            (Input0->getType() == Accumulator->getType() &&
1941fe013be4SDimitry Andric             "Accumulator and input need to be of the same type"));
1942fe013be4SDimitry Andric     if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1943fe013be4SDimitry Andric       ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1944fe013be4SDimitry Andric                                              Input0, Input1);
1945fe013be4SDimitry Andric     else
1946fe013be4SDimitry Andric       ReplacementNode = TL->createComplexDeinterleavingIR(
1947fe013be4SDimitry Andric           Builder, Node->Operation, Node->Rotation, Input0, Input1,
1948fe013be4SDimitry Andric           Accumulator);
1949fe013be4SDimitry Andric     break;
1950fe013be4SDimitry Andric   }
1951fe013be4SDimitry Andric   case ComplexDeinterleavingOperation::Deinterleave:
1952fe013be4SDimitry Andric     llvm_unreachable("Deinterleave node should already have ReplacementNode");
1953fe013be4SDimitry Andric     break;
1954fe013be4SDimitry Andric   case ComplexDeinterleavingOperation::Splat: {
1955fe013be4SDimitry Andric     auto *NewTy = VectorType::getDoubleElementsVectorType(
1956fe013be4SDimitry Andric         cast<VectorType>(Node->Real->getType()));
1957fe013be4SDimitry Andric     auto *R = dyn_cast<Instruction>(Node->Real);
1958fe013be4SDimitry Andric     auto *I = dyn_cast<Instruction>(Node->Imag);
1959fe013be4SDimitry Andric     if (R && I) {
1960fe013be4SDimitry Andric       // Splats that are not constant are interleaved where they are located
1961fe013be4SDimitry Andric       Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1962fe013be4SDimitry Andric       IRBuilder<> IRB(InsertPoint);
1963fe013be4SDimitry Andric       ReplacementNode =
1964fe013be4SDimitry Andric           IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy,
1965fe013be4SDimitry Andric                               {Node->Real, Node->Imag});
1966fe013be4SDimitry Andric     } else {
1967fe013be4SDimitry Andric       ReplacementNode =
1968fe013be4SDimitry Andric           Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1969fe013be4SDimitry Andric                                   NewTy, {Node->Real, Node->Imag});
1970fe013be4SDimitry Andric     }
1971fe013be4SDimitry Andric     break;
1972fe013be4SDimitry Andric   }
1973fe013be4SDimitry Andric   case ComplexDeinterleavingOperation::ReductionPHI: {
1974fe013be4SDimitry Andric     // If Operation is ReductionPHI, a new empty PHINode is created.
1975fe013be4SDimitry Andric     // It is filled later when the ReductionOperation is processed.
1976fe013be4SDimitry Andric     auto *VTy = cast<VectorType>(Node->Real->getType());
1977fe013be4SDimitry Andric     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1978fe013be4SDimitry Andric     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
1979fe013be4SDimitry Andric     OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1980fe013be4SDimitry Andric     ReplacementNode = NewPHI;
1981fe013be4SDimitry Andric     break;
1982fe013be4SDimitry Andric   }
1983fe013be4SDimitry Andric   case ComplexDeinterleavingOperation::ReductionOperation:
1984fe013be4SDimitry Andric     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1985fe013be4SDimitry Andric     processReductionOperation(ReplacementNode, Node);
1986fe013be4SDimitry Andric     break;
1987fe013be4SDimitry Andric   case ComplexDeinterleavingOperation::ReductionSelect: {
1988fe013be4SDimitry Andric     auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1989fe013be4SDimitry Andric     auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1990fe013be4SDimitry Andric     auto *A = replaceNode(Builder, Node->Operands[0]);
1991fe013be4SDimitry Andric     auto *B = replaceNode(Builder, Node->Operands[1]);
1992fe013be4SDimitry Andric     auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1993fe013be4SDimitry Andric         cast<VectorType>(MaskReal->getType()));
1994fe013be4SDimitry Andric     auto *NewMask =
1995fe013be4SDimitry Andric         Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1996fe013be4SDimitry Andric                                 NewMaskTy, {MaskReal, MaskImag});
1997fe013be4SDimitry Andric     ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1998fe013be4SDimitry Andric     break;
1999fe013be4SDimitry Andric   }
2000fe013be4SDimitry Andric   }
2001bdd1243dSDimitry Andric 
2002fe013be4SDimitry Andric   assert(ReplacementNode && "Target failed to create Intrinsic call.");
2003bdd1243dSDimitry Andric   NumComplexTransformations += 1;
2004fe013be4SDimitry Andric   Node->ReplacementNode = ReplacementNode;
2005fe013be4SDimitry Andric   return ReplacementNode;
2006fe013be4SDimitry Andric }
2007fe013be4SDimitry Andric 
processReductionOperation(Value * OperationReplacement,RawNodePtr Node)2008fe013be4SDimitry Andric void ComplexDeinterleavingGraph::processReductionOperation(
2009fe013be4SDimitry Andric     Value *OperationReplacement, RawNodePtr Node) {
2010fe013be4SDimitry Andric   auto *Real = cast<Instruction>(Node->Real);
2011fe013be4SDimitry Andric   auto *Imag = cast<Instruction>(Node->Imag);
2012fe013be4SDimitry Andric   auto *OldPHIReal = ReductionInfo[Real].first;
2013fe013be4SDimitry Andric   auto *OldPHIImag = ReductionInfo[Imag].first;
2014fe013be4SDimitry Andric   auto *NewPHI = OldToNewPHI[OldPHIReal];
2015fe013be4SDimitry Andric 
2016fe013be4SDimitry Andric   auto *VTy = cast<VectorType>(Real->getType());
2017fe013be4SDimitry Andric   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2018fe013be4SDimitry Andric 
2019fe013be4SDimitry Andric   // We have to interleave initial origin values coming from IncomingBlock
2020fe013be4SDimitry Andric   Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2021fe013be4SDimitry Andric   Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2022fe013be4SDimitry Andric 
2023fe013be4SDimitry Andric   IRBuilder<> Builder(Incoming->getTerminator());
2024fe013be4SDimitry Andric   auto *NewInit = Builder.CreateIntrinsic(
2025fe013be4SDimitry Andric       Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
2026fe013be4SDimitry Andric 
2027fe013be4SDimitry Andric   NewPHI->addIncoming(NewInit, Incoming);
2028fe013be4SDimitry Andric   NewPHI->addIncoming(OperationReplacement, BackEdge);
2029fe013be4SDimitry Andric 
2030fe013be4SDimitry Andric   // Deinterleave complex vector outside of loop so that it can be finally
2031fe013be4SDimitry Andric   // reduced
2032fe013be4SDimitry Andric   auto *FinalReductionReal = ReductionInfo[Real].second;
2033fe013be4SDimitry Andric   auto *FinalReductionImag = ReductionInfo[Imag].second;
2034fe013be4SDimitry Andric 
2035fe013be4SDimitry Andric   Builder.SetInsertPoint(
2036fe013be4SDimitry Andric       &*FinalReductionReal->getParent()->getFirstInsertionPt());
2037fe013be4SDimitry Andric   auto *Deinterleave = Builder.CreateIntrinsic(
2038fe013be4SDimitry Andric       Intrinsic::experimental_vector_deinterleave2,
2039fe013be4SDimitry Andric       OperationReplacement->getType(), OperationReplacement);
2040fe013be4SDimitry Andric 
2041fe013be4SDimitry Andric   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2042fe013be4SDimitry Andric   FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2043fe013be4SDimitry Andric 
2044fe013be4SDimitry Andric   Builder.SetInsertPoint(FinalReductionImag);
2045fe013be4SDimitry Andric   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2046fe013be4SDimitry Andric   FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2047bdd1243dSDimitry Andric }
2048bdd1243dSDimitry Andric 
replaceNodes()2049bdd1243dSDimitry Andric void ComplexDeinterleavingGraph::replaceNodes() {
2050fe013be4SDimitry Andric   SmallVector<Instruction *, 16> DeadInstrRoots;
2051fe013be4SDimitry Andric   for (auto *RootInstruction : OrderedRoots) {
2052fe013be4SDimitry Andric     // Check if this potential root went through check process and we can
2053fe013be4SDimitry Andric     // deinterleave it
2054fe013be4SDimitry Andric     if (!RootToNode.count(RootInstruction))
2055fe013be4SDimitry Andric       continue;
2056fe013be4SDimitry Andric 
2057fe013be4SDimitry Andric     IRBuilder<> Builder(RootInstruction);
2058fe013be4SDimitry Andric     auto RootNode = RootToNode[RootInstruction];
2059fe013be4SDimitry Andric     Value *R = replaceNode(Builder, RootNode.get());
2060fe013be4SDimitry Andric 
2061fe013be4SDimitry Andric     if (RootNode->Operation ==
2062fe013be4SDimitry Andric         ComplexDeinterleavingOperation::ReductionOperation) {
2063fe013be4SDimitry Andric       auto *RootReal = cast<Instruction>(RootNode->Real);
2064fe013be4SDimitry Andric       auto *RootImag = cast<Instruction>(RootNode->Imag);
2065fe013be4SDimitry Andric       ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2066fe013be4SDimitry Andric       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2067fe013be4SDimitry Andric       DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2068fe013be4SDimitry Andric       DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2069fe013be4SDimitry Andric     } else {
2070fe013be4SDimitry Andric       assert(R && "Unable to find replacement for RootInstruction");
2071fe013be4SDimitry Andric       DeadInstrRoots.push_back(RootInstruction);
2072fe013be4SDimitry Andric       RootInstruction->replaceAllUsesWith(R);
2073fe013be4SDimitry Andric     }
2074bdd1243dSDimitry Andric   }
2075bdd1243dSDimitry Andric 
2076fe013be4SDimitry Andric   for (auto *I : DeadInstrRoots)
2077fe013be4SDimitry Andric     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2078bdd1243dSDimitry Andric }
2079