1 //===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass combines dag nodes to form fewer, simpler DAG nodes.  It can be run
10 // both before and after the DAG is legalized.
11 //
12 // This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13 // primarily intended to handle simplification opportunities that are implicit
14 // in the LLVM IR and exposed by the various codegen lowering phases.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/IntervalMap.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallBitVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/SmallSet.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/Statistic.h"
32 #include "llvm/Analysis/AliasAnalysis.h"
33 #include "llvm/Analysis/MemoryLocation.h"
34 #include "llvm/Analysis/TargetLibraryInfo.h"
35 #include "llvm/Analysis/VectorUtils.h"
36 #include "llvm/CodeGen/DAGCombine.h"
37 #include "llvm/CodeGen/ISDOpcodes.h"
38 #include "llvm/CodeGen/MachineFunction.h"
39 #include "llvm/CodeGen/MachineMemOperand.h"
40 #include "llvm/CodeGen/RuntimeLibcalls.h"
41 #include "llvm/CodeGen/SelectionDAG.h"
42 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
43 #include "llvm/CodeGen/SelectionDAGNodes.h"
44 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
45 #include "llvm/CodeGen/TargetLowering.h"
46 #include "llvm/CodeGen/TargetRegisterInfo.h"
47 #include "llvm/CodeGen/TargetSubtargetInfo.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/IR/Attributes.h"
50 #include "llvm/IR/Constant.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/DerivedTypes.h"
53 #include "llvm/IR/Function.h"
54 #include "llvm/IR/Metadata.h"
55 #include "llvm/Support/Casting.h"
56 #include "llvm/Support/CodeGen.h"
57 #include "llvm/Support/CommandLine.h"
58 #include "llvm/Support/Compiler.h"
59 #include "llvm/Support/Debug.h"
60 #include "llvm/Support/ErrorHandling.h"
61 #include "llvm/Support/KnownBits.h"
62 #include "llvm/Support/MachineValueType.h"
63 #include "llvm/Support/MathExtras.h"
64 #include "llvm/Support/raw_ostream.h"
65 #include "llvm/Target/TargetMachine.h"
66 #include "llvm/Target/TargetOptions.h"
67 #include <algorithm>
68 #include <cassert>
69 #include <cstdint>
70 #include <functional>
71 #include <iterator>
72 #include <string>
73 #include <tuple>
74 #include <utility>
75 
76 using namespace llvm;
77 
78 #define DEBUG_TYPE "dagcombine"
79 
80 STATISTIC(NodesCombined   , "Number of dag nodes combined");
81 STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
82 STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
83 STATISTIC(OpsNarrowed     , "Number of load/op/store narrowed");
84 STATISTIC(LdStFP2Int      , "Number of fp load/store pairs transformed to int");
85 STATISTIC(SlicedLoads, "Number of load sliced");
86 STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
87 
88 static cl::opt<bool>
89 CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
90                  cl::desc("Enable DAG combiner's use of IR alias analysis"));
91 
92 static cl::opt<bool>
93 UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
94         cl::desc("Enable DAG combiner's use of TBAA"));
95 
96 #ifndef NDEBUG
97 static cl::opt<std::string>
98 CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
99                    cl::desc("Only use DAG-combiner alias analysis in this"
100                             " function"));
101 #endif
102 
103 /// Hidden option to stress test load slicing, i.e., when this option
104 /// is enabled, load slicing bypasses most of its profitability guards.
105 static cl::opt<bool>
106 StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
107                   cl::desc("Bypass the profitability model of load slicing"),
108                   cl::init(false));
109 
110 static cl::opt<bool>
111   MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
112                     cl::desc("DAG combiner may split indexing from loads"));
113 
114 static cl::opt<bool>
115     EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
116                        cl::desc("DAG combiner enable merging multiple stores "
117                                 "into a wider store"));
118 
119 static cl::opt<unsigned> TokenFactorInlineLimit(
120     "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
121     cl::desc("Limit the number of operands to inline for Token Factors"));
122 
123 static cl::opt<unsigned> StoreMergeDependenceLimit(
124     "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
125     cl::desc("Limit the number of times for the same StoreNode and RootNode "
126              "to bail out in store merging dependence check"));
127 
128 static cl::opt<bool> EnableReduceLoadOpStoreWidth(
129     "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(true),
130     cl::desc("DAG combiner enable reducing the width of load/op/store "
131              "sequence"));
132 
133 static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
134     "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(true),
135     cl::desc("DAG combiner enable load/<replace bytes>/store with "
136              "a narrower store"));
137 
138 namespace {
139 
140   class DAGCombiner {
141     SelectionDAG &DAG;
142     const TargetLowering &TLI;
143     const SelectionDAGTargetInfo *STI;
144     CombineLevel Level = BeforeLegalizeTypes;
145     CodeGenOpt::Level OptLevel;
146     bool LegalDAG = false;
147     bool LegalOperations = false;
148     bool LegalTypes = false;
149     bool ForCodeSize;
150     bool DisableGenericCombines;
151 
152     /// Worklist of all of the nodes that need to be simplified.
153     ///
154     /// This must behave as a stack -- new nodes to process are pushed onto the
155     /// back and when processing we pop off of the back.
156     ///
157     /// The worklist will not contain duplicates but may contain null entries
158     /// due to nodes being deleted from the underlying DAG.
159     SmallVector<SDNode *, 64> Worklist;
160 
161     /// Mapping from an SDNode to its position on the worklist.
162     ///
163     /// This is used to find and remove nodes from the worklist (by nulling
164     /// them) when they are deleted from the underlying DAG. It relies on
165     /// stable indices of nodes within the worklist.
166     DenseMap<SDNode *, unsigned> WorklistMap;
167     /// This records all nodes attempted to add to the worklist since we
168     /// considered a new worklist entry. As we keep do not add duplicate nodes
169     /// in the worklist, this is different from the tail of the worklist.
170     SmallSetVector<SDNode *, 32> PruningList;
171 
172     /// Set of nodes which have been combined (at least once).
173     ///
174     /// This is used to allow us to reliably add any operands of a DAG node
175     /// which have not yet been combined to the worklist.
176     SmallPtrSet<SDNode *, 32> CombinedNodes;
177 
178     /// Map from candidate StoreNode to the pair of RootNode and count.
179     /// The count is used to track how many times we have seen the StoreNode
180     /// with the same RootNode bail out in dependence check. If we have seen
181     /// the bail out for the same pair many times over a limit, we won't
182     /// consider the StoreNode with the same RootNode as store merging
183     /// candidate again.
184     DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
185 
186     // AA - Used for DAG load/store alias analysis.
187     AliasAnalysis *AA;
188 
189     /// When an instruction is simplified, add all users of the instruction to
190     /// the work lists because they might get more simplified now.
AddUsersToWorklist(SDNode * N)191     void AddUsersToWorklist(SDNode *N) {
192       for (SDNode *Node : N->uses())
193         AddToWorklist(Node);
194     }
195 
196     /// Convenient shorthand to add a node and all of its user to the worklist.
AddToWorklistWithUsers(SDNode * N)197     void AddToWorklistWithUsers(SDNode *N) {
198       AddUsersToWorklist(N);
199       AddToWorklist(N);
200     }
201 
202     // Prune potentially dangling nodes. This is called after
203     // any visit to a node, but should also be called during a visit after any
204     // failed combine which may have created a DAG node.
clearAddedDanglingWorklistEntries()205     void clearAddedDanglingWorklistEntries() {
206       // Check any nodes added to the worklist to see if they are prunable.
207       while (!PruningList.empty()) {
208         auto *N = PruningList.pop_back_val();
209         if (N->use_empty())
210           recursivelyDeleteUnusedNodes(N);
211       }
212     }
213 
getNextWorklistEntry()214     SDNode *getNextWorklistEntry() {
215       // Before we do any work, remove nodes that are not in use.
216       clearAddedDanglingWorklistEntries();
217       SDNode *N = nullptr;
218       // The Worklist holds the SDNodes in order, but it may contain null
219       // entries.
220       while (!N && !Worklist.empty()) {
221         N = Worklist.pop_back_val();
222       }
223 
224       if (N) {
225         bool GoodWorklistEntry = WorklistMap.erase(N);
226         (void)GoodWorklistEntry;
227         assert(GoodWorklistEntry &&
228                "Found a worklist entry without a corresponding map entry!");
229       }
230       return N;
231     }
232 
233     /// Call the node-specific routine that folds each particular type of node.
234     SDValue visit(SDNode *N);
235 
236   public:
DAGCombiner(SelectionDAG & D,AliasAnalysis * AA,CodeGenOpt::Level OL)237     DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOpt::Level OL)
238         : DAG(D), TLI(D.getTargetLoweringInfo()),
239           STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
240       ForCodeSize = DAG.shouldOptForSize();
241       DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
242 
243       MaximumLegalStoreInBits = 0;
244       // We use the minimum store size here, since that's all we can guarantee
245       // for the scalable vector types.
246       for (MVT VT : MVT::all_valuetypes())
247         if (EVT(VT).isSimple() && VT != MVT::Other &&
248             TLI.isTypeLegal(EVT(VT)) &&
249             VT.getSizeInBits().getKnownMinSize() >= MaximumLegalStoreInBits)
250           MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinSize();
251     }
252 
ConsiderForPruning(SDNode * N)253     void ConsiderForPruning(SDNode *N) {
254       // Mark this for potential pruning.
255       PruningList.insert(N);
256     }
257 
258     /// Add to the worklist making sure its instance is at the back (next to be
259     /// processed.)
AddToWorklist(SDNode * N)260     void AddToWorklist(SDNode *N) {
261       assert(N->getOpcode() != ISD::DELETED_NODE &&
262              "Deleted Node added to Worklist");
263 
264       // Skip handle nodes as they can't usefully be combined and confuse the
265       // zero-use deletion strategy.
266       if (N->getOpcode() == ISD::HANDLENODE)
267         return;
268 
269       ConsiderForPruning(N);
270 
271       if (WorklistMap.insert(std::make_pair(N, Worklist.size())).second)
272         Worklist.push_back(N);
273     }
274 
275     /// Remove all instances of N from the worklist.
removeFromWorklist(SDNode * N)276     void removeFromWorklist(SDNode *N) {
277       CombinedNodes.erase(N);
278       PruningList.remove(N);
279       StoreRootCountMap.erase(N);
280 
281       auto It = WorklistMap.find(N);
282       if (It == WorklistMap.end())
283         return; // Not in the worklist.
284 
285       // Null out the entry rather than erasing it to avoid a linear operation.
286       Worklist[It->second] = nullptr;
287       WorklistMap.erase(It);
288     }
289 
290     void deleteAndRecombine(SDNode *N);
291     bool recursivelyDeleteUnusedNodes(SDNode *N);
292 
293     /// Replaces all uses of the results of one DAG node with new values.
294     SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
295                       bool AddTo = true);
296 
297     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res,bool AddTo=true)298     SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
299       return CombineTo(N, &Res, 1, AddTo);
300     }
301 
302     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo=true)303     SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
304                       bool AddTo = true) {
305       SDValue To[] = { Res0, Res1 };
306       return CombineTo(N, To, 2, AddTo);
307     }
308 
309     void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
310 
311   private:
312     unsigned MaximumLegalStoreInBits;
313 
314     /// Check the specified integer node value to see if it can be simplified or
315     /// if things it uses can be simplified by bit propagation.
316     /// If so, return true.
SimplifyDemandedBits(SDValue Op)317     bool SimplifyDemandedBits(SDValue Op) {
318       unsigned BitWidth = Op.getScalarValueSizeInBits();
319       APInt DemandedBits = APInt::getAllOnes(BitWidth);
320       return SimplifyDemandedBits(Op, DemandedBits);
321     }
322 
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits)323     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
324       TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
325       KnownBits Known;
326       if (!TLI.SimplifyDemandedBits(Op, DemandedBits, Known, TLO, 0, false))
327         return false;
328 
329       // Revisit the node.
330       AddToWorklist(Op.getNode());
331 
332       CommitTargetLoweringOpt(TLO);
333       return true;
334     }
335 
336     /// Check the specified vector node value to see if it can be simplified or
337     /// if things it uses can be simplified as it only uses some of the
338     /// elements. If so, return true.
SimplifyDemandedVectorElts(SDValue Op)339     bool SimplifyDemandedVectorElts(SDValue Op) {
340       // TODO: For now just pretend it cannot be simplified.
341       if (Op.getValueType().isScalableVector())
342         return false;
343 
344       unsigned NumElts = Op.getValueType().getVectorNumElements();
345       APInt DemandedElts = APInt::getAllOnes(NumElts);
346       return SimplifyDemandedVectorElts(Op, DemandedElts);
347     }
348 
349     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
350                               const APInt &DemandedElts,
351                               bool AssumeSingleUse = false);
352     bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
353                                     bool AssumeSingleUse = false);
354 
355     bool CombineToPreIndexedLoadStore(SDNode *N);
356     bool CombineToPostIndexedLoadStore(SDNode *N);
357     SDValue SplitIndexingFromLoad(LoadSDNode *LD);
358     bool SliceUpLoad(SDNode *N);
359 
360     // Scalars have size 0 to distinguish from singleton vectors.
361     SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
362     bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
363     bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
364 
365     /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
366     ///   load.
367     ///
368     /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
369     /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
370     /// \param EltNo index of the vector element to load.
371     /// \param OriginalLoad load that EVE came from to be replaced.
372     /// \returns EVE on success SDValue() on failure.
373     SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
374                                          SDValue EltNo,
375                                          LoadSDNode *OriginalLoad);
376     void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
377     SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
378     SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
379     SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
380     SDValue PromoteIntBinOp(SDValue Op);
381     SDValue PromoteIntShiftOp(SDValue Op);
382     SDValue PromoteExtend(SDValue Op);
383     bool PromoteLoad(SDValue Op);
384 
385     /// Call the node-specific routine that knows how to fold each
386     /// particular type of node. If that doesn't do anything, try the
387     /// target-specific DAG combines.
388     SDValue combine(SDNode *N);
389 
390     // Visitation implementation - Implement dag node combining for different
391     // node types.  The semantics are as follows:
392     // Return Value:
393     //   SDValue.getNode() == 0 - No change was made
394     //   SDValue.getNode() == N - N was replaced, is dead and has been handled.
395     //   otherwise              - N should be replaced by the returned Operand.
396     //
397     SDValue visitTokenFactor(SDNode *N);
398     SDValue visitMERGE_VALUES(SDNode *N);
399     SDValue visitADD(SDNode *N);
400     SDValue visitADDLike(SDNode *N);
401     SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
402     SDValue visitSUB(SDNode *N);
403     SDValue visitADDSAT(SDNode *N);
404     SDValue visitSUBSAT(SDNode *N);
405     SDValue visitADDC(SDNode *N);
406     SDValue visitADDO(SDNode *N);
407     SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
408     SDValue visitSUBC(SDNode *N);
409     SDValue visitSUBO(SDNode *N);
410     SDValue visitADDE(SDNode *N);
411     SDValue visitADDCARRY(SDNode *N);
412     SDValue visitSADDO_CARRY(SDNode *N);
413     SDValue visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, SDNode *N);
414     SDValue visitSUBE(SDNode *N);
415     SDValue visitSUBCARRY(SDNode *N);
416     SDValue visitSSUBO_CARRY(SDNode *N);
417     SDValue visitMUL(SDNode *N);
418     SDValue visitMULFIX(SDNode *N);
419     SDValue useDivRem(SDNode *N);
420     SDValue visitSDIV(SDNode *N);
421     SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
422     SDValue visitUDIV(SDNode *N);
423     SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
424     SDValue visitREM(SDNode *N);
425     SDValue visitMULHU(SDNode *N);
426     SDValue visitMULHS(SDNode *N);
427     SDValue visitAVG(SDNode *N);
428     SDValue visitSMUL_LOHI(SDNode *N);
429     SDValue visitUMUL_LOHI(SDNode *N);
430     SDValue visitMULO(SDNode *N);
431     SDValue visitIMINMAX(SDNode *N);
432     SDValue visitAND(SDNode *N);
433     SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
434     SDValue visitOR(SDNode *N);
435     SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N);
436     SDValue visitXOR(SDNode *N);
437     SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
438     SDValue visitSHL(SDNode *N);
439     SDValue visitSRA(SDNode *N);
440     SDValue visitSRL(SDNode *N);
441     SDValue visitFunnelShift(SDNode *N);
442     SDValue visitSHLSAT(SDNode *N);
443     SDValue visitRotate(SDNode *N);
444     SDValue visitABS(SDNode *N);
445     SDValue visitBSWAP(SDNode *N);
446     SDValue visitBITREVERSE(SDNode *N);
447     SDValue visitCTLZ(SDNode *N);
448     SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
449     SDValue visitCTTZ(SDNode *N);
450     SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
451     SDValue visitCTPOP(SDNode *N);
452     SDValue visitSELECT(SDNode *N);
453     SDValue visitVSELECT(SDNode *N);
454     SDValue visitSELECT_CC(SDNode *N);
455     SDValue visitSETCC(SDNode *N);
456     SDValue visitSETCCCARRY(SDNode *N);
457     SDValue visitSIGN_EXTEND(SDNode *N);
458     SDValue visitZERO_EXTEND(SDNode *N);
459     SDValue visitANY_EXTEND(SDNode *N);
460     SDValue visitAssertExt(SDNode *N);
461     SDValue visitAssertAlign(SDNode *N);
462     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
463     SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
464     SDValue visitTRUNCATE(SDNode *N);
465     SDValue visitBITCAST(SDNode *N);
466     SDValue visitFREEZE(SDNode *N);
467     SDValue visitBUILD_PAIR(SDNode *N);
468     SDValue visitFADD(SDNode *N);
469     SDValue visitSTRICT_FADD(SDNode *N);
470     SDValue visitFSUB(SDNode *N);
471     SDValue visitFMUL(SDNode *N);
472     SDValue visitFMA(SDNode *N);
473     SDValue visitFDIV(SDNode *N);
474     SDValue visitFREM(SDNode *N);
475     SDValue visitFSQRT(SDNode *N);
476     SDValue visitFCOPYSIGN(SDNode *N);
477     SDValue visitFPOW(SDNode *N);
478     SDValue visitSINT_TO_FP(SDNode *N);
479     SDValue visitUINT_TO_FP(SDNode *N);
480     SDValue visitFP_TO_SINT(SDNode *N);
481     SDValue visitFP_TO_UINT(SDNode *N);
482     SDValue visitFP_ROUND(SDNode *N);
483     SDValue visitFP_EXTEND(SDNode *N);
484     SDValue visitFNEG(SDNode *N);
485     SDValue visitFABS(SDNode *N);
486     SDValue visitFCEIL(SDNode *N);
487     SDValue visitFTRUNC(SDNode *N);
488     SDValue visitFFLOOR(SDNode *N);
489     SDValue visitFMinMax(SDNode *N);
490     SDValue visitBRCOND(SDNode *N);
491     SDValue visitBR_CC(SDNode *N);
492     SDValue visitLOAD(SDNode *N);
493 
494     SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
495     SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
496 
497     SDValue visitSTORE(SDNode *N);
498     SDValue visitLIFETIME_END(SDNode *N);
499     SDValue visitINSERT_VECTOR_ELT(SDNode *N);
500     SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
501     SDValue visitBUILD_VECTOR(SDNode *N);
502     SDValue visitCONCAT_VECTORS(SDNode *N);
503     SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
504     SDValue visitVECTOR_SHUFFLE(SDNode *N);
505     SDValue visitSCALAR_TO_VECTOR(SDNode *N);
506     SDValue visitINSERT_SUBVECTOR(SDNode *N);
507     SDValue visitMLOAD(SDNode *N);
508     SDValue visitMSTORE(SDNode *N);
509     SDValue visitMGATHER(SDNode *N);
510     SDValue visitMSCATTER(SDNode *N);
511     SDValue visitFP_TO_FP16(SDNode *N);
512     SDValue visitFP16_TO_FP(SDNode *N);
513     SDValue visitFP_TO_BF16(SDNode *N);
514     SDValue visitVECREDUCE(SDNode *N);
515     SDValue visitVPOp(SDNode *N);
516 
517     SDValue visitFADDForFMACombine(SDNode *N);
518     SDValue visitFSUBForFMACombine(SDNode *N);
519     SDValue visitFMULForFMADistributiveCombine(SDNode *N);
520 
521     SDValue XformToShuffleWithZero(SDNode *N);
522     bool reassociationCanBreakAddressingModePattern(unsigned Opc,
523                                                     const SDLoc &DL,
524                                                     SDNode *N,
525                                                     SDValue N0,
526                                                     SDValue N1);
527     SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
528                                       SDValue N1);
529     SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
530                            SDValue N1, SDNodeFlags Flags);
531 
532     SDValue visitShiftByConstant(SDNode *N);
533 
534     SDValue foldSelectOfConstants(SDNode *N);
535     SDValue foldVSelectOfConstants(SDNode *N);
536     SDValue foldBinOpIntoSelect(SDNode *BO);
537     bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
538     SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
539     SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
540     SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
541                              SDValue N2, SDValue N3, ISD::CondCode CC,
542                              bool NotExtCompare = false);
543     SDValue convertSelectOfFPConstantsToLoadOffset(
544         const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
545         ISD::CondCode CC);
546     SDValue foldSignChangeInBitcast(SDNode *N);
547     SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
548                                    SDValue N2, SDValue N3, ISD::CondCode CC);
549     SDValue foldSelectOfBinops(SDNode *N);
550     SDValue foldSextSetcc(SDNode *N);
551     SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
552                               const SDLoc &DL);
553     SDValue foldSubToUSubSat(EVT DstVT, SDNode *N);
554     SDValue unfoldMaskedMerge(SDNode *N);
555     SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
556     SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
557                           const SDLoc &DL, bool foldBooleans);
558     SDValue rebuildSetCC(SDValue N);
559 
560     bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
561                            SDValue &CC, bool MatchStrict = false) const;
562     bool isOneUseSetCC(SDValue N) const;
563 
564     SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
565                                          unsigned HiOp);
566     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
567     SDValue CombineExtLoad(SDNode *N);
568     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
569     SDValue combineRepeatedFPDivisors(SDNode *N);
570     SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
571     SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
572     SDValue BuildSDIV(SDNode *N);
573     SDValue BuildSDIVPow2(SDNode *N);
574     SDValue BuildUDIV(SDNode *N);
575     SDValue BuildSREMPow2(SDNode *N);
576     SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
577     SDValue BuildLogBase2(SDValue V, const SDLoc &DL);
578     SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
579     SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
580     SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
581     SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
582     SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
583                                 SDNodeFlags Flags, bool Reciprocal);
584     SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
585                                 SDNodeFlags Flags, bool Reciprocal);
586     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
587                                bool DemandHighBits = true);
588     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
589     SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
590                               SDValue InnerPos, SDValue InnerNeg, bool HasPos,
591                               unsigned PosOpcode, unsigned NegOpcode,
592                               const SDLoc &DL);
593     SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
594                               SDValue InnerPos, SDValue InnerNeg, bool HasPos,
595                               unsigned PosOpcode, unsigned NegOpcode,
596                               const SDLoc &DL);
597     SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
598     SDValue MatchLoadCombine(SDNode *N);
599     SDValue mergeTruncStores(StoreSDNode *N);
600     SDValue reduceLoadWidth(SDNode *N);
601     SDValue ReduceLoadOpStoreWidth(SDNode *N);
602     SDValue splitMergedValStore(StoreSDNode *ST);
603     SDValue TransformFPLoadStorePair(SDNode *N);
604     SDValue convertBuildVecZextToZext(SDNode *N);
605     SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
606     SDValue reduceBuildVecTruncToBitCast(SDNode *N);
607     SDValue reduceBuildVecToShuffle(SDNode *N);
608     SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
609                                   ArrayRef<int> VectorMask, SDValue VecIn1,
610                                   SDValue VecIn2, unsigned LeftIdx,
611                                   bool DidSplitVec);
612     SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
613 
614     /// Walk up chain skipping non-aliasing memory nodes,
615     /// looking for aliasing nodes and adding them to the Aliases vector.
616     void GatherAllAliases(SDNode *N, SDValue OriginalChain,
617                           SmallVectorImpl<SDValue> &Aliases);
618 
619     /// Return true if there is any possibility that the two addresses overlap.
620     bool mayAlias(SDNode *Op0, SDNode *Op1) const;
621 
622     /// Walk up chain skipping non-aliasing memory nodes, looking for a better
623     /// chain (aliasing node.)
624     SDValue FindBetterChain(SDNode *N, SDValue Chain);
625 
626     /// Try to replace a store and any possibly adjacent stores on
627     /// consecutive chains with better chains. Return true only if St is
628     /// replaced.
629     ///
630     /// Notice that other chains may still be replaced even if the function
631     /// returns false.
632     bool findBetterNeighborChains(StoreSDNode *St);
633 
634     // Helper for findBetterNeighborChains. Walk up store chain add additional
635     // chained stores that do not overlap and can be parallelized.
636     bool parallelizeChainedStores(StoreSDNode *St);
637 
638     /// Holds a pointer to an LSBaseSDNode as well as information on where it
639     /// is located in a sequence of memory operations connected by a chain.
640     struct MemOpLink {
641       // Ptr to the mem node.
642       LSBaseSDNode *MemNode;
643 
644       // Offset from the base ptr.
645       int64_t OffsetFromBase;
646 
MemOpLink__anon54f00e400111::DAGCombiner::MemOpLink647       MemOpLink(LSBaseSDNode *N, int64_t Offset)
648           : MemNode(N), OffsetFromBase(Offset) {}
649     };
650 
651     // Classify the origin of a stored value.
652     enum class StoreSource { Unknown, Constant, Extract, Load };
getStoreSource(SDValue StoreVal)653     StoreSource getStoreSource(SDValue StoreVal) {
654       switch (StoreVal.getOpcode()) {
655       case ISD::Constant:
656       case ISD::ConstantFP:
657         return StoreSource::Constant;
658       case ISD::EXTRACT_VECTOR_ELT:
659       case ISD::EXTRACT_SUBVECTOR:
660         return StoreSource::Extract;
661       case ISD::LOAD:
662         return StoreSource::Load;
663       default:
664         return StoreSource::Unknown;
665       }
666     }
667 
668     /// This is a helper function for visitMUL to check the profitability
669     /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
670     /// MulNode is the original multiply, AddNode is (add x, c1),
671     /// and ConstNode is c2.
672     bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
673                                      SDValue ConstNode);
674 
675     /// This is a helper function for visitAND and visitZERO_EXTEND.  Returns
676     /// true if the (and (load x) c) pattern matches an extload.  ExtVT returns
677     /// the type of the loaded value to be extended.
678     bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
679                           EVT LoadResultTy, EVT &ExtVT);
680 
681     /// Helper function to calculate whether the given Load/Store can have its
682     /// width reduced to ExtVT.
683     bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
684                            EVT &MemVT, unsigned ShAmt = 0);
685 
686     /// Used by BackwardsPropagateMask to find suitable loads.
687     bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
688                            SmallPtrSetImpl<SDNode*> &NodesWithConsts,
689                            ConstantSDNode *Mask, SDNode *&NodeToMask);
690     /// Attempt to propagate a given AND node back to load leaves so that they
691     /// can be combined into narrow loads.
692     bool BackwardsPropagateMask(SDNode *N);
693 
694     /// Helper function for mergeConsecutiveStores which merges the component
695     /// store chains.
696     SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
697                                 unsigned NumStores);
698 
699     /// This is a helper function for mergeConsecutiveStores. When the source
700     /// elements of the consecutive stores are all constants or all extracted
701     /// vector elements, try to merge them into one larger store introducing
702     /// bitcasts if necessary.  \return True if a merged store was created.
703     bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
704                                          EVT MemVT, unsigned NumStores,
705                                          bool IsConstantSrc, bool UseVector,
706                                          bool UseTrunc);
707 
708     /// This is a helper function for mergeConsecutiveStores. Stores that
709     /// potentially may be merged with St are placed in StoreNodes. RootNode is
710     /// a chain predecessor to all store candidates.
711     void getStoreMergeCandidates(StoreSDNode *St,
712                                  SmallVectorImpl<MemOpLink> &StoreNodes,
713                                  SDNode *&Root);
714 
715     /// Helper function for mergeConsecutiveStores. Checks if candidate stores
716     /// have indirect dependency through their operands. RootNode is the
717     /// predecessor to all stores calculated by getStoreMergeCandidates and is
718     /// used to prune the dependency check. \return True if safe to merge.
719     bool checkMergeStoreCandidatesForDependencies(
720         SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
721         SDNode *RootNode);
722 
723     /// This is a helper function for mergeConsecutiveStores. Given a list of
724     /// store candidates, find the first N that are consecutive in memory.
725     /// Returns 0 if there are not at least 2 consecutive stores to try merging.
726     unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
727                                   int64_t ElementSizeBytes) const;
728 
729     /// This is a helper function for mergeConsecutiveStores. It is used for
730     /// store chains that are composed entirely of constant values.
731     bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
732                                   unsigned NumConsecutiveStores,
733                                   EVT MemVT, SDNode *Root, bool AllowVectors);
734 
735     /// This is a helper function for mergeConsecutiveStores. It is used for
736     /// store chains that are composed entirely of extracted vector elements.
737     /// When extracting multiple vector elements, try to store them in one
738     /// vector store rather than a sequence of scalar stores.
739     bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
740                                  unsigned NumConsecutiveStores, EVT MemVT,
741                                  SDNode *Root);
742 
743     /// This is a helper function for mergeConsecutiveStores. It is used for
744     /// store chains that are composed entirely of loaded values.
745     bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
746                               unsigned NumConsecutiveStores, EVT MemVT,
747                               SDNode *Root, bool AllowVectors,
748                               bool IsNonTemporalStore, bool IsNonTemporalLoad);
749 
750     /// Merge consecutive store operations into a wide store.
751     /// This optimization uses wide integers or vectors when possible.
752     /// \return true if stores were merged.
753     bool mergeConsecutiveStores(StoreSDNode *St);
754 
755     /// Try to transform a truncation where C is a constant:
756     ///     (trunc (and X, C)) -> (and (trunc X), (trunc C))
757     ///
758     /// \p N needs to be a truncation and its first operand an AND. Other
759     /// requirements are checked by the function (e.g. that trunc is
760     /// single-use) and if missed an empty SDValue is returned.
761     SDValue distributeTruncateThroughAnd(SDNode *N);
762 
763     /// Helper function to determine whether the target supports operation
764     /// given by \p Opcode for type \p VT, that is, whether the operation
765     /// is legal or custom before legalizing operations, and whether is
766     /// legal (but not custom) after legalization.
hasOperation(unsigned Opcode,EVT VT)767     bool hasOperation(unsigned Opcode, EVT VT) {
768       return TLI.isOperationLegalOrCustom(Opcode, VT, LegalOperations);
769     }
770 
771   public:
772     /// Runs the dag combiner on all nodes in the work list
773     void Run(CombineLevel AtLevel);
774 
getDAG() const775     SelectionDAG &getDAG() const { return DAG; }
776 
777     /// Returns a type large enough to hold any valid shift amount - before type
778     /// legalization these can be huge.
getShiftAmountTy(EVT LHSTy)779     EVT getShiftAmountTy(EVT LHSTy) {
780       assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
781       return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout(), LegalTypes);
782     }
783 
784     /// This method returns true if we are running before type legalization or
785     /// if the specified VT is legal.
isTypeLegal(const EVT & VT)786     bool isTypeLegal(const EVT &VT) {
787       if (!LegalTypes) return true;
788       return TLI.isTypeLegal(VT);
789     }
790 
791     /// Convenience wrapper around TargetLowering::getSetCCResultType
getSetCCResultType(EVT VT) const792     EVT getSetCCResultType(EVT VT) const {
793       return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
794     }
795 
796     void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
797                          SDValue OrigLoad, SDValue ExtLoad,
798                          ISD::NodeType ExtType);
799   };
800 
801 /// This class is a DAGUpdateListener that removes any deleted
802 /// nodes from the worklist.
803 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
804   DAGCombiner &DC;
805 
806 public:
WorklistRemover(DAGCombiner & dc)807   explicit WorklistRemover(DAGCombiner &dc)
808     : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
809 
NodeDeleted(SDNode * N,SDNode * E)810   void NodeDeleted(SDNode *N, SDNode *E) override {
811     DC.removeFromWorklist(N);
812   }
813 };
814 
815 class WorklistInserter : public SelectionDAG::DAGUpdateListener {
816   DAGCombiner &DC;
817 
818 public:
WorklistInserter(DAGCombiner & dc)819   explicit WorklistInserter(DAGCombiner &dc)
820       : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
821 
822   // FIXME: Ideally we could add N to the worklist, but this causes exponential
823   //        compile time costs in large DAGs, e.g. Halide.
NodeInserted(SDNode * N)824   void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
825 };
826 
827 } // end anonymous namespace
828 
829 //===----------------------------------------------------------------------===//
830 //  TargetLowering::DAGCombinerInfo implementation
831 //===----------------------------------------------------------------------===//
832 
AddToWorklist(SDNode * N)833 void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
834   ((DAGCombiner*)DC)->AddToWorklist(N);
835 }
836 
837 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,ArrayRef<SDValue> To,bool AddTo)838 CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
839   return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
840 }
841 
842 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res,bool AddTo)843 CombineTo(SDNode *N, SDValue Res, bool AddTo) {
844   return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
845 }
846 
847 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo)848 CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
849   return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
850 }
851 
852 bool TargetLowering::DAGCombinerInfo::
recursivelyDeleteUnusedNodes(SDNode * N)853 recursivelyDeleteUnusedNodes(SDNode *N) {
854   return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
855 }
856 
857 void TargetLowering::DAGCombinerInfo::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)858 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
859   return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
860 }
861 
862 //===----------------------------------------------------------------------===//
863 // Helper Functions
864 //===----------------------------------------------------------------------===//
865 
deleteAndRecombine(SDNode * N)866 void DAGCombiner::deleteAndRecombine(SDNode *N) {
867   removeFromWorklist(N);
868 
869   // If the operands of this node are only used by the node, they will now be
870   // dead. Make sure to re-visit them and recursively delete dead nodes.
871   for (const SDValue &Op : N->ops())
872     // For an operand generating multiple values, one of the values may
873     // become dead allowing further simplification (e.g. split index
874     // arithmetic from an indexed load).
875     if (Op->hasOneUse() || Op->getNumValues() > 1)
876       AddToWorklist(Op.getNode());
877 
878   DAG.DeleteNode(N);
879 }
880 
881 // APInts must be the same size for most operations, this helper
882 // function zero extends the shorter of the pair so that they match.
883 // We provide an Offset so that we can create bitwidths that won't overflow.
zeroExtendToMatch(APInt & LHS,APInt & RHS,unsigned Offset=0)884 static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
885   unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
886   LHS = LHS.zext(Bits);
887   RHS = RHS.zext(Bits);
888 }
889 
890 // Return true if this node is a setcc, or is a select_cc
891 // that selects between the target values used for true and false, making it
892 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
893 // the appropriate nodes based on the type of node we are checking. This
894 // simplifies life a bit for the callers.
isSetCCEquivalent(SDValue N,SDValue & LHS,SDValue & RHS,SDValue & CC,bool MatchStrict) const895 bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
896                                     SDValue &CC, bool MatchStrict) const {
897   if (N.getOpcode() == ISD::SETCC) {
898     LHS = N.getOperand(0);
899     RHS = N.getOperand(1);
900     CC  = N.getOperand(2);
901     return true;
902   }
903 
904   if (MatchStrict &&
905       (N.getOpcode() == ISD::STRICT_FSETCC ||
906        N.getOpcode() == ISD::STRICT_FSETCCS)) {
907     LHS = N.getOperand(1);
908     RHS = N.getOperand(2);
909     CC  = N.getOperand(3);
910     return true;
911   }
912 
913   if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N.getOperand(2)) ||
914       !TLI.isConstFalseVal(N.getOperand(3)))
915     return false;
916 
917   if (TLI.getBooleanContents(N.getValueType()) ==
918       TargetLowering::UndefinedBooleanContent)
919     return false;
920 
921   LHS = N.getOperand(0);
922   RHS = N.getOperand(1);
923   CC  = N.getOperand(4);
924   return true;
925 }
926 
927 /// Return true if this is a SetCC-equivalent operation with only one use.
928 /// If this is true, it allows the users to invert the operation for free when
929 /// it is profitable to do so.
isOneUseSetCC(SDValue N) const930 bool DAGCombiner::isOneUseSetCC(SDValue N) const {
931   SDValue N0, N1, N2;
932   if (isSetCCEquivalent(N, N0, N1, N2) && N->hasOneUse())
933     return true;
934   return false;
935 }
936 
isConstantSplatVectorMaskForType(SDNode * N,EVT ScalarTy)937 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
938   if (!ScalarTy.isSimple())
939     return false;
940 
941   uint64_t MaskForTy = 0ULL;
942   switch (ScalarTy.getSimpleVT().SimpleTy) {
943   case MVT::i8:
944     MaskForTy = 0xFFULL;
945     break;
946   case MVT::i16:
947     MaskForTy = 0xFFFFULL;
948     break;
949   case MVT::i32:
950     MaskForTy = 0xFFFFFFFFULL;
951     break;
952   default:
953     return false;
954     break;
955   }
956 
957   APInt Val;
958   if (ISD::isConstantSplatVector(N, Val))
959     return Val.getLimitedValue() == MaskForTy;
960 
961   return false;
962 }
963 
964 // Determines if it is a constant integer or a splat/build vector of constant
965 // integers (and undefs).
966 // Do not permit build vector implicit truncation.
isConstantOrConstantVector(SDValue N,bool NoOpaques=false)967 static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
968   if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
969     return !(Const->isOpaque() && NoOpaques);
970   if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
971     return false;
972   unsigned BitWidth = N.getScalarValueSizeInBits();
973   for (const SDValue &Op : N->op_values()) {
974     if (Op.isUndef())
975       continue;
976     ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
977     if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
978         (Const->isOpaque() && NoOpaques))
979       return false;
980   }
981   return true;
982 }
983 
984 // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
985 // undef's.
isAnyConstantBuildVector(SDValue V,bool NoOpaques=false)986 static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
987   if (V.getOpcode() != ISD::BUILD_VECTOR)
988     return false;
989   return isConstantOrConstantVector(V, NoOpaques) ||
990          ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
991 }
992 
993 // Determine if this an indexed load with an opaque target constant index.
canSplitIdx(LoadSDNode * LD)994 static bool canSplitIdx(LoadSDNode *LD) {
995   return MaySplitLoadIndex &&
996          (LD->getOperand(2).getOpcode() != ISD::TargetConstant ||
997           !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
998 }
999 
reassociationCanBreakAddressingModePattern(unsigned Opc,const SDLoc & DL,SDNode * N,SDValue N0,SDValue N1)1000 bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1001                                                              const SDLoc &DL,
1002                                                              SDNode *N,
1003                                                              SDValue N0,
1004                                                              SDValue N1) {
1005   // Currently this only tries to ensure we don't undo the GEP splits done by
1006   // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1007   // we check if the following transformation would be problematic:
1008   // (load/store (add, (add, x, offset1), offset2)) ->
1009   // (load/store (add, x, offset1+offset2)).
1010 
1011   // (load/store (add, (add, x, y), offset2)) ->
1012   // (load/store (add, (add, x, offset2), y)).
1013 
1014   if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
1015     return false;
1016 
1017   auto *C2 = dyn_cast<ConstantSDNode>(N1);
1018   if (!C2)
1019     return false;
1020 
1021   const APInt &C2APIntVal = C2->getAPIntValue();
1022   if (C2APIntVal.getSignificantBits() > 64)
1023     return false;
1024 
1025   if (auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
1026     if (N0.hasOneUse())
1027       return false;
1028 
1029     const APInt &C1APIntVal = C1->getAPIntValue();
1030     const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1031     if (CombinedValueIntVal.getSignificantBits() > 64)
1032       return false;
1033     const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1034 
1035     for (SDNode *Node : N->uses()) {
1036       if (auto *LoadStore = dyn_cast<MemSDNode>(Node)) {
1037         // Is x[offset2] already not a legal addressing mode? If so then
1038         // reassociating the constants breaks nothing (we test offset2 because
1039         // that's the one we hope to fold into the load or store).
1040         TargetLoweringBase::AddrMode AM;
1041         AM.HasBaseReg = true;
1042         AM.BaseOffs = C2APIntVal.getSExtValue();
1043         EVT VT = LoadStore->getMemoryVT();
1044         unsigned AS = LoadStore->getAddressSpace();
1045         Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1046         if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1047           continue;
1048 
1049         // Would x[offset1+offset2] still be a legal addressing mode?
1050         AM.BaseOffs = CombinedValue;
1051         if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1052           return true;
1053       }
1054     }
1055   } else {
1056     if (auto *GA = dyn_cast<GlobalAddressSDNode>(N0.getOperand(1)))
1057       if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1058         return false;
1059 
1060     for (SDNode *Node : N->uses()) {
1061       auto *LoadStore = dyn_cast<MemSDNode>(Node);
1062       if (!LoadStore)
1063         return false;
1064 
1065       // Is x[offset2] a legal addressing mode? If so then
1066       // reassociating the constants breaks address pattern
1067       TargetLoweringBase::AddrMode AM;
1068       AM.HasBaseReg = true;
1069       AM.BaseOffs = C2APIntVal.getSExtValue();
1070       EVT VT = LoadStore->getMemoryVT();
1071       unsigned AS = LoadStore->getAddressSpace();
1072       Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1073       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1074         return false;
1075     }
1076     return true;
1077   }
1078 
1079   return false;
1080 }
1081 
1082 // Helper for DAGCombiner::reassociateOps. Try to reassociate an expression
1083 // such as (Opc N0, N1), if \p N0 is the same kind of operation as \p Opc.
reassociateOpsCommutative(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1)1084 SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1085                                                SDValue N0, SDValue N1) {
1086   EVT VT = N0.getValueType();
1087 
1088   if (N0.getOpcode() != Opc)
1089     return SDValue();
1090 
1091   SDValue N00 = N0.getOperand(0);
1092   SDValue N01 = N0.getOperand(1);
1093 
1094   if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(N01))) {
1095     if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(N1))) {
1096       // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1097       if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, {N01, N1}))
1098         return DAG.getNode(Opc, DL, VT, N00, OpNode);
1099       return SDValue();
1100     }
1101     if (TLI.isReassocProfitable(DAG, N0, N1)) {
1102       // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1103       //              iff (op x, c1) has one use
1104       SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1);
1105       return DAG.getNode(Opc, DL, VT, OpNode, N01);
1106     }
1107   }
1108 
1109   // Check for repeated operand logic simplifications.
1110   if (Opc == ISD::AND || Opc == ISD::OR) {
1111     // (N00 & N01) & N00 --> N00 & N01
1112     // (N00 & N01) & N01 --> N00 & N01
1113     // (N00 | N01) | N00 --> N00 | N01
1114     // (N00 | N01) | N01 --> N00 | N01
1115     if (N1 == N00 || N1 == N01)
1116       return N0;
1117   }
1118   if (Opc == ISD::XOR) {
1119     // (N00 ^ N01) ^ N00 --> N01
1120     if (N1 == N00)
1121       return N01;
1122     // (N00 ^ N01) ^ N01 --> N00
1123     if (N1 == N01)
1124       return N00;
1125   }
1126 
1127   if (TLI.isReassocProfitable(DAG, N0, N1)) {
1128     if (N1 != N01) {
1129       // Reassociate if (op N00, N1) already exist
1130       if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N00, N1})) {
1131         // if Op (Op N00, N1), N01 already exist
1132         // we need to stop reassciate to avoid dead loop
1133         if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N01}))
1134           return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N01);
1135       }
1136     }
1137 
1138     if (N1 != N00) {
1139       // Reassociate if (op N01, N1) already exist
1140       if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N01, N1})) {
1141         // if Op (Op N01, N1), N00 already exist
1142         // we need to stop reassciate to avoid dead loop
1143         if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N00}))
1144           return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N00);
1145       }
1146     }
1147   }
1148 
1149   return SDValue();
1150 }
1151 
1152 // Try to reassociate commutative binops.
reassociateOps(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1153 SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1154                                     SDValue N1, SDNodeFlags Flags) {
1155   assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1156 
1157   // Floating-point reassociation is not allowed without loose FP math.
1158   if (N0.getValueType().isFloatingPoint() ||
1159       N1.getValueType().isFloatingPoint())
1160     if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1161       return SDValue();
1162 
1163   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1))
1164     return Combined;
1165   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0))
1166     return Combined;
1167   return SDValue();
1168 }
1169 
CombineTo(SDNode * N,const SDValue * To,unsigned NumTo,bool AddTo)1170 SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1171                                bool AddTo) {
1172   assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1173   ++NodesCombined;
1174   LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1175              To[0].dump(&DAG);
1176              dbgs() << " and " << NumTo - 1 << " other values\n");
1177   for (unsigned i = 0, e = NumTo; i != e; ++i)
1178     assert((!To[i].getNode() ||
1179             N->getValueType(i) == To[i].getValueType()) &&
1180            "Cannot combine value to value of different type!");
1181 
1182   WorklistRemover DeadNodes(*this);
1183   DAG.ReplaceAllUsesWith(N, To);
1184   if (AddTo) {
1185     // Push the new nodes and any users onto the worklist
1186     for (unsigned i = 0, e = NumTo; i != e; ++i) {
1187       if (To[i].getNode())
1188         AddToWorklistWithUsers(To[i].getNode());
1189     }
1190   }
1191 
1192   // Finally, if the node is now dead, remove it from the graph.  The node
1193   // may not be dead if the replacement process recursively simplified to
1194   // something else needing this node.
1195   if (N->use_empty())
1196     deleteAndRecombine(N);
1197   return SDValue(N, 0);
1198 }
1199 
1200 void DAGCombiner::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1201 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1202   // Replace the old value with the new one.
1203   ++NodesCombined;
1204   LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1205              dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1206 
1207   // Replace all uses.  If any nodes become isomorphic to other nodes and
1208   // are deleted, make sure to remove them from our worklist.
1209   WorklistRemover DeadNodes(*this);
1210   DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1211 
1212   // Push the new node and any (possibly new) users onto the worklist.
1213   AddToWorklistWithUsers(TLO.New.getNode());
1214 
1215   // Finally, if the node is now dead, remove it from the graph.  The node
1216   // may not be dead if the replacement process recursively simplified to
1217   // something else needing this node.
1218   if (TLO.Old->use_empty())
1219     deleteAndRecombine(TLO.Old.getNode());
1220 }
1221 
1222 /// Check the specified integer node value to see if it can be simplified or if
1223 /// things it uses can be simplified by bit propagation. If so, return true.
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,bool AssumeSingleUse)1224 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1225                                        const APInt &DemandedElts,
1226                                        bool AssumeSingleUse) {
1227   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1228   KnownBits Known;
1229   if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, 0,
1230                                 AssumeSingleUse))
1231     return false;
1232 
1233   // Revisit the node.
1234   AddToWorklist(Op.getNode());
1235 
1236   CommitTargetLoweringOpt(TLO);
1237   return true;
1238 }
1239 
1240 /// Check the specified vector node value to see if it can be simplified or
1241 /// if things it uses can be simplified as it only uses some of the elements.
1242 /// If so, return true.
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,bool AssumeSingleUse)1243 bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1244                                              const APInt &DemandedElts,
1245                                              bool AssumeSingleUse) {
1246   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1247   APInt KnownUndef, KnownZero;
1248   if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1249                                       TLO, 0, AssumeSingleUse))
1250     return false;
1251 
1252   // Revisit the node.
1253   AddToWorklist(Op.getNode());
1254 
1255   CommitTargetLoweringOpt(TLO);
1256   return true;
1257 }
1258 
ReplaceLoadWithPromotedLoad(SDNode * Load,SDNode * ExtLoad)1259 void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1260   SDLoc DL(Load);
1261   EVT VT = Load->getValueType(0);
1262   SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1263 
1264   LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1265              Trunc.dump(&DAG); dbgs() << '\n');
1266   WorklistRemover DeadNodes(*this);
1267   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1268   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1269   deleteAndRecombine(Load);
1270   AddToWorklist(Trunc.getNode());
1271 }
1272 
PromoteOperand(SDValue Op,EVT PVT,bool & Replace)1273 SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1274   Replace = false;
1275   SDLoc DL(Op);
1276   if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1277     LoadSDNode *LD = cast<LoadSDNode>(Op);
1278     EVT MemVT = LD->getMemoryVT();
1279     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1280                                                       : LD->getExtensionType();
1281     Replace = true;
1282     return DAG.getExtLoad(ExtType, DL, PVT,
1283                           LD->getChain(), LD->getBasePtr(),
1284                           MemVT, LD->getMemOperand());
1285   }
1286 
1287   unsigned Opc = Op.getOpcode();
1288   switch (Opc) {
1289   default: break;
1290   case ISD::AssertSext:
1291     if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1292       return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1293     break;
1294   case ISD::AssertZext:
1295     if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1296       return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1297     break;
1298   case ISD::Constant: {
1299     unsigned ExtOpc =
1300       Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1301     return DAG.getNode(ExtOpc, DL, PVT, Op);
1302   }
1303   }
1304 
1305   if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1306     return SDValue();
1307   return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1308 }
1309 
SExtPromoteOperand(SDValue Op,EVT PVT)1310 SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1311   if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
1312     return SDValue();
1313   EVT OldVT = Op.getValueType();
1314   SDLoc DL(Op);
1315   bool Replace = false;
1316   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1317   if (!NewOp.getNode())
1318     return SDValue();
1319   AddToWorklist(NewOp.getNode());
1320 
1321   if (Replace)
1322     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1323   return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1324                      DAG.getValueType(OldVT));
1325 }
1326 
ZExtPromoteOperand(SDValue Op,EVT PVT)1327 SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1328   EVT OldVT = Op.getValueType();
1329   SDLoc DL(Op);
1330   bool Replace = false;
1331   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1332   if (!NewOp.getNode())
1333     return SDValue();
1334   AddToWorklist(NewOp.getNode());
1335 
1336   if (Replace)
1337     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1338   return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1339 }
1340 
1341 /// Promote the specified integer binary operation if the target indicates it is
1342 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1343 /// i32 since i16 instructions are longer.
PromoteIntBinOp(SDValue Op)1344 SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1345   if (!LegalOperations)
1346     return SDValue();
1347 
1348   EVT VT = Op.getValueType();
1349   if (VT.isVector() || !VT.isInteger())
1350     return SDValue();
1351 
1352   // If operation type is 'undesirable', e.g. i16 on x86, consider
1353   // promoting it.
1354   unsigned Opc = Op.getOpcode();
1355   if (TLI.isTypeDesirableForOp(Opc, VT))
1356     return SDValue();
1357 
1358   EVT PVT = VT;
1359   // Consult target whether it is a good idea to promote this operation and
1360   // what's the right type to promote it to.
1361   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1362     assert(PVT != VT && "Don't know what type to promote to!");
1363 
1364     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1365 
1366     bool Replace0 = false;
1367     SDValue N0 = Op.getOperand(0);
1368     SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1369 
1370     bool Replace1 = false;
1371     SDValue N1 = Op.getOperand(1);
1372     SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1373     SDLoc DL(Op);
1374 
1375     SDValue RV =
1376         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1377 
1378     // We are always replacing N0/N1's use in N and only need additional
1379     // replacements if there are additional uses.
1380     // Note: We are checking uses of the *nodes* (SDNode) rather than values
1381     //       (SDValue) here because the node may reference multiple values
1382     //       (for example, the chain value of a load node).
1383     Replace0 &= !N0->hasOneUse();
1384     Replace1 &= (N0 != N1) && !N1->hasOneUse();
1385 
1386     // Combine Op here so it is preserved past replacements.
1387     CombineTo(Op.getNode(), RV);
1388 
1389     // If operands have a use ordering, make sure we deal with
1390     // predecessor first.
1391     if (Replace0 && Replace1 && N0->isPredecessorOf(N1.getNode())) {
1392       std::swap(N0, N1);
1393       std::swap(NN0, NN1);
1394     }
1395 
1396     if (Replace0) {
1397       AddToWorklist(NN0.getNode());
1398       ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1399     }
1400     if (Replace1) {
1401       AddToWorklist(NN1.getNode());
1402       ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1403     }
1404     return Op;
1405   }
1406   return SDValue();
1407 }
1408 
1409 /// Promote the specified integer shift operation if the target indicates it is
1410 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1411 /// i32 since i16 instructions are longer.
PromoteIntShiftOp(SDValue Op)1412 SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1413   if (!LegalOperations)
1414     return SDValue();
1415 
1416   EVT VT = Op.getValueType();
1417   if (VT.isVector() || !VT.isInteger())
1418     return SDValue();
1419 
1420   // If operation type is 'undesirable', e.g. i16 on x86, consider
1421   // promoting it.
1422   unsigned Opc = Op.getOpcode();
1423   if (TLI.isTypeDesirableForOp(Opc, VT))
1424     return SDValue();
1425 
1426   EVT PVT = VT;
1427   // Consult target whether it is a good idea to promote this operation and
1428   // what's the right type to promote it to.
1429   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1430     assert(PVT != VT && "Don't know what type to promote to!");
1431 
1432     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1433 
1434     bool Replace = false;
1435     SDValue N0 = Op.getOperand(0);
1436     if (Opc == ISD::SRA)
1437       N0 = SExtPromoteOperand(N0, PVT);
1438     else if (Opc == ISD::SRL)
1439       N0 = ZExtPromoteOperand(N0, PVT);
1440     else
1441       N0 = PromoteOperand(N0, PVT, Replace);
1442 
1443     if (!N0.getNode())
1444       return SDValue();
1445 
1446     SDLoc DL(Op);
1447     SDValue N1 = Op.getOperand(1);
1448     SDValue RV =
1449         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
1450 
1451     if (Replace)
1452       ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1453 
1454     // Deal with Op being deleted.
1455     if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1456       return RV;
1457   }
1458   return SDValue();
1459 }
1460 
PromoteExtend(SDValue Op)1461 SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1462   if (!LegalOperations)
1463     return SDValue();
1464 
1465   EVT VT = Op.getValueType();
1466   if (VT.isVector() || !VT.isInteger())
1467     return SDValue();
1468 
1469   // If operation type is 'undesirable', e.g. i16 on x86, consider
1470   // promoting it.
1471   unsigned Opc = Op.getOpcode();
1472   if (TLI.isTypeDesirableForOp(Opc, VT))
1473     return SDValue();
1474 
1475   EVT PVT = VT;
1476   // Consult target whether it is a good idea to promote this operation and
1477   // what's the right type to promote it to.
1478   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1479     assert(PVT != VT && "Don't know what type to promote to!");
1480     // fold (aext (aext x)) -> (aext x)
1481     // fold (aext (zext x)) -> (zext x)
1482     // fold (aext (sext x)) -> (sext x)
1483     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1484     return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1485   }
1486   return SDValue();
1487 }
1488 
PromoteLoad(SDValue Op)1489 bool DAGCombiner::PromoteLoad(SDValue Op) {
1490   if (!LegalOperations)
1491     return false;
1492 
1493   if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1494     return false;
1495 
1496   EVT VT = Op.getValueType();
1497   if (VT.isVector() || !VT.isInteger())
1498     return false;
1499 
1500   // If operation type is 'undesirable', e.g. i16 on x86, consider
1501   // promoting it.
1502   unsigned Opc = Op.getOpcode();
1503   if (TLI.isTypeDesirableForOp(Opc, VT))
1504     return false;
1505 
1506   EVT PVT = VT;
1507   // Consult target whether it is a good idea to promote this operation and
1508   // what's the right type to promote it to.
1509   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1510     assert(PVT != VT && "Don't know what type to promote to!");
1511 
1512     SDLoc DL(Op);
1513     SDNode *N = Op.getNode();
1514     LoadSDNode *LD = cast<LoadSDNode>(N);
1515     EVT MemVT = LD->getMemoryVT();
1516     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1517                                                       : LD->getExtensionType();
1518     SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1519                                    LD->getChain(), LD->getBasePtr(),
1520                                    MemVT, LD->getMemOperand());
1521     SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1522 
1523     LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1524                Result.dump(&DAG); dbgs() << '\n');
1525     WorklistRemover DeadNodes(*this);
1526     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1527     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1528     deleteAndRecombine(N);
1529     AddToWorklist(Result.getNode());
1530     return true;
1531   }
1532   return false;
1533 }
1534 
1535 /// Recursively delete a node which has no uses and any operands for
1536 /// which it is the only use.
1537 ///
1538 /// Note that this both deletes the nodes and removes them from the worklist.
1539 /// It also adds any nodes who have had a user deleted to the worklist as they
1540 /// may now have only one use and subject to other combines.
recursivelyDeleteUnusedNodes(SDNode * N)1541 bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1542   if (!N->use_empty())
1543     return false;
1544 
1545   SmallSetVector<SDNode *, 16> Nodes;
1546   Nodes.insert(N);
1547   do {
1548     N = Nodes.pop_back_val();
1549     if (!N)
1550       continue;
1551 
1552     if (N->use_empty()) {
1553       for (const SDValue &ChildN : N->op_values())
1554         Nodes.insert(ChildN.getNode());
1555 
1556       removeFromWorklist(N);
1557       DAG.DeleteNode(N);
1558     } else {
1559       AddToWorklist(N);
1560     }
1561   } while (!Nodes.empty());
1562   return true;
1563 }
1564 
1565 //===----------------------------------------------------------------------===//
1566 //  Main DAG Combiner implementation
1567 //===----------------------------------------------------------------------===//
1568 
Run(CombineLevel AtLevel)1569 void DAGCombiner::Run(CombineLevel AtLevel) {
1570   // set the instance variables, so that the various visit routines may use it.
1571   Level = AtLevel;
1572   LegalDAG = Level >= AfterLegalizeDAG;
1573   LegalOperations = Level >= AfterLegalizeVectorOps;
1574   LegalTypes = Level >= AfterLegalizeTypes;
1575 
1576   WorklistInserter AddNodes(*this);
1577 
1578   // Add all the dag nodes to the worklist.
1579   for (SDNode &Node : DAG.allnodes())
1580     AddToWorklist(&Node);
1581 
1582   // Create a dummy node (which is not added to allnodes), that adds a reference
1583   // to the root node, preventing it from being deleted, and tracking any
1584   // changes of the root.
1585   HandleSDNode Dummy(DAG.getRoot());
1586 
1587   // While we have a valid worklist entry node, try to combine it.
1588   while (SDNode *N = getNextWorklistEntry()) {
1589     // If N has no uses, it is dead.  Make sure to revisit all N's operands once
1590     // N is deleted from the DAG, since they too may now be dead or may have a
1591     // reduced number of uses, allowing other xforms.
1592     if (recursivelyDeleteUnusedNodes(N))
1593       continue;
1594 
1595     WorklistRemover DeadNodes(*this);
1596 
1597     // If this combine is running after legalizing the DAG, re-legalize any
1598     // nodes pulled off the worklist.
1599     if (LegalDAG) {
1600       SmallSetVector<SDNode *, 16> UpdatedNodes;
1601       bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1602 
1603       for (SDNode *LN : UpdatedNodes)
1604         AddToWorklistWithUsers(LN);
1605 
1606       if (!NIsValid)
1607         continue;
1608     }
1609 
1610     LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1611 
1612     // Add any operands of the new node which have not yet been combined to the
1613     // worklist as well. Because the worklist uniques things already, this
1614     // won't repeatedly process the same operand.
1615     CombinedNodes.insert(N);
1616     for (const SDValue &ChildN : N->op_values())
1617       if (!CombinedNodes.count(ChildN.getNode()))
1618         AddToWorklist(ChildN.getNode());
1619 
1620     SDValue RV = combine(N);
1621 
1622     if (!RV.getNode())
1623       continue;
1624 
1625     ++NodesCombined;
1626 
1627     // If we get back the same node we passed in, rather than a new node or
1628     // zero, we know that the node must have defined multiple values and
1629     // CombineTo was used.  Since CombineTo takes care of the worklist
1630     // mechanics for us, we have no work to do in this case.
1631     if (RV.getNode() == N)
1632       continue;
1633 
1634     assert(N->getOpcode() != ISD::DELETED_NODE &&
1635            RV.getOpcode() != ISD::DELETED_NODE &&
1636            "Node was deleted but visit returned new node!");
1637 
1638     LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1639 
1640     if (N->getNumValues() == RV->getNumValues())
1641       DAG.ReplaceAllUsesWith(N, RV.getNode());
1642     else {
1643       assert(N->getValueType(0) == RV.getValueType() &&
1644              N->getNumValues() == 1 && "Type mismatch");
1645       DAG.ReplaceAllUsesWith(N, &RV);
1646     }
1647 
1648     // Push the new node and any users onto the worklist.  Omit this if the
1649     // new node is the EntryToken (e.g. if a store managed to get optimized
1650     // out), because re-visiting the EntryToken and its users will not uncover
1651     // any additional opportunities, but there may be a large number of such
1652     // users, potentially causing compile time explosion.
1653     if (RV.getOpcode() != ISD::EntryToken) {
1654       AddToWorklist(RV.getNode());
1655       AddUsersToWorklist(RV.getNode());
1656     }
1657 
1658     // Finally, if the node is now dead, remove it from the graph.  The node
1659     // may not be dead if the replacement process recursively simplified to
1660     // something else needing this node. This will also take care of adding any
1661     // operands which have lost a user to the worklist.
1662     recursivelyDeleteUnusedNodes(N);
1663   }
1664 
1665   // If the root changed (e.g. it was a dead load, update the root).
1666   DAG.setRoot(Dummy.getValue());
1667   DAG.RemoveDeadNodes();
1668 }
1669 
visit(SDNode * N)1670 SDValue DAGCombiner::visit(SDNode *N) {
1671   switch (N->getOpcode()) {
1672   default: break;
1673   case ISD::TokenFactor:        return visitTokenFactor(N);
1674   case ISD::MERGE_VALUES:       return visitMERGE_VALUES(N);
1675   case ISD::ADD:                return visitADD(N);
1676   case ISD::SUB:                return visitSUB(N);
1677   case ISD::SADDSAT:
1678   case ISD::UADDSAT:            return visitADDSAT(N);
1679   case ISD::SSUBSAT:
1680   case ISD::USUBSAT:            return visitSUBSAT(N);
1681   case ISD::ADDC:               return visitADDC(N);
1682   case ISD::SADDO:
1683   case ISD::UADDO:              return visitADDO(N);
1684   case ISD::SUBC:               return visitSUBC(N);
1685   case ISD::SSUBO:
1686   case ISD::USUBO:              return visitSUBO(N);
1687   case ISD::ADDE:               return visitADDE(N);
1688   case ISD::ADDCARRY:           return visitADDCARRY(N);
1689   case ISD::SADDO_CARRY:        return visitSADDO_CARRY(N);
1690   case ISD::SUBE:               return visitSUBE(N);
1691   case ISD::SUBCARRY:           return visitSUBCARRY(N);
1692   case ISD::SSUBO_CARRY:        return visitSSUBO_CARRY(N);
1693   case ISD::SMULFIX:
1694   case ISD::SMULFIXSAT:
1695   case ISD::UMULFIX:
1696   case ISD::UMULFIXSAT:         return visitMULFIX(N);
1697   case ISD::MUL:                return visitMUL(N);
1698   case ISD::SDIV:               return visitSDIV(N);
1699   case ISD::UDIV:               return visitUDIV(N);
1700   case ISD::SREM:
1701   case ISD::UREM:               return visitREM(N);
1702   case ISD::MULHU:              return visitMULHU(N);
1703   case ISD::MULHS:              return visitMULHS(N);
1704   case ISD::AVGFLOORS:
1705   case ISD::AVGFLOORU:
1706   case ISD::AVGCEILS:
1707   case ISD::AVGCEILU:           return visitAVG(N);
1708   case ISD::SMUL_LOHI:          return visitSMUL_LOHI(N);
1709   case ISD::UMUL_LOHI:          return visitUMUL_LOHI(N);
1710   case ISD::SMULO:
1711   case ISD::UMULO:              return visitMULO(N);
1712   case ISD::SMIN:
1713   case ISD::SMAX:
1714   case ISD::UMIN:
1715   case ISD::UMAX:               return visitIMINMAX(N);
1716   case ISD::AND:                return visitAND(N);
1717   case ISD::OR:                 return visitOR(N);
1718   case ISD::XOR:                return visitXOR(N);
1719   case ISD::SHL:                return visitSHL(N);
1720   case ISD::SRA:                return visitSRA(N);
1721   case ISD::SRL:                return visitSRL(N);
1722   case ISD::ROTR:
1723   case ISD::ROTL:               return visitRotate(N);
1724   case ISD::FSHL:
1725   case ISD::FSHR:               return visitFunnelShift(N);
1726   case ISD::SSHLSAT:
1727   case ISD::USHLSAT:            return visitSHLSAT(N);
1728   case ISD::ABS:                return visitABS(N);
1729   case ISD::BSWAP:              return visitBSWAP(N);
1730   case ISD::BITREVERSE:         return visitBITREVERSE(N);
1731   case ISD::CTLZ:               return visitCTLZ(N);
1732   case ISD::CTLZ_ZERO_UNDEF:    return visitCTLZ_ZERO_UNDEF(N);
1733   case ISD::CTTZ:               return visitCTTZ(N);
1734   case ISD::CTTZ_ZERO_UNDEF:    return visitCTTZ_ZERO_UNDEF(N);
1735   case ISD::CTPOP:              return visitCTPOP(N);
1736   case ISD::SELECT:             return visitSELECT(N);
1737   case ISD::VSELECT:            return visitVSELECT(N);
1738   case ISD::SELECT_CC:          return visitSELECT_CC(N);
1739   case ISD::SETCC:              return visitSETCC(N);
1740   case ISD::SETCCCARRY:         return visitSETCCCARRY(N);
1741   case ISD::SIGN_EXTEND:        return visitSIGN_EXTEND(N);
1742   case ISD::ZERO_EXTEND:        return visitZERO_EXTEND(N);
1743   case ISD::ANY_EXTEND:         return visitANY_EXTEND(N);
1744   case ISD::AssertSext:
1745   case ISD::AssertZext:         return visitAssertExt(N);
1746   case ISD::AssertAlign:        return visitAssertAlign(N);
1747   case ISD::SIGN_EXTEND_INREG:  return visitSIGN_EXTEND_INREG(N);
1748   case ISD::SIGN_EXTEND_VECTOR_INREG:
1749   case ISD::ZERO_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
1750   case ISD::TRUNCATE:           return visitTRUNCATE(N);
1751   case ISD::BITCAST:            return visitBITCAST(N);
1752   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
1753   case ISD::FADD:               return visitFADD(N);
1754   case ISD::STRICT_FADD:        return visitSTRICT_FADD(N);
1755   case ISD::FSUB:               return visitFSUB(N);
1756   case ISD::FMUL:               return visitFMUL(N);
1757   case ISD::FMA:                return visitFMA(N);
1758   case ISD::FDIV:               return visitFDIV(N);
1759   case ISD::FREM:               return visitFREM(N);
1760   case ISD::FSQRT:              return visitFSQRT(N);
1761   case ISD::FCOPYSIGN:          return visitFCOPYSIGN(N);
1762   case ISD::FPOW:               return visitFPOW(N);
1763   case ISD::SINT_TO_FP:         return visitSINT_TO_FP(N);
1764   case ISD::UINT_TO_FP:         return visitUINT_TO_FP(N);
1765   case ISD::FP_TO_SINT:         return visitFP_TO_SINT(N);
1766   case ISD::FP_TO_UINT:         return visitFP_TO_UINT(N);
1767   case ISD::FP_ROUND:           return visitFP_ROUND(N);
1768   case ISD::FP_EXTEND:          return visitFP_EXTEND(N);
1769   case ISD::FNEG:               return visitFNEG(N);
1770   case ISD::FABS:               return visitFABS(N);
1771   case ISD::FFLOOR:             return visitFFLOOR(N);
1772   case ISD::FMINNUM:
1773   case ISD::FMAXNUM:
1774   case ISD::FMINIMUM:
1775   case ISD::FMAXIMUM:           return visitFMinMax(N);
1776   case ISD::FCEIL:              return visitFCEIL(N);
1777   case ISD::FTRUNC:             return visitFTRUNC(N);
1778   case ISD::BRCOND:             return visitBRCOND(N);
1779   case ISD::BR_CC:              return visitBR_CC(N);
1780   case ISD::LOAD:               return visitLOAD(N);
1781   case ISD::STORE:              return visitSTORE(N);
1782   case ISD::INSERT_VECTOR_ELT:  return visitINSERT_VECTOR_ELT(N);
1783   case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
1784   case ISD::BUILD_VECTOR:       return visitBUILD_VECTOR(N);
1785   case ISD::CONCAT_VECTORS:     return visitCONCAT_VECTORS(N);
1786   case ISD::EXTRACT_SUBVECTOR:  return visitEXTRACT_SUBVECTOR(N);
1787   case ISD::VECTOR_SHUFFLE:     return visitVECTOR_SHUFFLE(N);
1788   case ISD::SCALAR_TO_VECTOR:   return visitSCALAR_TO_VECTOR(N);
1789   case ISD::INSERT_SUBVECTOR:   return visitINSERT_SUBVECTOR(N);
1790   case ISD::MGATHER:            return visitMGATHER(N);
1791   case ISD::MLOAD:              return visitMLOAD(N);
1792   case ISD::MSCATTER:           return visitMSCATTER(N);
1793   case ISD::MSTORE:             return visitMSTORE(N);
1794   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
1795   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
1796   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
1797   case ISD::FP_TO_BF16:         return visitFP_TO_BF16(N);
1798   case ISD::FREEZE:             return visitFREEZE(N);
1799   case ISD::VECREDUCE_FADD:
1800   case ISD::VECREDUCE_FMUL:
1801   case ISD::VECREDUCE_ADD:
1802   case ISD::VECREDUCE_MUL:
1803   case ISD::VECREDUCE_AND:
1804   case ISD::VECREDUCE_OR:
1805   case ISD::VECREDUCE_XOR:
1806   case ISD::VECREDUCE_SMAX:
1807   case ISD::VECREDUCE_SMIN:
1808   case ISD::VECREDUCE_UMAX:
1809   case ISD::VECREDUCE_UMIN:
1810   case ISD::VECREDUCE_FMAX:
1811   case ISD::VECREDUCE_FMIN:     return visitVECREDUCE(N);
1812 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
1813 #include "llvm/IR/VPIntrinsics.def"
1814     return visitVPOp(N);
1815   }
1816   return SDValue();
1817 }
1818 
combine(SDNode * N)1819 SDValue DAGCombiner::combine(SDNode *N) {
1820   SDValue RV;
1821   if (!DisableGenericCombines)
1822     RV = visit(N);
1823 
1824   // If nothing happened, try a target-specific DAG combine.
1825   if (!RV.getNode()) {
1826     assert(N->getOpcode() != ISD::DELETED_NODE &&
1827            "Node was deleted but visit returned NULL!");
1828 
1829     if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
1830         TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
1831 
1832       // Expose the DAG combiner to the target combiner impls.
1833       TargetLowering::DAGCombinerInfo
1834         DagCombineInfo(DAG, Level, false, this);
1835 
1836       RV = TLI.PerformDAGCombine(N, DagCombineInfo);
1837     }
1838   }
1839 
1840   // If nothing happened still, try promoting the operation.
1841   if (!RV.getNode()) {
1842     switch (N->getOpcode()) {
1843     default: break;
1844     case ISD::ADD:
1845     case ISD::SUB:
1846     case ISD::MUL:
1847     case ISD::AND:
1848     case ISD::OR:
1849     case ISD::XOR:
1850       RV = PromoteIntBinOp(SDValue(N, 0));
1851       break;
1852     case ISD::SHL:
1853     case ISD::SRA:
1854     case ISD::SRL:
1855       RV = PromoteIntShiftOp(SDValue(N, 0));
1856       break;
1857     case ISD::SIGN_EXTEND:
1858     case ISD::ZERO_EXTEND:
1859     case ISD::ANY_EXTEND:
1860       RV = PromoteExtend(SDValue(N, 0));
1861       break;
1862     case ISD::LOAD:
1863       if (PromoteLoad(SDValue(N, 0)))
1864         RV = SDValue(N, 0);
1865       break;
1866     }
1867   }
1868 
1869   // If N is a commutative binary node, try to eliminate it if the commuted
1870   // version is already present in the DAG.
1871   if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode())) {
1872     SDValue N0 = N->getOperand(0);
1873     SDValue N1 = N->getOperand(1);
1874 
1875     // Constant operands are canonicalized to RHS.
1876     if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
1877       SDValue Ops[] = {N1, N0};
1878       SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
1879                                             N->getFlags());
1880       if (CSENode)
1881         return SDValue(CSENode, 0);
1882     }
1883   }
1884 
1885   return RV;
1886 }
1887 
1888 /// Given a node, return its input chain if it has one, otherwise return a null
1889 /// sd operand.
getInputChainForNode(SDNode * N)1890 static SDValue getInputChainForNode(SDNode *N) {
1891   if (unsigned NumOps = N->getNumOperands()) {
1892     if (N->getOperand(0).getValueType() == MVT::Other)
1893       return N->getOperand(0);
1894     if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
1895       return N->getOperand(NumOps-1);
1896     for (unsigned i = 1; i < NumOps-1; ++i)
1897       if (N->getOperand(i).getValueType() == MVT::Other)
1898         return N->getOperand(i);
1899   }
1900   return SDValue();
1901 }
1902 
visitTokenFactor(SDNode * N)1903 SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
1904   // If N has two operands, where one has an input chain equal to the other,
1905   // the 'other' chain is redundant.
1906   if (N->getNumOperands() == 2) {
1907     if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
1908       return N->getOperand(0);
1909     if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
1910       return N->getOperand(1);
1911   }
1912 
1913   // Don't simplify token factors if optnone.
1914   if (OptLevel == CodeGenOpt::None)
1915     return SDValue();
1916 
1917   // Don't simplify the token factor if the node itself has too many operands.
1918   if (N->getNumOperands() > TokenFactorInlineLimit)
1919     return SDValue();
1920 
1921   // If the sole user is a token factor, we should make sure we have a
1922   // chance to merge them together. This prevents TF chains from inhibiting
1923   // optimizations.
1924   if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
1925     AddToWorklist(*(N->use_begin()));
1926 
1927   SmallVector<SDNode *, 8> TFs;     // List of token factors to visit.
1928   SmallVector<SDValue, 8> Ops;      // Ops for replacing token factor.
1929   SmallPtrSet<SDNode*, 16> SeenOps;
1930   bool Changed = false;             // If we should replace this token factor.
1931 
1932   // Start out with this token factor.
1933   TFs.push_back(N);
1934 
1935   // Iterate through token factors.  The TFs grows when new token factors are
1936   // encountered.
1937   for (unsigned i = 0; i < TFs.size(); ++i) {
1938     // Limit number of nodes to inline, to avoid quadratic compile times.
1939     // We have to add the outstanding Token Factors to Ops, otherwise we might
1940     // drop Ops from the resulting Token Factors.
1941     if (Ops.size() > TokenFactorInlineLimit) {
1942       for (unsigned j = i; j < TFs.size(); j++)
1943         Ops.emplace_back(TFs[j], 0);
1944       // Drop unprocessed Token Factors from TFs, so we do not add them to the
1945       // combiner worklist later.
1946       TFs.resize(i);
1947       break;
1948     }
1949 
1950     SDNode *TF = TFs[i];
1951     // Check each of the operands.
1952     for (const SDValue &Op : TF->op_values()) {
1953       switch (Op.getOpcode()) {
1954       case ISD::EntryToken:
1955         // Entry tokens don't need to be added to the list. They are
1956         // redundant.
1957         Changed = true;
1958         break;
1959 
1960       case ISD::TokenFactor:
1961         if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
1962           // Queue up for processing.
1963           TFs.push_back(Op.getNode());
1964           Changed = true;
1965           break;
1966         }
1967         LLVM_FALLTHROUGH;
1968 
1969       default:
1970         // Only add if it isn't already in the list.
1971         if (SeenOps.insert(Op.getNode()).second)
1972           Ops.push_back(Op);
1973         else
1974           Changed = true;
1975         break;
1976       }
1977     }
1978   }
1979 
1980   // Re-visit inlined Token Factors, to clean them up in case they have been
1981   // removed. Skip the first Token Factor, as this is the current node.
1982   for (unsigned i = 1, e = TFs.size(); i < e; i++)
1983     AddToWorklist(TFs[i]);
1984 
1985   // Remove Nodes that are chained to another node in the list. Do so
1986   // by walking up chains breath-first stopping when we've seen
1987   // another operand. In general we must climb to the EntryNode, but we can exit
1988   // early if we find all remaining work is associated with just one operand as
1989   // no further pruning is possible.
1990 
1991   // List of nodes to search through and original Ops from which they originate.
1992   SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
1993   SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
1994   SmallPtrSet<SDNode *, 16> SeenChains;
1995   bool DidPruneOps = false;
1996 
1997   unsigned NumLeftToConsider = 0;
1998   for (const SDValue &Op : Ops) {
1999     Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
2000     OpWorkCount.push_back(1);
2001   }
2002 
2003   auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2004     // If this is an Op, we can remove the op from the list. Remark any
2005     // search associated with it as from the current OpNumber.
2006     if (SeenOps.contains(Op)) {
2007       Changed = true;
2008       DidPruneOps = true;
2009       unsigned OrigOpNumber = 0;
2010       while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2011         OrigOpNumber++;
2012       assert((OrigOpNumber != Ops.size()) &&
2013              "expected to find TokenFactor Operand");
2014       // Re-mark worklist from OrigOpNumber to OpNumber
2015       for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2016         if (Worklist[i].second == OrigOpNumber) {
2017           Worklist[i].second = OpNumber;
2018         }
2019       }
2020       OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2021       OpWorkCount[OrigOpNumber] = 0;
2022       NumLeftToConsider--;
2023     }
2024     // Add if it's a new chain
2025     if (SeenChains.insert(Op).second) {
2026       OpWorkCount[OpNumber]++;
2027       Worklist.push_back(std::make_pair(Op, OpNumber));
2028     }
2029   };
2030 
2031   for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2032     // We need at least be consider at least 2 Ops to prune.
2033     if (NumLeftToConsider <= 1)
2034       break;
2035     auto CurNode = Worklist[i].first;
2036     auto CurOpNumber = Worklist[i].second;
2037     assert((OpWorkCount[CurOpNumber] > 0) &&
2038            "Node should not appear in worklist");
2039     switch (CurNode->getOpcode()) {
2040     case ISD::EntryToken:
2041       // Hitting EntryToken is the only way for the search to terminate without
2042       // hitting
2043       // another operand's search. Prevent us from marking this operand
2044       // considered.
2045       NumLeftToConsider++;
2046       break;
2047     case ISD::TokenFactor:
2048       for (const SDValue &Op : CurNode->op_values())
2049         AddToWorklist(i, Op.getNode(), CurOpNumber);
2050       break;
2051     case ISD::LIFETIME_START:
2052     case ISD::LIFETIME_END:
2053     case ISD::CopyFromReg:
2054     case ISD::CopyToReg:
2055       AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
2056       break;
2057     default:
2058       if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
2059         AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2060       break;
2061     }
2062     OpWorkCount[CurOpNumber]--;
2063     if (OpWorkCount[CurOpNumber] == 0)
2064       NumLeftToConsider--;
2065   }
2066 
2067   // If we've changed things around then replace token factor.
2068   if (Changed) {
2069     SDValue Result;
2070     if (Ops.empty()) {
2071       // The entry token is the only possible outcome.
2072       Result = DAG.getEntryNode();
2073     } else {
2074       if (DidPruneOps) {
2075         SmallVector<SDValue, 8> PrunedOps;
2076         //
2077         for (const SDValue &Op : Ops) {
2078           if (SeenChains.count(Op.getNode()) == 0)
2079             PrunedOps.push_back(Op);
2080         }
2081         Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
2082       } else {
2083         Result = DAG.getTokenFactor(SDLoc(N), Ops);
2084       }
2085     }
2086     return Result;
2087   }
2088   return SDValue();
2089 }
2090 
2091 /// MERGE_VALUES can always be eliminated.
visitMERGE_VALUES(SDNode * N)2092 SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2093   WorklistRemover DeadNodes(*this);
2094   // Replacing results may cause a different MERGE_VALUES to suddenly
2095   // be CSE'd with N, and carry its uses with it. Iterate until no
2096   // uses remain, to ensure that the node can be safely deleted.
2097   // First add the users of this node to the work list so that they
2098   // can be tried again once they have new operands.
2099   AddUsersToWorklist(N);
2100   do {
2101     // Do as a single replacement to avoid rewalking use lists.
2102     SmallVector<SDValue, 8> Ops;
2103     for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2104       Ops.push_back(N->getOperand(i));
2105     DAG.ReplaceAllUsesWith(N, Ops.data());
2106   } while (!N->use_empty());
2107   deleteAndRecombine(N);
2108   return SDValue(N, 0);   // Return N so it doesn't get rechecked!
2109 }
2110 
2111 /// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2112 /// ConstantSDNode pointer else nullptr.
getAsNonOpaqueConstant(SDValue N)2113 static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2114   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
2115   return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2116 }
2117 
2118 /// Return true if 'Use' is a load or a store that uses N as its base pointer
2119 /// and that N may be folded in the load / store addressing mode.
canFoldInAddressingMode(SDNode * N,SDNode * Use,SelectionDAG & DAG,const TargetLowering & TLI)2120 static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2121                                     const TargetLowering &TLI) {
2122   EVT VT;
2123   unsigned AS;
2124 
2125   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
2126     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2127       return false;
2128     VT = LD->getMemoryVT();
2129     AS = LD->getAddressSpace();
2130   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
2131     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2132       return false;
2133     VT = ST->getMemoryVT();
2134     AS = ST->getAddressSpace();
2135   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Use)) {
2136     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2137       return false;
2138     VT = LD->getMemoryVT();
2139     AS = LD->getAddressSpace();
2140   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Use)) {
2141     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2142       return false;
2143     VT = ST->getMemoryVT();
2144     AS = ST->getAddressSpace();
2145   } else {
2146     return false;
2147   }
2148 
2149   TargetLowering::AddrMode AM;
2150   if (N->getOpcode() == ISD::ADD) {
2151     AM.HasBaseReg = true;
2152     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2153     if (Offset)
2154       // [reg +/- imm]
2155       AM.BaseOffs = Offset->getSExtValue();
2156     else
2157       // [reg +/- reg]
2158       AM.Scale = 1;
2159   } else if (N->getOpcode() == ISD::SUB) {
2160     AM.HasBaseReg = true;
2161     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2162     if (Offset)
2163       // [reg +/- imm]
2164       AM.BaseOffs = -Offset->getSExtValue();
2165     else
2166       // [reg +/- reg]
2167       AM.Scale = 1;
2168   } else {
2169     return false;
2170   }
2171 
2172   return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
2173                                    VT.getTypeForEVT(*DAG.getContext()), AS);
2174 }
2175 
2176 /// This inverts a canonicalization in IR that replaces a variable select arm
2177 /// with an identity constant. Codegen improves if we re-use the variable
2178 /// operand rather than load a constant. This can also be converted into a
2179 /// masked vector operation if the target supports it.
foldSelectWithIdentityConstant(SDNode * N,SelectionDAG & DAG,bool ShouldCommuteOperands)2180 static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2181                                               bool ShouldCommuteOperands) {
2182   // Match a select as operand 1. The identity constant that we are looking for
2183   // is only valid as operand 1 of a non-commutative binop.
2184   SDValue N0 = N->getOperand(0);
2185   SDValue N1 = N->getOperand(1);
2186   if (ShouldCommuteOperands)
2187     std::swap(N0, N1);
2188 
2189   // TODO: Should this apply to scalar select too?
2190   if (!N1.hasOneUse() || N1.getOpcode() != ISD::VSELECT)
2191     return SDValue();
2192 
2193   unsigned Opcode = N->getOpcode();
2194   EVT VT = N->getValueType(0);
2195   SDValue Cond = N1.getOperand(0);
2196   SDValue TVal = N1.getOperand(1);
2197   SDValue FVal = N1.getOperand(2);
2198 
2199   // TODO: The cases should match with IR's ConstantExpr::getBinOpIdentity().
2200   // TODO: Target-specific opcodes could be added. Ex: "isCommutativeBinOp()".
2201   // TODO: With fast-math (NSZ), allow the opposite-sign form of zero?
2202   auto isIdentityConstantForOpcode = [](unsigned Opcode, SDValue V) {
2203     if (ConstantFPSDNode *C = isConstOrConstSplatFP(V)) {
2204       switch (Opcode) {
2205       case ISD::FADD: // X + -0.0 --> X
2206         return C->isZero() && C->isNegative();
2207       case ISD::FSUB: // X - 0.0 --> X
2208         return C->isZero() && !C->isNegative();
2209       case ISD::FMUL: // X * 1.0 --> X
2210       case ISD::FDIV: // X / 1.0 --> X
2211         return C->isExactlyValue(1.0);
2212       }
2213     }
2214     if (ConstantSDNode *C = isConstOrConstSplat(V)) {
2215       switch (Opcode) {
2216       case ISD::ADD: // X + 0 --> X
2217       case ISD::SUB: // X - 0 --> X
2218       case ISD::SHL: // X << 0 --> X
2219       case ISD::SRA: // X s>> 0 --> X
2220       case ISD::SRL: // X u>> 0 --> X
2221         return C->isZero();
2222       case ISD::MUL: // X * 1 --> X
2223         return C->isOne();
2224       }
2225     }
2226     return false;
2227   };
2228 
2229   // This transform increases uses of N0, so freeze it to be safe.
2230   // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2231   if (isIdentityConstantForOpcode(Opcode, TVal)) {
2232     SDValue F0 = DAG.getFreeze(N0);
2233     SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
2234     return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
2235   }
2236   // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2237   if (isIdentityConstantForOpcode(Opcode, FVal)) {
2238     SDValue F0 = DAG.getFreeze(N0);
2239     SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
2240     return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
2241   }
2242 
2243   return SDValue();
2244 }
2245 
foldBinOpIntoSelect(SDNode * BO)2246 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2247   assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2248          "Unexpected binary operator");
2249 
2250   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2251   auto BinOpcode = BO->getOpcode();
2252   EVT VT = BO->getValueType(0);
2253   if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
2254     if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2255       return Sel;
2256 
2257     if (TLI.isCommutativeBinOp(BO->getOpcode()))
2258       if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2259         return Sel;
2260   }
2261 
2262   // Don't do this unless the old select is going away. We want to eliminate the
2263   // binary operator, not replace a binop with a select.
2264   // TODO: Handle ISD::SELECT_CC.
2265   unsigned SelOpNo = 0;
2266   SDValue Sel = BO->getOperand(0);
2267   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2268     SelOpNo = 1;
2269     Sel = BO->getOperand(1);
2270   }
2271 
2272   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2273     return SDValue();
2274 
2275   SDValue CT = Sel.getOperand(1);
2276   if (!isConstantOrConstantVector(CT, true) &&
2277       !DAG.isConstantFPBuildVectorOrConstantFP(CT))
2278     return SDValue();
2279 
2280   SDValue CF = Sel.getOperand(2);
2281   if (!isConstantOrConstantVector(CF, true) &&
2282       !DAG.isConstantFPBuildVectorOrConstantFP(CF))
2283     return SDValue();
2284 
2285   // Bail out if any constants are opaque because we can't constant fold those.
2286   // The exception is "and" and "or" with either 0 or -1 in which case we can
2287   // propagate non constant operands into select. I.e.:
2288   // and (select Cond, 0, -1), X --> select Cond, 0, X
2289   // or X, (select Cond, -1, 0) --> select Cond, -1, X
2290   bool CanFoldNonConst =
2291       (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2292       (isNullOrNullSplat(CT) || isAllOnesOrAllOnesSplat(CT)) &&
2293       (isNullOrNullSplat(CF) || isAllOnesOrAllOnesSplat(CF));
2294 
2295   SDValue CBO = BO->getOperand(SelOpNo ^ 1);
2296   if (!CanFoldNonConst &&
2297       !isConstantOrConstantVector(CBO, true) &&
2298       !DAG.isConstantFPBuildVectorOrConstantFP(CBO))
2299     return SDValue();
2300 
2301   // We have a select-of-constants followed by a binary operator with a
2302   // constant. Eliminate the binop by pulling the constant math into the select.
2303   // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
2304   SDLoc DL(Sel);
2305   SDValue NewCT = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CT)
2306                           : DAG.getNode(BinOpcode, DL, VT, CT, CBO);
2307   if (!CanFoldNonConst && !NewCT.isUndef() &&
2308       !isConstantOrConstantVector(NewCT, true) &&
2309       !DAG.isConstantFPBuildVectorOrConstantFP(NewCT))
2310     return SDValue();
2311 
2312   SDValue NewCF = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CF)
2313                           : DAG.getNode(BinOpcode, DL, VT, CF, CBO);
2314   if (!CanFoldNonConst && !NewCF.isUndef() &&
2315       !isConstantOrConstantVector(NewCF, true) &&
2316       !DAG.isConstantFPBuildVectorOrConstantFP(NewCF))
2317     return SDValue();
2318 
2319   SDValue SelectOp = DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
2320   SelectOp->setFlags(BO->getFlags());
2321   return SelectOp;
2322 }
2323 
foldAddSubBoolOfMaskedVal(SDNode * N,SelectionDAG & DAG)2324 static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) {
2325   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2326          "Expecting add or sub");
2327 
2328   // Match a constant operand and a zext operand for the math instruction:
2329   // add Z, C
2330   // sub C, Z
2331   bool IsAdd = N->getOpcode() == ISD::ADD;
2332   SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2333   SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2334   auto *CN = dyn_cast<ConstantSDNode>(C);
2335   if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2336     return SDValue();
2337 
2338   // Match the zext operand as a setcc of a boolean.
2339   if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
2340       Z.getOperand(0).getValueType() != MVT::i1)
2341     return SDValue();
2342 
2343   // Match the compare as: setcc (X & 1), 0, eq.
2344   SDValue SetCC = Z.getOperand(0);
2345   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
2346   if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
2347       SetCC.getOperand(0).getOpcode() != ISD::AND ||
2348       !isOneConstant(SetCC.getOperand(0).getOperand(1)))
2349     return SDValue();
2350 
2351   // We are adding/subtracting a constant and an inverted low bit. Turn that
2352   // into a subtract/add of the low bit with incremented/decremented constant:
2353   // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2354   // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2355   EVT VT = C.getValueType();
2356   SDLoc DL(N);
2357   SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
2358   SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
2359                        DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2360   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2361 }
2362 
2363 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2364 /// a shift and add with a different constant.
foldAddSubOfSignBit(SDNode * N,SelectionDAG & DAG)2365 static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
2366   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2367          "Expecting add or sub");
2368 
2369   // We need a constant operand for the add/sub, and the other operand is a
2370   // logical shift right: add (srl), C or sub C, (srl).
2371   bool IsAdd = N->getOpcode() == ISD::ADD;
2372   SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2373   SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2374   if (!DAG.isConstantIntBuildVectorOrConstantInt(ConstantOp) ||
2375       ShiftOp.getOpcode() != ISD::SRL)
2376     return SDValue();
2377 
2378   // The shift must be of a 'not' value.
2379   SDValue Not = ShiftOp.getOperand(0);
2380   if (!Not.hasOneUse() || !isBitwiseNot(Not))
2381     return SDValue();
2382 
2383   // The shift must be moving the sign bit to the least-significant-bit.
2384   EVT VT = ShiftOp.getValueType();
2385   SDValue ShAmt = ShiftOp.getOperand(1);
2386   ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2387   if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2388     return SDValue();
2389 
2390   // Eliminate the 'not' by adjusting the shift and add/sub constant:
2391   // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2392   // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2393   SDLoc DL(N);
2394   if (SDValue NewC = DAG.FoldConstantArithmetic(
2395           IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2396           {ConstantOp, DAG.getConstant(1, DL, VT)})) {
2397     SDValue NewShift = DAG.getNode(IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2398                                    Not.getOperand(0), ShAmt);
2399     return DAG.getNode(ISD::ADD, DL, VT, NewShift, NewC);
2400   }
2401 
2402   return SDValue();
2403 }
2404 
isADDLike(SDValue V,const SelectionDAG & DAG)2405 static bool isADDLike(SDValue V, const SelectionDAG &DAG) {
2406   unsigned Opcode = V.getOpcode();
2407   if (Opcode == ISD::OR)
2408     return DAG.haveNoCommonBitsSet(V.getOperand(0), V.getOperand(1));
2409   if (Opcode == ISD::XOR)
2410     return isMinSignedConstant(V.getOperand(1));
2411   return false;
2412 }
2413 
2414 /// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2415 /// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2416 /// are no common bits set in the operands).
visitADDLike(SDNode * N)2417 SDValue DAGCombiner::visitADDLike(SDNode *N) {
2418   SDValue N0 = N->getOperand(0);
2419   SDValue N1 = N->getOperand(1);
2420   EVT VT = N0.getValueType();
2421   SDLoc DL(N);
2422 
2423   // fold (add x, undef) -> undef
2424   if (N0.isUndef())
2425     return N0;
2426   if (N1.isUndef())
2427     return N1;
2428 
2429   // fold (add c1, c2) -> c1+c2
2430   if (SDValue C = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1}))
2431     return C;
2432 
2433   // canonicalize constant to RHS
2434   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2435       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2436     return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2437 
2438   // fold vector ops
2439   if (VT.isVector()) {
2440     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2441       return FoldedVOp;
2442 
2443     // fold (add x, 0) -> x, vector edition
2444     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2445       return N0;
2446   }
2447 
2448   // fold (add x, 0) -> x
2449   if (isNullConstant(N1))
2450     return N0;
2451 
2452   if (N0.getOpcode() == ISD::SUB) {
2453     SDValue N00 = N0.getOperand(0);
2454     SDValue N01 = N0.getOperand(1);
2455 
2456     // fold ((A-c1)+c2) -> (A+(c2-c1))
2457     if (SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N01}))
2458       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2459 
2460     // fold ((c1-A)+c2) -> (c1+c2)-A
2461     if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N00}))
2462       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2463   }
2464 
2465   // add (sext i1 X), 1 -> zext (not i1 X)
2466   // We don't transform this pattern:
2467   //   add (zext i1 X), -1 -> sext (not i1 X)
2468   // because most (?) targets generate better code for the zext form.
2469   if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2470       isOneOrOneSplat(N1)) {
2471     SDValue X = N0.getOperand(0);
2472     if ((!LegalOperations ||
2473          (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2474           TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
2475         X.getScalarValueSizeInBits() == 1) {
2476       SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2477       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2478     }
2479   }
2480 
2481   // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2482   // iff (or x, c0) is equivalent to (add x, c0).
2483   // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2484   // iff (xor x, c0) is equivalent to (add x, c0).
2485   if (isADDLike(N0, DAG)) {
2486     SDValue N01 = N0.getOperand(1);
2487     if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N01}))
2488       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add);
2489   }
2490 
2491   if (SDValue NewSel = foldBinOpIntoSelect(N))
2492     return NewSel;
2493 
2494   // reassociate add
2495   if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N, N0, N1)) {
2496     if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2497       return RADD;
2498 
2499     // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2500     // equivalent to (add x, c).
2501     // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2502     // equivalent to (add x, c).
2503     auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2504       if (isADDLike(N0, DAG) && N0.hasOneUse() &&
2505           isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2506         return DAG.getNode(ISD::ADD, DL, VT,
2507                            DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
2508                            N0.getOperand(1));
2509       }
2510       return SDValue();
2511     };
2512     if (SDValue Add = ReassociateAddOr(N0, N1))
2513       return Add;
2514     if (SDValue Add = ReassociateAddOr(N1, N0))
2515       return Add;
2516   }
2517   // fold ((0-A) + B) -> B-A
2518   if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
2519     return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2520 
2521   // fold (A + (0-B)) -> A-B
2522   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
2523     return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
2524 
2525   // fold (A+(B-A)) -> B
2526   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
2527     return N1.getOperand(0);
2528 
2529   // fold ((B-A)+A) -> B
2530   if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
2531     return N0.getOperand(0);
2532 
2533   // fold ((A-B)+(C-A)) -> (C-B)
2534   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2535       N0.getOperand(0) == N1.getOperand(1))
2536     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2537                        N0.getOperand(1));
2538 
2539   // fold ((A-B)+(B-C)) -> (A-C)
2540   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2541       N0.getOperand(1) == N1.getOperand(0))
2542     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
2543                        N1.getOperand(1));
2544 
2545   // fold (A+(B-(A+C))) to (B-C)
2546   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2547       N0 == N1.getOperand(1).getOperand(0))
2548     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2549                        N1.getOperand(1).getOperand(1));
2550 
2551   // fold (A+(B-(C+A))) to (B-C)
2552   if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2553       N0 == N1.getOperand(1).getOperand(1))
2554     return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2555                        N1.getOperand(1).getOperand(0));
2556 
2557   // fold (A+((B-A)+or-C)) to (B+or-C)
2558   if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
2559       N1.getOperand(0).getOpcode() == ISD::SUB &&
2560       N0 == N1.getOperand(0).getOperand(1))
2561     return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
2562                        N1.getOperand(1));
2563 
2564   // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2565   if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2566       N0->hasOneUse() && N1->hasOneUse()) {
2567     SDValue N00 = N0.getOperand(0);
2568     SDValue N01 = N0.getOperand(1);
2569     SDValue N10 = N1.getOperand(0);
2570     SDValue N11 = N1.getOperand(1);
2571 
2572     if (isConstantOrConstantVector(N00) || isConstantOrConstantVector(N10))
2573       return DAG.getNode(ISD::SUB, DL, VT,
2574                          DAG.getNode(ISD::ADD, SDLoc(N0), VT, N00, N10),
2575                          DAG.getNode(ISD::ADD, SDLoc(N1), VT, N01, N11));
2576   }
2577 
2578   // fold (add (umax X, C), -C) --> (usubsat X, C)
2579   if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
2580     auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2581       return (!Max && !Op) ||
2582              (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2583     };
2584     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
2585                                   /*AllowUndefs*/ true))
2586       return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
2587                          N0.getOperand(1));
2588   }
2589 
2590   if (SimplifyDemandedBits(SDValue(N, 0)))
2591     return SDValue(N, 0);
2592 
2593   if (isOneOrOneSplat(N1)) {
2594     // fold (add (xor a, -1), 1) -> (sub 0, a)
2595     if (isBitwiseNot(N0))
2596       return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
2597                          N0.getOperand(0));
2598 
2599     // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2600     if (N0.getOpcode() == ISD::ADD) {
2601       SDValue A, Xor;
2602 
2603       if (isBitwiseNot(N0.getOperand(0))) {
2604         A = N0.getOperand(1);
2605         Xor = N0.getOperand(0);
2606       } else if (isBitwiseNot(N0.getOperand(1))) {
2607         A = N0.getOperand(0);
2608         Xor = N0.getOperand(1);
2609       }
2610 
2611       if (Xor)
2612         return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
2613     }
2614 
2615     // Look for:
2616     //   add (add x, y), 1
2617     // And if the target does not like this form then turn into:
2618     //   sub y, (xor x, -1)
2619     if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2620         N0.hasOneUse()) {
2621       SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2622                                 DAG.getAllOnesConstant(DL, VT));
2623       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
2624     }
2625   }
2626 
2627   // (x - y) + -1  ->  add (xor y, -1), x
2628   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
2629       isAllOnesOrAllOnesSplat(N1)) {
2630     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1), N1);
2631     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
2632   }
2633 
2634   if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
2635     return Combined;
2636 
2637   if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
2638     return Combined;
2639 
2640   return SDValue();
2641 }
2642 
visitADD(SDNode * N)2643 SDValue DAGCombiner::visitADD(SDNode *N) {
2644   SDValue N0 = N->getOperand(0);
2645   SDValue N1 = N->getOperand(1);
2646   EVT VT = N0.getValueType();
2647   SDLoc DL(N);
2648 
2649   if (SDValue Combined = visitADDLike(N))
2650     return Combined;
2651 
2652   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
2653     return V;
2654 
2655   if (SDValue V = foldAddSubOfSignBit(N, DAG))
2656     return V;
2657 
2658   // fold (a+b) -> (a|b) iff a and b share no bits.
2659   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
2660       DAG.haveNoCommonBitsSet(N0, N1))
2661     return DAG.getNode(ISD::OR, DL, VT, N0, N1);
2662 
2663   // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
2664   if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
2665     const APInt &C0 = N0->getConstantOperandAPInt(0);
2666     const APInt &C1 = N1->getConstantOperandAPInt(0);
2667     return DAG.getVScale(DL, VT, C0 + C1);
2668   }
2669 
2670   // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
2671   if ((N0.getOpcode() == ISD::ADD) &&
2672       (N0.getOperand(1).getOpcode() == ISD::VSCALE) &&
2673       (N1.getOpcode() == ISD::VSCALE)) {
2674     const APInt &VS0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2675     const APInt &VS1 = N1->getConstantOperandAPInt(0);
2676     SDValue VS = DAG.getVScale(DL, VT, VS0 + VS1);
2677     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
2678   }
2679 
2680   // Fold (add step_vector(c1), step_vector(c2)  to step_vector(c1+c2))
2681   if (N0.getOpcode() == ISD::STEP_VECTOR &&
2682       N1.getOpcode() == ISD::STEP_VECTOR) {
2683     const APInt &C0 = N0->getConstantOperandAPInt(0);
2684     const APInt &C1 = N1->getConstantOperandAPInt(0);
2685     APInt NewStep = C0 + C1;
2686     return DAG.getStepVector(DL, VT, NewStep);
2687   }
2688 
2689   // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
2690   if ((N0.getOpcode() == ISD::ADD) &&
2691       (N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR) &&
2692       (N1.getOpcode() == ISD::STEP_VECTOR)) {
2693     const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2694     const APInt &SV1 = N1->getConstantOperandAPInt(0);
2695     APInt NewStep = SV0 + SV1;
2696     SDValue SV = DAG.getStepVector(DL, VT, NewStep);
2697     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
2698   }
2699 
2700   return SDValue();
2701 }
2702 
visitADDSAT(SDNode * N)2703 SDValue DAGCombiner::visitADDSAT(SDNode *N) {
2704   unsigned Opcode = N->getOpcode();
2705   SDValue N0 = N->getOperand(0);
2706   SDValue N1 = N->getOperand(1);
2707   EVT VT = N0.getValueType();
2708   SDLoc DL(N);
2709 
2710   // fold (add_sat x, undef) -> -1
2711   if (N0.isUndef() || N1.isUndef())
2712     return DAG.getAllOnesConstant(DL, VT);
2713 
2714   // fold (add_sat c1, c2) -> c3
2715   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
2716     return C;
2717 
2718   // canonicalize constant to RHS
2719   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2720       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2721     return DAG.getNode(Opcode, DL, VT, N1, N0);
2722 
2723   // fold vector ops
2724   if (VT.isVector()) {
2725     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2726       return FoldedVOp;
2727 
2728     // fold (add_sat x, 0) -> x, vector edition
2729     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2730       return N0;
2731   }
2732 
2733   // fold (add_sat x, 0) -> x
2734   if (isNullConstant(N1))
2735     return N0;
2736 
2737   // If it cannot overflow, transform into an add.
2738   if (Opcode == ISD::UADDSAT)
2739     if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2740       return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
2741 
2742   return SDValue();
2743 }
2744 
getAsCarry(const TargetLowering & TLI,SDValue V)2745 static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) {
2746   bool Masked = false;
2747 
2748   // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
2749   while (true) {
2750     if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
2751       V = V.getOperand(0);
2752       continue;
2753     }
2754 
2755     if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
2756       Masked = true;
2757       V = V.getOperand(0);
2758       continue;
2759     }
2760 
2761     break;
2762   }
2763 
2764   // If this is not a carry, return.
2765   if (V.getResNo() != 1)
2766     return SDValue();
2767 
2768   if (V.getOpcode() != ISD::ADDCARRY && V.getOpcode() != ISD::SUBCARRY &&
2769       V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
2770     return SDValue();
2771 
2772   EVT VT = V->getValueType(0);
2773   if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
2774     return SDValue();
2775 
2776   // If the result is masked, then no matter what kind of bool it is we can
2777   // return. If it isn't, then we need to make sure the bool type is either 0 or
2778   // 1 and not other values.
2779   if (Masked ||
2780       TLI.getBooleanContents(V.getValueType()) ==
2781           TargetLoweringBase::ZeroOrOneBooleanContent)
2782     return V;
2783 
2784   return SDValue();
2785 }
2786 
2787 /// Given the operands of an add/sub operation, see if the 2nd operand is a
2788 /// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
2789 /// the opcode and bypass the mask operation.
foldAddSubMasked1(bool IsAdd,SDValue N0,SDValue N1,SelectionDAG & DAG,const SDLoc & DL)2790 static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
2791                                  SelectionDAG &DAG, const SDLoc &DL) {
2792   if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
2793     return SDValue();
2794 
2795   EVT VT = N0.getValueType();
2796   if (DAG.ComputeNumSignBits(N1.getOperand(0)) != VT.getScalarSizeInBits())
2797     return SDValue();
2798 
2799   // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
2800   // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
2801   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N1.getOperand(0));
2802 }
2803 
2804 /// Helper for doing combines based on N0 and N1 being added to each other.
visitADDLikeCommutative(SDValue N0,SDValue N1,SDNode * LocReference)2805 SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
2806                                           SDNode *LocReference) {
2807   EVT VT = N0.getValueType();
2808   SDLoc DL(LocReference);
2809 
2810   // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
2811   if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB &&
2812       isNullOrNullSplat(N1.getOperand(0).getOperand(0)))
2813     return DAG.getNode(ISD::SUB, DL, VT, N0,
2814                        DAG.getNode(ISD::SHL, DL, VT,
2815                                    N1.getOperand(0).getOperand(1),
2816                                    N1.getOperand(1)));
2817 
2818   if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
2819     return V;
2820 
2821   // Look for:
2822   //   add (add x, 1), y
2823   // And if the target does not like this form then turn into:
2824   //   sub y, (xor x, -1)
2825   if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2826       N0.hasOneUse() && isOneOrOneSplat(N0.getOperand(1))) {
2827     SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2828                               DAG.getAllOnesConstant(DL, VT));
2829     return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
2830   }
2831 
2832   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
2833     // Hoist one-use subtraction by non-opaque constant:
2834     //   (x - C) + y  ->  (x + y) - C
2835     // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
2836     if (isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
2837       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
2838       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2839     }
2840     // Hoist one-use subtraction from non-opaque constant:
2841     //   (C - x) + y  ->  (y - x) + C
2842     if (isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
2843       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2844       return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
2845     }
2846   }
2847 
2848   // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
2849   // rather than 'add 0/-1' (the zext should get folded).
2850   // add (sext i1 Y), X --> sub X, (zext i1 Y)
2851   if (N0.getOpcode() == ISD::SIGN_EXTEND &&
2852       N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
2853       TLI.getBooleanContents(VT) == TargetLowering::ZeroOrOneBooleanContent) {
2854     SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
2855     return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
2856   }
2857 
2858   // add X, (sextinreg Y i1) -> sub X, (and Y 1)
2859   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
2860     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
2861     if (TN->getVT() == MVT::i1) {
2862       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
2863                                  DAG.getConstant(1, DL, VT));
2864       return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
2865     }
2866   }
2867 
2868   // (add X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
2869   if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1)) &&
2870       N1.getResNo() == 0)
2871     return DAG.getNode(ISD::ADDCARRY, DL, N1->getVTList(),
2872                        N0, N1.getOperand(0), N1.getOperand(2));
2873 
2874   // (add X, Carry) -> (addcarry X, 0, Carry)
2875   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
2876     if (SDValue Carry = getAsCarry(TLI, N1))
2877       return DAG.getNode(ISD::ADDCARRY, DL,
2878                          DAG.getVTList(VT, Carry.getValueType()), N0,
2879                          DAG.getConstant(0, DL, VT), Carry);
2880 
2881   return SDValue();
2882 }
2883 
visitADDC(SDNode * N)2884 SDValue DAGCombiner::visitADDC(SDNode *N) {
2885   SDValue N0 = N->getOperand(0);
2886   SDValue N1 = N->getOperand(1);
2887   EVT VT = N0.getValueType();
2888   SDLoc DL(N);
2889 
2890   // If the flag result is dead, turn this into an ADD.
2891   if (!N->hasAnyUseOfValue(1))
2892     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2893                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2894 
2895   // canonicalize constant to RHS.
2896   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2897   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2898   if (N0C && !N1C)
2899     return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
2900 
2901   // fold (addc x, 0) -> x + no carry out
2902   if (isNullConstant(N1))
2903     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
2904                                         DL, MVT::Glue));
2905 
2906   // If it cannot overflow, transform into an add.
2907   if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2908     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2909                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2910 
2911   return SDValue();
2912 }
2913 
2914 /**
2915  * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
2916  * then the flip also occurs if computing the inverse is the same cost.
2917  * This function returns an empty SDValue in case it cannot flip the boolean
2918  * without increasing the cost of the computation. If you want to flip a boolean
2919  * no matter what, use DAG.getLogicalNOT.
2920  */
extractBooleanFlip(SDValue V,SelectionDAG & DAG,const TargetLowering & TLI,bool Force)2921 static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
2922                                   const TargetLowering &TLI,
2923                                   bool Force) {
2924   if (Force && isa<ConstantSDNode>(V))
2925     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
2926 
2927   if (V.getOpcode() != ISD::XOR)
2928     return SDValue();
2929 
2930   ConstantSDNode *Const = isConstOrConstSplat(V.getOperand(1), false);
2931   if (!Const)
2932     return SDValue();
2933 
2934   EVT VT = V.getValueType();
2935 
2936   bool IsFlip = false;
2937   switch(TLI.getBooleanContents(VT)) {
2938     case TargetLowering::ZeroOrOneBooleanContent:
2939       IsFlip = Const->isOne();
2940       break;
2941     case TargetLowering::ZeroOrNegativeOneBooleanContent:
2942       IsFlip = Const->isAllOnes();
2943       break;
2944     case TargetLowering::UndefinedBooleanContent:
2945       IsFlip = (Const->getAPIntValue() & 0x01) == 1;
2946       break;
2947   }
2948 
2949   if (IsFlip)
2950     return V.getOperand(0);
2951   if (Force)
2952     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
2953   return SDValue();
2954 }
2955 
visitADDO(SDNode * N)2956 SDValue DAGCombiner::visitADDO(SDNode *N) {
2957   SDValue N0 = N->getOperand(0);
2958   SDValue N1 = N->getOperand(1);
2959   EVT VT = N0.getValueType();
2960   bool IsSigned = (ISD::SADDO == N->getOpcode());
2961 
2962   EVT CarryVT = N->getValueType(1);
2963   SDLoc DL(N);
2964 
2965   // If the flag result is dead, turn this into an ADD.
2966   if (!N->hasAnyUseOfValue(1))
2967     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2968                      DAG.getUNDEF(CarryVT));
2969 
2970   // canonicalize constant to RHS.
2971   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2972       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2973     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
2974 
2975   // fold (addo x, 0) -> x + no carry out
2976   if (isNullOrNullSplat(N1))
2977     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
2978 
2979   if (!IsSigned) {
2980     // If it cannot overflow, transform into an add.
2981     if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2982       return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2983                        DAG.getConstant(0, DL, CarryVT));
2984 
2985     // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
2986     if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
2987       SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
2988                                 DAG.getConstant(0, DL, VT), N0.getOperand(0));
2989       return CombineTo(
2990           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
2991     }
2992 
2993     if (SDValue Combined = visitUADDOLike(N0, N1, N))
2994       return Combined;
2995 
2996     if (SDValue Combined = visitUADDOLike(N1, N0, N))
2997       return Combined;
2998   }
2999 
3000   return SDValue();
3001 }
3002 
visitUADDOLike(SDValue N0,SDValue N1,SDNode * N)3003 SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3004   EVT VT = N0.getValueType();
3005   if (VT.isVector())
3006     return SDValue();
3007 
3008   // (uaddo X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
3009   // If Y + 1 cannot overflow.
3010   if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1))) {
3011     SDValue Y = N1.getOperand(0);
3012     SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
3013     if (DAG.computeOverflowKind(Y, One) == SelectionDAG::OFK_Never)
3014       return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, Y,
3015                          N1.getOperand(2));
3016   }
3017 
3018   // (uaddo X, Carry) -> (addcarry X, 0, Carry)
3019   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
3020     if (SDValue Carry = getAsCarry(TLI, N1))
3021       return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0,
3022                          DAG.getConstant(0, SDLoc(N), VT), Carry);
3023 
3024   return SDValue();
3025 }
3026 
visitADDE(SDNode * N)3027 SDValue DAGCombiner::visitADDE(SDNode *N) {
3028   SDValue N0 = N->getOperand(0);
3029   SDValue N1 = N->getOperand(1);
3030   SDValue CarryIn = N->getOperand(2);
3031 
3032   // canonicalize constant to RHS
3033   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3034   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3035   if (N0C && !N1C)
3036     return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
3037                        N1, N0, CarryIn);
3038 
3039   // fold (adde x, y, false) -> (addc x, y)
3040   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3041     return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
3042 
3043   return SDValue();
3044 }
3045 
visitADDCARRY(SDNode * N)3046 SDValue DAGCombiner::visitADDCARRY(SDNode *N) {
3047   SDValue N0 = N->getOperand(0);
3048   SDValue N1 = N->getOperand(1);
3049   SDValue CarryIn = N->getOperand(2);
3050   SDLoc DL(N);
3051 
3052   // canonicalize constant to RHS
3053   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3054   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3055   if (N0C && !N1C)
3056     return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn);
3057 
3058   // fold (addcarry x, y, false) -> (uaddo x, y)
3059   if (isNullConstant(CarryIn)) {
3060     if (!LegalOperations ||
3061         TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
3062       return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
3063   }
3064 
3065   // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3066   if (isNullConstant(N0) && isNullConstant(N1)) {
3067     EVT VT = N0.getValueType();
3068     EVT CarryVT = CarryIn.getValueType();
3069     SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
3070     AddToWorklist(CarryExt.getNode());
3071     return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
3072                                     DAG.getConstant(1, DL, VT)),
3073                      DAG.getConstant(0, DL, CarryVT));
3074   }
3075 
3076   if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N))
3077     return Combined;
3078 
3079   if (SDValue Combined = visitADDCARRYLike(N1, N0, CarryIn, N))
3080     return Combined;
3081 
3082   return SDValue();
3083 }
3084 
visitSADDO_CARRY(SDNode * N)3085 SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3086   SDValue N0 = N->getOperand(0);
3087   SDValue N1 = N->getOperand(1);
3088   SDValue CarryIn = N->getOperand(2);
3089   SDLoc DL(N);
3090 
3091   // canonicalize constant to RHS
3092   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3093   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3094   if (N0C && !N1C)
3095     return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3096 
3097   // fold (saddo_carry x, y, false) -> (saddo x, y)
3098   if (isNullConstant(CarryIn)) {
3099     if (!LegalOperations ||
3100         TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0)))
3101       return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1);
3102   }
3103 
3104   return SDValue();
3105 }
3106 
3107 /**
3108  * If we are facing some sort of diamond carry propapagtion pattern try to
3109  * break it up to generate something like:
3110  *   (addcarry X, 0, (addcarry A, B, Z):Carry)
3111  *
3112  * The end result is usually an increase in operation required, but because the
3113  * carry is now linearized, other tranforms can kick in and optimize the DAG.
3114  *
3115  * Patterns typically look something like
3116  *            (uaddo A, B)
3117  *             /       \
3118  *          Carry      Sum
3119  *            |          \
3120  *            | (addcarry *, 0, Z)
3121  *            |       /
3122  *             \   Carry
3123  *              |   /
3124  * (addcarry X, *, *)
3125  *
3126  * But numerous variation exist. Our goal is to identify A, B, X and Z and
3127  * produce a combine with a single path for carry propagation.
3128  */
combineADDCARRYDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,SDValue X,SDValue Carry0,SDValue Carry1,SDNode * N)3129 static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG,
3130                                       SDValue X, SDValue Carry0, SDValue Carry1,
3131                                       SDNode *N) {
3132   if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3133     return SDValue();
3134   if (Carry1.getOpcode() != ISD::UADDO)
3135     return SDValue();
3136 
3137   SDValue Z;
3138 
3139   /**
3140    * First look for a suitable Z. It will present itself in the form of
3141    * (addcarry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3142    */
3143   if (Carry0.getOpcode() == ISD::ADDCARRY &&
3144       isNullConstant(Carry0.getOperand(1))) {
3145     Z = Carry0.getOperand(2);
3146   } else if (Carry0.getOpcode() == ISD::UADDO &&
3147              isOneConstant(Carry0.getOperand(1))) {
3148     EVT VT = Combiner.getSetCCResultType(Carry0.getValueType());
3149     Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
3150   } else {
3151     // We couldn't find a suitable Z.
3152     return SDValue();
3153   }
3154 
3155 
3156   auto cancelDiamond = [&](SDValue A,SDValue B) {
3157     SDLoc DL(N);
3158     SDValue NewY = DAG.getNode(ISD::ADDCARRY, DL, Carry0->getVTList(), A, B, Z);
3159     Combiner.AddToWorklist(NewY.getNode());
3160     return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), X,
3161                        DAG.getConstant(0, DL, X.getValueType()),
3162                        NewY.getValue(1));
3163   };
3164 
3165   /**
3166    *      (uaddo A, B)
3167    *           |
3168    *          Sum
3169    *           |
3170    * (addcarry *, 0, Z)
3171    */
3172   if (Carry0.getOperand(0) == Carry1.getValue(0)) {
3173     return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
3174   }
3175 
3176   /**
3177    * (addcarry A, 0, Z)
3178    *         |
3179    *        Sum
3180    *         |
3181    *  (uaddo *, B)
3182    */
3183   if (Carry1.getOperand(0) == Carry0.getValue(0)) {
3184     return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
3185   }
3186 
3187   if (Carry1.getOperand(1) == Carry0.getValue(0)) {
3188     return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
3189   }
3190 
3191   return SDValue();
3192 }
3193 
3194 // If we are facing some sort of diamond carry/borrow in/out pattern try to
3195 // match patterns like:
3196 //
3197 //          (uaddo A, B)            CarryIn
3198 //            |  \                     |
3199 //            |   \                    |
3200 //    PartialSum   PartialCarryOutX   /
3201 //            |        |             /
3202 //            |    ____|____________/
3203 //            |   /    |
3204 //     (uaddo *, *)    \________
3205 //       |  \                   \
3206 //       |   \                   |
3207 //       |    PartialCarryOutY   |
3208 //       |        \              |
3209 //       |         \            /
3210 //   AddCarrySum    |    ______/
3211 //                  |   /
3212 //   CarryOut = (or *, *)
3213 //
3214 // And generate ADDCARRY (or SUBCARRY) with two result values:
3215 //
3216 //    {AddCarrySum, CarryOut} = (addcarry A, B, CarryIn)
3217 //
3218 // Our goal is to identify A, B, and CarryIn and produce ADDCARRY/SUBCARRY with
3219 // a single path for carry/borrow out propagation:
combineCarryDiamond(SelectionDAG & DAG,const TargetLowering & TLI,SDValue N0,SDValue N1,SDNode * N)3220 static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3221                                    SDValue N0, SDValue N1, SDNode *N) {
3222   SDValue Carry0 = getAsCarry(TLI, N0);
3223   if (!Carry0)
3224     return SDValue();
3225   SDValue Carry1 = getAsCarry(TLI, N1);
3226   if (!Carry1)
3227     return SDValue();
3228 
3229   unsigned Opcode = Carry0.getOpcode();
3230   if (Opcode != Carry1.getOpcode())
3231     return SDValue();
3232   if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3233     return SDValue();
3234 
3235   // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3236   // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3237   if (Carry1.getNode()->isOperandOf(Carry0.getNode()))
3238     std::swap(Carry0, Carry1);
3239 
3240   // Check if nodes are connected in expected way.
3241   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3242       Carry1.getOperand(1) != Carry0.getValue(0))
3243     return SDValue();
3244 
3245   // The carry in value must be on the righthand side for subtraction.
3246   unsigned CarryInOperandNum =
3247       Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
3248   if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3249     return SDValue();
3250   SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
3251 
3252   unsigned NewOp = Opcode == ISD::UADDO ? ISD::ADDCARRY : ISD::SUBCARRY;
3253   if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
3254     return SDValue();
3255 
3256   // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3257   // TODO: make getAsCarry() aware of how partial carries are merged.
3258   if (CarryIn.getOpcode() != ISD::ZERO_EXTEND)
3259     return SDValue();
3260   CarryIn = CarryIn.getOperand(0);
3261   if (CarryIn.getValueType() != MVT::i1)
3262     return SDValue();
3263 
3264   SDLoc DL(N);
3265   SDValue Merged =
3266       DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
3267                   Carry0.getOperand(1), CarryIn);
3268 
3269   // Please note that because we have proven that the result of the UADDO/USUBO
3270   // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3271   // therefore prove that if the first UADDO/USUBO overflows, the second
3272   // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3273   // maximum value.
3274   //
3275   //   0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3276   //   0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3277   //
3278   // This is important because it means that OR and XOR can be used to merge
3279   // carry flags; and that AND can return a constant zero.
3280   //
3281   // TODO: match other operations that can merge flags (ADD, etc)
3282   DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
3283   if (N->getOpcode() == ISD::AND)
3284     return DAG.getConstant(0, DL, MVT::i1);
3285   return Merged.getValue(1);
3286 }
3287 
visitADDCARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3288 SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
3289                                        SDNode *N) {
3290   // fold (addcarry (xor a, -1), b, c) -> (subcarry b, a, !c) and flip carry.
3291   if (isBitwiseNot(N0))
3292     if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
3293       SDLoc DL(N);
3294       SDValue Sub = DAG.getNode(ISD::SUBCARRY, DL, N->getVTList(), N1,
3295                                 N0.getOperand(0), NotC);
3296       return CombineTo(
3297           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3298     }
3299 
3300   // Iff the flag result is dead:
3301   // (addcarry (add|uaddo X, Y), 0, Carry) -> (addcarry X, Y, Carry)
3302   // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3303   // or the dependency between the instructions.
3304   if ((N0.getOpcode() == ISD::ADD ||
3305        (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3306         N0.getValue(1) != CarryIn)) &&
3307       isNullConstant(N1) && !N->hasAnyUseOfValue(1))
3308     return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(),
3309                        N0.getOperand(0), N0.getOperand(1), CarryIn);
3310 
3311   /**
3312    * When one of the addcarry argument is itself a carry, we may be facing
3313    * a diamond carry propagation. In which case we try to transform the DAG
3314    * to ensure linear carry propagation if that is possible.
3315    */
3316   if (auto Y = getAsCarry(TLI, N1)) {
3317     // Because both are carries, Y and Z can be swapped.
3318     if (auto R = combineADDCARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
3319       return R;
3320     if (auto R = combineADDCARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
3321       return R;
3322   }
3323 
3324   return SDValue();
3325 }
3326 
3327 // Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3328 // clamp/truncation if necessary.
getTruncatedUSUBSAT(EVT DstVT,EVT SrcVT,SDValue LHS,SDValue RHS,SelectionDAG & DAG,const SDLoc & DL)3329 static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3330                                    SDValue RHS, SelectionDAG &DAG,
3331                                    const SDLoc &DL) {
3332   assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3333          "Illegal truncation");
3334 
3335   if (DstVT == SrcVT)
3336     return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3337 
3338   // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3339   // clamping RHS.
3340   APInt UpperBits = APInt::getBitsSetFrom(SrcVT.getScalarSizeInBits(),
3341                                           DstVT.getScalarSizeInBits());
3342   if (!DAG.MaskedValueIsZero(LHS, UpperBits))
3343     return SDValue();
3344 
3345   SDValue SatLimit =
3346       DAG.getConstant(APInt::getLowBitsSet(SrcVT.getScalarSizeInBits(),
3347                                            DstVT.getScalarSizeInBits()),
3348                       DL, SrcVT);
3349   RHS = DAG.getNode(ISD::UMIN, DL, SrcVT, RHS, SatLimit);
3350   RHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, RHS);
3351   LHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, LHS);
3352   return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3353 }
3354 
3355 // Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3356 // usubsat(a,b), optionally as a truncated type.
foldSubToUSubSat(EVT DstVT,SDNode * N)3357 SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N) {
3358   if (N->getOpcode() != ISD::SUB ||
3359       !(!LegalOperations || hasOperation(ISD::USUBSAT, DstVT)))
3360     return SDValue();
3361 
3362   EVT SubVT = N->getValueType(0);
3363   SDValue Op0 = N->getOperand(0);
3364   SDValue Op1 = N->getOperand(1);
3365 
3366   // Try to find umax(a,b) - b or a - umin(a,b) patterns
3367   // they may be converted to usubsat(a,b).
3368   if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3369     SDValue MaxLHS = Op0.getOperand(0);
3370     SDValue MaxRHS = Op0.getOperand(1);
3371     if (MaxLHS == Op1)
3372       return getTruncatedUSUBSAT(DstVT, SubVT, MaxRHS, Op1, DAG, SDLoc(N));
3373     if (MaxRHS == Op1)
3374       return getTruncatedUSUBSAT(DstVT, SubVT, MaxLHS, Op1, DAG, SDLoc(N));
3375   }
3376 
3377   if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3378     SDValue MinLHS = Op1.getOperand(0);
3379     SDValue MinRHS = Op1.getOperand(1);
3380     if (MinLHS == Op0)
3381       return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinRHS, DAG, SDLoc(N));
3382     if (MinRHS == Op0)
3383       return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinLHS, DAG, SDLoc(N));
3384   }
3385 
3386   // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3387   if (Op1.getOpcode() == ISD::TRUNCATE &&
3388       Op1.getOperand(0).getOpcode() == ISD::UMIN &&
3389       Op1.getOperand(0).hasOneUse()) {
3390     SDValue MinLHS = Op1.getOperand(0).getOperand(0);
3391     SDValue MinRHS = Op1.getOperand(0).getOperand(1);
3392     if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(0) == Op0)
3393       return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinLHS, MinRHS,
3394                                  DAG, SDLoc(N));
3395     if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(0) == Op0)
3396       return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinRHS, MinLHS,
3397                                  DAG, SDLoc(N));
3398   }
3399 
3400   return SDValue();
3401 }
3402 
3403 // Since it may not be valid to emit a fold to zero for vector initializers
3404 // check if we can before folding.
tryFoldToZero(const SDLoc & DL,const TargetLowering & TLI,EVT VT,SelectionDAG & DAG,bool LegalOperations)3405 static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
3406                              SelectionDAG &DAG, bool LegalOperations) {
3407   if (!VT.isVector())
3408     return DAG.getConstant(0, DL, VT);
3409   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
3410     return DAG.getConstant(0, DL, VT);
3411   return SDValue();
3412 }
3413 
visitSUB(SDNode * N)3414 SDValue DAGCombiner::visitSUB(SDNode *N) {
3415   SDValue N0 = N->getOperand(0);
3416   SDValue N1 = N->getOperand(1);
3417   EVT VT = N0.getValueType();
3418   SDLoc DL(N);
3419 
3420   auto PeekThroughFreeze = [](SDValue N) {
3421     if (N->getOpcode() == ISD::FREEZE && N.hasOneUse())
3422       return N->getOperand(0);
3423     return N;
3424   };
3425 
3426   // fold (sub x, x) -> 0
3427   // FIXME: Refactor this and xor and other similar operations together.
3428   if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
3429     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
3430 
3431   // fold (sub c1, c2) -> c3
3432   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
3433     return C;
3434 
3435   // fold vector ops
3436   if (VT.isVector()) {
3437     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3438       return FoldedVOp;
3439 
3440     // fold (sub x, 0) -> x, vector edition
3441     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3442       return N0;
3443   }
3444 
3445   if (SDValue NewSel = foldBinOpIntoSelect(N))
3446     return NewSel;
3447 
3448   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3449 
3450   // fold (sub x, c) -> (add x, -c)
3451   if (N1C) {
3452     return DAG.getNode(ISD::ADD, DL, VT, N0,
3453                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3454   }
3455 
3456   if (isNullOrNullSplat(N0)) {
3457     unsigned BitWidth = VT.getScalarSizeInBits();
3458     // Right-shifting everything out but the sign bit followed by negation is
3459     // the same as flipping arithmetic/logical shift type without the negation:
3460     // -(X >>u 31) -> (X >>s 31)
3461     // -(X >>s 31) -> (X >>u 31)
3462     if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
3463       ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
3464       if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
3465         auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
3466         if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
3467           return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
3468       }
3469     }
3470 
3471     // 0 - X --> 0 if the sub is NUW.
3472     if (N->getFlags().hasNoUnsignedWrap())
3473       return N0;
3474 
3475     if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
3476       // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3477       // N1 must be 0 because negating the minimum signed value is undefined.
3478       if (N->getFlags().hasNoSignedWrap())
3479         return N0;
3480 
3481       // 0 - X --> X if X is 0 or the minimum signed value.
3482       return N1;
3483     }
3484 
3485     // Convert 0 - abs(x).
3486     if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
3487         !TLI.isOperationLegalOrCustom(ISD::ABS, VT))
3488       if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true))
3489         return Result;
3490 
3491     // Fold neg(splat(neg(x)) -> splat(x)
3492     if (VT.isVector()) {
3493       SDValue N1S = DAG.getSplatValue(N1, true);
3494       if (N1S && N1S.getOpcode() == ISD::SUB &&
3495           isNullConstant(N1S.getOperand(0))) {
3496         if (VT.isScalableVector())
3497           return DAG.getSplatVector(VT, DL, N1S.getOperand(1));
3498         return DAG.getSplatBuildVector(VT, DL, N1S.getOperand(1));
3499       }
3500     }
3501   }
3502 
3503   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3504   if (isAllOnesOrAllOnesSplat(N0))
3505     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3506 
3507   // fold (A - (0-B)) -> A+B
3508   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
3509     return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
3510 
3511   // fold A-(A-B) -> B
3512   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
3513     return N1.getOperand(1);
3514 
3515   // fold (A+B)-A -> B
3516   if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
3517     return N0.getOperand(1);
3518 
3519   // fold (A+B)-B -> A
3520   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
3521     return N0.getOperand(0);
3522 
3523   // fold (A+C1)-C2 -> A+(C1-C2)
3524   if (N0.getOpcode() == ISD::ADD) {
3525     SDValue N01 = N0.getOperand(1);
3526     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N01, N1}))
3527       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
3528   }
3529 
3530   // fold C2-(A+C1) -> (C2-C1)-A
3531   if (N1.getOpcode() == ISD::ADD) {
3532     SDValue N11 = N1.getOperand(1);
3533     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11}))
3534       return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
3535   }
3536 
3537   // fold (A-C1)-C2 -> A-(C1+C2)
3538   if (N0.getOpcode() == ISD::SUB) {
3539     SDValue N01 = N0.getOperand(1);
3540     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N01, N1}))
3541       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
3542   }
3543 
3544   // fold (c1-A)-c2 -> (c1-c2)-A
3545   if (N0.getOpcode() == ISD::SUB) {
3546     SDValue N00 = N0.getOperand(0);
3547     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N00, N1}))
3548       return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
3549   }
3550 
3551   // fold ((A+(B+or-C))-B) -> A+or-C
3552   if (N0.getOpcode() == ISD::ADD &&
3553       (N0.getOperand(1).getOpcode() == ISD::SUB ||
3554        N0.getOperand(1).getOpcode() == ISD::ADD) &&
3555       N0.getOperand(1).getOperand(0) == N1)
3556     return DAG.getNode(N0.getOperand(1).getOpcode(), DL, VT, N0.getOperand(0),
3557                        N0.getOperand(1).getOperand(1));
3558 
3559   // fold ((A+(C+B))-B) -> A+C
3560   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1).getOpcode() == ISD::ADD &&
3561       N0.getOperand(1).getOperand(1) == N1)
3562     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0),
3563                        N0.getOperand(1).getOperand(0));
3564 
3565   // fold ((A-(B-C))-C) -> A-B
3566   if (N0.getOpcode() == ISD::SUB && N0.getOperand(1).getOpcode() == ISD::SUB &&
3567       N0.getOperand(1).getOperand(1) == N1)
3568     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
3569                        N0.getOperand(1).getOperand(0));
3570 
3571   // fold (A-(B-C)) -> A+(C-B)
3572   if (N1.getOpcode() == ISD::SUB && N1.hasOneUse())
3573     return DAG.getNode(ISD::ADD, DL, VT, N0,
3574                        DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1),
3575                                    N1.getOperand(0)));
3576 
3577   // A - (A & B)  ->  A & (~B)
3578   if (N1.getOpcode() == ISD::AND) {
3579     SDValue A = N1.getOperand(0);
3580     SDValue B = N1.getOperand(1);
3581     if (A != N0)
3582       std::swap(A, B);
3583     if (A == N0 &&
3584         (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true))) {
3585       SDValue InvB =
3586           DAG.getNode(ISD::XOR, DL, VT, B, DAG.getAllOnesConstant(DL, VT));
3587       return DAG.getNode(ISD::AND, DL, VT, A, InvB);
3588     }
3589   }
3590 
3591   // fold (X - (-Y * Z)) -> (X + (Y * Z))
3592   if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {
3593     if (N1.getOperand(0).getOpcode() == ISD::SUB &&
3594         isNullOrNullSplat(N1.getOperand(0).getOperand(0))) {
3595       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3596                                 N1.getOperand(0).getOperand(1),
3597                                 N1.getOperand(1));
3598       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3599     }
3600     if (N1.getOperand(1).getOpcode() == ISD::SUB &&
3601         isNullOrNullSplat(N1.getOperand(1).getOperand(0))) {
3602       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3603                                 N1.getOperand(0),
3604                                 N1.getOperand(1).getOperand(1));
3605       return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3606     }
3607   }
3608 
3609   // If either operand of a sub is undef, the result is undef
3610   if (N0.isUndef())
3611     return N0;
3612   if (N1.isUndef())
3613     return N1;
3614 
3615   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
3616     return V;
3617 
3618   if (SDValue V = foldAddSubOfSignBit(N, DAG))
3619     return V;
3620 
3621   if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, SDLoc(N)))
3622     return V;
3623 
3624   if (SDValue V = foldSubToUSubSat(VT, N))
3625     return V;
3626 
3627   // (x - y) - 1  ->  add (xor y, -1), x
3628   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB && isOneOrOneSplat(N1)) {
3629     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1),
3630                               DAG.getAllOnesConstant(DL, VT));
3631     return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
3632   }
3633 
3634   // Look for:
3635   //   sub y, (xor x, -1)
3636   // And if the target does not like this form then turn into:
3637   //   add (add x, y), 1
3638   if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
3639     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
3640     return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
3641   }
3642 
3643   // Hoist one-use addition by non-opaque constant:
3644   //   (x + C) - y  ->  (x - y) + C
3645   if (N0.hasOneUse() && N0.getOpcode() == ISD::ADD &&
3646       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3647     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3648     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
3649   }
3650   // y - (x + C)  ->  (y - x) - C
3651   if (N1.hasOneUse() && N1.getOpcode() == ISD::ADD &&
3652       isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
3653     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
3654     return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
3655   }
3656   // (x - C) - y  ->  (x - y) - C
3657   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3658   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
3659       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3660     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3661     return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
3662   }
3663   // (C - x) - y  ->  C - (x + y)
3664   if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
3665       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3666     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
3667     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
3668   }
3669 
3670   // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
3671   // rather than 'sub 0/1' (the sext should get folded).
3672   // sub X, (zext i1 Y) --> add X, (sext i1 Y)
3673   if (N1.getOpcode() == ISD::ZERO_EXTEND &&
3674       N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
3675       TLI.getBooleanContents(VT) ==
3676           TargetLowering::ZeroOrNegativeOneBooleanContent) {
3677     SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
3678     return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
3679   }
3680 
3681   // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X)
3682   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
3683     if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) {
3684       SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1);
3685       SDValue S0 = N1.getOperand(0);
3686       if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0))
3687         if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1)))
3688           if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
3689             return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0);
3690     }
3691   }
3692 
3693   // If the relocation model supports it, consider symbol offsets.
3694   if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
3695     if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
3696       // fold (sub Sym, c) -> Sym-c
3697       if (N1C && GA->getOpcode() == ISD::GlobalAddress)
3698         return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT,
3699                                     GA->getOffset() -
3700                                         (uint64_t)N1C->getSExtValue());
3701       // fold (sub Sym+c1, Sym+c2) -> c1-c2
3702       if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
3703         if (GA->getGlobal() == GB->getGlobal())
3704           return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
3705                                  DL, VT);
3706     }
3707 
3708   // sub X, (sextinreg Y i1) -> add X, (and Y 1)
3709   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3710     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3711     if (TN->getVT() == MVT::i1) {
3712       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3713                                  DAG.getConstant(1, DL, VT));
3714       return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
3715     }
3716   }
3717 
3718   // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
3719   if (N1.getOpcode() == ISD::VSCALE) {
3720     const APInt &IntVal = N1.getConstantOperandAPInt(0);
3721     return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal));
3722   }
3723 
3724   // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
3725   if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
3726     APInt NewStep = -N1.getConstantOperandAPInt(0);
3727     return DAG.getNode(ISD::ADD, DL, VT, N0,
3728                        DAG.getStepVector(DL, VT, NewStep));
3729   }
3730 
3731   // Prefer an add for more folding potential and possibly better codegen:
3732   // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
3733   if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
3734     SDValue ShAmt = N1.getOperand(1);
3735     ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
3736     if (ShAmtC &&
3737         ShAmtC->getAPIntValue() == (N1.getScalarValueSizeInBits() - 1)) {
3738       SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
3739       return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
3740     }
3741   }
3742 
3743   // As with the previous fold, prefer add for more folding potential.
3744   // Subtracting SMIN/0 is the same as adding SMIN/0:
3745   // N0 - (X << BW-1) --> N0 + (X << BW-1)
3746   if (N1.getOpcode() == ISD::SHL) {
3747     ConstantSDNode *ShlC = isConstOrConstSplat(N1.getOperand(1));
3748     if (ShlC && ShlC->getAPIntValue() == VT.getScalarSizeInBits() - 1)
3749       return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
3750   }
3751 
3752   if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) {
3753     // (sub Carry, X)  ->  (addcarry (sub 0, X), 0, Carry)
3754     if (SDValue Carry = getAsCarry(TLI, N0)) {
3755       SDValue X = N1;
3756       SDValue Zero = DAG.getConstant(0, DL, VT);
3757       SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
3758       return DAG.getNode(ISD::ADDCARRY, DL,
3759                          DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
3760                          Carry);
3761     }
3762   }
3763 
3764   // If there's no chance of borrowing from adjacent bits, then sub is xor:
3765   // sub C0, X --> xor X, C0
3766   if (ConstantSDNode *C0 = isConstOrConstSplat(N0)) {
3767     if (!C0->isOpaque()) {
3768       const APInt &C0Val = C0->getAPIntValue();
3769       const APInt &MaybeOnes = ~DAG.computeKnownBits(N1).Zero;
3770       if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
3771         return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3772     }
3773   }
3774 
3775   return SDValue();
3776 }
3777 
visitSUBSAT(SDNode * N)3778 SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
3779   SDValue N0 = N->getOperand(0);
3780   SDValue N1 = N->getOperand(1);
3781   EVT VT = N0.getValueType();
3782   SDLoc DL(N);
3783 
3784   // fold (sub_sat x, undef) -> 0
3785   if (N0.isUndef() || N1.isUndef())
3786     return DAG.getConstant(0, DL, VT);
3787 
3788   // fold (sub_sat x, x) -> 0
3789   if (N0 == N1)
3790     return DAG.getConstant(0, DL, VT);
3791 
3792   // fold (sub_sat c1, c2) -> c3
3793   if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1}))
3794     return C;
3795 
3796   // fold vector ops
3797   if (VT.isVector()) {
3798     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3799       return FoldedVOp;
3800 
3801     // fold (sub_sat x, 0) -> x, vector edition
3802     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3803       return N0;
3804   }
3805 
3806   // fold (sub_sat x, 0) -> x
3807   if (isNullConstant(N1))
3808     return N0;
3809 
3810   return SDValue();
3811 }
3812 
visitSUBC(SDNode * N)3813 SDValue DAGCombiner::visitSUBC(SDNode *N) {
3814   SDValue N0 = N->getOperand(0);
3815   SDValue N1 = N->getOperand(1);
3816   EVT VT = N0.getValueType();
3817   SDLoc DL(N);
3818 
3819   // If the flag result is dead, turn this into an SUB.
3820   if (!N->hasAnyUseOfValue(1))
3821     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3822                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3823 
3824   // fold (subc x, x) -> 0 + no borrow
3825   if (N0 == N1)
3826     return CombineTo(N, DAG.getConstant(0, DL, VT),
3827                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3828 
3829   // fold (subc x, 0) -> x + no borrow
3830   if (isNullConstant(N1))
3831     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3832 
3833   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3834   if (isAllOnesConstant(N0))
3835     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3836                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3837 
3838   return SDValue();
3839 }
3840 
visitSUBO(SDNode * N)3841 SDValue DAGCombiner::visitSUBO(SDNode *N) {
3842   SDValue N0 = N->getOperand(0);
3843   SDValue N1 = N->getOperand(1);
3844   EVT VT = N0.getValueType();
3845   bool IsSigned = (ISD::SSUBO == N->getOpcode());
3846 
3847   EVT CarryVT = N->getValueType(1);
3848   SDLoc DL(N);
3849 
3850   // If the flag result is dead, turn this into an SUB.
3851   if (!N->hasAnyUseOfValue(1))
3852     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3853                      DAG.getUNDEF(CarryVT));
3854 
3855   // fold (subo x, x) -> 0 + no borrow
3856   if (N0 == N1)
3857     return CombineTo(N, DAG.getConstant(0, DL, VT),
3858                      DAG.getConstant(0, DL, CarryVT));
3859 
3860   ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3861 
3862   // fold (subox, c) -> (addo x, -c)
3863   if (IsSigned && N1C && !N1C->getAPIntValue().isMinSignedValue()) {
3864     return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
3865                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3866   }
3867 
3868   // fold (subo x, 0) -> x + no borrow
3869   if (isNullOrNullSplat(N1))
3870     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3871 
3872   // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3873   if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
3874     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3875                      DAG.getConstant(0, DL, CarryVT));
3876 
3877   return SDValue();
3878 }
3879 
visitSUBE(SDNode * N)3880 SDValue DAGCombiner::visitSUBE(SDNode *N) {
3881   SDValue N0 = N->getOperand(0);
3882   SDValue N1 = N->getOperand(1);
3883   SDValue CarryIn = N->getOperand(2);
3884 
3885   // fold (sube x, y, false) -> (subc x, y)
3886   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3887     return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
3888 
3889   return SDValue();
3890 }
3891 
visitSUBCARRY(SDNode * N)3892 SDValue DAGCombiner::visitSUBCARRY(SDNode *N) {
3893   SDValue N0 = N->getOperand(0);
3894   SDValue N1 = N->getOperand(1);
3895   SDValue CarryIn = N->getOperand(2);
3896 
3897   // fold (subcarry x, y, false) -> (usubo x, y)
3898   if (isNullConstant(CarryIn)) {
3899     if (!LegalOperations ||
3900         TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
3901       return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
3902   }
3903 
3904   return SDValue();
3905 }
3906 
visitSSUBO_CARRY(SDNode * N)3907 SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
3908   SDValue N0 = N->getOperand(0);
3909   SDValue N1 = N->getOperand(1);
3910   SDValue CarryIn = N->getOperand(2);
3911 
3912   // fold (ssubo_carry x, y, false) -> (ssubo x, y)
3913   if (isNullConstant(CarryIn)) {
3914     if (!LegalOperations ||
3915         TLI.isOperationLegalOrCustom(ISD::SSUBO, N->getValueType(0)))
3916       return DAG.getNode(ISD::SSUBO, SDLoc(N), N->getVTList(), N0, N1);
3917   }
3918 
3919   return SDValue();
3920 }
3921 
3922 // Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
3923 // UMULFIXSAT here.
visitMULFIX(SDNode * N)3924 SDValue DAGCombiner::visitMULFIX(SDNode *N) {
3925   SDValue N0 = N->getOperand(0);
3926   SDValue N1 = N->getOperand(1);
3927   SDValue Scale = N->getOperand(2);
3928   EVT VT = N0.getValueType();
3929 
3930   // fold (mulfix x, undef, scale) -> 0
3931   if (N0.isUndef() || N1.isUndef())
3932     return DAG.getConstant(0, SDLoc(N), VT);
3933 
3934   // Canonicalize constant to RHS (vector doesn't have to splat)
3935   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3936      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3937     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
3938 
3939   // fold (mulfix x, 0, scale) -> 0
3940   if (isNullConstant(N1))
3941     return DAG.getConstant(0, SDLoc(N), VT);
3942 
3943   return SDValue();
3944 }
3945 
visitMUL(SDNode * N)3946 SDValue DAGCombiner::visitMUL(SDNode *N) {
3947   SDValue N0 = N->getOperand(0);
3948   SDValue N1 = N->getOperand(1);
3949   EVT VT = N0.getValueType();
3950   SDLoc DL(N);
3951 
3952   // fold (mul x, undef) -> 0
3953   if (N0.isUndef() || N1.isUndef())
3954     return DAG.getConstant(0, DL, VT);
3955 
3956   // fold (mul c1, c2) -> c1*c2
3957   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}))
3958     return C;
3959 
3960   // canonicalize constant to RHS (vector doesn't have to splat)
3961   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3962       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3963     return DAG.getNode(ISD::MUL, DL, VT, N1, N0);
3964 
3965   bool N1IsConst = false;
3966   bool N1IsOpaqueConst = false;
3967   APInt ConstValue1;
3968 
3969   // fold vector ops
3970   if (VT.isVector()) {
3971     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3972       return FoldedVOp;
3973 
3974     N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
3975     assert((!N1IsConst ||
3976             ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
3977            "Splat APInt should be element width");
3978   } else {
3979     N1IsConst = isa<ConstantSDNode>(N1);
3980     if (N1IsConst) {
3981       ConstValue1 = cast<ConstantSDNode>(N1)->getAPIntValue();
3982       N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
3983     }
3984   }
3985 
3986   // fold (mul x, 0) -> 0
3987   if (N1IsConst && ConstValue1.isZero())
3988     return N1;
3989 
3990   // fold (mul x, 1) -> x
3991   if (N1IsConst && ConstValue1.isOne())
3992     return N0;
3993 
3994   if (SDValue NewSel = foldBinOpIntoSelect(N))
3995     return NewSel;
3996 
3997   // fold (mul x, -1) -> 0-x
3998   if (N1IsConst && ConstValue1.isAllOnes())
3999     return DAG.getNode(ISD::SUB, DL, VT,
4000                        DAG.getConstant(0, DL, VT), N0);
4001 
4002   // fold (mul x, (1 << c)) -> x << c
4003   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4004       DAG.isKnownToBeAPowerOfTwo(N1) &&
4005       (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4006     SDValue LogBase2 = BuildLogBase2(N1, DL);
4007     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4008     SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4009     return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
4010   }
4011 
4012   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4013   if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4014     unsigned Log2Val = (-ConstValue1).logBase2();
4015     // FIXME: If the input is something that is easily negated (e.g. a
4016     // single-use add), we should put the negate there.
4017     return DAG.getNode(ISD::SUB, DL, VT,
4018                        DAG.getConstant(0, DL, VT),
4019                        DAG.getNode(ISD::SHL, DL, VT, N0,
4020                             DAG.getConstant(Log2Val, DL,
4021                                       getShiftAmountTy(N0.getValueType()))));
4022   }
4023 
4024   // Try to transform:
4025   // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4026   // mul x, (2^N + 1) --> add (shl x, N), x
4027   // mul x, (2^N - 1) --> sub (shl x, N), x
4028   // Examples: x * 33 --> (x << 5) + x
4029   //           x * 15 --> (x << 4) - x
4030   //           x * -33 --> -((x << 5) + x)
4031   //           x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4032   // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4033   // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4034   // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4035   // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4036   //           x * 0xf800 --> (x << 16) - (x << 11)
4037   //           x * -0x8800 --> -((x << 15) + (x << 11))
4038   //           x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4039   if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4040     // TODO: We could handle more general decomposition of any constant by
4041     //       having the target set a limit on number of ops and making a
4042     //       callback to determine that sequence (similar to sqrt expansion).
4043     unsigned MathOp = ISD::DELETED_NODE;
4044     APInt MulC = ConstValue1.abs();
4045     // The constant `2` should be treated as (2^0 + 1).
4046     unsigned TZeros = MulC == 2 ? 0 : MulC.countTrailingZeros();
4047     MulC.lshrInPlace(TZeros);
4048     if ((MulC - 1).isPowerOf2())
4049       MathOp = ISD::ADD;
4050     else if ((MulC + 1).isPowerOf2())
4051       MathOp = ISD::SUB;
4052 
4053     if (MathOp != ISD::DELETED_NODE) {
4054       unsigned ShAmt =
4055           MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4056       ShAmt += TZeros;
4057       assert(ShAmt < VT.getScalarSizeInBits() &&
4058              "multiply-by-constant generated out of bounds shift");
4059       SDValue Shl =
4060           DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
4061       SDValue R =
4062           TZeros ? DAG.getNode(MathOp, DL, VT, Shl,
4063                                DAG.getNode(ISD::SHL, DL, VT, N0,
4064                                            DAG.getConstant(TZeros, DL, VT)))
4065                  : DAG.getNode(MathOp, DL, VT, Shl, N0);
4066       if (ConstValue1.isNegative())
4067         R = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), R);
4068       return R;
4069     }
4070   }
4071 
4072   // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4073   if (N0.getOpcode() == ISD::SHL) {
4074     SDValue N01 = N0.getOperand(1);
4075     if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
4076       return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
4077   }
4078 
4079   // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4080   // use.
4081   {
4082     SDValue Sh, Y;
4083 
4084     // Check for both (mul (shl X, C), Y)  and  (mul Y, (shl X, C)).
4085     if (N0.getOpcode() == ISD::SHL &&
4086         isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse()) {
4087       Sh = N0; Y = N1;
4088     } else if (N1.getOpcode() == ISD::SHL &&
4089                isConstantOrConstantVector(N1.getOperand(1)) &&
4090                N1->hasOneUse()) {
4091       Sh = N1; Y = N0;
4092     }
4093 
4094     if (Sh.getNode()) {
4095       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4096       return DAG.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4097     }
4098   }
4099 
4100   // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4101   if (DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
4102       N0.getOpcode() == ISD::ADD &&
4103       DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
4104       isMulAddWithConstProfitable(N, N0, N1))
4105     return DAG.getNode(
4106         ISD::ADD, DL, VT,
4107         DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4108         DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4109 
4110   // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4111   if (N0.getOpcode() == ISD::VSCALE)
4112     if (ConstantSDNode *NC1 = isConstOrConstSplat(N1)) {
4113       const APInt &C0 = N0.getConstantOperandAPInt(0);
4114       const APInt &C1 = NC1->getAPIntValue();
4115       return DAG.getVScale(DL, VT, C0 * C1);
4116     }
4117 
4118   // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4119   APInt MulVal;
4120   if (N0.getOpcode() == ISD::STEP_VECTOR)
4121     if (ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
4122       const APInt &C0 = N0.getConstantOperandAPInt(0);
4123       APInt NewStep = C0 * MulVal;
4124       return DAG.getStepVector(DL, VT, NewStep);
4125     }
4126 
4127   // Fold ((mul x, 0/undef) -> 0,
4128   //       (mul x, 1) -> x) -> x)
4129   // -> and(x, mask)
4130   // We can replace vectors with '0' and '1' factors with a clearing mask.
4131   if (VT.isFixedLengthVector()) {
4132     unsigned NumElts = VT.getVectorNumElements();
4133     SmallBitVector ClearMask;
4134     ClearMask.reserve(NumElts);
4135     auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4136       if (!V || V->isZero()) {
4137         ClearMask.push_back(true);
4138         return true;
4139       }
4140       ClearMask.push_back(false);
4141       return V->isOne();
4142     };
4143     if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) &&
4144         ISD::matchUnaryPredicate(N1, IsClearMask, /*AllowUndefs*/ true)) {
4145       assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4146       EVT LegalSVT = N1.getOperand(0).getValueType();
4147       SDValue Zero = DAG.getConstant(0, DL, LegalSVT);
4148       SDValue AllOnes = DAG.getAllOnesConstant(DL, LegalSVT);
4149       SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4150       for (unsigned I = 0; I != NumElts; ++I)
4151         if (ClearMask[I])
4152           Mask[I] = Zero;
4153       return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getBuildVector(VT, DL, Mask));
4154     }
4155   }
4156 
4157   // reassociate mul
4158   if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4159     return RMUL;
4160 
4161   // Simplify the operands using demanded-bits information.
4162   if (SimplifyDemandedBits(SDValue(N, 0)))
4163     return SDValue(N, 0);
4164 
4165   return SDValue();
4166 }
4167 
4168 /// Return true if divmod libcall is available.
isDivRemLibcallAvailable(SDNode * Node,bool isSigned,const TargetLowering & TLI)4169 static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4170                                      const TargetLowering &TLI) {
4171   RTLIB::Libcall LC;
4172   EVT NodeType = Node->getValueType(0);
4173   if (!NodeType.isSimple())
4174     return false;
4175   switch (NodeType.getSimpleVT().SimpleTy) {
4176   default: return false; // No libcall for vector types.
4177   case MVT::i8:   LC= isSigned ? RTLIB::SDIVREM_I8  : RTLIB::UDIVREM_I8;  break;
4178   case MVT::i16:  LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4179   case MVT::i32:  LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4180   case MVT::i64:  LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4181   case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4182   }
4183 
4184   return TLI.getLibcallName(LC) != nullptr;
4185 }
4186 
4187 /// Issue divrem if both quotient and remainder are needed.
useDivRem(SDNode * Node)4188 SDValue DAGCombiner::useDivRem(SDNode *Node) {
4189   if (Node->use_empty())
4190     return SDValue(); // This is a dead node, leave it alone.
4191 
4192   unsigned Opcode = Node->getOpcode();
4193   bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4194   unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4195 
4196   // DivMod lib calls can still work on non-legal types if using lib-calls.
4197   EVT VT = Node->getValueType(0);
4198   if (VT.isVector() || !VT.isInteger())
4199     return SDValue();
4200 
4201   if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
4202     return SDValue();
4203 
4204   // If DIVREM is going to get expanded into a libcall,
4205   // but there is no libcall available, then don't combine.
4206   if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
4207       !isDivRemLibcallAvailable(Node, isSigned, TLI))
4208     return SDValue();
4209 
4210   // If div is legal, it's better to do the normal expansion
4211   unsigned OtherOpcode = 0;
4212   if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4213     OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4214     if (TLI.isOperationLegalOrCustom(Opcode, VT))
4215       return SDValue();
4216   } else {
4217     OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4218     if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
4219       return SDValue();
4220   }
4221 
4222   SDValue Op0 = Node->getOperand(0);
4223   SDValue Op1 = Node->getOperand(1);
4224   SDValue combined;
4225   for (SDNode *User : Op0->uses()) {
4226     if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4227         User->use_empty())
4228       continue;
4229     // Convert the other matching node(s), too;
4230     // otherwise, the DIVREM may get target-legalized into something
4231     // target-specific that we won't be able to recognize.
4232     unsigned UserOpc = User->getOpcode();
4233     if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4234         User->getOperand(0) == Op0 &&
4235         User->getOperand(1) == Op1) {
4236       if (!combined) {
4237         if (UserOpc == OtherOpcode) {
4238           SDVTList VTs = DAG.getVTList(VT, VT);
4239           combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
4240         } else if (UserOpc == DivRemOpc) {
4241           combined = SDValue(User, 0);
4242         } else {
4243           assert(UserOpc == Opcode);
4244           continue;
4245         }
4246       }
4247       if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4248         CombineTo(User, combined);
4249       else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4250         CombineTo(User, combined.getValue(1));
4251     }
4252   }
4253   return combined;
4254 }
4255 
simplifyDivRem(SDNode * N,SelectionDAG & DAG)4256 static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4257   SDValue N0 = N->getOperand(0);
4258   SDValue N1 = N->getOperand(1);
4259   EVT VT = N->getValueType(0);
4260   SDLoc DL(N);
4261 
4262   unsigned Opc = N->getOpcode();
4263   bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4264   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4265 
4266   // X / undef -> undef
4267   // X % undef -> undef
4268   // X / 0 -> undef
4269   // X % 0 -> undef
4270   // NOTE: This includes vectors where any divisor element is zero/undef.
4271   if (DAG.isUndef(Opc, {N0, N1}))
4272     return DAG.getUNDEF(VT);
4273 
4274   // undef / X -> 0
4275   // undef % X -> 0
4276   if (N0.isUndef())
4277     return DAG.getConstant(0, DL, VT);
4278 
4279   // 0 / X -> 0
4280   // 0 % X -> 0
4281   ConstantSDNode *N0C = isConstOrConstSplat(N0);
4282   if (N0C && N0C->isZero())
4283     return N0;
4284 
4285   // X / X -> 1
4286   // X % X -> 0
4287   if (N0 == N1)
4288     return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
4289 
4290   // X / 1 -> X
4291   // X % 1 -> 0
4292   // If this is a boolean op (single-bit element type), we can't have
4293   // division-by-zero or remainder-by-zero, so assume the divisor is 1.
4294   // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
4295   // it's a 1.
4296   if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
4297     return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
4298 
4299   return SDValue();
4300 }
4301 
visitSDIV(SDNode * N)4302 SDValue DAGCombiner::visitSDIV(SDNode *N) {
4303   SDValue N0 = N->getOperand(0);
4304   SDValue N1 = N->getOperand(1);
4305   EVT VT = N->getValueType(0);
4306   EVT CCVT = getSetCCResultType(VT);
4307   SDLoc DL(N);
4308 
4309   // fold (sdiv c1, c2) -> c1/c2
4310   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
4311     return C;
4312 
4313   // fold vector ops
4314   if (VT.isVector())
4315     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4316       return FoldedVOp;
4317 
4318   // fold (sdiv X, -1) -> 0-X
4319   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4320   if (N1C && N1C->isAllOnes())
4321     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
4322 
4323   // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
4324   if (N1C && N1C->getAPIntValue().isMinSignedValue())
4325     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4326                          DAG.getConstant(1, DL, VT),
4327                          DAG.getConstant(0, DL, VT));
4328 
4329   if (SDValue V = simplifyDivRem(N, DAG))
4330     return V;
4331 
4332   if (SDValue NewSel = foldBinOpIntoSelect(N))
4333     return NewSel;
4334 
4335   // If we know the sign bits of both operands are zero, strength reduce to a
4336   // udiv instead.  Handles (X&15) /s 4 -> X&15 >> 2
4337   if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4338     return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
4339 
4340   if (SDValue V = visitSDIVLike(N0, N1, N)) {
4341     // If the corresponding remainder node exists, update its users with
4342     // (Dividend - (Quotient * Divisor).
4343     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
4344                                               { N0, N1 })) {
4345       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4346       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4347       AddToWorklist(Mul.getNode());
4348       AddToWorklist(Sub.getNode());
4349       CombineTo(RemNode, Sub);
4350     }
4351     return V;
4352   }
4353 
4354   // sdiv, srem -> sdivrem
4355   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4356   // true.  Otherwise, we break the simplification logic in visitREM().
4357   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4358   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4359     if (SDValue DivRem = useDivRem(N))
4360         return DivRem;
4361 
4362   return SDValue();
4363 }
4364 
isDivisorPowerOfTwo(SDValue Divisor)4365 static bool isDivisorPowerOfTwo(SDValue Divisor) {
4366   // Helper for determining whether a value is a power-2 constant scalar or a
4367   // vector of such elements.
4368   auto IsPowerOfTwo = [](ConstantSDNode *C) {
4369     if (C->isZero() || C->isOpaque())
4370       return false;
4371     if (C->getAPIntValue().isPowerOf2())
4372       return true;
4373     if (C->getAPIntValue().isNegatedPowerOf2())
4374       return true;
4375     return false;
4376   };
4377 
4378   return ISD::matchUnaryPredicate(Divisor, IsPowerOfTwo);
4379 }
4380 
visitSDIVLike(SDValue N0,SDValue N1,SDNode * N)4381 SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4382   SDLoc DL(N);
4383   EVT VT = N->getValueType(0);
4384   EVT CCVT = getSetCCResultType(VT);
4385   unsigned BitWidth = VT.getScalarSizeInBits();
4386 
4387   // fold (sdiv X, pow2) -> simple ops after legalize
4388   // FIXME: We check for the exact bit here because the generic lowering gives
4389   // better results in that case. The target-specific lowering should learn how
4390   // to handle exact sdivs efficiently.
4391   if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1)) {
4392     // Target-specific implementation of sdiv x, pow2.
4393     if (SDValue Res = BuildSDIVPow2(N))
4394       return Res;
4395 
4396     // Create constants that are functions of the shift amount value.
4397     EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
4398     SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
4399     SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
4400     C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
4401     SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
4402     if (!isConstantOrConstantVector(Inexact))
4403       return SDValue();
4404 
4405     // Splat the sign bit into the register
4406     SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
4407                                DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
4408     AddToWorklist(Sign.getNode());
4409 
4410     // Add (N0 < 0) ? abs2 - 1 : 0;
4411     SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
4412     AddToWorklist(Srl.getNode());
4413     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
4414     AddToWorklist(Add.getNode());
4415     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
4416     AddToWorklist(Sra.getNode());
4417 
4418     // Special case: (sdiv X, 1) -> X
4419     // Special Case: (sdiv X, -1) -> 0-X
4420     SDValue One = DAG.getConstant(1, DL, VT);
4421     SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
4422     SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
4423     SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
4424     SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
4425     Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
4426 
4427     // If dividing by a positive value, we're done. Otherwise, the result must
4428     // be negated.
4429     SDValue Zero = DAG.getConstant(0, DL, VT);
4430     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
4431 
4432     // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
4433     SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
4434     SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
4435     return Res;
4436   }
4437 
4438   // If integer divide is expensive and we satisfy the requirements, emit an
4439   // alternate sequence.  Targets may check function attributes for size/speed
4440   // trade-offs.
4441   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4442   if (isConstantOrConstantVector(N1) &&
4443       !TLI.isIntDivCheap(N->getValueType(0), Attr))
4444     if (SDValue Op = BuildSDIV(N))
4445       return Op;
4446 
4447   return SDValue();
4448 }
4449 
visitUDIV(SDNode * N)4450 SDValue DAGCombiner::visitUDIV(SDNode *N) {
4451   SDValue N0 = N->getOperand(0);
4452   SDValue N1 = N->getOperand(1);
4453   EVT VT = N->getValueType(0);
4454   EVT CCVT = getSetCCResultType(VT);
4455   SDLoc DL(N);
4456 
4457   // fold (udiv c1, c2) -> c1/c2
4458   if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
4459     return C;
4460 
4461   // fold vector ops
4462   if (VT.isVector())
4463     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4464       return FoldedVOp;
4465 
4466   // fold (udiv X, -1) -> select(X == -1, 1, 0)
4467   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4468   if (N1C && N1C->isAllOnes())
4469     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4470                          DAG.getConstant(1, DL, VT),
4471                          DAG.getConstant(0, DL, VT));
4472 
4473   if (SDValue V = simplifyDivRem(N, DAG))
4474     return V;
4475 
4476   if (SDValue NewSel = foldBinOpIntoSelect(N))
4477     return NewSel;
4478 
4479   if (SDValue V = visitUDIVLike(N0, N1, N)) {
4480     // If the corresponding remainder node exists, update its users with
4481     // (Dividend - (Quotient * Divisor).
4482     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
4483                                               { N0, N1 })) {
4484       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4485       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4486       AddToWorklist(Mul.getNode());
4487       AddToWorklist(Sub.getNode());
4488       CombineTo(RemNode, Sub);
4489     }
4490     return V;
4491   }
4492 
4493   // sdiv, srem -> sdivrem
4494   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4495   // true.  Otherwise, we break the simplification logic in visitREM().
4496   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4497   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4498     if (SDValue DivRem = useDivRem(N))
4499         return DivRem;
4500 
4501   return SDValue();
4502 }
4503 
visitUDIVLike(SDValue N0,SDValue N1,SDNode * N)4504 SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4505   SDLoc DL(N);
4506   EVT VT = N->getValueType(0);
4507 
4508   // fold (udiv x, (1 << c)) -> x >>u c
4509   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4510       DAG.isKnownToBeAPowerOfTwo(N1)) {
4511     SDValue LogBase2 = BuildLogBase2(N1, DL);
4512     AddToWorklist(LogBase2.getNode());
4513 
4514     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4515     SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4516     AddToWorklist(Trunc.getNode());
4517     return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4518   }
4519 
4520   // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
4521   if (N1.getOpcode() == ISD::SHL) {
4522     SDValue N10 = N1.getOperand(0);
4523     if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) &&
4524         DAG.isKnownToBeAPowerOfTwo(N10)) {
4525       SDValue LogBase2 = BuildLogBase2(N10, DL);
4526       AddToWorklist(LogBase2.getNode());
4527 
4528       EVT ADDVT = N1.getOperand(1).getValueType();
4529       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
4530       AddToWorklist(Trunc.getNode());
4531       SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
4532       AddToWorklist(Add.getNode());
4533       return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
4534     }
4535   }
4536 
4537   // fold (udiv x, c) -> alternate
4538   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4539   if (isConstantOrConstantVector(N1) &&
4540       !TLI.isIntDivCheap(N->getValueType(0), Attr))
4541     if (SDValue Op = BuildUDIV(N))
4542       return Op;
4543 
4544   return SDValue();
4545 }
4546 
buildOptimizedSREM(SDValue N0,SDValue N1,SDNode * N)4547 SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
4548   if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1) &&
4549       !DAG.doesNodeExist(ISD::SDIV, N->getVTList(), {N0, N1})) {
4550     // Target-specific implementation of srem x, pow2.
4551     if (SDValue Res = BuildSREMPow2(N))
4552       return Res;
4553   }
4554   return SDValue();
4555 }
4556 
4557 // handles ISD::SREM and ISD::UREM
visitREM(SDNode * N)4558 SDValue DAGCombiner::visitREM(SDNode *N) {
4559   unsigned Opcode = N->getOpcode();
4560   SDValue N0 = N->getOperand(0);
4561   SDValue N1 = N->getOperand(1);
4562   EVT VT = N->getValueType(0);
4563   EVT CCVT = getSetCCResultType(VT);
4564 
4565   bool isSigned = (Opcode == ISD::SREM);
4566   SDLoc DL(N);
4567 
4568   // fold (rem c1, c2) -> c1%c2
4569   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4570     return C;
4571 
4572   // fold (urem X, -1) -> select(FX == -1, 0, FX)
4573   // Freeze the numerator to avoid a miscompile with an undefined value.
4574   if (!isSigned && llvm::isAllOnesOrAllOnesSplat(N1, /*AllowUndefs*/ false)) {
4575     SDValue F0 = DAG.getFreeze(N0);
4576     SDValue EqualsNeg1 = DAG.getSetCC(DL, CCVT, F0, N1, ISD::SETEQ);
4577     return DAG.getSelect(DL, VT, EqualsNeg1, DAG.getConstant(0, DL, VT), F0);
4578   }
4579 
4580   if (SDValue V = simplifyDivRem(N, DAG))
4581     return V;
4582 
4583   if (SDValue NewSel = foldBinOpIntoSelect(N))
4584     return NewSel;
4585 
4586   if (isSigned) {
4587     // If we know the sign bits of both operands are zero, strength reduce to a
4588     // urem instead.  Handles (X & 0x0FFFFFFF) %s 16 -> X&15
4589     if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4590       return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
4591   } else {
4592     if (DAG.isKnownToBeAPowerOfTwo(N1)) {
4593       // fold (urem x, pow2) -> (and x, pow2-1)
4594       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4595       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4596       AddToWorklist(Add.getNode());
4597       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4598     }
4599     // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
4600     // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
4601     // TODO: We should sink the following into isKnownToBePowerOfTwo
4602     // using a OrZero parameter analogous to our handling in ValueTracking.
4603     if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
4604         DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
4605       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4606       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4607       AddToWorklist(Add.getNode());
4608       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4609     }
4610   }
4611 
4612   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4613 
4614   // If X/C can be simplified by the division-by-constant logic, lower
4615   // X%C to the equivalent of X-X/C*C.
4616   // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
4617   // speculative DIV must not cause a DIVREM conversion.  We guard against this
4618   // by skipping the simplification if isIntDivCheap().  When div is not cheap,
4619   // combine will not return a DIVREM.  Regardless, checking cheapness here
4620   // makes sense since the simplification results in fatter code.
4621   if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
4622     if (isSigned) {
4623       // check if we can build faster implementation for srem
4624       if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
4625         return OptimizedRem;
4626     }
4627 
4628     SDValue OptimizedDiv =
4629         isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
4630     if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
4631       // If the equivalent Div node also exists, update its users.
4632       unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4633       if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
4634                                                 { N0, N1 }))
4635         CombineTo(DivNode, OptimizedDiv);
4636       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
4637       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4638       AddToWorklist(OptimizedDiv.getNode());
4639       AddToWorklist(Mul.getNode());
4640       return Sub;
4641     }
4642   }
4643 
4644   // sdiv, srem -> sdivrem
4645   if (SDValue DivRem = useDivRem(N))
4646     return DivRem.getValue(1);
4647 
4648   return SDValue();
4649 }
4650 
visitMULHS(SDNode * N)4651 SDValue DAGCombiner::visitMULHS(SDNode *N) {
4652   SDValue N0 = N->getOperand(0);
4653   SDValue N1 = N->getOperand(1);
4654   EVT VT = N->getValueType(0);
4655   SDLoc DL(N);
4656 
4657   // fold (mulhs c1, c2)
4658   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHS, DL, VT, {N0, N1}))
4659     return C;
4660 
4661   // canonicalize constant to RHS.
4662   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4663       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4664     return DAG.getNode(ISD::MULHS, DL, N->getVTList(), N1, N0);
4665 
4666   if (VT.isVector()) {
4667     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4668       return FoldedVOp;
4669 
4670     // fold (mulhs x, 0) -> 0
4671     // do not return N1, because undef node may exist.
4672     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4673       return DAG.getConstant(0, DL, VT);
4674   }
4675 
4676   // fold (mulhs x, 0) -> 0
4677   if (isNullConstant(N1))
4678     return N1;
4679 
4680   // fold (mulhs x, 1) -> (sra x, size(x)-1)
4681   if (isOneConstant(N1))
4682     return DAG.getNode(ISD::SRA, DL, N0.getValueType(), N0,
4683                        DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
4684                                        getShiftAmountTy(N0.getValueType())));
4685 
4686   // fold (mulhs x, undef) -> 0
4687   if (N0.isUndef() || N1.isUndef())
4688     return DAG.getConstant(0, DL, VT);
4689 
4690   // If the type twice as wide is legal, transform the mulhs to a wider multiply
4691   // plus a shift.
4692   if (!TLI.isOperationLegalOrCustom(ISD::MULHS, VT) && VT.isSimple() &&
4693       !VT.isVector()) {
4694     MVT Simple = VT.getSimpleVT();
4695     unsigned SimpleSize = Simple.getSizeInBits();
4696     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4697     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4698       N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
4699       N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
4700       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4701       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4702             DAG.getConstant(SimpleSize, DL,
4703                             getShiftAmountTy(N1.getValueType())));
4704       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4705     }
4706   }
4707 
4708   return SDValue();
4709 }
4710 
visitMULHU(SDNode * N)4711 SDValue DAGCombiner::visitMULHU(SDNode *N) {
4712   SDValue N0 = N->getOperand(0);
4713   SDValue N1 = N->getOperand(1);
4714   EVT VT = N->getValueType(0);
4715   SDLoc DL(N);
4716 
4717   // fold (mulhu c1, c2)
4718   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHU, DL, VT, {N0, N1}))
4719     return C;
4720 
4721   // canonicalize constant to RHS.
4722   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4723       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4724     return DAG.getNode(ISD::MULHU, DL, N->getVTList(), N1, N0);
4725 
4726   if (VT.isVector()) {
4727     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4728       return FoldedVOp;
4729 
4730     // fold (mulhu x, 0) -> 0
4731     // do not return N1, because undef node may exist.
4732     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4733       return DAG.getConstant(0, DL, VT);
4734   }
4735 
4736   // fold (mulhu x, 0) -> 0
4737   if (isNullConstant(N1))
4738     return N1;
4739 
4740   // fold (mulhu x, 1) -> 0
4741   if (isOneConstant(N1))
4742     return DAG.getConstant(0, DL, N0.getValueType());
4743 
4744   // fold (mulhu x, undef) -> 0
4745   if (N0.isUndef() || N1.isUndef())
4746     return DAG.getConstant(0, DL, VT);
4747 
4748   // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
4749   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4750       DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) {
4751     unsigned NumEltBits = VT.getScalarSizeInBits();
4752     SDValue LogBase2 = BuildLogBase2(N1, DL);
4753     SDValue SRLAmt = DAG.getNode(
4754         ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
4755     EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4756     SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
4757     return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4758   }
4759 
4760   // If the type twice as wide is legal, transform the mulhu to a wider multiply
4761   // plus a shift.
4762   if (!TLI.isOperationLegalOrCustom(ISD::MULHU, VT) && VT.isSimple() &&
4763       !VT.isVector()) {
4764     MVT Simple = VT.getSimpleVT();
4765     unsigned SimpleSize = Simple.getSizeInBits();
4766     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4767     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4768       N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
4769       N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
4770       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4771       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4772             DAG.getConstant(SimpleSize, DL,
4773                             getShiftAmountTy(N1.getValueType())));
4774       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4775     }
4776   }
4777 
4778   // Simplify the operands using demanded-bits information.
4779   // We don't have demanded bits support for MULHU so this just enables constant
4780   // folding based on known bits.
4781   if (SimplifyDemandedBits(SDValue(N, 0)))
4782     return SDValue(N, 0);
4783 
4784   return SDValue();
4785 }
4786 
visitAVG(SDNode * N)4787 SDValue DAGCombiner::visitAVG(SDNode *N) {
4788   unsigned Opcode = N->getOpcode();
4789   SDValue N0 = N->getOperand(0);
4790   SDValue N1 = N->getOperand(1);
4791   EVT VT = N->getValueType(0);
4792   SDLoc DL(N);
4793 
4794   // fold (avg c1, c2)
4795   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4796     return C;
4797 
4798   // canonicalize constant to RHS.
4799   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4800       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4801     return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
4802 
4803   if (VT.isVector()) {
4804     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4805       return FoldedVOp;
4806 
4807     // fold (avgfloor x, 0) -> x >> 1
4808     if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
4809       if (Opcode == ISD::AVGFLOORS)
4810         return DAG.getNode(ISD::SRA, DL, VT, N0, DAG.getConstant(1, DL, VT));
4811       if (Opcode == ISD::AVGFLOORU)
4812         return DAG.getNode(ISD::SRL, DL, VT, N0, DAG.getConstant(1, DL, VT));
4813     }
4814   }
4815 
4816   // fold (avg x, undef) -> x
4817   if (N0.isUndef())
4818     return N1;
4819   if (N1.isUndef())
4820     return N0;
4821 
4822   // TODO If we use avg for scalars anywhere, we can add (avgfl x, 0) -> x >> 1
4823 
4824   return SDValue();
4825 }
4826 
4827 /// Perform optimizations common to nodes that compute two values. LoOp and HiOp
4828 /// give the opcodes for the two computations that are being performed. Return
4829 /// true if a simplification was made.
SimplifyNodeWithTwoResults(SDNode * N,unsigned LoOp,unsigned HiOp)4830 SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
4831                                                 unsigned HiOp) {
4832   // If the high half is not needed, just compute the low half.
4833   bool HiExists = N->hasAnyUseOfValue(1);
4834   if (!HiExists && (!LegalOperations ||
4835                     TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
4836     SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4837     return CombineTo(N, Res, Res);
4838   }
4839 
4840   // If the low half is not needed, just compute the high half.
4841   bool LoExists = N->hasAnyUseOfValue(0);
4842   if (!LoExists && (!LegalOperations ||
4843                     TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
4844     SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4845     return CombineTo(N, Res, Res);
4846   }
4847 
4848   // If both halves are used, return as it is.
4849   if (LoExists && HiExists)
4850     return SDValue();
4851 
4852   // If the two computed results can be simplified separately, separate them.
4853   if (LoExists) {
4854     SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4855     AddToWorklist(Lo.getNode());
4856     SDValue LoOpt = combine(Lo.getNode());
4857     if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
4858         (!LegalOperations ||
4859          TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
4860       return CombineTo(N, LoOpt, LoOpt);
4861   }
4862 
4863   if (HiExists) {
4864     SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4865     AddToWorklist(Hi.getNode());
4866     SDValue HiOpt = combine(Hi.getNode());
4867     if (HiOpt.getNode() && HiOpt != Hi &&
4868         (!LegalOperations ||
4869          TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
4870       return CombineTo(N, HiOpt, HiOpt);
4871   }
4872 
4873   return SDValue();
4874 }
4875 
visitSMUL_LOHI(SDNode * N)4876 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
4877   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
4878     return Res;
4879 
4880   SDValue N0 = N->getOperand(0);
4881   SDValue N1 = N->getOperand(1);
4882   EVT VT = N->getValueType(0);
4883   SDLoc DL(N);
4884 
4885   // canonicalize constant to RHS (vector doesn't have to splat)
4886   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4887       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4888     return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N1, N0);
4889 
4890   // If the type is twice as wide is legal, transform the mulhu to a wider
4891   // multiply plus a shift.
4892   if (VT.isSimple() && !VT.isVector()) {
4893     MVT Simple = VT.getSimpleVT();
4894     unsigned SimpleSize = Simple.getSizeInBits();
4895     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4896     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4897       SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
4898       SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
4899       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4900       // Compute the high part as N1.
4901       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4902             DAG.getConstant(SimpleSize, DL,
4903                             getShiftAmountTy(Lo.getValueType())));
4904       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4905       // Compute the low part as N0.
4906       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4907       return CombineTo(N, Lo, Hi);
4908     }
4909   }
4910 
4911   return SDValue();
4912 }
4913 
visitUMUL_LOHI(SDNode * N)4914 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
4915   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
4916     return Res;
4917 
4918   SDValue N0 = N->getOperand(0);
4919   SDValue N1 = N->getOperand(1);
4920   EVT VT = N->getValueType(0);
4921   SDLoc DL(N);
4922 
4923   // canonicalize constant to RHS (vector doesn't have to splat)
4924   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4925       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4926     return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N1, N0);
4927 
4928   // (umul_lohi N0, 0) -> (0, 0)
4929   if (isNullConstant(N1)) {
4930     SDValue Zero = DAG.getConstant(0, DL, VT);
4931     return CombineTo(N, Zero, Zero);
4932   }
4933 
4934   // (umul_lohi N0, 1) -> (N0, 0)
4935   if (isOneConstant(N1)) {
4936     SDValue Zero = DAG.getConstant(0, DL, VT);
4937     return CombineTo(N, N0, Zero);
4938   }
4939 
4940   // If the type is twice as wide is legal, transform the mulhu to a wider
4941   // multiply plus a shift.
4942   if (VT.isSimple() && !VT.isVector()) {
4943     MVT Simple = VT.getSimpleVT();
4944     unsigned SimpleSize = Simple.getSizeInBits();
4945     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4946     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4947       SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
4948       SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
4949       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4950       // Compute the high part as N1.
4951       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4952             DAG.getConstant(SimpleSize, DL,
4953                             getShiftAmountTy(Lo.getValueType())));
4954       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4955       // Compute the low part as N0.
4956       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4957       return CombineTo(N, Lo, Hi);
4958     }
4959   }
4960 
4961   return SDValue();
4962 }
4963 
visitMULO(SDNode * N)4964 SDValue DAGCombiner::visitMULO(SDNode *N) {
4965   SDValue N0 = N->getOperand(0);
4966   SDValue N1 = N->getOperand(1);
4967   EVT VT = N0.getValueType();
4968   bool IsSigned = (ISD::SMULO == N->getOpcode());
4969 
4970   EVT CarryVT = N->getValueType(1);
4971   SDLoc DL(N);
4972 
4973   ConstantSDNode *N0C = isConstOrConstSplat(N0);
4974   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4975 
4976   // fold operation with constant operands.
4977   // TODO: Move this to FoldConstantArithmetic when it supports nodes with
4978   // multiple results.
4979   if (N0C && N1C) {
4980     bool Overflow;
4981     APInt Result =
4982         IsSigned ? N0C->getAPIntValue().smul_ov(N1C->getAPIntValue(), Overflow)
4983                  : N0C->getAPIntValue().umul_ov(N1C->getAPIntValue(), Overflow);
4984     return CombineTo(N, DAG.getConstant(Result, DL, VT),
4985                      DAG.getBoolConstant(Overflow, DL, CarryVT, CarryVT));
4986   }
4987 
4988   // canonicalize constant to RHS.
4989   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4990       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4991     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
4992 
4993   // fold (mulo x, 0) -> 0 + no carry out
4994   if (isNullOrNullSplat(N1))
4995     return CombineTo(N, DAG.getConstant(0, DL, VT),
4996                      DAG.getConstant(0, DL, CarryVT));
4997 
4998   // (mulo x, 2) -> (addo x, x)
4999   // FIXME: This needs a freeze.
5000   if (N1C && N1C->getAPIntValue() == 2 &&
5001       (!IsSigned || VT.getScalarSizeInBits() > 2))
5002     return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
5003                        N->getVTList(), N0, N0);
5004 
5005   if (IsSigned) {
5006     // A 1 bit SMULO overflows if both inputs are 1.
5007     if (VT.getScalarSizeInBits() == 1) {
5008       SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1);
5009       return CombineTo(N, And,
5010                        DAG.getSetCC(DL, CarryVT, And,
5011                                     DAG.getConstant(0, DL, VT), ISD::SETNE));
5012     }
5013 
5014     // Multiplying n * m significant bits yields a result of n + m significant
5015     // bits. If the total number of significant bits does not exceed the
5016     // result bit width (minus 1), there is no overflow.
5017     unsigned SignBits = DAG.ComputeNumSignBits(N0);
5018     if (SignBits > 1)
5019       SignBits += DAG.ComputeNumSignBits(N1);
5020     if (SignBits > VT.getScalarSizeInBits() + 1)
5021       return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
5022                        DAG.getConstant(0, DL, CarryVT));
5023   } else {
5024     KnownBits N1Known = DAG.computeKnownBits(N1);
5025     KnownBits N0Known = DAG.computeKnownBits(N0);
5026     bool Overflow;
5027     (void)N0Known.getMaxValue().umul_ov(N1Known.getMaxValue(), Overflow);
5028     if (!Overflow)
5029       return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
5030                        DAG.getConstant(0, DL, CarryVT));
5031   }
5032 
5033   return SDValue();
5034 }
5035 
5036 // Function to calculate whether the Min/Max pair of SDNodes (potentially
5037 // swapped around) make a signed saturate pattern, clamping to between a signed
5038 // saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
5039 // Returns the node being clamped and the bitwidth of the clamp in BW. Should
5040 // work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
5041 // same as SimplifySelectCC. N0<N1 ? N2 : N3.
isSaturatingMinMax(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,unsigned & BW,bool & Unsigned)5042 static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
5043                                   SDValue N3, ISD::CondCode CC, unsigned &BW,
5044                                   bool &Unsigned) {
5045   auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
5046                             ISD::CondCode CC) {
5047     // The compare and select operand should be the same or the select operands
5048     // should be truncated versions of the comparison.
5049     if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0)))
5050       return 0;
5051     // The constants need to be the same or a truncated version of each other.
5052     ConstantSDNode *N1C = isConstOrConstSplat(N1);
5053     ConstantSDNode *N3C = isConstOrConstSplat(N3);
5054     if (!N1C || !N3C)
5055       return 0;
5056     const APInt &C1 = N1C->getAPIntValue();
5057     const APInt &C2 = N3C->getAPIntValue();
5058     if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(C1.getBitWidth()))
5059       return 0;
5060     return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
5061   };
5062 
5063   // Check the initial value is a SMIN/SMAX equivalent.
5064   unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
5065   if (!Opcode0)
5066     return SDValue();
5067 
5068   SDValue N00, N01, N02, N03;
5069   ISD::CondCode N0CC;
5070   switch (N0.getOpcode()) {
5071   case ISD::SMIN:
5072   case ISD::SMAX:
5073     N00 = N02 = N0.getOperand(0);
5074     N01 = N03 = N0.getOperand(1);
5075     N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
5076     break;
5077   case ISD::SELECT_CC:
5078     N00 = N0.getOperand(0);
5079     N01 = N0.getOperand(1);
5080     N02 = N0.getOperand(2);
5081     N03 = N0.getOperand(3);
5082     N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
5083     break;
5084   case ISD::SELECT:
5085   case ISD::VSELECT:
5086     if (N0.getOperand(0).getOpcode() != ISD::SETCC)
5087       return SDValue();
5088     N00 = N0.getOperand(0).getOperand(0);
5089     N01 = N0.getOperand(0).getOperand(1);
5090     N02 = N0.getOperand(1);
5091     N03 = N0.getOperand(2);
5092     N0CC = cast<CondCodeSDNode>(N0.getOperand(0).getOperand(2))->get();
5093     break;
5094   default:
5095     return SDValue();
5096   }
5097 
5098   unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
5099   if (!Opcode1 || Opcode0 == Opcode1)
5100     return SDValue();
5101 
5102   ConstantSDNode *MinCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N1 : N01);
5103   ConstantSDNode *MaxCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N01 : N1);
5104   if (!MinCOp || !MaxCOp || MinCOp->getValueType(0) != MaxCOp->getValueType(0))
5105     return SDValue();
5106 
5107   const APInt &MinC = MinCOp->getAPIntValue();
5108   const APInt &MaxC = MaxCOp->getAPIntValue();
5109   APInt MinCPlus1 = MinC + 1;
5110   if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
5111     BW = MinCPlus1.exactLogBase2() + 1;
5112     Unsigned = false;
5113     return N02;
5114   }
5115 
5116   if (MaxC == 0 && MinCPlus1.isPowerOf2()) {
5117     BW = MinCPlus1.exactLogBase2();
5118     Unsigned = true;
5119     return N02;
5120   }
5121 
5122   return SDValue();
5123 }
5124 
PerformMinMaxFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5125 static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5126                                            SDValue N3, ISD::CondCode CC,
5127                                            SelectionDAG &DAG) {
5128   unsigned BW;
5129   bool Unsigned;
5130   SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned);
5131   if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
5132     return SDValue();
5133   EVT FPVT = Fp.getOperand(0).getValueType();
5134   EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
5135   if (FPVT.isVector())
5136     NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
5137                              FPVT.getVectorElementCount());
5138   unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
5139   if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(NewOpc, FPVT, NewVT))
5140     return SDValue();
5141   SDLoc DL(Fp);
5142   SDValue Sat = DAG.getNode(NewOpc, DL, NewVT, Fp.getOperand(0),
5143                             DAG.getValueType(NewVT.getScalarType()));
5144   return Unsigned ? DAG.getZExtOrTrunc(Sat, DL, N2->getValueType(0))
5145                   : DAG.getSExtOrTrunc(Sat, DL, N2->getValueType(0));
5146 }
5147 
PerformUMinFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5148 static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5149                                          SDValue N3, ISD::CondCode CC,
5150                                          SelectionDAG &DAG) {
5151   // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
5152   // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
5153   // be truncated versions of the the setcc (N0/N1).
5154   if ((N0 != N2 &&
5155        (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0))) ||
5156       N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
5157     return SDValue();
5158   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5159   ConstantSDNode *N3C = isConstOrConstSplat(N3);
5160   if (!N1C || !N3C)
5161     return SDValue();
5162   const APInt &C1 = N1C->getAPIntValue();
5163   const APInt &C3 = N3C->getAPIntValue();
5164   if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
5165       C1 != C3.zext(C1.getBitWidth()))
5166     return SDValue();
5167 
5168   unsigned BW = (C1 + 1).exactLogBase2();
5169   EVT FPVT = N0.getOperand(0).getValueType();
5170   EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
5171   if (FPVT.isVector())
5172     NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
5173                              FPVT.getVectorElementCount());
5174   if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
5175                                                         FPVT, NewVT))
5176     return SDValue();
5177 
5178   SDValue Sat =
5179       DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(N0), NewVT, N0.getOperand(0),
5180                   DAG.getValueType(NewVT.getScalarType()));
5181   return DAG.getZExtOrTrunc(Sat, SDLoc(N0), N3.getValueType());
5182 }
5183 
visitIMINMAX(SDNode * N)5184 SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
5185   SDValue N0 = N->getOperand(0);
5186   SDValue N1 = N->getOperand(1);
5187   EVT VT = N0.getValueType();
5188   unsigned Opcode = N->getOpcode();
5189   SDLoc DL(N);
5190 
5191   // fold operation with constant operands.
5192   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5193     return C;
5194 
5195   // If the operands are the same, this is a no-op.
5196   if (N0 == N1)
5197     return N0;
5198 
5199   // canonicalize constant to RHS
5200   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5201       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5202     return DAG.getNode(Opcode, DL, VT, N1, N0);
5203 
5204   // fold vector ops
5205   if (VT.isVector())
5206     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5207       return FoldedVOp;
5208 
5209   // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
5210   // Only do this if the current op isn't legal and the flipped is.
5211   if (!TLI.isOperationLegal(Opcode, VT) &&
5212       (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
5213       (N1.isUndef() || DAG.SignBitIsZero(N1))) {
5214     unsigned AltOpcode;
5215     switch (Opcode) {
5216     case ISD::SMIN: AltOpcode = ISD::UMIN; break;
5217     case ISD::SMAX: AltOpcode = ISD::UMAX; break;
5218     case ISD::UMIN: AltOpcode = ISD::SMIN; break;
5219     case ISD::UMAX: AltOpcode = ISD::SMAX; break;
5220     default: llvm_unreachable("Unknown MINMAX opcode");
5221     }
5222     if (TLI.isOperationLegal(AltOpcode, VT))
5223       return DAG.getNode(AltOpcode, DL, VT, N0, N1);
5224   }
5225 
5226   if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
5227     if (SDValue S = PerformMinMaxFpToSatCombine(
5228             N0, N1, N0, N1, Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
5229       return S;
5230   if (Opcode == ISD::UMIN)
5231     if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG))
5232       return S;
5233 
5234   // Simplify the operands using demanded-bits information.
5235   if (SimplifyDemandedBits(SDValue(N, 0)))
5236     return SDValue(N, 0);
5237 
5238   return SDValue();
5239 }
5240 
5241 /// If this is a bitwise logic instruction and both operands have the same
5242 /// opcode, try to sink the other opcode after the logic instruction.
hoistLogicOpWithSameOpcodeHands(SDNode * N)5243 SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
5244   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
5245   EVT VT = N0.getValueType();
5246   unsigned LogicOpcode = N->getOpcode();
5247   unsigned HandOpcode = N0.getOpcode();
5248   assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
5249           LogicOpcode == ISD::XOR) && "Expected logic opcode");
5250   assert(HandOpcode == N1.getOpcode() && "Bad input!");
5251 
5252   // Bail early if none of these transforms apply.
5253   if (N0.getNumOperands() == 0)
5254     return SDValue();
5255 
5256   // FIXME: We should check number of uses of the operands to not increase
5257   //        the instruction count for all transforms.
5258 
5259   // Handle size-changing casts.
5260   SDValue X = N0.getOperand(0);
5261   SDValue Y = N1.getOperand(0);
5262   EVT XVT = X.getValueType();
5263   SDLoc DL(N);
5264   if (HandOpcode == ISD::ANY_EXTEND || HandOpcode == ISD::ZERO_EXTEND ||
5265       HandOpcode == ISD::SIGN_EXTEND) {
5266     // If both operands have other uses, this transform would create extra
5267     // instructions without eliminating anything.
5268     if (!N0.hasOneUse() && !N1.hasOneUse())
5269       return SDValue();
5270     // We need matching integer source types.
5271     if (XVT != Y.getValueType())
5272       return SDValue();
5273     // Don't create an illegal op during or after legalization. Don't ever
5274     // create an unsupported vector op.
5275     if ((VT.isVector() || LegalOperations) &&
5276         !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
5277       return SDValue();
5278     // Avoid infinite looping with PromoteIntBinOp.
5279     // TODO: Should we apply desirable/legal constraints to all opcodes?
5280     if (HandOpcode == ISD::ANY_EXTEND && LegalTypes &&
5281         !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
5282       return SDValue();
5283     // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
5284     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5285     return DAG.getNode(HandOpcode, DL, VT, Logic);
5286   }
5287 
5288   // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
5289   if (HandOpcode == ISD::TRUNCATE) {
5290     // If both operands have other uses, this transform would create extra
5291     // instructions without eliminating anything.
5292     if (!N0.hasOneUse() && !N1.hasOneUse())
5293       return SDValue();
5294     // We need matching source types.
5295     if (XVT != Y.getValueType())
5296       return SDValue();
5297     // Don't create an illegal op during or after legalization.
5298     if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
5299       return SDValue();
5300     // Be extra careful sinking truncate. If it's free, there's no benefit in
5301     // widening a binop. Also, don't create a logic op on an illegal type.
5302     if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
5303       return SDValue();
5304     if (!TLI.isTypeLegal(XVT))
5305       return SDValue();
5306     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5307     return DAG.getNode(HandOpcode, DL, VT, Logic);
5308   }
5309 
5310   // For binops SHL/SRL/SRA/AND:
5311   //   logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
5312   if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
5313        HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
5314       N0.getOperand(1) == N1.getOperand(1)) {
5315     // If either operand has other uses, this transform is not an improvement.
5316     if (!N0.hasOneUse() || !N1.hasOneUse())
5317       return SDValue();
5318     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5319     return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
5320   }
5321 
5322   // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
5323   if (HandOpcode == ISD::BSWAP) {
5324     // If either operand has other uses, this transform is not an improvement.
5325     if (!N0.hasOneUse() || !N1.hasOneUse())
5326       return SDValue();
5327     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5328     return DAG.getNode(HandOpcode, DL, VT, Logic);
5329   }
5330 
5331   // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
5332   // Only perform this optimization up until type legalization, before
5333   // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
5334   // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
5335   // we don't want to undo this promotion.
5336   // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
5337   // on scalars.
5338   if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
5339        Level <= AfterLegalizeTypes) {
5340     // Input types must be integer and the same.
5341     if (XVT.isInteger() && XVT == Y.getValueType() &&
5342         !(VT.isVector() && TLI.isTypeLegal(VT) &&
5343           !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
5344       SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5345       return DAG.getNode(HandOpcode, DL, VT, Logic);
5346     }
5347   }
5348 
5349   // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
5350   // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
5351   // If both shuffles use the same mask, and both shuffle within a single
5352   // vector, then it is worthwhile to move the swizzle after the operation.
5353   // The type-legalizer generates this pattern when loading illegal
5354   // vector types from memory. In many cases this allows additional shuffle
5355   // optimizations.
5356   // There are other cases where moving the shuffle after the xor/and/or
5357   // is profitable even if shuffles don't perform a swizzle.
5358   // If both shuffles use the same mask, and both shuffles have the same first
5359   // or second operand, then it might still be profitable to move the shuffle
5360   // after the xor/and/or operation.
5361   if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
5362     auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
5363     auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
5364     assert(X.getValueType() == Y.getValueType() &&
5365            "Inputs to shuffles are not the same type");
5366 
5367     // Check that both shuffles use the same mask. The masks are known to be of
5368     // the same length because the result vector type is the same.
5369     // Check also that shuffles have only one use to avoid introducing extra
5370     // instructions.
5371     if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
5372         !SVN0->getMask().equals(SVN1->getMask()))
5373       return SDValue();
5374 
5375     // Don't try to fold this node if it requires introducing a
5376     // build vector of all zeros that might be illegal at this stage.
5377     SDValue ShOp = N0.getOperand(1);
5378     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5379       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5380 
5381     // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
5382     if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
5383       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
5384                                   N0.getOperand(0), N1.getOperand(0));
5385       return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
5386     }
5387 
5388     // Don't try to fold this node if it requires introducing a
5389     // build vector of all zeros that might be illegal at this stage.
5390     ShOp = N0.getOperand(0);
5391     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5392       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5393 
5394     // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
5395     if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
5396       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
5397                                   N1.getOperand(1));
5398       return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
5399     }
5400   }
5401 
5402   return SDValue();
5403 }
5404 
5405 /// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
foldLogicOfSetCCs(bool IsAnd,SDValue N0,SDValue N1,const SDLoc & DL)5406 SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
5407                                        const SDLoc &DL) {
5408   SDValue LL, LR, RL, RR, N0CC, N1CC;
5409   if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
5410       !isSetCCEquivalent(N1, RL, RR, N1CC))
5411     return SDValue();
5412 
5413   assert(N0.getValueType() == N1.getValueType() &&
5414          "Unexpected operand types for bitwise logic op");
5415   assert(LL.getValueType() == LR.getValueType() &&
5416          RL.getValueType() == RR.getValueType() &&
5417          "Unexpected operand types for setcc");
5418 
5419   // If we're here post-legalization or the logic op type is not i1, the logic
5420   // op type must match a setcc result type. Also, all folds require new
5421   // operations on the left and right operands, so those types must match.
5422   EVT VT = N0.getValueType();
5423   EVT OpVT = LL.getValueType();
5424   if (LegalOperations || VT.getScalarType() != MVT::i1)
5425     if (VT != getSetCCResultType(OpVT))
5426       return SDValue();
5427   if (OpVT != RL.getValueType())
5428     return SDValue();
5429 
5430   ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
5431   ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
5432   bool IsInteger = OpVT.isInteger();
5433   if (LR == RR && CC0 == CC1 && IsInteger) {
5434     bool IsZero = isNullOrNullSplat(LR);
5435     bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
5436 
5437     // All bits clear?
5438     bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
5439     // All sign bits clear?
5440     bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
5441     // Any bits set?
5442     bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
5443     // Any sign bits set?
5444     bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
5445 
5446     // (and (seteq X,  0), (seteq Y,  0)) --> (seteq (or X, Y),  0)
5447     // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
5448     // (or  (setne X,  0), (setne Y,  0)) --> (setne (or X, Y),  0)
5449     // (or  (setlt X,  0), (setlt Y,  0)) --> (setlt (or X, Y),  0)
5450     if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
5451       SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
5452       AddToWorklist(Or.getNode());
5453       return DAG.getSetCC(DL, VT, Or, LR, CC1);
5454     }
5455 
5456     // All bits set?
5457     bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
5458     // All sign bits set?
5459     bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
5460     // Any bits clear?
5461     bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
5462     // Any sign bits clear?
5463     bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
5464 
5465     // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
5466     // (and (setlt X,  0), (setlt Y,  0)) --> (setlt (and X, Y),  0)
5467     // (or  (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
5468     // (or  (setgt X, -1), (setgt Y  -1)) --> (setgt (and X, Y), -1)
5469     if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
5470       SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
5471       AddToWorklist(And.getNode());
5472       return DAG.getSetCC(DL, VT, And, LR, CC1);
5473     }
5474   }
5475 
5476   // TODO: What is the 'or' equivalent of this fold?
5477   // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
5478   if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
5479       IsInteger && CC0 == ISD::SETNE &&
5480       ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
5481        (isAllOnesConstant(LR) && isNullConstant(RR)))) {
5482     SDValue One = DAG.getConstant(1, DL, OpVT);
5483     SDValue Two = DAG.getConstant(2, DL, OpVT);
5484     SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
5485     AddToWorklist(Add.getNode());
5486     return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
5487   }
5488 
5489   // Try more general transforms if the predicates match and the only user of
5490   // the compares is the 'and' or 'or'.
5491   if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
5492       N0.hasOneUse() && N1.hasOneUse()) {
5493     // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
5494     // or  (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
5495     if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
5496       SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
5497       SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
5498       SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
5499       SDValue Zero = DAG.getConstant(0, DL, OpVT);
5500       return DAG.getSetCC(DL, VT, Or, Zero, CC1);
5501     }
5502 
5503     // Turn compare of constants whose difference is 1 bit into add+and+setcc.
5504     if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
5505       // Match a shared variable operand and 2 non-opaque constant operands.
5506       auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
5507         // The difference of the constants must be a single bit.
5508         const APInt &CMax =
5509             APIntOps::umax(C0->getAPIntValue(), C1->getAPIntValue());
5510         const APInt &CMin =
5511             APIntOps::umin(C0->getAPIntValue(), C1->getAPIntValue());
5512         return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
5513       };
5514       if (LL == RL && ISD::matchBinaryPredicate(LR, RR, MatchDiffPow2)) {
5515         // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
5516         // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
5517         SDValue Max = DAG.getNode(ISD::UMAX, DL, OpVT, LR, RR);
5518         SDValue Min = DAG.getNode(ISD::UMIN, DL, OpVT, LR, RR);
5519         SDValue Offset = DAG.getNode(ISD::SUB, DL, OpVT, LL, Min);
5520         SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, Max, Min);
5521         SDValue Mask = DAG.getNOT(DL, Diff, OpVT);
5522         SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Offset, Mask);
5523         SDValue Zero = DAG.getConstant(0, DL, OpVT);
5524         return DAG.getSetCC(DL, VT, And, Zero, CC0);
5525       }
5526     }
5527   }
5528 
5529   // Canonicalize equivalent operands to LL == RL.
5530   if (LL == RR && LR == RL) {
5531     CC1 = ISD::getSetCCSwappedOperands(CC1);
5532     std::swap(RL, RR);
5533   }
5534 
5535   // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
5536   // (or  (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
5537   if (LL == RL && LR == RR) {
5538     ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
5539                                 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
5540     if (NewCC != ISD::SETCC_INVALID &&
5541         (!LegalOperations ||
5542          (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
5543           TLI.isOperationLegal(ISD::SETCC, OpVT))))
5544       return DAG.getSetCC(DL, VT, LL, LR, NewCC);
5545   }
5546 
5547   return SDValue();
5548 }
5549 
5550 /// This contains all DAGCombine rules which reduce two values combined by
5551 /// an And operation to a single value. This makes them reusable in the context
5552 /// of visitSELECT(). Rules involving constants are not included as
5553 /// visitSELECT() already handles those cases.
visitANDLike(SDValue N0,SDValue N1,SDNode * N)5554 SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
5555   EVT VT = N1.getValueType();
5556   SDLoc DL(N);
5557 
5558   // fold (and x, undef) -> 0
5559   if (N0.isUndef() || N1.isUndef())
5560     return DAG.getConstant(0, DL, VT);
5561 
5562   if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
5563     return V;
5564 
5565   // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
5566   if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
5567       VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
5568     if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5569       if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
5570         // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
5571         // immediate for an add, but it is legal if its top c2 bits are set,
5572         // transform the ADD so the immediate doesn't need to be materialized
5573         // in a register.
5574         APInt ADDC = ADDI->getAPIntValue();
5575         APInt SRLC = SRLI->getAPIntValue();
5576         if (ADDC.getMinSignedBits() <= 64 &&
5577             SRLC.ult(VT.getSizeInBits()) &&
5578             !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
5579           APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
5580                                              SRLC.getZExtValue());
5581           if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
5582             ADDC |= Mask;
5583             if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
5584               SDLoc DL0(N0);
5585               SDValue NewAdd =
5586                 DAG.getNode(ISD::ADD, DL0, VT,
5587                             N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
5588               CombineTo(N0.getNode(), NewAdd);
5589               // Return N so it doesn't get rechecked!
5590               return SDValue(N, 0);
5591             }
5592           }
5593         }
5594       }
5595     }
5596   }
5597 
5598   // Reduce bit extract of low half of an integer to the narrower type.
5599   // (and (srl i64:x, K), KMask) ->
5600   //   (i64 zero_extend (and (srl (i32 (trunc i64:x)), K)), KMask)
5601   if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
5602     if (ConstantSDNode *CAnd = dyn_cast<ConstantSDNode>(N1)) {
5603       if (ConstantSDNode *CShift = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5604         unsigned Size = VT.getSizeInBits();
5605         const APInt &AndMask = CAnd->getAPIntValue();
5606         unsigned ShiftBits = CShift->getZExtValue();
5607 
5608         // Bail out, this node will probably disappear anyway.
5609         if (ShiftBits == 0)
5610           return SDValue();
5611 
5612         unsigned MaskBits = AndMask.countTrailingOnes();
5613         EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), Size / 2);
5614 
5615         if (AndMask.isMask() &&
5616             // Required bits must not span the two halves of the integer and
5617             // must fit in the half size type.
5618             (ShiftBits + MaskBits <= Size / 2) &&
5619             TLI.isNarrowingProfitable(VT, HalfVT) &&
5620             TLI.isTypeDesirableForOp(ISD::AND, HalfVT) &&
5621             TLI.isTypeDesirableForOp(ISD::SRL, HalfVT) &&
5622             TLI.isTruncateFree(VT, HalfVT) &&
5623             TLI.isZExtFree(HalfVT, VT)) {
5624           // The isNarrowingProfitable is to avoid regressions on PPC and
5625           // AArch64 which match a few 64-bit bit insert / bit extract patterns
5626           // on downstream users of this. Those patterns could probably be
5627           // extended to handle extensions mixed in.
5628 
5629           SDValue SL(N0);
5630           assert(MaskBits <= Size);
5631 
5632           // Extracting the highest bit of the low half.
5633           EVT ShiftVT = TLI.getShiftAmountTy(HalfVT, DAG.getDataLayout());
5634           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, HalfVT,
5635                                       N0.getOperand(0));
5636 
5637           SDValue NewMask = DAG.getConstant(AndMask.trunc(Size / 2), SL, HalfVT);
5638           SDValue ShiftK = DAG.getConstant(ShiftBits, SL, ShiftVT);
5639           SDValue Shift = DAG.getNode(ISD::SRL, SL, HalfVT, Trunc, ShiftK);
5640           SDValue And = DAG.getNode(ISD::AND, SL, HalfVT, Shift, NewMask);
5641           return DAG.getNode(ISD::ZERO_EXTEND, SL, VT, And);
5642         }
5643       }
5644     }
5645   }
5646 
5647   return SDValue();
5648 }
5649 
isAndLoadExtLoad(ConstantSDNode * AndC,LoadSDNode * LoadN,EVT LoadResultTy,EVT & ExtVT)5650 bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
5651                                    EVT LoadResultTy, EVT &ExtVT) {
5652   if (!AndC->getAPIntValue().isMask())
5653     return false;
5654 
5655   unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes();
5656 
5657   ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
5658   EVT LoadedVT = LoadN->getMemoryVT();
5659 
5660   if (ExtVT == LoadedVT &&
5661       (!LegalOperations ||
5662        TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
5663     // ZEXTLOAD will match without needing to change the size of the value being
5664     // loaded.
5665     return true;
5666   }
5667 
5668   // Do not change the width of a volatile or atomic loads.
5669   if (!LoadN->isSimple())
5670     return false;
5671 
5672   // Do not generate loads of non-round integer types since these can
5673   // be expensive (and would be wrong if the type is not byte sized).
5674   if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
5675     return false;
5676 
5677   if (LegalOperations &&
5678       !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
5679     return false;
5680 
5681   if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT))
5682     return false;
5683 
5684   return true;
5685 }
5686 
isLegalNarrowLdSt(LSBaseSDNode * LDST,ISD::LoadExtType ExtType,EVT & MemVT,unsigned ShAmt)5687 bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
5688                                     ISD::LoadExtType ExtType, EVT &MemVT,
5689                                     unsigned ShAmt) {
5690   if (!LDST)
5691     return false;
5692   // Only allow byte offsets.
5693   if (ShAmt % 8)
5694     return false;
5695 
5696   // Do not generate loads of non-round integer types since these can
5697   // be expensive (and would be wrong if the type is not byte sized).
5698   if (!MemVT.isRound())
5699     return false;
5700 
5701   // Don't change the width of a volatile or atomic loads.
5702   if (!LDST->isSimple())
5703     return false;
5704 
5705   EVT LdStMemVT = LDST->getMemoryVT();
5706 
5707   // Bail out when changing the scalable property, since we can't be sure that
5708   // we're actually narrowing here.
5709   if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
5710     return false;
5711 
5712   // Verify that we are actually reducing a load width here.
5713   if (LdStMemVT.bitsLT(MemVT))
5714     return false;
5715 
5716   // Ensure that this isn't going to produce an unsupported memory access.
5717   if (ShAmt) {
5718     assert(ShAmt % 8 == 0 && "ShAmt is byte offset");
5719     const unsigned ByteShAmt = ShAmt / 8;
5720     const Align LDSTAlign = LDST->getAlign();
5721     const Align NarrowAlign = commonAlignment(LDSTAlign, ByteShAmt);
5722     if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
5723                                 LDST->getAddressSpace(), NarrowAlign,
5724                                 LDST->getMemOperand()->getFlags()))
5725       return false;
5726   }
5727 
5728   // It's not possible to generate a constant of extended or untyped type.
5729   EVT PtrType = LDST->getBasePtr().getValueType();
5730   if (PtrType == MVT::Untyped || PtrType.isExtended())
5731     return false;
5732 
5733   if (isa<LoadSDNode>(LDST)) {
5734     LoadSDNode *Load = cast<LoadSDNode>(LDST);
5735     // Don't transform one with multiple uses, this would require adding a new
5736     // load.
5737     if (!SDValue(Load, 0).hasOneUse())
5738       return false;
5739 
5740     if (LegalOperations &&
5741         !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
5742       return false;
5743 
5744     // For the transform to be legal, the load must produce only two values
5745     // (the value loaded and the chain).  Don't transform a pre-increment
5746     // load, for example, which produces an extra value.  Otherwise the
5747     // transformation is not equivalent, and the downstream logic to replace
5748     // uses gets things wrong.
5749     if (Load->getNumValues() > 2)
5750       return false;
5751 
5752     // If the load that we're shrinking is an extload and we're not just
5753     // discarding the extension we can't simply shrink the load. Bail.
5754     // TODO: It would be possible to merge the extensions in some cases.
5755     if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
5756         Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
5757       return false;
5758 
5759     if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT))
5760       return false;
5761   } else {
5762     assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
5763     StoreSDNode *Store = cast<StoreSDNode>(LDST);
5764     // Can't write outside the original store
5765     if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
5766       return false;
5767 
5768     if (LegalOperations &&
5769         !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
5770       return false;
5771   }
5772   return true;
5773 }
5774 
SearchForAndLoads(SDNode * N,SmallVectorImpl<LoadSDNode * > & Loads,SmallPtrSetImpl<SDNode * > & NodesWithConsts,ConstantSDNode * Mask,SDNode * & NodeToMask)5775 bool DAGCombiner::SearchForAndLoads(SDNode *N,
5776                                     SmallVectorImpl<LoadSDNode*> &Loads,
5777                                     SmallPtrSetImpl<SDNode*> &NodesWithConsts,
5778                                     ConstantSDNode *Mask,
5779                                     SDNode *&NodeToMask) {
5780   // Recursively search for the operands, looking for loads which can be
5781   // narrowed.
5782   for (SDValue Op : N->op_values()) {
5783     if (Op.getValueType().isVector())
5784       return false;
5785 
5786     // Some constants may need fixing up later if they are too large.
5787     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
5788       if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
5789           (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
5790         NodesWithConsts.insert(N);
5791       continue;
5792     }
5793 
5794     if (!Op.hasOneUse())
5795       return false;
5796 
5797     switch(Op.getOpcode()) {
5798     case ISD::LOAD: {
5799       auto *Load = cast<LoadSDNode>(Op);
5800       EVT ExtVT;
5801       if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
5802           isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
5803 
5804         // ZEXTLOAD is already small enough.
5805         if (Load->getExtensionType() == ISD::ZEXTLOAD &&
5806             ExtVT.bitsGE(Load->getMemoryVT()))
5807           continue;
5808 
5809         // Use LE to convert equal sized loads to zext.
5810         if (ExtVT.bitsLE(Load->getMemoryVT()))
5811           Loads.push_back(Load);
5812 
5813         continue;
5814       }
5815       return false;
5816     }
5817     case ISD::ZERO_EXTEND:
5818     case ISD::AssertZext: {
5819       unsigned ActiveBits = Mask->getAPIntValue().countTrailingOnes();
5820       EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
5821       EVT VT = Op.getOpcode() == ISD::AssertZext ?
5822         cast<VTSDNode>(Op.getOperand(1))->getVT() :
5823         Op.getOperand(0).getValueType();
5824 
5825       // We can accept extending nodes if the mask is wider or an equal
5826       // width to the original type.
5827       if (ExtVT.bitsGE(VT))
5828         continue;
5829       break;
5830     }
5831     case ISD::OR:
5832     case ISD::XOR:
5833     case ISD::AND:
5834       if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
5835                              NodeToMask))
5836         return false;
5837       continue;
5838     }
5839 
5840     // Allow one node which will masked along with any loads found.
5841     if (NodeToMask)
5842       return false;
5843 
5844     // Also ensure that the node to be masked only produces one data result.
5845     NodeToMask = Op.getNode();
5846     if (NodeToMask->getNumValues() > 1) {
5847       bool HasValue = false;
5848       for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
5849         MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
5850         if (VT != MVT::Glue && VT != MVT::Other) {
5851           if (HasValue) {
5852             NodeToMask = nullptr;
5853             return false;
5854           }
5855           HasValue = true;
5856         }
5857       }
5858       assert(HasValue && "Node to be masked has no data result?");
5859     }
5860   }
5861   return true;
5862 }
5863 
BackwardsPropagateMask(SDNode * N)5864 bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
5865   auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
5866   if (!Mask)
5867     return false;
5868 
5869   if (!Mask->getAPIntValue().isMask())
5870     return false;
5871 
5872   // No need to do anything if the and directly uses a load.
5873   if (isa<LoadSDNode>(N->getOperand(0)))
5874     return false;
5875 
5876   SmallVector<LoadSDNode*, 8> Loads;
5877   SmallPtrSet<SDNode*, 2> NodesWithConsts;
5878   SDNode *FixupNode = nullptr;
5879   if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
5880     if (Loads.size() == 0)
5881       return false;
5882 
5883     LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
5884     SDValue MaskOp = N->getOperand(1);
5885 
5886     // If it exists, fixup the single node we allow in the tree that needs
5887     // masking.
5888     if (FixupNode) {
5889       LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
5890       SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
5891                                 FixupNode->getValueType(0),
5892                                 SDValue(FixupNode, 0), MaskOp);
5893       DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
5894       if (And.getOpcode() == ISD ::AND)
5895         DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
5896     }
5897 
5898     // Narrow any constants that need it.
5899     for (auto *LogicN : NodesWithConsts) {
5900       SDValue Op0 = LogicN->getOperand(0);
5901       SDValue Op1 = LogicN->getOperand(1);
5902 
5903       if (isa<ConstantSDNode>(Op0))
5904           std::swap(Op0, Op1);
5905 
5906       SDValue And = DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(),
5907                                 Op1, MaskOp);
5908 
5909       DAG.UpdateNodeOperands(LogicN, Op0, And);
5910     }
5911 
5912     // Create narrow loads.
5913     for (auto *Load : Loads) {
5914       LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
5915       SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
5916                                 SDValue(Load, 0), MaskOp);
5917       DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
5918       if (And.getOpcode() == ISD ::AND)
5919         And = SDValue(
5920             DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
5921       SDValue NewLoad = reduceLoadWidth(And.getNode());
5922       assert(NewLoad &&
5923              "Shouldn't be masking the load if it can't be narrowed");
5924       CombineTo(Load, NewLoad, NewLoad.getValue(1));
5925     }
5926     DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
5927     return true;
5928   }
5929   return false;
5930 }
5931 
5932 // Unfold
5933 //    x &  (-1 'logical shift' y)
5934 // To
5935 //    (x 'opposite logical shift' y) 'logical shift' y
5936 // if it is better for performance.
unfoldExtremeBitClearingToShifts(SDNode * N)5937 SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
5938   assert(N->getOpcode() == ISD::AND);
5939 
5940   SDValue N0 = N->getOperand(0);
5941   SDValue N1 = N->getOperand(1);
5942 
5943   // Do we actually prefer shifts over mask?
5944   if (!TLI.shouldFoldMaskToVariableShiftPair(N0))
5945     return SDValue();
5946 
5947   // Try to match  (-1 '[outer] logical shift' y)
5948   unsigned OuterShift;
5949   unsigned InnerShift; // The opposite direction to the OuterShift.
5950   SDValue Y;           // Shift amount.
5951   auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
5952     if (!M.hasOneUse())
5953       return false;
5954     OuterShift = M->getOpcode();
5955     if (OuterShift == ISD::SHL)
5956       InnerShift = ISD::SRL;
5957     else if (OuterShift == ISD::SRL)
5958       InnerShift = ISD::SHL;
5959     else
5960       return false;
5961     if (!isAllOnesConstant(M->getOperand(0)))
5962       return false;
5963     Y = M->getOperand(1);
5964     return true;
5965   };
5966 
5967   SDValue X;
5968   if (matchMask(N1))
5969     X = N0;
5970   else if (matchMask(N0))
5971     X = N1;
5972   else
5973     return SDValue();
5974 
5975   SDLoc DL(N);
5976   EVT VT = N->getValueType(0);
5977 
5978   //     tmp = x   'opposite logical shift' y
5979   SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
5980   //     ret = tmp 'logical shift' y
5981   SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
5982 
5983   return T1;
5984 }
5985 
5986 /// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
5987 /// For a target with a bit test, this is expected to become test + set and save
5988 /// at least 1 instruction.
combineShiftAnd1ToBitTest(SDNode * And,SelectionDAG & DAG)5989 static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
5990   assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
5991 
5992   // This is probably not worthwhile without a supported type.
5993   EVT VT = And->getValueType(0);
5994   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5995   if (!TLI.isTypeLegal(VT))
5996     return SDValue();
5997 
5998   // Look through an optional extension.
5999   SDValue And0 = And->getOperand(0), And1 = And->getOperand(1);
6000   if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
6001     And0 = And0.getOperand(0);
6002   if (!isOneConstant(And1) || !And0.hasOneUse())
6003     return SDValue();
6004 
6005   SDValue Src = And0;
6006 
6007   // Attempt to find a 'not' op.
6008   // TODO: Should we favor test+set even without the 'not' op?
6009   bool FoundNot = false;
6010   if (isBitwiseNot(Src)) {
6011     FoundNot = true;
6012     Src = Src.getOperand(0);
6013 
6014     // Look though an optional truncation. The source operand may not be the
6015     // same type as the original 'and', but that is ok because we are masking
6016     // off everything but the low bit.
6017     if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
6018       Src = Src.getOperand(0);
6019   }
6020 
6021   // Match a shift-right by constant.
6022   if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
6023     return SDValue();
6024 
6025   // We might have looked through casts that make this transform invalid.
6026   // TODO: If the source type is wider than the result type, do the mask and
6027   //       compare in the source type.
6028   unsigned VTBitWidth = VT.getScalarSizeInBits();
6029   SDValue ShiftAmt = Src.getOperand(1);
6030   auto *ShiftAmtC = dyn_cast<ConstantSDNode>(ShiftAmt);
6031   if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(VTBitWidth))
6032     return SDValue();
6033 
6034   // Set source to shift source.
6035   Src = Src.getOperand(0);
6036 
6037   // Try again to find a 'not' op.
6038   // TODO: Should we favor test+set even with two 'not' ops?
6039   if (!FoundNot) {
6040     if (!isBitwiseNot(Src))
6041       return SDValue();
6042     Src = Src.getOperand(0);
6043   }
6044 
6045   if (!TLI.hasBitTest(Src, ShiftAmt))
6046     return SDValue();
6047 
6048   // Turn this into a bit-test pattern using mask op + setcc:
6049   // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
6050   // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
6051   SDLoc DL(And);
6052   SDValue X = DAG.getZExtOrTrunc(Src, DL, VT);
6053   EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
6054   SDValue Mask = DAG.getConstant(
6055       APInt::getOneBitSet(VTBitWidth, ShiftAmtC->getZExtValue()), DL, VT);
6056   SDValue NewAnd = DAG.getNode(ISD::AND, DL, VT, X, Mask);
6057   SDValue Zero = DAG.getConstant(0, DL, VT);
6058   SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
6059   return DAG.getZExtOrTrunc(Setcc, DL, VT);
6060 }
6061 
6062 /// For targets that support usubsat, match a bit-hack form of that operation
6063 /// that ends in 'and' and convert it.
foldAndToUsubsat(SDNode * N,SelectionDAG & DAG)6064 static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG) {
6065   SDValue N0 = N->getOperand(0);
6066   SDValue N1 = N->getOperand(1);
6067   EVT VT = N1.getValueType();
6068 
6069   // Canonicalize SRA as operand 1.
6070   if (N0.getOpcode() == ISD::SRA)
6071     std::swap(N0, N1);
6072 
6073   // xor/add with SMIN (signmask) are logically equivalent.
6074   if (N0.getOpcode() != ISD::XOR && N0.getOpcode() != ISD::ADD)
6075     return SDValue();
6076 
6077   if (N1.getOpcode() != ISD::SRA || !N0.hasOneUse() || !N1.hasOneUse() ||
6078       N0.getOperand(0) != N1.getOperand(0))
6079     return SDValue();
6080 
6081   unsigned BitWidth = VT.getScalarSizeInBits();
6082   ConstantSDNode *XorC = isConstOrConstSplat(N0.getOperand(1), true);
6083   ConstantSDNode *SraC = isConstOrConstSplat(N1.getOperand(1), true);
6084   if (!XorC || !XorC->getAPIntValue().isSignMask() ||
6085       !SraC || SraC->getAPIntValue() != BitWidth - 1)
6086     return SDValue();
6087 
6088   // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
6089   // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
6090   SDLoc DL(N);
6091   SDValue SignMask = DAG.getConstant(XorC->getAPIntValue(), DL, VT);
6092   return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0), SignMask);
6093 }
6094 
6095 /// Given a bitwise logic operation N with a matching bitwise logic operand,
6096 /// fold a pattern where 2 of the source operands are identically shifted
6097 /// values. For example:
6098 /// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
foldLogicOfShifts(SDNode * N,SDValue LogicOp,SDValue ShiftOp,SelectionDAG & DAG)6099 static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
6100                                  SelectionDAG &DAG) {
6101   unsigned LogicOpcode = N->getOpcode();
6102   assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
6103           LogicOpcode == ISD::XOR)
6104          && "Expected bitwise logic operation");
6105 
6106   if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
6107     return SDValue();
6108 
6109   // Match another bitwise logic op and a shift.
6110   unsigned ShiftOpcode = ShiftOp.getOpcode();
6111   if (LogicOp.getOpcode() != LogicOpcode ||
6112       !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
6113         ShiftOpcode == ISD::SRA))
6114     return SDValue();
6115 
6116   // Match another shift op inside the first logic operand. Handle both commuted
6117   // possibilities.
6118   // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6119   // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6120   SDValue X1 = ShiftOp.getOperand(0);
6121   SDValue Y = ShiftOp.getOperand(1);
6122   SDValue X0, Z;
6123   if (LogicOp.getOperand(0).getOpcode() == ShiftOpcode &&
6124       LogicOp.getOperand(0).getOperand(1) == Y) {
6125     X0 = LogicOp.getOperand(0).getOperand(0);
6126     Z = LogicOp.getOperand(1);
6127   } else if (LogicOp.getOperand(1).getOpcode() == ShiftOpcode &&
6128              LogicOp.getOperand(1).getOperand(1) == Y) {
6129     X0 = LogicOp.getOperand(1).getOperand(0);
6130     Z = LogicOp.getOperand(0);
6131   } else {
6132     return SDValue();
6133   }
6134 
6135   EVT VT = N->getValueType(0);
6136   SDLoc DL(N);
6137   SDValue LogicX = DAG.getNode(LogicOpcode, DL, VT, X0, X1);
6138   SDValue NewShift = DAG.getNode(ShiftOpcode, DL, VT, LogicX, Y);
6139   return DAG.getNode(LogicOpcode, DL, VT, NewShift, Z);
6140 }
6141 
visitAND(SDNode * N)6142 SDValue DAGCombiner::visitAND(SDNode *N) {
6143   SDValue N0 = N->getOperand(0);
6144   SDValue N1 = N->getOperand(1);
6145   EVT VT = N1.getValueType();
6146 
6147   // x & x --> x
6148   if (N0 == N1)
6149     return N0;
6150 
6151   // fold (and c1, c2) -> c1&c2
6152   if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, {N0, N1}))
6153     return C;
6154 
6155   // canonicalize constant to RHS
6156   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
6157       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
6158     return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0);
6159 
6160   // fold vector ops
6161   if (VT.isVector()) {
6162     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
6163       return FoldedVOp;
6164 
6165     // fold (and x, 0) -> 0, vector edition
6166     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
6167       // do not return N1, because undef node may exist in N1
6168       return DAG.getConstant(APInt::getZero(N1.getScalarValueSizeInBits()),
6169                              SDLoc(N), N1.getValueType());
6170 
6171     // fold (and x, -1) -> x, vector edition
6172     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
6173       return N0;
6174 
6175     // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
6176     auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
6177     ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
6178     if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && N0.hasOneUse() &&
6179         Splat && N1.hasOneUse()) {
6180       EVT LoadVT = MLoad->getMemoryVT();
6181       EVT ExtVT = VT;
6182       if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
6183         // For this AND to be a zero extension of the masked load the elements
6184         // of the BuildVec must mask the bottom bits of the extended element
6185         // type
6186         uint64_t ElementSize =
6187             LoadVT.getVectorElementType().getScalarSizeInBits();
6188         if (Splat->getAPIntValue().isMask(ElementSize)) {
6189           return DAG.getMaskedLoad(
6190               ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(),
6191               MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
6192               LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
6193               ISD::ZEXTLOAD, MLoad->isExpandingLoad());
6194         }
6195       }
6196     }
6197   }
6198 
6199   // fold (and x, -1) -> x
6200   if (isAllOnesConstant(N1))
6201     return N0;
6202 
6203   // if (and x, c) is known to be zero, return 0
6204   unsigned BitWidth = VT.getScalarSizeInBits();
6205   ConstantSDNode *N1C = isConstOrConstSplat(N1);
6206   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(BitWidth)))
6207     return DAG.getConstant(0, SDLoc(N), VT);
6208 
6209   if (SDValue NewSel = foldBinOpIntoSelect(N))
6210     return NewSel;
6211 
6212   // reassociate and
6213   if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
6214     return RAND;
6215 
6216   // Try to convert a constant mask AND into a shuffle clear mask.
6217   if (VT.isVector())
6218     if (SDValue Shuffle = XformToShuffleWithZero(N))
6219       return Shuffle;
6220 
6221   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
6222     return Combined;
6223 
6224   // fold (and (or x, C), D) -> D if (C & D) == D
6225   auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
6226     return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
6227   };
6228   if (N0.getOpcode() == ISD::OR &&
6229       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
6230     return N1;
6231   // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
6232   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
6233     SDValue N0Op0 = N0.getOperand(0);
6234     APInt Mask = ~N1C->getAPIntValue();
6235     Mask = Mask.trunc(N0Op0.getScalarValueSizeInBits());
6236     if (DAG.MaskedValueIsZero(N0Op0, Mask)) {
6237       SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
6238                                  N0.getValueType(), N0Op0);
6239 
6240       // Replace uses of the AND with uses of the Zero extend node.
6241       CombineTo(N, Zext);
6242 
6243       // We actually want to replace all uses of the any_extend with the
6244       // zero_extend, to avoid duplicating things.  This will later cause this
6245       // AND to be folded.
6246       CombineTo(N0.getNode(), Zext);
6247       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
6248     }
6249   }
6250 
6251   // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
6252   // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
6253   // already be zero by virtue of the width of the base type of the load.
6254   //
6255   // the 'X' node here can either be nothing or an extract_vector_elt to catch
6256   // more cases.
6257   if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
6258        N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
6259        N0.getOperand(0).getOpcode() == ISD::LOAD &&
6260        N0.getOperand(0).getResNo() == 0) ||
6261       (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
6262     LoadSDNode *Load = cast<LoadSDNode>( (N0.getOpcode() == ISD::LOAD) ?
6263                                          N0 : N0.getOperand(0) );
6264 
6265     // Get the constant (if applicable) the zero'th operand is being ANDed with.
6266     // This can be a pure constant or a vector splat, in which case we treat the
6267     // vector as a scalar and use the splat value.
6268     APInt Constant = APInt::getZero(1);
6269     if (const ConstantSDNode *C = isConstOrConstSplat(
6270             N1, /*AllowUndef=*/false, /*AllowTruncation=*/true)) {
6271       Constant = C->getAPIntValue();
6272     } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
6273       APInt SplatValue, SplatUndef;
6274       unsigned SplatBitSize;
6275       bool HasAnyUndefs;
6276       bool IsSplat = Vector->isConstantSplat(SplatValue, SplatUndef,
6277                                              SplatBitSize, HasAnyUndefs);
6278       if (IsSplat) {
6279         // Undef bits can contribute to a possible optimisation if set, so
6280         // set them.
6281         SplatValue |= SplatUndef;
6282 
6283         // The splat value may be something like "0x00FFFFFF", which means 0 for
6284         // the first vector value and FF for the rest, repeating. We need a mask
6285         // that will apply equally to all members of the vector, so AND all the
6286         // lanes of the constant together.
6287         unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
6288 
6289         // If the splat value has been compressed to a bitlength lower
6290         // than the size of the vector lane, we need to re-expand it to
6291         // the lane size.
6292         if (EltBitWidth > SplatBitSize)
6293           for (SplatValue = SplatValue.zextOrTrunc(EltBitWidth);
6294                SplatBitSize < EltBitWidth; SplatBitSize = SplatBitSize * 2)
6295             SplatValue |= SplatValue.shl(SplatBitSize);
6296 
6297         // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
6298         // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
6299         if ((SplatBitSize % EltBitWidth) == 0) {
6300           Constant = APInt::getAllOnes(EltBitWidth);
6301           for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
6302             Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
6303         }
6304       }
6305     }
6306 
6307     // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
6308     // actually legal and isn't going to get expanded, else this is a false
6309     // optimisation.
6310     bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
6311                                                     Load->getValueType(0),
6312                                                     Load->getMemoryVT());
6313 
6314     // Resize the constant to the same size as the original memory access before
6315     // extension. If it is still the AllOnesValue then this AND is completely
6316     // unneeded.
6317     Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
6318 
6319     bool B;
6320     switch (Load->getExtensionType()) {
6321     default: B = false; break;
6322     case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
6323     case ISD::ZEXTLOAD:
6324     case ISD::NON_EXTLOAD: B = true; break;
6325     }
6326 
6327     if (B && Constant.isAllOnes()) {
6328       // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
6329       // preserve semantics once we get rid of the AND.
6330       SDValue NewLoad(Load, 0);
6331 
6332       // Fold the AND away. NewLoad may get replaced immediately.
6333       CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
6334 
6335       if (Load->getExtensionType() == ISD::EXTLOAD) {
6336         NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
6337                               Load->getValueType(0), SDLoc(Load),
6338                               Load->getChain(), Load->getBasePtr(),
6339                               Load->getOffset(), Load->getMemoryVT(),
6340                               Load->getMemOperand());
6341         // Replace uses of the EXTLOAD with the new ZEXTLOAD.
6342         if (Load->getNumValues() == 3) {
6343           // PRE/POST_INC loads have 3 values.
6344           SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
6345                            NewLoad.getValue(2) };
6346           CombineTo(Load, To, 3, true);
6347         } else {
6348           CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
6349         }
6350       }
6351 
6352       return SDValue(N, 0); // Return N so it doesn't get rechecked!
6353     }
6354   }
6355 
6356   if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
6357       ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
6358     SDValue Ext = N0.getOperand(0);
6359     EVT ExtVT = Ext->getValueType(0);
6360     SDValue Extendee = Ext->getOperand(0);
6361 
6362     unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
6363     if (N1C->getAPIntValue().isMask(ScalarWidth) &&
6364         (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, ExtVT))) {
6365       //    (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
6366       // => (extract_subvector (iN_zeroext v))
6367       SDValue ZeroExtExtendee =
6368           DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), ExtVT, Extendee);
6369 
6370       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, ZeroExtExtendee,
6371                          N0.getOperand(1));
6372     }
6373   }
6374 
6375   // fold (and (masked_gather x)) -> (zext_masked_gather x)
6376   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
6377     EVT MemVT = GN0->getMemoryVT();
6378     EVT ScalarVT = MemVT.getScalarType();
6379 
6380     if (SDValue(GN0, 0).hasOneUse() &&
6381         isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) &&
6382         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
6383       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
6384                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
6385 
6386       SDValue ZExtLoad = DAG.getMaskedGather(
6387           DAG.getVTList(VT, MVT::Other), MemVT, SDLoc(N), Ops,
6388           GN0->getMemOperand(), GN0->getIndexType(), ISD::ZEXTLOAD);
6389 
6390       CombineTo(N, ZExtLoad);
6391       AddToWorklist(ZExtLoad.getNode());
6392       // Avoid recheck of N.
6393       return SDValue(N, 0);
6394     }
6395   }
6396 
6397   // fold (and (load x), 255) -> (zextload x, i8)
6398   // fold (and (extload x, i16), 255) -> (zextload x, i8)
6399   if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
6400     if (SDValue Res = reduceLoadWidth(N))
6401       return Res;
6402 
6403   if (LegalTypes) {
6404     // Attempt to propagate the AND back up to the leaves which, if they're
6405     // loads, can be combined to narrow loads and the AND node can be removed.
6406     // Perform after legalization so that extend nodes will already be
6407     // combined into the loads.
6408     if (BackwardsPropagateMask(N))
6409       return SDValue(N, 0);
6410   }
6411 
6412   if (SDValue Combined = visitANDLike(N0, N1, N))
6413     return Combined;
6414 
6415   // Simplify: (and (op x...), (op y...))  -> (op (and x, y))
6416   if (N0.getOpcode() == N1.getOpcode())
6417     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
6418       return V;
6419 
6420   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
6421     return R;
6422   if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
6423     return R;
6424 
6425   // Masking the negated extension of a boolean is just the zero-extended
6426   // boolean:
6427   // and (sub 0, zext(bool X)), 1 --> zext(bool X)
6428   // and (sub 0, sext(bool X)), 1 --> zext(bool X)
6429   //
6430   // Note: the SimplifyDemandedBits fold below can make an information-losing
6431   // transform, and then we have no way to find this better fold.
6432   if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
6433     if (isNullOrNullSplat(N0.getOperand(0))) {
6434       SDValue SubRHS = N0.getOperand(1);
6435       if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
6436           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
6437         return SubRHS;
6438       if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
6439           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
6440         return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, SubRHS.getOperand(0));
6441     }
6442   }
6443 
6444   // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
6445   // fold (and (sra)) -> (and (srl)) when possible.
6446   if (SimplifyDemandedBits(SDValue(N, 0)))
6447     return SDValue(N, 0);
6448 
6449   // fold (zext_inreg (extload x)) -> (zextload x)
6450   // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
6451   if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
6452       (ISD::isEXTLoad(N0.getNode()) ||
6453        (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
6454     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
6455     EVT MemVT = LN0->getMemoryVT();
6456     // If we zero all the possible extended bits, then we can turn this into
6457     // a zextload if we are running before legalize or the operation is legal.
6458     unsigned ExtBitSize = N1.getScalarValueSizeInBits();
6459     unsigned MemBitSize = MemVT.getScalarSizeInBits();
6460     APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
6461     if (DAG.MaskedValueIsZero(N1, ExtBits) &&
6462         ((!LegalOperations && LN0->isSimple()) ||
6463          TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
6464       SDValue ExtLoad =
6465           DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
6466                          LN0->getBasePtr(), MemVT, LN0->getMemOperand());
6467       AddToWorklist(N);
6468       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
6469       return SDValue(N, 0); // Return N so it doesn't get rechecked!
6470     }
6471   }
6472 
6473   // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
6474   if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
6475     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
6476                                            N0.getOperand(1), false))
6477       return BSwap;
6478   }
6479 
6480   if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
6481     return Shifts;
6482 
6483   if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
6484     return V;
6485 
6486   // Recognize the following pattern:
6487   //
6488   // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
6489   //
6490   // where bitmask is a mask that clears the upper bits of AndVT. The
6491   // number of bits in bitmask must be a power of two.
6492   auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
6493     if (LHS->getOpcode() != ISD::SIGN_EXTEND)
6494       return false;
6495 
6496     auto *C = dyn_cast<ConstantSDNode>(RHS);
6497     if (!C)
6498       return false;
6499 
6500     if (!C->getAPIntValue().isMask(
6501             LHS.getOperand(0).getValueType().getFixedSizeInBits()))
6502       return false;
6503 
6504     return true;
6505   };
6506 
6507   // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
6508   if (IsAndZeroExtMask(N0, N1))
6509     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, N0.getOperand(0));
6510 
6511   if (hasOperation(ISD::USUBSAT, VT))
6512     if (SDValue V = foldAndToUsubsat(N, DAG))
6513       return V;
6514 
6515   return SDValue();
6516 }
6517 
6518 /// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
MatchBSwapHWordLow(SDNode * N,SDValue N0,SDValue N1,bool DemandHighBits)6519 SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
6520                                         bool DemandHighBits) {
6521   if (!LegalOperations)
6522     return SDValue();
6523 
6524   EVT VT = N->getValueType(0);
6525   if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
6526     return SDValue();
6527   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
6528     return SDValue();
6529 
6530   // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
6531   bool LookPassAnd0 = false;
6532   bool LookPassAnd1 = false;
6533   if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
6534     std::swap(N0, N1);
6535   if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
6536     std::swap(N0, N1);
6537   if (N0.getOpcode() == ISD::AND) {
6538     if (!N0->hasOneUse())
6539       return SDValue();
6540     ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6541     // Also handle 0xffff since the LHS is guaranteed to have zeros there.
6542     // This is needed for X86.
6543     if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
6544                   N01C->getZExtValue() != 0xFFFF))
6545       return SDValue();
6546     N0 = N0.getOperand(0);
6547     LookPassAnd0 = true;
6548   }
6549 
6550   if (N1.getOpcode() == ISD::AND) {
6551     if (!N1->hasOneUse())
6552       return SDValue();
6553     ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
6554     if (!N11C || N11C->getZExtValue() != 0xFF)
6555       return SDValue();
6556     N1 = N1.getOperand(0);
6557     LookPassAnd1 = true;
6558   }
6559 
6560   if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
6561     std::swap(N0, N1);
6562   if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
6563     return SDValue();
6564   if (!N0->hasOneUse() || !N1->hasOneUse())
6565     return SDValue();
6566 
6567   ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6568   ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
6569   if (!N01C || !N11C)
6570     return SDValue();
6571   if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
6572     return SDValue();
6573 
6574   // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
6575   SDValue N00 = N0->getOperand(0);
6576   if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
6577     if (!N00->hasOneUse())
6578       return SDValue();
6579     ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
6580     if (!N001C || N001C->getZExtValue() != 0xFF)
6581       return SDValue();
6582     N00 = N00.getOperand(0);
6583     LookPassAnd0 = true;
6584   }
6585 
6586   SDValue N10 = N1->getOperand(0);
6587   if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
6588     if (!N10->hasOneUse())
6589       return SDValue();
6590     ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
6591     // Also allow 0xFFFF since the bits will be shifted out. This is needed
6592     // for X86.
6593     if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
6594                    N101C->getZExtValue() != 0xFFFF))
6595       return SDValue();
6596     N10 = N10.getOperand(0);
6597     LookPassAnd1 = true;
6598   }
6599 
6600   if (N00 != N10)
6601     return SDValue();
6602 
6603   // Make sure everything beyond the low halfword gets set to zero since the SRL
6604   // 16 will clear the top bits.
6605   unsigned OpSizeInBits = VT.getSizeInBits();
6606   if (OpSizeInBits > 16) {
6607     // If the left-shift isn't masked out then the only way this is a bswap is
6608     // if all bits beyond the low 8 are 0. In that case the entire pattern
6609     // reduces to a left shift anyway: leave it for other parts of the combiner.
6610     if (DemandHighBits && !LookPassAnd0)
6611       return SDValue();
6612 
6613     // However, if the right shift isn't masked out then it might be because
6614     // it's not needed. See if we can spot that too. If the high bits aren't
6615     // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
6616     // upper bits to be zero.
6617     if (!LookPassAnd1) {
6618       unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
6619       if (!DAG.MaskedValueIsZero(N10,
6620                                  APInt::getBitsSet(OpSizeInBits, 16, HighBit)))
6621         return SDValue();
6622     }
6623   }
6624 
6625   SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
6626   if (OpSizeInBits > 16) {
6627     SDLoc DL(N);
6628     Res = DAG.getNode(ISD::SRL, DL, VT, Res,
6629                       DAG.getConstant(OpSizeInBits - 16, DL,
6630                                       getShiftAmountTy(VT)));
6631   }
6632   return Res;
6633 }
6634 
6635 /// Return true if the specified node is an element that makes up a 32-bit
6636 /// packed halfword byteswap.
6637 /// ((x & 0x000000ff) << 8) |
6638 /// ((x & 0x0000ff00) >> 8) |
6639 /// ((x & 0x00ff0000) << 8) |
6640 /// ((x & 0xff000000) >> 8)
isBSwapHWordElement(SDValue N,MutableArrayRef<SDNode * > Parts)6641 static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
6642   if (!N->hasOneUse())
6643     return false;
6644 
6645   unsigned Opc = N.getOpcode();
6646   if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
6647     return false;
6648 
6649   SDValue N0 = N.getOperand(0);
6650   unsigned Opc0 = N0.getOpcode();
6651   if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
6652     return false;
6653 
6654   ConstantSDNode *N1C = nullptr;
6655   // SHL or SRL: look upstream for AND mask operand
6656   if (Opc == ISD::AND)
6657     N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6658   else if (Opc0 == ISD::AND)
6659     N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6660   if (!N1C)
6661     return false;
6662 
6663   unsigned MaskByteOffset;
6664   switch (N1C->getZExtValue()) {
6665   default:
6666     return false;
6667   case 0xFF:       MaskByteOffset = 0; break;
6668   case 0xFF00:     MaskByteOffset = 1; break;
6669   case 0xFFFF:
6670     // In case demanded bits didn't clear the bits that will be shifted out.
6671     // This is needed for X86.
6672     if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
6673       MaskByteOffset = 1;
6674       break;
6675     }
6676     return false;
6677   case 0xFF0000:   MaskByteOffset = 2; break;
6678   case 0xFF000000: MaskByteOffset = 3; break;
6679   }
6680 
6681   // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
6682   if (Opc == ISD::AND) {
6683     if (MaskByteOffset == 0 || MaskByteOffset == 2) {
6684       // (x >> 8) & 0xff
6685       // (x >> 8) & 0xff0000
6686       if (Opc0 != ISD::SRL)
6687         return false;
6688       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6689       if (!C || C->getZExtValue() != 8)
6690         return false;
6691     } else {
6692       // (x << 8) & 0xff00
6693       // (x << 8) & 0xff000000
6694       if (Opc0 != ISD::SHL)
6695         return false;
6696       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6697       if (!C || C->getZExtValue() != 8)
6698         return false;
6699     }
6700   } else if (Opc == ISD::SHL) {
6701     // (x & 0xff) << 8
6702     // (x & 0xff0000) << 8
6703     if (MaskByteOffset != 0 && MaskByteOffset != 2)
6704       return false;
6705     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6706     if (!C || C->getZExtValue() != 8)
6707       return false;
6708   } else { // Opc == ISD::SRL
6709     // (x & 0xff00) >> 8
6710     // (x & 0xff000000) >> 8
6711     if (MaskByteOffset != 1 && MaskByteOffset != 3)
6712       return false;
6713     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6714     if (!C || C->getZExtValue() != 8)
6715       return false;
6716   }
6717 
6718   if (Parts[MaskByteOffset])
6719     return false;
6720 
6721   Parts[MaskByteOffset] = N0.getOperand(0).getNode();
6722   return true;
6723 }
6724 
6725 // Match 2 elements of a packed halfword bswap.
isBSwapHWordPair(SDValue N,MutableArrayRef<SDNode * > Parts)6726 static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
6727   if (N.getOpcode() == ISD::OR)
6728     return isBSwapHWordElement(N.getOperand(0), Parts) &&
6729            isBSwapHWordElement(N.getOperand(1), Parts);
6730 
6731   if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
6732     ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
6733     if (!C || C->getAPIntValue() != 16)
6734       return false;
6735     Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
6736     return true;
6737   }
6738 
6739   return false;
6740 }
6741 
6742 // Match this pattern:
6743 //   (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
6744 // And rewrite this to:
6745 //   (rotr (bswap A), 16)
matchBSwapHWordOrAndAnd(const TargetLowering & TLI,SelectionDAG & DAG,SDNode * N,SDValue N0,SDValue N1,EVT VT,EVT ShiftAmountTy)6746 static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
6747                                        SelectionDAG &DAG, SDNode *N, SDValue N0,
6748                                        SDValue N1, EVT VT, EVT ShiftAmountTy) {
6749   assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
6750          "MatchBSwapHWordOrAndAnd: expecting i32");
6751   if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
6752     return SDValue();
6753   if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
6754     return SDValue();
6755   // TODO: this is too restrictive; lifting this restriction requires more tests
6756   if (!N0->hasOneUse() || !N1->hasOneUse())
6757     return SDValue();
6758   ConstantSDNode *Mask0 = isConstOrConstSplat(N0.getOperand(1));
6759   ConstantSDNode *Mask1 = isConstOrConstSplat(N1.getOperand(1));
6760   if (!Mask0 || !Mask1)
6761     return SDValue();
6762   if (Mask0->getAPIntValue() != 0xff00ff00 ||
6763       Mask1->getAPIntValue() != 0x00ff00ff)
6764     return SDValue();
6765   SDValue Shift0 = N0.getOperand(0);
6766   SDValue Shift1 = N1.getOperand(0);
6767   if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
6768     return SDValue();
6769   ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(Shift0.getOperand(1));
6770   ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(Shift1.getOperand(1));
6771   if (!ShiftAmt0 || !ShiftAmt1)
6772     return SDValue();
6773   if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
6774     return SDValue();
6775   if (Shift0.getOperand(0) != Shift1.getOperand(0))
6776     return SDValue();
6777 
6778   SDLoc DL(N);
6779   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
6780   SDValue ShAmt = DAG.getConstant(16, DL, ShiftAmountTy);
6781   return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
6782 }
6783 
6784 /// Match a 32-bit packed halfword bswap. That is
6785 /// ((x & 0x000000ff) << 8) |
6786 /// ((x & 0x0000ff00) >> 8) |
6787 /// ((x & 0x00ff0000) << 8) |
6788 /// ((x & 0xff000000) >> 8)
6789 /// => (rotl (bswap x), 16)
MatchBSwapHWord(SDNode * N,SDValue N0,SDValue N1)6790 SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
6791   if (!LegalOperations)
6792     return SDValue();
6793 
6794   EVT VT = N->getValueType(0);
6795   if (VT != MVT::i32)
6796     return SDValue();
6797   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
6798     return SDValue();
6799 
6800   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
6801                                               getShiftAmountTy(VT)))
6802   return BSwap;
6803 
6804   // Try again with commuted operands.
6805   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT,
6806                                               getShiftAmountTy(VT)))
6807   return BSwap;
6808 
6809 
6810   // Look for either
6811   // (or (bswaphpair), (bswaphpair))
6812   // (or (or (bswaphpair), (and)), (and))
6813   // (or (or (and), (bswaphpair)), (and))
6814   SDNode *Parts[4] = {};
6815 
6816   if (isBSwapHWordPair(N0, Parts)) {
6817     // (or (or (and), (and)), (or (and), (and)))
6818     if (!isBSwapHWordPair(N1, Parts))
6819       return SDValue();
6820   } else if (N0.getOpcode() == ISD::OR) {
6821     // (or (or (or (and), (and)), (and)), (and))
6822     if (!isBSwapHWordElement(N1, Parts))
6823       return SDValue();
6824     SDValue N00 = N0.getOperand(0);
6825     SDValue N01 = N0.getOperand(1);
6826     if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
6827         !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
6828       return SDValue();
6829   } else {
6830     return SDValue();
6831   }
6832 
6833   // Make sure the parts are all coming from the same node.
6834   if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
6835     return SDValue();
6836 
6837   SDLoc DL(N);
6838   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
6839                               SDValue(Parts[0], 0));
6840 
6841   // Result of the bswap should be rotated by 16. If it's not legal, then
6842   // do  (x << 16) | (x >> 16).
6843   SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
6844   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
6845     return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
6846   if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
6847     return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
6848   return DAG.getNode(ISD::OR, DL, VT,
6849                      DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
6850                      DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
6851 }
6852 
6853 /// This contains all DAGCombine rules which reduce two values combined by
6854 /// an Or operation to a single value \see visitANDLike().
visitORLike(SDValue N0,SDValue N1,SDNode * N)6855 SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
6856   EVT VT = N1.getValueType();
6857   SDLoc DL(N);
6858 
6859   // fold (or x, undef) -> -1
6860   if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
6861     return DAG.getAllOnesConstant(DL, VT);
6862 
6863   if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
6864     return V;
6865 
6866   // (or (and X, C1), (and Y, C2))  -> (and (or X, Y), C3) if possible.
6867   if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
6868       // Don't increase # computations.
6869       (N0->hasOneUse() || N1->hasOneUse())) {
6870     // We can only do this xform if we know that bits from X that are set in C2
6871     // but not in C1 are already zero.  Likewise for Y.
6872     if (const ConstantSDNode *N0O1C =
6873         getAsNonOpaqueConstant(N0.getOperand(1))) {
6874       if (const ConstantSDNode *N1O1C =
6875           getAsNonOpaqueConstant(N1.getOperand(1))) {
6876         // We can only do this xform if we know that bits from X that are set in
6877         // C2 but not in C1 are already zero.  Likewise for Y.
6878         const APInt &LHSMask = N0O1C->getAPIntValue();
6879         const APInt &RHSMask = N1O1C->getAPIntValue();
6880 
6881         if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
6882             DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
6883           SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
6884                                   N0.getOperand(0), N1.getOperand(0));
6885           return DAG.getNode(ISD::AND, DL, VT, X,
6886                              DAG.getConstant(LHSMask | RHSMask, DL, VT));
6887         }
6888       }
6889     }
6890   }
6891 
6892   // (or (and X, M), (and X, N)) -> (and X, (or M, N))
6893   if (N0.getOpcode() == ISD::AND &&
6894       N1.getOpcode() == ISD::AND &&
6895       N0.getOperand(0) == N1.getOperand(0) &&
6896       // Don't increase # computations.
6897       (N0->hasOneUse() || N1->hasOneUse())) {
6898     SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
6899                             N0.getOperand(1), N1.getOperand(1));
6900     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
6901   }
6902 
6903   return SDValue();
6904 }
6905 
6906 /// OR combines for which the commuted variant will be tried as well.
visitORCommutative(SelectionDAG & DAG,SDValue N0,SDValue N1,SDNode * N)6907 static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
6908                                   SDNode *N) {
6909   EVT VT = N0.getValueType();
6910   if (N0.getOpcode() == ISD::AND) {
6911     SDValue N00 = N0.getOperand(0);
6912     SDValue N01 = N0.getOperand(1);
6913 
6914     // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
6915     // TODO: Set AllowUndefs = true.
6916     if (getBitwiseNotOperand(N01, N00,
6917                              /* AllowUndefs */ false) == N1)
6918       return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N1);
6919 
6920     // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
6921     if (getBitwiseNotOperand(N00, N01,
6922                              /* AllowUndefs */ false) == N1)
6923       return DAG.getNode(ISD::OR, SDLoc(N), VT, N01, N1);
6924   }
6925 
6926   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
6927     return R;
6928 
6929   auto peekThroughZext = [](SDValue V) {
6930     if (V->getOpcode() == ISD::ZERO_EXTEND)
6931       return V->getOperand(0);
6932     return V;
6933   };
6934 
6935   // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
6936   if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
6937       N0.getOperand(0) == N1.getOperand(0) &&
6938       peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
6939     return N0;
6940 
6941   // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
6942   if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
6943       N0.getOperand(1) == N1.getOperand(0) &&
6944       peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
6945     return N0;
6946 
6947   return SDValue();
6948 }
6949 
visitOR(SDNode * N)6950 SDValue DAGCombiner::visitOR(SDNode *N) {
6951   SDValue N0 = N->getOperand(0);
6952   SDValue N1 = N->getOperand(1);
6953   EVT VT = N1.getValueType();
6954 
6955   // x | x --> x
6956   if (N0 == N1)
6957     return N0;
6958 
6959   // fold (or c1, c2) -> c1|c2
6960   if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, {N0, N1}))
6961     return C;
6962 
6963   // canonicalize constant to RHS
6964   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
6965       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
6966     return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0);
6967 
6968   // fold vector ops
6969   if (VT.isVector()) {
6970     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
6971       return FoldedVOp;
6972 
6973     // fold (or x, 0) -> x, vector edition
6974     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
6975       return N0;
6976 
6977     // fold (or x, -1) -> -1, vector edition
6978     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
6979       // do not return N1, because undef node may exist in N1
6980       return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType());
6981 
6982     // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
6983     // Do this only if the resulting type / shuffle is legal.
6984     auto *SV0 = dyn_cast<ShuffleVectorSDNode>(N0);
6985     auto *SV1 = dyn_cast<ShuffleVectorSDNode>(N1);
6986     if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
6987       bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
6988       bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
6989       bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
6990       bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
6991       // Ensure both shuffles have a zero input.
6992       if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
6993         assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
6994         assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
6995         bool CanFold = true;
6996         int NumElts = VT.getVectorNumElements();
6997         SmallVector<int, 4> Mask(NumElts, -1);
6998 
6999         for (int i = 0; i != NumElts; ++i) {
7000           int M0 = SV0->getMaskElt(i);
7001           int M1 = SV1->getMaskElt(i);
7002 
7003           // Determine if either index is pointing to a zero vector.
7004           bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
7005           bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
7006 
7007           // If one element is zero and the otherside is undef, keep undef.
7008           // This also handles the case that both are undef.
7009           if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
7010             continue;
7011 
7012           // Make sure only one of the elements is zero.
7013           if (M0Zero == M1Zero) {
7014             CanFold = false;
7015             break;
7016           }
7017 
7018           assert((M0 >= 0 || M1 >= 0) && "Undef index!");
7019 
7020           // We have a zero and non-zero element. If the non-zero came from
7021           // SV0 make the index a LHS index. If it came from SV1, make it
7022           // a RHS index. We need to mod by NumElts because we don't care
7023           // which operand it came from in the original shuffles.
7024           Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
7025         }
7026 
7027         if (CanFold) {
7028           SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
7029           SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
7030 
7031           SDValue LegalShuffle =
7032               TLI.buildLegalVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS,
7033                                           Mask, DAG);
7034           if (LegalShuffle)
7035             return LegalShuffle;
7036         }
7037       }
7038     }
7039   }
7040 
7041   // fold (or x, 0) -> x
7042   if (isNullConstant(N1))
7043     return N0;
7044 
7045   // fold (or x, -1) -> -1
7046   if (isAllOnesConstant(N1))
7047     return N1;
7048 
7049   if (SDValue NewSel = foldBinOpIntoSelect(N))
7050     return NewSel;
7051 
7052   // fold (or x, c) -> c iff (x & ~c) == 0
7053   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
7054   if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
7055     return N1;
7056 
7057   if (SDValue Combined = visitORLike(N0, N1, N))
7058     return Combined;
7059 
7060   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7061     return Combined;
7062 
7063   // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
7064   if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
7065     return BSwap;
7066   if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
7067     return BSwap;
7068 
7069   // reassociate or
7070   if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
7071     return ROR;
7072 
7073   // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
7074   // iff (c1 & c2) != 0 or c1/c2 are undef.
7075   auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
7076     return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
7077   };
7078   if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
7079       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
7080     if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
7081                                                  {N1, N0.getOperand(1)})) {
7082       SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
7083       AddToWorklist(IOR.getNode());
7084       return DAG.getNode(ISD::AND, SDLoc(N), VT, COR, IOR);
7085     }
7086   }
7087 
7088   if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
7089     return Combined;
7090   if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
7091     return Combined;
7092 
7093   // Simplify: (or (op x...), (op y...))  -> (op (or x, y))
7094   if (N0.getOpcode() == N1.getOpcode())
7095     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7096       return V;
7097 
7098   // See if this is some rotate idiom.
7099   if (SDValue Rot = MatchRotate(N0, N1, SDLoc(N)))
7100     return Rot;
7101 
7102   if (SDValue Load = MatchLoadCombine(N))
7103     return Load;
7104 
7105   // Simplify the operands using demanded-bits information.
7106   if (SimplifyDemandedBits(SDValue(N, 0)))
7107     return SDValue(N, 0);
7108 
7109   // If OR can be rewritten into ADD, try combines based on ADD.
7110   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
7111       DAG.haveNoCommonBitsSet(N0, N1))
7112     if (SDValue Combined = visitADDLike(N))
7113       return Combined;
7114 
7115   return SDValue();
7116 }
7117 
stripConstantMask(SelectionDAG & DAG,SDValue Op,SDValue & Mask)7118 static SDValue stripConstantMask(SelectionDAG &DAG, SDValue Op, SDValue &Mask) {
7119   if (Op.getOpcode() == ISD::AND &&
7120       DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
7121     Mask = Op.getOperand(1);
7122     return Op.getOperand(0);
7123   }
7124   return Op;
7125 }
7126 
7127 /// Match "(X shl/srl V1) & V2" where V2 may not be present.
matchRotateHalf(SelectionDAG & DAG,SDValue Op,SDValue & Shift,SDValue & Mask)7128 static bool matchRotateHalf(SelectionDAG &DAG, SDValue Op, SDValue &Shift,
7129                             SDValue &Mask) {
7130   Op = stripConstantMask(DAG, Op, Mask);
7131   if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
7132     Shift = Op;
7133     return true;
7134   }
7135   return false;
7136 }
7137 
7138 /// Helper function for visitOR to extract the needed side of a rotate idiom
7139 /// from a shl/srl/mul/udiv.  This is meant to handle cases where
7140 /// InstCombine merged some outside op with one of the shifts from
7141 /// the rotate pattern.
7142 /// \returns An empty \c SDValue if the needed shift couldn't be extracted.
7143 /// Otherwise, returns an expansion of \p ExtractFrom based on the following
7144 /// patterns:
7145 ///
7146 ///   (or (add v v) (shrl v bitwidth-1)):
7147 ///     expands (add v v) -> (shl v 1)
7148 ///
7149 ///   (or (mul v c0) (shrl (mul v c1) c2)):
7150 ///     expands (mul v c0) -> (shl (mul v c1) c3)
7151 ///
7152 ///   (or (udiv v c0) (shl (udiv v c1) c2)):
7153 ///     expands (udiv v c0) -> (shrl (udiv v c1) c3)
7154 ///
7155 ///   (or (shl v c0) (shrl (shl v c1) c2)):
7156 ///     expands (shl v c0) -> (shl (shl v c1) c3)
7157 ///
7158 ///   (or (shrl v c0) (shl (shrl v c1) c2)):
7159 ///     expands (shrl v c0) -> (shrl (shrl v c1) c3)
7160 ///
7161 /// Such that in all cases, c3+c2==bitwidth(op v c1).
extractShiftForRotate(SelectionDAG & DAG,SDValue OppShift,SDValue ExtractFrom,SDValue & Mask,const SDLoc & DL)7162 static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
7163                                      SDValue ExtractFrom, SDValue &Mask,
7164                                      const SDLoc &DL) {
7165   assert(OppShift && ExtractFrom && "Empty SDValue");
7166   if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
7167     return SDValue();
7168 
7169   ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
7170 
7171   // Value and Type of the shift.
7172   SDValue OppShiftLHS = OppShift.getOperand(0);
7173   EVT ShiftedVT = OppShiftLHS.getValueType();
7174 
7175   // Amount of the existing shift.
7176   ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
7177 
7178   // (add v v) -> (shl v 1)
7179   // TODO: Should this be a general DAG canonicalization?
7180   if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
7181       ExtractFrom.getOpcode() == ISD::ADD &&
7182       ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
7183       ExtractFrom.getOperand(0) == OppShiftLHS &&
7184       OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
7185     return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
7186                        DAG.getShiftAmountConstant(1, ShiftedVT, DL));
7187 
7188   // Preconditions:
7189   //    (or (op0 v c0) (shiftl/r (op0 v c1) c2))
7190   //
7191   // Find opcode of the needed shift to be extracted from (op0 v c0).
7192   unsigned Opcode = ISD::DELETED_NODE;
7193   bool IsMulOrDiv = false;
7194   // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
7195   // opcode or its arithmetic (mul or udiv) variant.
7196   auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
7197     IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
7198     if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
7199       return false;
7200     Opcode = NeededShift;
7201     return true;
7202   };
7203   // op0 must be either the needed shift opcode or the mul/udiv equivalent
7204   // that the needed shift can be extracted from.
7205   if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
7206       (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
7207     return SDValue();
7208 
7209   // op0 must be the same opcode on both sides, have the same LHS argument,
7210   // and produce the same value type.
7211   if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
7212       OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
7213       ShiftedVT != ExtractFrom.getValueType())
7214     return SDValue();
7215 
7216   // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
7217   ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
7218   // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
7219   ConstantSDNode *ExtractFromCst =
7220       isConstOrConstSplat(ExtractFrom.getOperand(1));
7221   // TODO: We should be able to handle non-uniform constant vectors for these values
7222   // Check that we have constant values.
7223   if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
7224       !OppLHSCst || !OppLHSCst->getAPIntValue() ||
7225       !ExtractFromCst || !ExtractFromCst->getAPIntValue())
7226     return SDValue();
7227 
7228   // Compute the shift amount we need to extract to complete the rotate.
7229   const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
7230   if (OppShiftCst->getAPIntValue().ugt(VTWidth))
7231     return SDValue();
7232   APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
7233   // Normalize the bitwidth of the two mul/udiv/shift constant operands.
7234   APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
7235   APInt OppLHSAmt = OppLHSCst->getAPIntValue();
7236   zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
7237 
7238   // Now try extract the needed shift from the ExtractFrom op and see if the
7239   // result matches up with the existing shift's LHS op.
7240   if (IsMulOrDiv) {
7241     // Op to extract from is a mul or udiv by a constant.
7242     // Check:
7243     //     c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
7244     //     c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
7245     const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
7246                                                  NeededShiftAmt.getZExtValue());
7247     APInt ResultAmt;
7248     APInt Rem;
7249     APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
7250     if (Rem != 0 || ResultAmt != OppLHSAmt)
7251       return SDValue();
7252   } else {
7253     // Op to extract from is a shift by a constant.
7254     // Check:
7255     //      c2 - (bitwidth(op0 v c0) - c1) == c0
7256     if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
7257                                           ExtractFromAmt.getBitWidth()))
7258       return SDValue();
7259   }
7260 
7261   // Return the expanded shift op that should allow a rotate to be formed.
7262   EVT ShiftVT = OppShift.getOperand(1).getValueType();
7263   EVT ResVT = ExtractFrom.getValueType();
7264   SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
7265   return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
7266 }
7267 
7268 // Return true if we can prove that, whenever Neg and Pos are both in the
7269 // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos).  This means that
7270 // for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
7271 //
7272 //     (or (shift1 X, Neg), (shift2 X, Pos))
7273 //
7274 // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
7275 // in direction shift1 by Neg.  The range [0, EltSize) means that we only need
7276 // to consider shift amounts with defined behavior.
7277 //
7278 // The IsRotate flag should be set when the LHS of both shifts is the same.
7279 // Otherwise if matching a general funnel shift, it should be clear.
matchRotateSub(SDValue Pos,SDValue Neg,unsigned EltSize,SelectionDAG & DAG,bool IsRotate)7280 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
7281                            SelectionDAG &DAG, bool IsRotate) {
7282   const auto &TLI = DAG.getTargetLoweringInfo();
7283   // If EltSize is a power of 2 then:
7284   //
7285   //  (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
7286   //  (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
7287   //
7288   // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
7289   // for the stronger condition:
7290   //
7291   //     Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1)    [A]
7292   //
7293   // for all Neg and Pos.  Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
7294   // we can just replace Neg with Neg' for the rest of the function.
7295   //
7296   // In other cases we check for the even stronger condition:
7297   //
7298   //     Neg == EltSize - Pos                                    [B]
7299   //
7300   // for all Neg and Pos.  Note that the (or ...) then invokes undefined
7301   // behavior if Pos == 0 (and consequently Neg == EltSize).
7302   //
7303   // We could actually use [A] whenever EltSize is a power of 2, but the
7304   // only extra cases that it would match are those uninteresting ones
7305   // where Neg and Pos are never in range at the same time.  E.g. for
7306   // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
7307   // as well as (sub 32, Pos), but:
7308   //
7309   //     (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
7310   //
7311   // always invokes undefined behavior for 32-bit X.
7312   //
7313   // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
7314   // This allows us to peek through any operations that only affect Mask's
7315   // un-demanded bits.
7316   //
7317   // NOTE: We can only do this when matching operations which won't modify the
7318   // least Log2(EltSize) significant bits and not a general funnel shift.
7319   unsigned MaskLoBits = 0;
7320   if (IsRotate && isPowerOf2_64(EltSize)) {
7321     unsigned Bits = Log2_64(EltSize);
7322     unsigned NegBits = Neg.getScalarValueSizeInBits();
7323     if (NegBits >= Bits) {
7324       APInt DemandedBits = APInt::getLowBitsSet(NegBits, Bits);
7325       if (SDValue Inner =
7326               TLI.SimplifyMultipleUseDemandedBits(Neg, DemandedBits, DAG)) {
7327         Neg = Inner;
7328         MaskLoBits = Bits;
7329       }
7330     }
7331   }
7332 
7333   // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
7334   if (Neg.getOpcode() != ISD::SUB)
7335     return false;
7336   ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
7337   if (!NegC)
7338     return false;
7339   SDValue NegOp1 = Neg.getOperand(1);
7340 
7341   // On the RHS of [A], if Pos is the result of operation on Pos' that won't
7342   // affect Mask's demanded bits, just replace Pos with Pos'. These operations
7343   // are redundant for the purpose of the equality.
7344   if (MaskLoBits) {
7345     unsigned PosBits = Pos.getScalarValueSizeInBits();
7346     if (PosBits >= MaskLoBits) {
7347       APInt DemandedBits = APInt::getLowBitsSet(PosBits, MaskLoBits);
7348       if (SDValue Inner =
7349               TLI.SimplifyMultipleUseDemandedBits(Pos, DemandedBits, DAG)) {
7350         Pos = Inner;
7351       }
7352     }
7353   }
7354 
7355   // The condition we need is now:
7356   //
7357   //     (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
7358   //
7359   // If NegOp1 == Pos then we need:
7360   //
7361   //              EltSize & Mask == NegC & Mask
7362   //
7363   // (because "x & Mask" is a truncation and distributes through subtraction).
7364   //
7365   // We also need to account for a potential truncation of NegOp1 if the amount
7366   // has already been legalized to a shift amount type.
7367   APInt Width;
7368   if ((Pos == NegOp1) ||
7369       (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(0)))
7370     Width = NegC->getAPIntValue();
7371 
7372   // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
7373   // Then the condition we want to prove becomes:
7374   //
7375   //     (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
7376   //
7377   // which, again because "x & Mask" is a truncation, becomes:
7378   //
7379   //                NegC & Mask == (EltSize - PosC) & Mask
7380   //             EltSize & Mask == (NegC + PosC) & Mask
7381   else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
7382     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
7383       Width = PosC->getAPIntValue() + NegC->getAPIntValue();
7384     else
7385       return false;
7386   } else
7387     return false;
7388 
7389   // Now we just need to check that EltSize & Mask == Width & Mask.
7390   if (MaskLoBits)
7391     // EltSize & Mask is 0 since Mask is EltSize - 1.
7392     return Width.getLoBits(MaskLoBits) == 0;
7393   return Width == EltSize;
7394 }
7395 
7396 // A subroutine of MatchRotate used once we have found an OR of two opposite
7397 // shifts of Shifted.  If Neg == <operand size> - Pos then the OR reduces
7398 // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
7399 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
7400 // Neg with outer conversions stripped away.
MatchRotatePosNeg(SDValue Shifted,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)7401 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
7402                                        SDValue Neg, SDValue InnerPos,
7403                                        SDValue InnerNeg, bool HasPos,
7404                                        unsigned PosOpcode, unsigned NegOpcode,
7405                                        const SDLoc &DL) {
7406   // fold (or (shl x, (*ext y)),
7407   //          (srl x, (*ext (sub 32, y)))) ->
7408   //   (rotl x, y) or (rotr x, (sub 32, y))
7409   //
7410   // fold (or (shl x, (*ext (sub 32, y))),
7411   //          (srl x, (*ext y))) ->
7412   //   (rotr x, y) or (rotl x, (sub 32, y))
7413   EVT VT = Shifted.getValueType();
7414   if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
7415                      /*IsRotate*/ true)) {
7416     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
7417                        HasPos ? Pos : Neg);
7418   }
7419 
7420   return SDValue();
7421 }
7422 
7423 // A subroutine of MatchRotate used once we have found an OR of two opposite
7424 // shifts of N0 + N1.  If Neg == <operand size> - Pos then the OR reduces
7425 // to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
7426 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
7427 // Neg with outer conversions stripped away.
7428 // TODO: Merge with MatchRotatePosNeg.
MatchFunnelPosNeg(SDValue N0,SDValue N1,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)7429 SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
7430                                        SDValue Neg, SDValue InnerPos,
7431                                        SDValue InnerNeg, bool HasPos,
7432                                        unsigned PosOpcode, unsigned NegOpcode,
7433                                        const SDLoc &DL) {
7434   EVT VT = N0.getValueType();
7435   unsigned EltBits = VT.getScalarSizeInBits();
7436 
7437   // fold (or (shl x0, (*ext y)),
7438   //          (srl x1, (*ext (sub 32, y)))) ->
7439   //   (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
7440   //
7441   // fold (or (shl x0, (*ext (sub 32, y))),
7442   //          (srl x1, (*ext y))) ->
7443   //   (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
7444   if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) {
7445     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
7446                        HasPos ? Pos : Neg);
7447   }
7448 
7449   // Matching the shift+xor cases, we can't easily use the xor'd shift amount
7450   // so for now just use the PosOpcode case if its legal.
7451   // TODO: When can we use the NegOpcode case?
7452   if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
7453     auto IsBinOpImm = [](SDValue Op, unsigned BinOpc, unsigned Imm) {
7454       if (Op.getOpcode() != BinOpc)
7455         return false;
7456       ConstantSDNode *Cst = isConstOrConstSplat(Op.getOperand(1));
7457       return Cst && (Cst->getAPIntValue() == Imm);
7458     };
7459 
7460     // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
7461     //   -> (fshl x0, x1, y)
7462     if (IsBinOpImm(N1, ISD::SRL, 1) &&
7463         IsBinOpImm(InnerNeg, ISD::XOR, EltBits - 1) &&
7464         InnerPos == InnerNeg.getOperand(0) &&
7465         TLI.isOperationLegalOrCustom(ISD::FSHL, VT)) {
7466       return DAG.getNode(ISD::FSHL, DL, VT, N0, N1.getOperand(0), Pos);
7467     }
7468 
7469     // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
7470     //   -> (fshr x0, x1, y)
7471     if (IsBinOpImm(N0, ISD::SHL, 1) &&
7472         IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
7473         InnerNeg == InnerPos.getOperand(0) &&
7474         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
7475       return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
7476     }
7477 
7478     // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
7479     //   -> (fshr x0, x1, y)
7480     // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
7481     if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N0.getOperand(1) &&
7482         IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
7483         InnerNeg == InnerPos.getOperand(0) &&
7484         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
7485       return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
7486     }
7487   }
7488 
7489   return SDValue();
7490 }
7491 
7492 // MatchRotate - Handle an 'or' of two operands.  If this is one of the many
7493 // idioms for rotate, and if the target supports rotation instructions, generate
7494 // a rot[lr]. This also matches funnel shift patterns, similar to rotation but
7495 // with different shifted sources.
MatchRotate(SDValue LHS,SDValue RHS,const SDLoc & DL)7496 SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
7497   EVT VT = LHS.getValueType();
7498 
7499   // The target must have at least one rotate/funnel flavor.
7500   // We still try to match rotate by constant pre-legalization.
7501   // TODO: Support pre-legalization funnel-shift by constant.
7502   bool HasROTL = hasOperation(ISD::ROTL, VT);
7503   bool HasROTR = hasOperation(ISD::ROTR, VT);
7504   bool HasFSHL = hasOperation(ISD::FSHL, VT);
7505   bool HasFSHR = hasOperation(ISD::FSHR, VT);
7506 
7507   // If the type is going to be promoted and the target has enabled custom
7508   // lowering for rotate, allow matching rotate by non-constants. Only allow
7509   // this for scalar types.
7510   if (VT.isScalarInteger() && TLI.getTypeAction(*DAG.getContext(), VT) ==
7511                                   TargetLowering::TypePromoteInteger) {
7512     HasROTL |= TLI.getOperationAction(ISD::ROTL, VT) == TargetLowering::Custom;
7513     HasROTR |= TLI.getOperationAction(ISD::ROTR, VT) == TargetLowering::Custom;
7514   }
7515 
7516   if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
7517     return SDValue();
7518 
7519   // Check for truncated rotate.
7520   if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
7521       LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
7522     assert(LHS.getValueType() == RHS.getValueType());
7523     if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
7524       return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
7525     }
7526   }
7527 
7528   // Match "(X shl/srl V1) & V2" where V2 may not be present.
7529   SDValue LHSShift;   // The shift.
7530   SDValue LHSMask;    // AND value if any.
7531   matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
7532 
7533   SDValue RHSShift;   // The shift.
7534   SDValue RHSMask;    // AND value if any.
7535   matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
7536 
7537   // If neither side matched a rotate half, bail
7538   if (!LHSShift && !RHSShift)
7539     return SDValue();
7540 
7541   // InstCombine may have combined a constant shl, srl, mul, or udiv with one
7542   // side of the rotate, so try to handle that here. In all cases we need to
7543   // pass the matched shift from the opposite side to compute the opcode and
7544   // needed shift amount to extract.  We still want to do this if both sides
7545   // matched a rotate half because one half may be a potential overshift that
7546   // can be broken down (ie if InstCombine merged two shl or srl ops into a
7547   // single one).
7548 
7549   // Have LHS side of the rotate, try to extract the needed shift from the RHS.
7550   if (LHSShift)
7551     if (SDValue NewRHSShift =
7552             extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
7553       RHSShift = NewRHSShift;
7554   // Have RHS side of the rotate, try to extract the needed shift from the LHS.
7555   if (RHSShift)
7556     if (SDValue NewLHSShift =
7557             extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
7558       LHSShift = NewLHSShift;
7559 
7560   // If a side is still missing, nothing else we can do.
7561   if (!RHSShift || !LHSShift)
7562     return SDValue();
7563 
7564   // At this point we've matched or extracted a shift op on each side.
7565 
7566   if (LHSShift.getOpcode() == RHSShift.getOpcode())
7567     return SDValue(); // Shifts must disagree.
7568 
7569   // Canonicalize shl to left side in a shl/srl pair.
7570   if (RHSShift.getOpcode() == ISD::SHL) {
7571     std::swap(LHS, RHS);
7572     std::swap(LHSShift, RHSShift);
7573     std::swap(LHSMask, RHSMask);
7574   }
7575 
7576   // Something has gone wrong - we've lost the shl/srl pair - bail.
7577   if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
7578     return SDValue();
7579 
7580   unsigned EltSizeInBits = VT.getScalarSizeInBits();
7581   SDValue LHSShiftArg = LHSShift.getOperand(0);
7582   SDValue LHSShiftAmt = LHSShift.getOperand(1);
7583   SDValue RHSShiftArg = RHSShift.getOperand(0);
7584   SDValue RHSShiftAmt = RHSShift.getOperand(1);
7585 
7586   auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
7587                                         ConstantSDNode *RHS) {
7588     return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
7589   };
7590 
7591   auto ApplyMasks = [&](SDValue Res) {
7592     // If there is an AND of either shifted operand, apply it to the result.
7593     if (LHSMask.getNode() || RHSMask.getNode()) {
7594       SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
7595       SDValue Mask = AllOnes;
7596 
7597       if (LHSMask.getNode()) {
7598         SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
7599         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
7600                            DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
7601       }
7602       if (RHSMask.getNode()) {
7603         SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
7604         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
7605                            DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
7606       }
7607 
7608       Res = DAG.getNode(ISD::AND, DL, VT, Res, Mask);
7609     }
7610 
7611     return Res;
7612   };
7613 
7614   // TODO: Support pre-legalization funnel-shift by constant.
7615   bool IsRotate = LHSShift.getOperand(0) == RHSShift.getOperand(0);
7616   if (!IsRotate && !(HasFSHL || HasFSHR)) {
7617     if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
7618         ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
7619       // Look for a disguised rotate by constant.
7620       // The common shifted operand X may be hidden inside another 'or'.
7621       SDValue X, Y;
7622       auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
7623         if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
7624           return false;
7625         if (CommonOp == Or.getOperand(0)) {
7626           X = CommonOp;
7627           Y = Or.getOperand(1);
7628           return true;
7629         }
7630         if (CommonOp == Or.getOperand(1)) {
7631           X = CommonOp;
7632           Y = Or.getOperand(0);
7633           return true;
7634         }
7635         return false;
7636       };
7637 
7638       SDValue Res;
7639       if (matchOr(LHSShiftArg, RHSShiftArg)) {
7640         // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
7641         SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
7642         SDValue ShlY = DAG.getNode(ISD::SHL, DL, VT, Y, LHSShiftAmt);
7643         Res = DAG.getNode(ISD::OR, DL, VT, RotX, ShlY);
7644       } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
7645         // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
7646         SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
7647         SDValue SrlY = DAG.getNode(ISD::SRL, DL, VT, Y, RHSShiftAmt);
7648         Res = DAG.getNode(ISD::OR, DL, VT, RotX, SrlY);
7649       } else {
7650         return SDValue();
7651       }
7652 
7653       return ApplyMasks(Res);
7654     }
7655 
7656     return SDValue(); // Requires funnel shift support.
7657   }
7658 
7659   // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
7660   // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
7661   // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
7662   // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
7663   // iff C1+C2 == EltSizeInBits
7664   if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
7665     SDValue Res;
7666     if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
7667       bool UseROTL = !LegalOperations || HasROTL;
7668       Res = DAG.getNode(UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
7669                         UseROTL ? LHSShiftAmt : RHSShiftAmt);
7670     } else {
7671       bool UseFSHL = !LegalOperations || HasFSHL;
7672       Res = DAG.getNode(UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg,
7673                         RHSShiftArg, UseFSHL ? LHSShiftAmt : RHSShiftAmt);
7674     }
7675 
7676     return ApplyMasks(Res);
7677   }
7678 
7679   // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
7680   // shift.
7681   if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
7682     return SDValue();
7683 
7684   // If there is a mask here, and we have a variable shift, we can't be sure
7685   // that we're masking out the right stuff.
7686   if (LHSMask.getNode() || RHSMask.getNode())
7687     return SDValue();
7688 
7689   // If the shift amount is sign/zext/any-extended just peel it off.
7690   SDValue LExtOp0 = LHSShiftAmt;
7691   SDValue RExtOp0 = RHSShiftAmt;
7692   if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
7693        LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
7694        LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
7695        LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
7696       (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
7697        RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
7698        RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
7699        RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
7700     LExtOp0 = LHSShiftAmt.getOperand(0);
7701     RExtOp0 = RHSShiftAmt.getOperand(0);
7702   }
7703 
7704   if (IsRotate && (HasROTL || HasROTR)) {
7705     SDValue TryL =
7706         MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0,
7707                           RExtOp0, HasROTL, ISD::ROTL, ISD::ROTR, DL);
7708     if (TryL)
7709       return TryL;
7710 
7711     SDValue TryR =
7712         MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0,
7713                           LExtOp0, HasROTR, ISD::ROTR, ISD::ROTL, DL);
7714     if (TryR)
7715       return TryR;
7716   }
7717 
7718   SDValue TryL =
7719       MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt,
7720                         LExtOp0, RExtOp0, HasFSHL, ISD::FSHL, ISD::FSHR, DL);
7721   if (TryL)
7722     return TryL;
7723 
7724   SDValue TryR =
7725       MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
7726                         RExtOp0, LExtOp0, HasFSHR, ISD::FSHR, ISD::FSHL, DL);
7727   if (TryR)
7728     return TryR;
7729 
7730   return SDValue();
7731 }
7732 
7733 namespace {
7734 
7735 /// Represents known origin of an individual byte in load combine pattern. The
7736 /// value of the byte is either constant zero or comes from memory.
7737 struct ByteProvider {
7738   // For constant zero providers Load is set to nullptr. For memory providers
7739   // Load represents the node which loads the byte from memory.
7740   // ByteOffset is the offset of the byte in the value produced by the load.
7741   LoadSDNode *Load = nullptr;
7742   unsigned ByteOffset = 0;
7743 
7744   ByteProvider() = default;
7745 
getMemory__anon54f00e401611::ByteProvider7746   static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) {
7747     return ByteProvider(Load, ByteOffset);
7748   }
7749 
getConstantZero__anon54f00e401611::ByteProvider7750   static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0); }
7751 
isConstantZero__anon54f00e401611::ByteProvider7752   bool isConstantZero() const { return !Load; }
isMemory__anon54f00e401611::ByteProvider7753   bool isMemory() const { return Load; }
7754 
operator ==__anon54f00e401611::ByteProvider7755   bool operator==(const ByteProvider &Other) const {
7756     return Other.Load == Load && Other.ByteOffset == ByteOffset;
7757   }
7758 
7759 private:
ByteProvider__anon54f00e401611::ByteProvider7760   ByteProvider(LoadSDNode *Load, unsigned ByteOffset)
7761       : Load(Load), ByteOffset(ByteOffset) {}
7762 };
7763 
7764 } // end anonymous namespace
7765 
7766 /// Recursively traverses the expression calculating the origin of the requested
7767 /// byte of the given value. Returns None if the provider can't be calculated.
7768 ///
7769 /// For all the values except the root of the expression verifies that the value
7770 /// has exactly one use and if it's not true return None. This way if the origin
7771 /// of the byte is returned it's guaranteed that the values which contribute to
7772 /// the byte are not used outside of this expression.
7773 ///
7774 /// Because the parts of the expression are not allowed to have more than one
7775 /// use this function iterates over trees, not DAGs. So it never visits the same
7776 /// node more than once.
7777 static const Optional<ByteProvider>
calculateByteProvider(SDValue Op,unsigned Index,unsigned Depth,bool Root=false)7778 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
7779                       bool Root = false) {
7780   // Typical i64 by i8 pattern requires recursion up to 8 calls depth
7781   if (Depth == 10)
7782     return None;
7783 
7784   if (!Root && !Op.hasOneUse())
7785     return None;
7786 
7787   assert(Op.getValueType().isScalarInteger() && "can't handle other types");
7788   unsigned BitWidth = Op.getValueSizeInBits();
7789   if (BitWidth % 8 != 0)
7790     return None;
7791   unsigned ByteWidth = BitWidth / 8;
7792   assert(Index < ByteWidth && "invalid index requested");
7793   (void) ByteWidth;
7794 
7795   switch (Op.getOpcode()) {
7796   case ISD::OR: {
7797     auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1);
7798     if (!LHS)
7799       return None;
7800     auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1);
7801     if (!RHS)
7802       return None;
7803 
7804     if (LHS->isConstantZero())
7805       return RHS;
7806     if (RHS->isConstantZero())
7807       return LHS;
7808     return None;
7809   }
7810   case ISD::SHL: {
7811     auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
7812     if (!ShiftOp)
7813       return None;
7814 
7815     uint64_t BitShift = ShiftOp->getZExtValue();
7816     if (BitShift % 8 != 0)
7817       return None;
7818     uint64_t ByteShift = BitShift / 8;
7819 
7820     return Index < ByteShift
7821                ? ByteProvider::getConstantZero()
7822                : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
7823                                        Depth + 1);
7824   }
7825   case ISD::ANY_EXTEND:
7826   case ISD::SIGN_EXTEND:
7827   case ISD::ZERO_EXTEND: {
7828     SDValue NarrowOp = Op->getOperand(0);
7829     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
7830     if (NarrowBitWidth % 8 != 0)
7831       return None;
7832     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
7833 
7834     if (Index >= NarrowByteWidth)
7835       return Op.getOpcode() == ISD::ZERO_EXTEND
7836                  ? Optional<ByteProvider>(ByteProvider::getConstantZero())
7837                  : None;
7838     return calculateByteProvider(NarrowOp, Index, Depth + 1);
7839   }
7840   case ISD::BSWAP:
7841     return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
7842                                  Depth + 1);
7843   case ISD::LOAD: {
7844     auto L = cast<LoadSDNode>(Op.getNode());
7845     if (!L->isSimple() || L->isIndexed())
7846       return None;
7847 
7848     unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
7849     if (NarrowBitWidth % 8 != 0)
7850       return None;
7851     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
7852 
7853     if (Index >= NarrowByteWidth)
7854       return L->getExtensionType() == ISD::ZEXTLOAD
7855                  ? Optional<ByteProvider>(ByteProvider::getConstantZero())
7856                  : None;
7857     return ByteProvider::getMemory(L, Index);
7858   }
7859   }
7860 
7861   return None;
7862 }
7863 
littleEndianByteAt(unsigned BW,unsigned i)7864 static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
7865   return i;
7866 }
7867 
bigEndianByteAt(unsigned BW,unsigned i)7868 static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
7869   return BW - i - 1;
7870 }
7871 
7872 // Check if the bytes offsets we are looking at match with either big or
7873 // little endian value loaded. Return true for big endian, false for little
7874 // endian, and None if match failed.
isBigEndian(const ArrayRef<int64_t> ByteOffsets,int64_t FirstOffset)7875 static Optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
7876                                   int64_t FirstOffset) {
7877   // The endian can be decided only when it is 2 bytes at least.
7878   unsigned Width = ByteOffsets.size();
7879   if (Width < 2)
7880     return None;
7881 
7882   bool BigEndian = true, LittleEndian = true;
7883   for (unsigned i = 0; i < Width; i++) {
7884     int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
7885     LittleEndian &= CurrentByteOffset == littleEndianByteAt(Width, i);
7886     BigEndian &= CurrentByteOffset == bigEndianByteAt(Width, i);
7887     if (!BigEndian && !LittleEndian)
7888       return None;
7889   }
7890 
7891   assert((BigEndian != LittleEndian) && "It should be either big endian or"
7892                                         "little endian");
7893   return BigEndian;
7894 }
7895 
stripTruncAndExt(SDValue Value)7896 static SDValue stripTruncAndExt(SDValue Value) {
7897   switch (Value.getOpcode()) {
7898   case ISD::TRUNCATE:
7899   case ISD::ZERO_EXTEND:
7900   case ISD::SIGN_EXTEND:
7901   case ISD::ANY_EXTEND:
7902     return stripTruncAndExt(Value.getOperand(0));
7903   }
7904   return Value;
7905 }
7906 
7907 /// Match a pattern where a wide type scalar value is stored by several narrow
7908 /// stores. Fold it into a single store or a BSWAP and a store if the targets
7909 /// supports it.
7910 ///
7911 /// Assuming little endian target:
7912 ///  i8 *p = ...
7913 ///  i32 val = ...
7914 ///  p[0] = (val >> 0) & 0xFF;
7915 ///  p[1] = (val >> 8) & 0xFF;
7916 ///  p[2] = (val >> 16) & 0xFF;
7917 ///  p[3] = (val >> 24) & 0xFF;
7918 /// =>
7919 ///  *((i32)p) = val;
7920 ///
7921 ///  i8 *p = ...
7922 ///  i32 val = ...
7923 ///  p[0] = (val >> 24) & 0xFF;
7924 ///  p[1] = (val >> 16) & 0xFF;
7925 ///  p[2] = (val >> 8) & 0xFF;
7926 ///  p[3] = (val >> 0) & 0xFF;
7927 /// =>
7928 ///  *((i32)p) = BSWAP(val);
mergeTruncStores(StoreSDNode * N)7929 SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
7930   // The matching looks for "store (trunc x)" patterns that appear early but are
7931   // likely to be replaced by truncating store nodes during combining.
7932   // TODO: If there is evidence that running this later would help, this
7933   //       limitation could be removed. Legality checks may need to be added
7934   //       for the created store and optional bswap/rotate.
7935   if (LegalOperations || OptLevel == CodeGenOpt::None)
7936     return SDValue();
7937 
7938   // We only handle merging simple stores of 1-4 bytes.
7939   // TODO: Allow unordered atomics when wider type is legal (see D66309)
7940   EVT MemVT = N->getMemoryVT();
7941   if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
7942       !N->isSimple() || N->isIndexed())
7943     return SDValue();
7944 
7945   // Collect all of the stores in the chain.
7946   SDValue Chain = N->getChain();
7947   SmallVector<StoreSDNode *, 8> Stores = {N};
7948   while (auto *Store = dyn_cast<StoreSDNode>(Chain)) {
7949     // All stores must be the same size to ensure that we are writing all of the
7950     // bytes in the wide value.
7951     // TODO: We could allow multiple sizes by tracking each stored byte.
7952     if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
7953         Store->isIndexed())
7954       return SDValue();
7955     Stores.push_back(Store);
7956     Chain = Store->getChain();
7957   }
7958   // There is no reason to continue if we do not have at least a pair of stores.
7959   if (Stores.size() < 2)
7960     return SDValue();
7961 
7962   // Handle simple types only.
7963   LLVMContext &Context = *DAG.getContext();
7964   unsigned NumStores = Stores.size();
7965   unsigned NarrowNumBits = N->getMemoryVT().getScalarSizeInBits();
7966   unsigned WideNumBits = NumStores * NarrowNumBits;
7967   EVT WideVT = EVT::getIntegerVT(Context, WideNumBits);
7968   if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
7969     return SDValue();
7970 
7971   // Check if all bytes of the source value that we are looking at are stored
7972   // to the same base address. Collect offsets from Base address into OffsetMap.
7973   SDValue SourceValue;
7974   SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
7975   int64_t FirstOffset = INT64_MAX;
7976   StoreSDNode *FirstStore = nullptr;
7977   Optional<BaseIndexOffset> Base;
7978   for (auto *Store : Stores) {
7979     // All the stores store different parts of the CombinedValue. A truncate is
7980     // required to get the partial value.
7981     SDValue Trunc = Store->getValue();
7982     if (Trunc.getOpcode() != ISD::TRUNCATE)
7983       return SDValue();
7984     // Other than the first/last part, a shift operation is required to get the
7985     // offset.
7986     int64_t Offset = 0;
7987     SDValue WideVal = Trunc.getOperand(0);
7988     if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
7989         isa<ConstantSDNode>(WideVal.getOperand(1))) {
7990       // The shift amount must be a constant multiple of the narrow type.
7991       // It is translated to the offset address in the wide source value "y".
7992       //
7993       // x = srl y, ShiftAmtC
7994       // i8 z = trunc x
7995       // store z, ...
7996       uint64_t ShiftAmtC = WideVal.getConstantOperandVal(1);
7997       if (ShiftAmtC % NarrowNumBits != 0)
7998         return SDValue();
7999 
8000       Offset = ShiftAmtC / NarrowNumBits;
8001       WideVal = WideVal.getOperand(0);
8002     }
8003 
8004     // Stores must share the same source value with different offsets.
8005     // Truncate and extends should be stripped to get the single source value.
8006     if (!SourceValue)
8007       SourceValue = WideVal;
8008     else if (stripTruncAndExt(SourceValue) != stripTruncAndExt(WideVal))
8009       return SDValue();
8010     else if (SourceValue.getValueType() != WideVT) {
8011       if (WideVal.getValueType() == WideVT ||
8012           WideVal.getScalarValueSizeInBits() >
8013               SourceValue.getScalarValueSizeInBits())
8014         SourceValue = WideVal;
8015       // Give up if the source value type is smaller than the store size.
8016       if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
8017         return SDValue();
8018     }
8019 
8020     // Stores must share the same base address.
8021     BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
8022     int64_t ByteOffsetFromBase = 0;
8023     if (!Base)
8024       Base = Ptr;
8025     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
8026       return SDValue();
8027 
8028     // Remember the first store.
8029     if (ByteOffsetFromBase < FirstOffset) {
8030       FirstStore = Store;
8031       FirstOffset = ByteOffsetFromBase;
8032     }
8033     // Map the offset in the store and the offset in the combined value, and
8034     // early return if it has been set before.
8035     if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
8036       return SDValue();
8037     OffsetMap[Offset] = ByteOffsetFromBase;
8038   }
8039 
8040   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
8041   assert(FirstStore && "First store must be set");
8042 
8043   // Check that a store of the wide type is both allowed and fast on the target
8044   const DataLayout &Layout = DAG.getDataLayout();
8045   bool Fast = false;
8046   bool Allowed = TLI.allowsMemoryAccess(Context, Layout, WideVT,
8047                                         *FirstStore->getMemOperand(), &Fast);
8048   if (!Allowed || !Fast)
8049     return SDValue();
8050 
8051   // Check if the pieces of the value are going to the expected places in memory
8052   // to merge the stores.
8053   auto checkOffsets = [&](bool MatchLittleEndian) {
8054     if (MatchLittleEndian) {
8055       for (unsigned i = 0; i != NumStores; ++i)
8056         if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
8057           return false;
8058     } else { // MatchBigEndian by reversing loop counter.
8059       for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
8060         if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
8061           return false;
8062     }
8063     return true;
8064   };
8065 
8066   // Check if the offsets line up for the native data layout of this target.
8067   bool NeedBswap = false;
8068   bool NeedRotate = false;
8069   if (!checkOffsets(Layout.isLittleEndian())) {
8070     // Special-case: check if byte offsets line up for the opposite endian.
8071     if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
8072       NeedBswap = true;
8073     else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
8074       NeedRotate = true;
8075     else
8076       return SDValue();
8077   }
8078 
8079   SDLoc DL(N);
8080   if (WideVT != SourceValue.getValueType()) {
8081     assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
8082            "Unexpected store value to merge");
8083     SourceValue = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SourceValue);
8084   }
8085 
8086   // Before legalize we can introduce illegal bswaps/rotates which will be later
8087   // converted to an explicit bswap sequence. This way we end up with a single
8088   // store and byte shuffling instead of several stores and byte shuffling.
8089   if (NeedBswap) {
8090     SourceValue = DAG.getNode(ISD::BSWAP, DL, WideVT, SourceValue);
8091   } else if (NeedRotate) {
8092     assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
8093     SDValue RotAmt = DAG.getConstant(WideNumBits / 2, DL, WideVT);
8094     SourceValue = DAG.getNode(ISD::ROTR, DL, WideVT, SourceValue, RotAmt);
8095   }
8096 
8097   SDValue NewStore =
8098       DAG.getStore(Chain, DL, SourceValue, FirstStore->getBasePtr(),
8099                    FirstStore->getPointerInfo(), FirstStore->getAlign());
8100 
8101   // Rely on other DAG combine rules to remove the other individual stores.
8102   DAG.ReplaceAllUsesWith(N, NewStore.getNode());
8103   return NewStore;
8104 }
8105 
8106 /// Match a pattern where a wide type scalar value is loaded by several narrow
8107 /// loads and combined by shifts and ors. Fold it into a single load or a load
8108 /// and a BSWAP if the targets supports it.
8109 ///
8110 /// Assuming little endian target:
8111 ///  i8 *a = ...
8112 ///  i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
8113 /// =>
8114 ///  i32 val = *((i32)a)
8115 ///
8116 ///  i8 *a = ...
8117 ///  i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
8118 /// =>
8119 ///  i32 val = BSWAP(*((i32)a))
8120 ///
8121 /// TODO: This rule matches complex patterns with OR node roots and doesn't
8122 /// interact well with the worklist mechanism. When a part of the pattern is
8123 /// updated (e.g. one of the loads) its direct users are put into the worklist,
8124 /// but the root node of the pattern which triggers the load combine is not
8125 /// necessarily a direct user of the changed node. For example, once the address
8126 /// of t28 load is reassociated load combine won't be triggered:
8127 ///             t25: i32 = add t4, Constant:i32<2>
8128 ///           t26: i64 = sign_extend t25
8129 ///        t27: i64 = add t2, t26
8130 ///       t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
8131 ///     t29: i32 = zero_extend t28
8132 ///   t32: i32 = shl t29, Constant:i8<8>
8133 /// t33: i32 = or t23, t32
8134 /// As a possible fix visitLoad can check if the load can be a part of a load
8135 /// combine pattern and add corresponding OR roots to the worklist.
MatchLoadCombine(SDNode * N)8136 SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
8137   assert(N->getOpcode() == ISD::OR &&
8138          "Can only match load combining against OR nodes");
8139 
8140   // Handles simple types only
8141   EVT VT = N->getValueType(0);
8142   if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
8143     return SDValue();
8144   unsigned ByteWidth = VT.getSizeInBits() / 8;
8145 
8146   bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
8147   auto MemoryByteOffset = [&] (ByteProvider P) {
8148     assert(P.isMemory() && "Must be a memory byte provider");
8149     unsigned LoadBitWidth = P.Load->getMemoryVT().getSizeInBits();
8150     assert(LoadBitWidth % 8 == 0 &&
8151            "can only analyze providers for individual bytes not bit");
8152     unsigned LoadByteWidth = LoadBitWidth / 8;
8153     return IsBigEndianTarget
8154             ? bigEndianByteAt(LoadByteWidth, P.ByteOffset)
8155             : littleEndianByteAt(LoadByteWidth, P.ByteOffset);
8156   };
8157 
8158   Optional<BaseIndexOffset> Base;
8159   SDValue Chain;
8160 
8161   SmallPtrSet<LoadSDNode *, 8> Loads;
8162   Optional<ByteProvider> FirstByteProvider;
8163   int64_t FirstOffset = INT64_MAX;
8164 
8165   // Check if all the bytes of the OR we are looking at are loaded from the same
8166   // base address. Collect bytes offsets from Base address in ByteOffsets.
8167   SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
8168   unsigned ZeroExtendedBytes = 0;
8169   for (int i = ByteWidth - 1; i >= 0; --i) {
8170     auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*Root=*/true);
8171     if (!P)
8172       return SDValue();
8173 
8174     if (P->isConstantZero()) {
8175       // It's OK for the N most significant bytes to be 0, we can just
8176       // zero-extend the load.
8177       if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
8178         return SDValue();
8179       continue;
8180     }
8181     assert(P->isMemory() && "provenance should either be memory or zero");
8182 
8183     LoadSDNode *L = P->Load;
8184     assert(L->hasNUsesOfValue(1, 0) && L->isSimple() &&
8185            !L->isIndexed() &&
8186            "Must be enforced by calculateByteProvider");
8187     assert(L->getOffset().isUndef() && "Unindexed load must have undef offset");
8188 
8189     // All loads must share the same chain
8190     SDValue LChain = L->getChain();
8191     if (!Chain)
8192       Chain = LChain;
8193     else if (Chain != LChain)
8194       return SDValue();
8195 
8196     // Loads must share the same base address
8197     BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
8198     int64_t ByteOffsetFromBase = 0;
8199     if (!Base)
8200       Base = Ptr;
8201     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
8202       return SDValue();
8203 
8204     // Calculate the offset of the current byte from the base address
8205     ByteOffsetFromBase += MemoryByteOffset(*P);
8206     ByteOffsets[i] = ByteOffsetFromBase;
8207 
8208     // Remember the first byte load
8209     if (ByteOffsetFromBase < FirstOffset) {
8210       FirstByteProvider = P;
8211       FirstOffset = ByteOffsetFromBase;
8212     }
8213 
8214     Loads.insert(L);
8215   }
8216   assert(!Loads.empty() && "All the bytes of the value must be loaded from "
8217          "memory, so there must be at least one load which produces the value");
8218   assert(Base && "Base address of the accessed memory location must be set");
8219   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
8220 
8221   bool NeedsZext = ZeroExtendedBytes > 0;
8222 
8223   EVT MemVT =
8224       EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
8225 
8226   if (!MemVT.isSimple())
8227     return SDValue();
8228 
8229   // Before legalize we can introduce too wide illegal loads which will be later
8230   // split into legal sized loads. This enables us to combine i64 load by i8
8231   // patterns to a couple of i32 loads on 32 bit targets.
8232   if (LegalOperations &&
8233       !TLI.isOperationLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
8234                             MemVT))
8235     return SDValue();
8236 
8237   // Check if the bytes of the OR we are looking at match with either big or
8238   // little endian value load
8239   Optional<bool> IsBigEndian = isBigEndian(
8240       makeArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
8241   if (!IsBigEndian)
8242     return SDValue();
8243 
8244   assert(FirstByteProvider && "must be set");
8245 
8246   // Ensure that the first byte is loaded from zero offset of the first load.
8247   // So the combined value can be loaded from the first load address.
8248   if (MemoryByteOffset(*FirstByteProvider) != 0)
8249     return SDValue();
8250   LoadSDNode *FirstLoad = FirstByteProvider->Load;
8251 
8252   // The node we are looking at matches with the pattern, check if we can
8253   // replace it with a single (possibly zero-extended) load and bswap + shift if
8254   // needed.
8255 
8256   // If the load needs byte swap check if the target supports it
8257   bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
8258 
8259   // Before legalize we can introduce illegal bswaps which will be later
8260   // converted to an explicit bswap sequence. This way we end up with a single
8261   // load and byte shuffling instead of several loads and byte shuffling.
8262   // We do not introduce illegal bswaps when zero-extending as this tends to
8263   // introduce too many arithmetic instructions.
8264   if (NeedsBswap && (LegalOperations || NeedsZext) &&
8265       !TLI.isOperationLegal(ISD::BSWAP, VT))
8266     return SDValue();
8267 
8268   // If we need to bswap and zero extend, we have to insert a shift. Check that
8269   // it is legal.
8270   if (NeedsBswap && NeedsZext && LegalOperations &&
8271       !TLI.isOperationLegal(ISD::SHL, VT))
8272     return SDValue();
8273 
8274   // Check that a load of the wide type is both allowed and fast on the target
8275   bool Fast = false;
8276   bool Allowed =
8277       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
8278                              *FirstLoad->getMemOperand(), &Fast);
8279   if (!Allowed || !Fast)
8280     return SDValue();
8281 
8282   SDValue NewLoad =
8283       DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
8284                      Chain, FirstLoad->getBasePtr(),
8285                      FirstLoad->getPointerInfo(), MemVT, FirstLoad->getAlign());
8286 
8287   // Transfer chain users from old loads to the new load.
8288   for (LoadSDNode *L : Loads)
8289     DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
8290 
8291   if (!NeedsBswap)
8292     return NewLoad;
8293 
8294   SDValue ShiftedLoad =
8295       NeedsZext
8296           ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
8297                         DAG.getShiftAmountConstant(ZeroExtendedBytes * 8, VT,
8298                                                    SDLoc(N), LegalOperations))
8299           : NewLoad;
8300   return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
8301 }
8302 
8303 // If the target has andn, bsl, or a similar bit-select instruction,
8304 // we want to unfold masked merge, with canonical pattern of:
8305 //   |        A  |  |B|
8306 //   ((x ^ y) & m) ^ y
8307 //    |  D  |
8308 // Into:
8309 //   (x & m) | (y & ~m)
8310 // If y is a constant, m is not a 'not', and the 'andn' does not work with
8311 // immediates, we unfold into a different pattern:
8312 //   ~(~x & m) & (m | y)
8313 // If x is a constant, m is a 'not', and the 'andn' does not work with
8314 // immediates, we unfold into a different pattern:
8315 //   (x | ~m) & ~(~m & ~y)
8316 // NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
8317 //       the very least that breaks andnpd / andnps patterns, and because those
8318 //       patterns are simplified in IR and shouldn't be created in the DAG
unfoldMaskedMerge(SDNode * N)8319 SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
8320   assert(N->getOpcode() == ISD::XOR);
8321 
8322   // Don't touch 'not' (i.e. where y = -1).
8323   if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
8324     return SDValue();
8325 
8326   EVT VT = N->getValueType(0);
8327 
8328   // There are 3 commutable operators in the pattern,
8329   // so we have to deal with 8 possible variants of the basic pattern.
8330   SDValue X, Y, M;
8331   auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
8332     if (And.getOpcode() != ISD::AND || !And.hasOneUse())
8333       return false;
8334     SDValue Xor = And.getOperand(XorIdx);
8335     if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
8336       return false;
8337     SDValue Xor0 = Xor.getOperand(0);
8338     SDValue Xor1 = Xor.getOperand(1);
8339     // Don't touch 'not' (i.e. where y = -1).
8340     if (isAllOnesOrAllOnesSplat(Xor1))
8341       return false;
8342     if (Other == Xor0)
8343       std::swap(Xor0, Xor1);
8344     if (Other != Xor1)
8345       return false;
8346     X = Xor0;
8347     Y = Xor1;
8348     M = And.getOperand(XorIdx ? 0 : 1);
8349     return true;
8350   };
8351 
8352   SDValue N0 = N->getOperand(0);
8353   SDValue N1 = N->getOperand(1);
8354   if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
8355       !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
8356     return SDValue();
8357 
8358   // Don't do anything if the mask is constant. This should not be reachable.
8359   // InstCombine should have already unfolded this pattern, and DAGCombiner
8360   // probably shouldn't produce it, too.
8361   if (isa<ConstantSDNode>(M.getNode()))
8362     return SDValue();
8363 
8364   // We can transform if the target has AndNot
8365   if (!TLI.hasAndNot(M))
8366     return SDValue();
8367 
8368   SDLoc DL(N);
8369 
8370   // If Y is a constant, check that 'andn' works with immediates. Unless M is
8371   // a bitwise not that would already allow ANDN to be used.
8372   if (!TLI.hasAndNot(Y) && !isBitwiseNot(M)) {
8373     assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
8374     // If not, we need to do a bit more work to make sure andn is still used.
8375     SDValue NotX = DAG.getNOT(DL, X, VT);
8376     SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
8377     SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
8378     SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
8379     return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
8380   }
8381 
8382   // If X is a constant and M is a bitwise not, check that 'andn' works with
8383   // immediates.
8384   if (!TLI.hasAndNot(X) && isBitwiseNot(M)) {
8385     assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
8386     // If not, we need to do a bit more work to make sure andn is still used.
8387     SDValue NotM = M.getOperand(0);
8388     SDValue LHS = DAG.getNode(ISD::OR, DL, VT, X, NotM);
8389     SDValue NotY = DAG.getNOT(DL, Y, VT);
8390     SDValue RHS = DAG.getNode(ISD::AND, DL, VT, NotM, NotY);
8391     SDValue NotRHS = DAG.getNOT(DL, RHS, VT);
8392     return DAG.getNode(ISD::AND, DL, VT, LHS, NotRHS);
8393   }
8394 
8395   SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
8396   SDValue NotM = DAG.getNOT(DL, M, VT);
8397   SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
8398 
8399   return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
8400 }
8401 
visitXOR(SDNode * N)8402 SDValue DAGCombiner::visitXOR(SDNode *N) {
8403   SDValue N0 = N->getOperand(0);
8404   SDValue N1 = N->getOperand(1);
8405   EVT VT = N0.getValueType();
8406   SDLoc DL(N);
8407 
8408   // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
8409   if (N0.isUndef() && N1.isUndef())
8410     return DAG.getConstant(0, DL, VT);
8411 
8412   // fold (xor x, undef) -> undef
8413   if (N0.isUndef())
8414     return N0;
8415   if (N1.isUndef())
8416     return N1;
8417 
8418   // fold (xor c1, c2) -> c1^c2
8419   if (SDValue C = DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, {N0, N1}))
8420     return C;
8421 
8422   // canonicalize constant to RHS
8423   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
8424       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
8425     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
8426 
8427   // fold vector ops
8428   if (VT.isVector()) {
8429     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
8430       return FoldedVOp;
8431 
8432     // fold (xor x, 0) -> x, vector edition
8433     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
8434       return N0;
8435   }
8436 
8437   // fold (xor x, 0) -> x
8438   if (isNullConstant(N1))
8439     return N0;
8440 
8441   if (SDValue NewSel = foldBinOpIntoSelect(N))
8442     return NewSel;
8443 
8444   // reassociate xor
8445   if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
8446     return RXOR;
8447 
8448   // look for 'add-like' folds:
8449   // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
8450   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
8451       isMinSignedConstant(N1))
8452     if (SDValue Combined = visitADDLike(N))
8453       return Combined;
8454 
8455   // fold !(x cc y) -> (x !cc y)
8456   unsigned N0Opcode = N0.getOpcode();
8457   SDValue LHS, RHS, CC;
8458   if (TLI.isConstTrueVal(N1) &&
8459       isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true)) {
8460     ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
8461                                                LHS.getValueType());
8462     if (!LegalOperations ||
8463         TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
8464       switch (N0Opcode) {
8465       default:
8466         llvm_unreachable("Unhandled SetCC Equivalent!");
8467       case ISD::SETCC:
8468         return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
8469       case ISD::SELECT_CC:
8470         return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
8471                                N0.getOperand(3), NotCC);
8472       case ISD::STRICT_FSETCC:
8473       case ISD::STRICT_FSETCCS: {
8474         if (N0.hasOneUse()) {
8475           // FIXME Can we handle multiple uses? Could we token factor the chain
8476           // results from the new/old setcc?
8477           SDValue SetCC =
8478               DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC,
8479                            N0.getOperand(0), N0Opcode == ISD::STRICT_FSETCCS);
8480           CombineTo(N, SetCC);
8481           DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), SetCC.getValue(1));
8482           recursivelyDeleteUnusedNodes(N0.getNode());
8483           return SDValue(N, 0); // Return N so it doesn't get rechecked!
8484         }
8485         break;
8486       }
8487       }
8488     }
8489   }
8490 
8491   // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
8492   if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
8493       isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
8494     SDValue V = N0.getOperand(0);
8495     SDLoc DL0(N0);
8496     V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
8497                     DAG.getConstant(1, DL0, V.getValueType()));
8498     AddToWorklist(V.getNode());
8499     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
8500   }
8501 
8502   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
8503   if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
8504       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
8505     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
8506     if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
8507       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
8508       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
8509       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
8510       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
8511       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
8512     }
8513   }
8514   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
8515   if (isAllOnesConstant(N1) && N0.hasOneUse() &&
8516       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
8517     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
8518     if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
8519       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
8520       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
8521       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
8522       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
8523       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
8524     }
8525   }
8526 
8527   // fold (not (neg x)) -> (add X, -1)
8528   // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
8529   // Y is a constant or the subtract has a single use.
8530   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::SUB &&
8531       isNullConstant(N0.getOperand(0))) {
8532     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
8533                        DAG.getAllOnesConstant(DL, VT));
8534   }
8535 
8536   // fold (not (add X, -1)) -> (neg X)
8537   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::ADD &&
8538       isAllOnesOrAllOnesSplat(N0.getOperand(1))) {
8539     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
8540                        N0.getOperand(0));
8541   }
8542 
8543   // fold (xor (and x, y), y) -> (and (not x), y)
8544   if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
8545     SDValue X = N0.getOperand(0);
8546     SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
8547     AddToWorklist(NotX.getNode());
8548     return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
8549   }
8550 
8551   // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
8552   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
8553     SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
8554     SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
8555     if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
8556       SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
8557       SDValue S0 = S.getOperand(0);
8558       if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
8559         if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
8560           if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
8561             return DAG.getNode(ISD::ABS, DL, VT, S0);
8562     }
8563   }
8564 
8565   // fold (xor x, x) -> 0
8566   if (N0 == N1)
8567     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
8568 
8569   // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
8570   // Here is a concrete example of this equivalence:
8571   // i16   x ==  14
8572   // i16 shl ==   1 << 14  == 16384 == 0b0100000000000000
8573   // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
8574   //
8575   // =>
8576   //
8577   // i16     ~1      == 0b1111111111111110
8578   // i16 rol(~1, 14) == 0b1011111111111111
8579   //
8580   // Some additional tips to help conceptualize this transform:
8581   // - Try to see the operation as placing a single zero in a value of all ones.
8582   // - There exists no value for x which would allow the result to contain zero.
8583   // - Values of x larger than the bitwidth are undefined and do not require a
8584   //   consistent result.
8585   // - Pushing the zero left requires shifting one bits in from the right.
8586   // A rotate left of ~1 is a nice way of achieving the desired result.
8587   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
8588       isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
8589     return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT),
8590                        N0.getOperand(1));
8591   }
8592 
8593   // Simplify: xor (op x...), (op y...)  -> (op (xor x, y))
8594   if (N0Opcode == N1.getOpcode())
8595     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
8596       return V;
8597 
8598   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
8599     return R;
8600   if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
8601     return R;
8602 
8603   // Unfold  ((x ^ y) & m) ^ y  into  (x & m) | (y & ~m)  if profitable
8604   if (SDValue MM = unfoldMaskedMerge(N))
8605     return MM;
8606 
8607   // Simplify the expression using non-local knowledge.
8608   if (SimplifyDemandedBits(SDValue(N, 0)))
8609     return SDValue(N, 0);
8610 
8611   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
8612     return Combined;
8613 
8614   return SDValue();
8615 }
8616 
8617 /// If we have a shift-by-constant of a bitwise logic op that itself has a
8618 /// shift-by-constant operand with identical opcode, we may be able to convert
8619 /// that into 2 independent shifts followed by the logic op. This is a
8620 /// throughput improvement.
combineShiftOfShiftedLogic(SDNode * Shift,SelectionDAG & DAG)8621 static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
8622   // Match a one-use bitwise logic op.
8623   SDValue LogicOp = Shift->getOperand(0);
8624   if (!LogicOp.hasOneUse())
8625     return SDValue();
8626 
8627   unsigned LogicOpcode = LogicOp.getOpcode();
8628   if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
8629       LogicOpcode != ISD::XOR)
8630     return SDValue();
8631 
8632   // Find a matching one-use shift by constant.
8633   unsigned ShiftOpcode = Shift->getOpcode();
8634   SDValue C1 = Shift->getOperand(1);
8635   ConstantSDNode *C1Node = isConstOrConstSplat(C1);
8636   assert(C1Node && "Expected a shift with constant operand");
8637   const APInt &C1Val = C1Node->getAPIntValue();
8638   auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
8639                              const APInt *&ShiftAmtVal) {
8640     if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
8641       return false;
8642 
8643     ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
8644     if (!ShiftCNode)
8645       return false;
8646 
8647     // Capture the shifted operand and shift amount value.
8648     ShiftOp = V.getOperand(0);
8649     ShiftAmtVal = &ShiftCNode->getAPIntValue();
8650 
8651     // Shift amount types do not have to match their operand type, so check that
8652     // the constants are the same width.
8653     if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
8654       return false;
8655 
8656     // The fold is not valid if the sum of the shift values exceeds bitwidth.
8657     if ((*ShiftAmtVal + C1Val).uge(V.getScalarValueSizeInBits()))
8658       return false;
8659 
8660     return true;
8661   };
8662 
8663   // Logic ops are commutative, so check each operand for a match.
8664   SDValue X, Y;
8665   const APInt *C0Val;
8666   if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
8667     Y = LogicOp.getOperand(1);
8668   else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
8669     Y = LogicOp.getOperand(0);
8670   else
8671     return SDValue();
8672 
8673   // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
8674   SDLoc DL(Shift);
8675   EVT VT = Shift->getValueType(0);
8676   EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
8677   SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
8678   SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
8679   SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
8680   return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2);
8681 }
8682 
8683 /// Handle transforms common to the three shifts, when the shift amount is a
8684 /// constant.
8685 /// We are looking for: (shift being one of shl/sra/srl)
8686 ///   shift (binop X, C0), C1
8687 /// And want to transform into:
8688 ///   binop (shift X, C1), (shift C0, C1)
visitShiftByConstant(SDNode * N)8689 SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
8690   assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
8691 
8692   // Do not turn a 'not' into a regular xor.
8693   if (isBitwiseNot(N->getOperand(0)))
8694     return SDValue();
8695 
8696   // The inner binop must be one-use, since we want to replace it.
8697   SDValue LHS = N->getOperand(0);
8698   if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
8699     return SDValue();
8700 
8701   // TODO: This is limited to early combining because it may reveal regressions
8702   //       otherwise. But since we just checked a target hook to see if this is
8703   //       desirable, that should have filtered out cases where this interferes
8704   //       with some other pattern matching.
8705   if (!LegalTypes)
8706     if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
8707       return R;
8708 
8709   // We want to pull some binops through shifts, so that we have (and (shift))
8710   // instead of (shift (and)), likewise for add, or, xor, etc.  This sort of
8711   // thing happens with address calculations, so it's important to canonicalize
8712   // it.
8713   switch (LHS.getOpcode()) {
8714   default:
8715     return SDValue();
8716   case ISD::OR:
8717   case ISD::XOR:
8718   case ISD::AND:
8719     break;
8720   case ISD::ADD:
8721     if (N->getOpcode() != ISD::SHL)
8722       return SDValue(); // only shl(add) not sr[al](add).
8723     break;
8724   }
8725 
8726   // We require the RHS of the binop to be a constant and not opaque as well.
8727   ConstantSDNode *BinOpCst = getAsNonOpaqueConstant(LHS.getOperand(1));
8728   if (!BinOpCst)
8729     return SDValue();
8730 
8731   // FIXME: disable this unless the input to the binop is a shift by a constant
8732   // or is copy/select. Enable this in other cases when figure out it's exactly
8733   // profitable.
8734   SDValue BinOpLHSVal = LHS.getOperand(0);
8735   bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
8736                             BinOpLHSVal.getOpcode() == ISD::SRA ||
8737                             BinOpLHSVal.getOpcode() == ISD::SRL) &&
8738                            isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
8739   bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
8740                         BinOpLHSVal.getOpcode() == ISD::SELECT;
8741 
8742   if (!IsShiftByConstant && !IsCopyOrSelect)
8743     return SDValue();
8744 
8745   if (IsCopyOrSelect && N->hasOneUse())
8746     return SDValue();
8747 
8748   // Fold the constants, shifting the binop RHS by the shift amount.
8749   SDLoc DL(N);
8750   EVT VT = N->getValueType(0);
8751   SDValue NewRHS = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(1),
8752                                N->getOperand(1));
8753   assert(isa<ConstantSDNode>(NewRHS) && "Folding was not successful!");
8754 
8755   SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
8756                                  N->getOperand(1));
8757   return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
8758 }
8759 
distributeTruncateThroughAnd(SDNode * N)8760 SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
8761   assert(N->getOpcode() == ISD::TRUNCATE);
8762   assert(N->getOperand(0).getOpcode() == ISD::AND);
8763 
8764   // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
8765   EVT TruncVT = N->getValueType(0);
8766   if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
8767       TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
8768     SDValue N01 = N->getOperand(0).getOperand(1);
8769     if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
8770       SDLoc DL(N);
8771       SDValue N00 = N->getOperand(0).getOperand(0);
8772       SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
8773       SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
8774       AddToWorklist(Trunc00.getNode());
8775       AddToWorklist(Trunc01.getNode());
8776       return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
8777     }
8778   }
8779 
8780   return SDValue();
8781 }
8782 
visitRotate(SDNode * N)8783 SDValue DAGCombiner::visitRotate(SDNode *N) {
8784   SDLoc dl(N);
8785   SDValue N0 = N->getOperand(0);
8786   SDValue N1 = N->getOperand(1);
8787   EVT VT = N->getValueType(0);
8788   unsigned Bitsize = VT.getScalarSizeInBits();
8789 
8790   // fold (rot x, 0) -> x
8791   if (isNullOrNullSplat(N1))
8792     return N0;
8793 
8794   // fold (rot x, c) -> x iff (c % BitSize) == 0
8795   if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
8796     APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
8797     if (DAG.MaskedValueIsZero(N1, ModuloMask))
8798       return N0;
8799   }
8800 
8801   // fold (rot x, c) -> (rot x, c % BitSize)
8802   bool OutOfRange = false;
8803   auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
8804     OutOfRange |= C->getAPIntValue().uge(Bitsize);
8805     return true;
8806   };
8807   if (ISD::matchUnaryPredicate(N1, MatchOutOfRange) && OutOfRange) {
8808     EVT AmtVT = N1.getValueType();
8809     SDValue Bits = DAG.getConstant(Bitsize, dl, AmtVT);
8810     if (SDValue Amt =
8811             DAG.FoldConstantArithmetic(ISD::UREM, dl, AmtVT, {N1, Bits}))
8812       return DAG.getNode(N->getOpcode(), dl, VT, N0, Amt);
8813   }
8814 
8815   // rot i16 X, 8 --> bswap X
8816   auto *RotAmtC = isConstOrConstSplat(N1);
8817   if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
8818       VT.getScalarSizeInBits() == 16 && hasOperation(ISD::BSWAP, VT))
8819     return DAG.getNode(ISD::BSWAP, dl, VT, N0);
8820 
8821   // Simplify the operands using demanded-bits information.
8822   if (SimplifyDemandedBits(SDValue(N, 0)))
8823     return SDValue(N, 0);
8824 
8825   // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
8826   if (N1.getOpcode() == ISD::TRUNCATE &&
8827       N1.getOperand(0).getOpcode() == ISD::AND) {
8828     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8829       return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
8830   }
8831 
8832   unsigned NextOp = N0.getOpcode();
8833 
8834   // fold (rot* (rot* x, c2), c1)
8835   //   -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize)) % bitsize)
8836   if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
8837     SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
8838     SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
8839     if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) {
8840       EVT ShiftVT = C1->getValueType(0);
8841       bool SameSide = (N->getOpcode() == NextOp);
8842       unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
8843       SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
8844       SDValue Norm1 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
8845                                                  {N1, BitsizeC});
8846       SDValue Norm2 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
8847                                                  {N0.getOperand(1), BitsizeC});
8848       if (Norm1 && Norm2)
8849         if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
8850                 CombineOp, dl, ShiftVT, {Norm1, Norm2})) {
8851           SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
8852               ISD::UREM, dl, ShiftVT, {CombinedShift, BitsizeC});
8853           return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
8854                              CombinedShiftNorm);
8855         }
8856     }
8857   }
8858   return SDValue();
8859 }
8860 
visitSHL(SDNode * N)8861 SDValue DAGCombiner::visitSHL(SDNode *N) {
8862   SDValue N0 = N->getOperand(0);
8863   SDValue N1 = N->getOperand(1);
8864   if (SDValue V = DAG.simplifyShift(N0, N1))
8865     return V;
8866 
8867   EVT VT = N0.getValueType();
8868   EVT ShiftVT = N1.getValueType();
8869   unsigned OpSizeInBits = VT.getScalarSizeInBits();
8870 
8871   // fold (shl c1, c2) -> c1<<c2
8872   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N0, N1}))
8873     return C;
8874 
8875   // fold vector ops
8876   if (VT.isVector()) {
8877     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
8878       return FoldedVOp;
8879 
8880     BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
8881     // If setcc produces all-one true value then:
8882     // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
8883     if (N1CV && N1CV->isConstant()) {
8884       if (N0.getOpcode() == ISD::AND) {
8885         SDValue N00 = N0->getOperand(0);
8886         SDValue N01 = N0->getOperand(1);
8887         BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
8888 
8889         if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
8890             TLI.getBooleanContents(N00.getOperand(0).getValueType()) ==
8891                 TargetLowering::ZeroOrNegativeOneBooleanContent) {
8892           if (SDValue C =
8893                   DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N01, N1}))
8894             return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C);
8895         }
8896       }
8897     }
8898   }
8899 
8900   if (SDValue NewSel = foldBinOpIntoSelect(N))
8901     return NewSel;
8902 
8903   // if (shl x, c) is known to be zero, return 0
8904   if (DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
8905     return DAG.getConstant(0, SDLoc(N), VT);
8906 
8907   // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
8908   if (N1.getOpcode() == ISD::TRUNCATE &&
8909       N1.getOperand(0).getOpcode() == ISD::AND) {
8910     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8911       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1);
8912   }
8913 
8914   if (SimplifyDemandedBits(SDValue(N, 0)))
8915     return SDValue(N, 0);
8916 
8917   // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
8918   if (N0.getOpcode() == ISD::SHL) {
8919     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
8920                                           ConstantSDNode *RHS) {
8921       APInt c1 = LHS->getAPIntValue();
8922       APInt c2 = RHS->getAPIntValue();
8923       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8924       return (c1 + c2).uge(OpSizeInBits);
8925     };
8926     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
8927       return DAG.getConstant(0, SDLoc(N), VT);
8928 
8929     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
8930                                        ConstantSDNode *RHS) {
8931       APInt c1 = LHS->getAPIntValue();
8932       APInt c2 = RHS->getAPIntValue();
8933       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8934       return (c1 + c2).ult(OpSizeInBits);
8935     };
8936     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
8937       SDLoc DL(N);
8938       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
8939       return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
8940     }
8941   }
8942 
8943   // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
8944   // For this to be valid, the second form must not preserve any of the bits
8945   // that are shifted out by the inner shift in the first form.  This means
8946   // the outer shift size must be >= the number of bits added by the ext.
8947   // As a corollary, we don't care what kind of ext it is.
8948   if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
8949        N0.getOpcode() == ISD::ANY_EXTEND ||
8950        N0.getOpcode() == ISD::SIGN_EXTEND) &&
8951       N0.getOperand(0).getOpcode() == ISD::SHL) {
8952     SDValue N0Op0 = N0.getOperand(0);
8953     SDValue InnerShiftAmt = N0Op0.getOperand(1);
8954     EVT InnerVT = N0Op0.getValueType();
8955     uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
8956 
8957     auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
8958                                                          ConstantSDNode *RHS) {
8959       APInt c1 = LHS->getAPIntValue();
8960       APInt c2 = RHS->getAPIntValue();
8961       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8962       return c2.uge(OpSizeInBits - InnerBitwidth) &&
8963              (c1 + c2).uge(OpSizeInBits);
8964     };
8965     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
8966                                   /*AllowUndefs*/ false,
8967                                   /*AllowTypeMismatch*/ true))
8968       return DAG.getConstant(0, SDLoc(N), VT);
8969 
8970     auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
8971                                                       ConstantSDNode *RHS) {
8972       APInt c1 = LHS->getAPIntValue();
8973       APInt c2 = RHS->getAPIntValue();
8974       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8975       return c2.uge(OpSizeInBits - InnerBitwidth) &&
8976              (c1 + c2).ult(OpSizeInBits);
8977     };
8978     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
8979                                   /*AllowUndefs*/ false,
8980                                   /*AllowTypeMismatch*/ true)) {
8981       SDLoc DL(N);
8982       SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
8983       SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
8984       Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
8985       return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
8986     }
8987   }
8988 
8989   // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
8990   // Only fold this if the inner zext has no other uses to avoid increasing
8991   // the total number of instructions.
8992   if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
8993       N0.getOperand(0).getOpcode() == ISD::SRL) {
8994     SDValue N0Op0 = N0.getOperand(0);
8995     SDValue InnerShiftAmt = N0Op0.getOperand(1);
8996 
8997     auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
8998       APInt c1 = LHS->getAPIntValue();
8999       APInt c2 = RHS->getAPIntValue();
9000       zeroExtendToMatch(c1, c2);
9001       return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
9002     };
9003     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
9004                                   /*AllowUndefs*/ false,
9005                                   /*AllowTypeMismatch*/ true)) {
9006       SDLoc DL(N);
9007       EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
9008       SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
9009       NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
9010       AddToWorklist(NewSHL.getNode());
9011       return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
9012     }
9013   }
9014 
9015   if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
9016     auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
9017                                            ConstantSDNode *RHS) {
9018       const APInt &LHSC = LHS->getAPIntValue();
9019       const APInt &RHSC = RHS->getAPIntValue();
9020       return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
9021              LHSC.getZExtValue() <= RHSC.getZExtValue();
9022     };
9023 
9024     SDLoc DL(N);
9025 
9026     // fold (shl (sr[la] exact X,  C1), C2) -> (shl    X, (C2-C1)) if C1 <= C2
9027     // fold (shl (sr[la] exact X,  C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
9028     if (N0->getFlags().hasExact()) {
9029       if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
9030                                     /*AllowUndefs*/ false,
9031                                     /*AllowTypeMismatch*/ true)) {
9032         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9033         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
9034         return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
9035       }
9036       if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
9037                                     /*AllowUndefs*/ false,
9038                                     /*AllowTypeMismatch*/ true)) {
9039         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9040         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
9041         return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Diff);
9042       }
9043     }
9044 
9045     // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
9046     //                               (and (srl x, (sub c1, c2), MASK)
9047     // Only fold this if the inner shift has no other uses -- if it does,
9048     // folding this will increase the total number of instructions.
9049     if (N0.getOpcode() == ISD::SRL &&
9050         (N0.getOperand(1) == N1 || N0.hasOneUse()) &&
9051         TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
9052       if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
9053                                     /*AllowUndefs*/ false,
9054                                     /*AllowTypeMismatch*/ true)) {
9055         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9056         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
9057         SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9058         Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N01);
9059         Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, Diff);
9060         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
9061         return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9062       }
9063       if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
9064                                     /*AllowUndefs*/ false,
9065                                     /*AllowTypeMismatch*/ true)) {
9066         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9067         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
9068         SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9069         Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N1);
9070         SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
9071         return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9072       }
9073     }
9074   }
9075 
9076   // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
9077   if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
9078       isConstantOrConstantVector(N1, /* No Opaques */ true)) {
9079     SDLoc DL(N);
9080     SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
9081     SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
9082     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
9083   }
9084 
9085   // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
9086   // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
9087   // Variant of version done on multiply, except mul by a power of 2 is turned
9088   // into a shift.
9089   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
9090       N0->hasOneUse() &&
9091       isConstantOrConstantVector(N1, /* No Opaques */ true) &&
9092       isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) &&
9093       TLI.isDesirableToCommuteWithShift(N, Level)) {
9094     SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
9095     SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
9096     AddToWorklist(Shl0.getNode());
9097     AddToWorklist(Shl1.getNode());
9098     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Shl0, Shl1);
9099   }
9100 
9101   // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
9102   if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
9103     SDValue N01 = N0.getOperand(1);
9104     if (SDValue Shl =
9105             DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1}))
9106       return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl);
9107   }
9108 
9109   ConstantSDNode *N1C = isConstOrConstSplat(N1);
9110   if (N1C && !N1C->isOpaque())
9111     if (SDValue NewSHL = visitShiftByConstant(N))
9112       return NewSHL;
9113 
9114   // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
9115   if (N0.getOpcode() == ISD::VSCALE)
9116     if (ConstantSDNode *NC1 = isConstOrConstSplat(N->getOperand(1))) {
9117       const APInt &C0 = N0.getConstantOperandAPInt(0);
9118       const APInt &C1 = NC1->getAPIntValue();
9119       return DAG.getVScale(SDLoc(N), VT, C0 << C1);
9120     }
9121 
9122   // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
9123   APInt ShlVal;
9124   if (N0.getOpcode() == ISD::STEP_VECTOR)
9125     if (ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
9126       const APInt &C0 = N0.getConstantOperandAPInt(0);
9127       if (ShlVal.ult(C0.getBitWidth())) {
9128         APInt NewStep = C0 << ShlVal;
9129         return DAG.getStepVector(SDLoc(N), VT, NewStep);
9130       }
9131     }
9132 
9133   return SDValue();
9134 }
9135 
9136 // Transform a right shift of a multiply into a multiply-high.
9137 // Examples:
9138 // (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
9139 // (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
combineShiftToMULH(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)9140 static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG,
9141                                   const TargetLowering &TLI) {
9142   assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
9143          "SRL or SRA node is required here!");
9144 
9145   // Check the shift amount. Proceed with the transformation if the shift
9146   // amount is constant.
9147   ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1));
9148   if (!ShiftAmtSrc)
9149     return SDValue();
9150 
9151   SDLoc DL(N);
9152 
9153   // The operation feeding into the shift must be a multiply.
9154   SDValue ShiftOperand = N->getOperand(0);
9155   if (ShiftOperand.getOpcode() != ISD::MUL)
9156     return SDValue();
9157 
9158   // Both operands must be equivalent extend nodes.
9159   SDValue LeftOp = ShiftOperand.getOperand(0);
9160   SDValue RightOp = ShiftOperand.getOperand(1);
9161 
9162   bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
9163   bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
9164 
9165   if (!IsSignExt && !IsZeroExt)
9166     return SDValue();
9167 
9168   EVT NarrowVT = LeftOp.getOperand(0).getValueType();
9169   unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
9170 
9171   SDValue MulhRightOp;
9172   if (ConstantSDNode *Constant = isConstOrConstSplat(RightOp)) {
9173     unsigned ActiveBits = IsSignExt
9174                               ? Constant->getAPIntValue().getMinSignedBits()
9175                               : Constant->getAPIntValue().getActiveBits();
9176     if (ActiveBits > NarrowVTSize)
9177       return SDValue();
9178     MulhRightOp = DAG.getConstant(
9179         Constant->getAPIntValue().trunc(NarrowVT.getScalarSizeInBits()), DL,
9180         NarrowVT);
9181   } else {
9182     if (LeftOp.getOpcode() != RightOp.getOpcode())
9183       return SDValue();
9184     // Check that the two extend nodes are the same type.
9185     if (NarrowVT != RightOp.getOperand(0).getValueType())
9186       return SDValue();
9187     MulhRightOp = RightOp.getOperand(0);
9188   }
9189 
9190   EVT WideVT = LeftOp.getValueType();
9191   // Proceed with the transformation if the wide types match.
9192   assert((WideVT == RightOp.getValueType()) &&
9193          "Cannot have a multiply node with two different operand types.");
9194 
9195   // Proceed with the transformation if the wide type is twice as large
9196   // as the narrow type.
9197   if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
9198     return SDValue();
9199 
9200   // Check the shift amount with the narrow type size.
9201   // Proceed with the transformation if the shift amount is the width
9202   // of the narrow type.
9203   unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
9204   if (ShiftAmt != NarrowVTSize)
9205     return SDValue();
9206 
9207   // If the operation feeding into the MUL is a sign extend (sext),
9208   // we use mulhs. Othewise, zero extends (zext) use mulhu.
9209   unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
9210 
9211   // Combine to mulh if mulh is legal/custom for the narrow type on the target.
9212   if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT))
9213     return SDValue();
9214 
9215   SDValue Result =
9216       DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0), MulhRightOp);
9217   return (N->getOpcode() == ISD::SRA ? DAG.getSExtOrTrunc(Result, DL, WideVT)
9218                                      : DAG.getZExtOrTrunc(Result, DL, WideVT));
9219 }
9220 
visitSRA(SDNode * N)9221 SDValue DAGCombiner::visitSRA(SDNode *N) {
9222   SDValue N0 = N->getOperand(0);
9223   SDValue N1 = N->getOperand(1);
9224   if (SDValue V = DAG.simplifyShift(N0, N1))
9225     return V;
9226 
9227   EVT VT = N0.getValueType();
9228   unsigned OpSizeInBits = VT.getScalarSizeInBits();
9229 
9230   // fold (sra c1, c2) -> (sra c1, c2)
9231   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, {N0, N1}))
9232     return C;
9233 
9234   // Arithmetic shifting an all-sign-bit value is a no-op.
9235   // fold (sra 0, x) -> 0
9236   // fold (sra -1, x) -> -1
9237   if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
9238     return N0;
9239 
9240   // fold vector ops
9241   if (VT.isVector())
9242     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
9243       return FoldedVOp;
9244 
9245   if (SDValue NewSel = foldBinOpIntoSelect(N))
9246     return NewSel;
9247 
9248   // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
9249   // sext_inreg.
9250   ConstantSDNode *N1C = isConstOrConstSplat(N1);
9251   if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
9252     unsigned LowBits = OpSizeInBits - (unsigned)N1C->getZExtValue();
9253     EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), LowBits);
9254     if (VT.isVector())
9255       ExtVT = EVT::getVectorVT(*DAG.getContext(), ExtVT,
9256                                VT.getVectorElementCount());
9257     if (!LegalOperations ||
9258         TLI.getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) ==
9259         TargetLowering::Legal)
9260       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
9261                          N0.getOperand(0), DAG.getValueType(ExtVT));
9262     // Even if we can't convert to sext_inreg, we might be able to remove
9263     // this shift pair if the input is already sign extended.
9264     if (DAG.ComputeNumSignBits(N0.getOperand(0)) > N1C->getZExtValue())
9265       return N0.getOperand(0);
9266   }
9267 
9268   // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
9269   // clamp (add c1, c2) to max shift.
9270   if (N0.getOpcode() == ISD::SRA) {
9271     SDLoc DL(N);
9272     EVT ShiftVT = N1.getValueType();
9273     EVT ShiftSVT = ShiftVT.getScalarType();
9274     SmallVector<SDValue, 16> ShiftValues;
9275 
9276     auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
9277       APInt c1 = LHS->getAPIntValue();
9278       APInt c2 = RHS->getAPIntValue();
9279       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9280       APInt Sum = c1 + c2;
9281       unsigned ShiftSum =
9282           Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
9283       ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
9284       return true;
9285     };
9286     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
9287       SDValue ShiftValue;
9288       if (N1.getOpcode() == ISD::BUILD_VECTOR)
9289         ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
9290       else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
9291         assert(ShiftValues.size() == 1 &&
9292                "Expected matchBinaryPredicate to return one element for "
9293                "SPLAT_VECTORs");
9294         ShiftValue = DAG.getSplatVector(ShiftVT, DL, ShiftValues[0]);
9295       } else
9296         ShiftValue = ShiftValues[0];
9297       return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
9298     }
9299   }
9300 
9301   // fold (sra (shl X, m), (sub result_size, n))
9302   // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
9303   // result_size - n != m.
9304   // If truncate is free for the target sext(shl) is likely to result in better
9305   // code.
9306   if (N0.getOpcode() == ISD::SHL && N1C) {
9307     // Get the two constanst of the shifts, CN0 = m, CN = n.
9308     const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
9309     if (N01C) {
9310       LLVMContext &Ctx = *DAG.getContext();
9311       // Determine what the truncate's result bitsize and type would be.
9312       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue());
9313 
9314       if (VT.isVector())
9315         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
9316 
9317       // Determine the residual right-shift amount.
9318       int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
9319 
9320       // If the shift is not a no-op (in which case this should be just a sign
9321       // extend already), the truncated to type is legal, sign_extend is legal
9322       // on that type, and the truncate to that type is both legal and free,
9323       // perform the transform.
9324       if ((ShiftAmt > 0) &&
9325           TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
9326           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
9327           TLI.isTruncateFree(VT, TruncVT)) {
9328         SDLoc DL(N);
9329         SDValue Amt = DAG.getConstant(ShiftAmt, DL,
9330             getShiftAmountTy(N0.getOperand(0).getValueType()));
9331         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
9332                                     N0.getOperand(0), Amt);
9333         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
9334                                     Shift);
9335         return DAG.getNode(ISD::SIGN_EXTEND, DL,
9336                            N->getValueType(0), Trunc);
9337       }
9338     }
9339   }
9340 
9341   // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
9342   //   sra (add (shl X, N1C), AddC), N1C -->
9343   //   sext (add (trunc X to (width - N1C)), AddC')
9344   //   sra (sub AddC, (shl X, N1C)), N1C -->
9345   //   sext (sub AddC1',(trunc X to (width - N1C)))
9346   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
9347       N0.hasOneUse()) {
9348     bool IsAdd = N0.getOpcode() == ISD::ADD;
9349     SDValue Shl = N0.getOperand(IsAdd ? 0 : 1);
9350     if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(1) == N1 &&
9351         Shl.hasOneUse()) {
9352       // TODO: AddC does not need to be a splat.
9353       if (ConstantSDNode *AddC =
9354               isConstOrConstSplat(N0.getOperand(IsAdd ? 1 : 0))) {
9355         // Determine what the truncate's type would be and ask the target if
9356         // that is a free operation.
9357         LLVMContext &Ctx = *DAG.getContext();
9358         unsigned ShiftAmt = N1C->getZExtValue();
9359         EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt);
9360         if (VT.isVector())
9361           TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
9362 
9363         // TODO: The simple type check probably belongs in the default hook
9364         //       implementation and/or target-specific overrides (because
9365         //       non-simple types likely require masking when legalized), but
9366         //       that restriction may conflict with other transforms.
9367         if (TruncVT.isSimple() && isTypeLegal(TruncVT) &&
9368             TLI.isTruncateFree(VT, TruncVT)) {
9369           SDLoc DL(N);
9370           SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
9371           SDValue ShiftC =
9372               DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).trunc(
9373                                   TruncVT.getScalarSizeInBits()),
9374                               DL, TruncVT);
9375           SDValue Add;
9376           if (IsAdd)
9377             Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
9378           else
9379             Add = DAG.getNode(ISD::SUB, DL, TruncVT, ShiftC, Trunc);
9380           return DAG.getSExtOrTrunc(Add, DL, VT);
9381         }
9382       }
9383     }
9384   }
9385 
9386   // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
9387   if (N1.getOpcode() == ISD::TRUNCATE &&
9388       N1.getOperand(0).getOpcode() == ISD::AND) {
9389     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9390       return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1);
9391   }
9392 
9393   // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
9394   // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
9395   //      if c1 is equal to the number of bits the trunc removes
9396   // TODO - support non-uniform vector shift amounts.
9397   if (N0.getOpcode() == ISD::TRUNCATE &&
9398       (N0.getOperand(0).getOpcode() == ISD::SRL ||
9399        N0.getOperand(0).getOpcode() == ISD::SRA) &&
9400       N0.getOperand(0).hasOneUse() &&
9401       N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
9402     SDValue N0Op0 = N0.getOperand(0);
9403     if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
9404       EVT LargeVT = N0Op0.getValueType();
9405       unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
9406       if (LargeShift->getAPIntValue() == TruncBits) {
9407         SDLoc DL(N);
9408         EVT LargeShiftVT = getShiftAmountTy(LargeVT);
9409         SDValue Amt = DAG.getZExtOrTrunc(N1, DL, LargeShiftVT);
9410         Amt = DAG.getNode(ISD::ADD, DL, LargeShiftVT, Amt,
9411                           DAG.getConstant(TruncBits, DL, LargeShiftVT));
9412         SDValue SRA =
9413             DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
9414         return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
9415       }
9416     }
9417   }
9418 
9419   // Simplify, based on bits shifted out of the LHS.
9420   if (SimplifyDemandedBits(SDValue(N, 0)))
9421     return SDValue(N, 0);
9422 
9423   // If the sign bit is known to be zero, switch this to a SRL.
9424   if (DAG.SignBitIsZero(N0))
9425     return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1);
9426 
9427   if (N1C && !N1C->isOpaque())
9428     if (SDValue NewSRA = visitShiftByConstant(N))
9429       return NewSRA;
9430 
9431   // Try to transform this shift into a multiply-high if
9432   // it matches the appropriate pattern detected in combineShiftToMULH.
9433   if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
9434     return MULH;
9435 
9436   // Attempt to convert a sra of a load into a narrower sign-extending load.
9437   if (SDValue NarrowLoad = reduceLoadWidth(N))
9438     return NarrowLoad;
9439 
9440   return SDValue();
9441 }
9442 
visitSRL(SDNode * N)9443 SDValue DAGCombiner::visitSRL(SDNode *N) {
9444   SDValue N0 = N->getOperand(0);
9445   SDValue N1 = N->getOperand(1);
9446   if (SDValue V = DAG.simplifyShift(N0, N1))
9447     return V;
9448 
9449   EVT VT = N0.getValueType();
9450   EVT ShiftVT = N1.getValueType();
9451   unsigned OpSizeInBits = VT.getScalarSizeInBits();
9452 
9453   // fold (srl c1, c2) -> c1 >>u c2
9454   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, {N0, N1}))
9455     return C;
9456 
9457   // fold vector ops
9458   if (VT.isVector())
9459     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
9460       return FoldedVOp;
9461 
9462   if (SDValue NewSel = foldBinOpIntoSelect(N))
9463     return NewSel;
9464 
9465   // if (srl x, c) is known to be zero, return 0
9466   ConstantSDNode *N1C = isConstOrConstSplat(N1);
9467   if (N1C &&
9468       DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
9469     return DAG.getConstant(0, SDLoc(N), VT);
9470 
9471   // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
9472   if (N0.getOpcode() == ISD::SRL) {
9473     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
9474                                           ConstantSDNode *RHS) {
9475       APInt c1 = LHS->getAPIntValue();
9476       APInt c2 = RHS->getAPIntValue();
9477       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9478       return (c1 + c2).uge(OpSizeInBits);
9479     };
9480     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
9481       return DAG.getConstant(0, SDLoc(N), VT);
9482 
9483     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
9484                                        ConstantSDNode *RHS) {
9485       APInt c1 = LHS->getAPIntValue();
9486       APInt c2 = RHS->getAPIntValue();
9487       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9488       return (c1 + c2).ult(OpSizeInBits);
9489     };
9490     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
9491       SDLoc DL(N);
9492       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
9493       return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
9494     }
9495   }
9496 
9497   if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
9498       N0.getOperand(0).getOpcode() == ISD::SRL) {
9499     SDValue InnerShift = N0.getOperand(0);
9500     // TODO - support non-uniform vector shift amounts.
9501     if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
9502       uint64_t c1 = N001C->getZExtValue();
9503       uint64_t c2 = N1C->getZExtValue();
9504       EVT InnerShiftVT = InnerShift.getValueType();
9505       EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
9506       uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
9507       // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
9508       // This is only valid if the OpSizeInBits + c1 = size of inner shift.
9509       if (c1 + OpSizeInBits == InnerShiftSize) {
9510         SDLoc DL(N);
9511         if (c1 + c2 >= InnerShiftSize)
9512           return DAG.getConstant(0, DL, VT);
9513         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
9514         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
9515                                        InnerShift.getOperand(0), NewShiftAmt);
9516         return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
9517       }
9518       // In the more general case, we can clear the high bits after the shift:
9519       // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
9520       if (N0.hasOneUse() && InnerShift.hasOneUse() &&
9521           c1 + c2 < InnerShiftSize) {
9522         SDLoc DL(N);
9523         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
9524         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
9525                                        InnerShift.getOperand(0), NewShiftAmt);
9526         SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
9527                                                             OpSizeInBits - c2),
9528                                        DL, InnerShiftVT);
9529         SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
9530         return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
9531       }
9532     }
9533   }
9534 
9535   // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
9536   //                               (and (srl x, (sub c2, c1), MASK)
9537   if (N0.getOpcode() == ISD::SHL &&
9538       (N0.getOperand(1) == N1 || N0->hasOneUse()) &&
9539       TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
9540     auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
9541                                            ConstantSDNode *RHS) {
9542       const APInt &LHSC = LHS->getAPIntValue();
9543       const APInt &RHSC = RHS->getAPIntValue();
9544       return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
9545              LHSC.getZExtValue() <= RHSC.getZExtValue();
9546     };
9547     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
9548                                   /*AllowUndefs*/ false,
9549                                   /*AllowTypeMismatch*/ true)) {
9550       SDLoc DL(N);
9551       SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9552       SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
9553       SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9554       Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N01);
9555       Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, Diff);
9556       SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
9557       return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9558     }
9559     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
9560                                   /*AllowUndefs*/ false,
9561                                   /*AllowTypeMismatch*/ true)) {
9562       SDLoc DL(N);
9563       SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9564       SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
9565       SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9566       Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N1);
9567       SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
9568       return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9569     }
9570   }
9571 
9572   // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
9573   // TODO - support non-uniform vector shift amounts.
9574   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
9575     // Shifting in all undef bits?
9576     EVT SmallVT = N0.getOperand(0).getValueType();
9577     unsigned BitSize = SmallVT.getScalarSizeInBits();
9578     if (N1C->getAPIntValue().uge(BitSize))
9579       return DAG.getUNDEF(VT);
9580 
9581     if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
9582       uint64_t ShiftAmt = N1C->getZExtValue();
9583       SDLoc DL0(N0);
9584       SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
9585                                        N0.getOperand(0),
9586                           DAG.getConstant(ShiftAmt, DL0,
9587                                           getShiftAmountTy(SmallVT)));
9588       AddToWorklist(SmallShift.getNode());
9589       APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
9590       SDLoc DL(N);
9591       return DAG.getNode(ISD::AND, DL, VT,
9592                          DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
9593                          DAG.getConstant(Mask, DL, VT));
9594     }
9595   }
9596 
9597   // fold (srl (sra X, Y), 31) -> (srl X, 31).  This srl only looks at the sign
9598   // bit, which is unmodified by sra.
9599   if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
9600     if (N0.getOpcode() == ISD::SRA)
9601       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1);
9602   }
9603 
9604   // fold (srl (ctlz x), "5") -> x  iff x has one bit set (the low bit).
9605   if (N1C && N0.getOpcode() == ISD::CTLZ &&
9606       N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
9607     KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
9608 
9609     // If any of the input bits are KnownOne, then the input couldn't be all
9610     // zeros, thus the result of the srl will always be zero.
9611     if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
9612 
9613     // If all of the bits input the to ctlz node are known to be zero, then
9614     // the result of the ctlz is "32" and the result of the shift is one.
9615     APInt UnknownBits = ~Known.Zero;
9616     if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
9617 
9618     // Otherwise, check to see if there is exactly one bit input to the ctlz.
9619     if (UnknownBits.isPowerOf2()) {
9620       // Okay, we know that only that the single bit specified by UnknownBits
9621       // could be set on input to the CTLZ node. If this bit is set, the SRL
9622       // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
9623       // to an SRL/XOR pair, which is likely to simplify more.
9624       unsigned ShAmt = UnknownBits.countTrailingZeros();
9625       SDValue Op = N0.getOperand(0);
9626 
9627       if (ShAmt) {
9628         SDLoc DL(N0);
9629         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
9630                   DAG.getConstant(ShAmt, DL,
9631                                   getShiftAmountTy(Op.getValueType())));
9632         AddToWorklist(Op.getNode());
9633       }
9634 
9635       SDLoc DL(N);
9636       return DAG.getNode(ISD::XOR, DL, VT,
9637                          Op, DAG.getConstant(1, DL, VT));
9638     }
9639   }
9640 
9641   // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
9642   if (N1.getOpcode() == ISD::TRUNCATE &&
9643       N1.getOperand(0).getOpcode() == ISD::AND) {
9644     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9645       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, NewOp1);
9646   }
9647 
9648   // fold operands of srl based on knowledge that the low bits are not
9649   // demanded.
9650   if (SimplifyDemandedBits(SDValue(N, 0)))
9651     return SDValue(N, 0);
9652 
9653   if (N1C && !N1C->isOpaque())
9654     if (SDValue NewSRL = visitShiftByConstant(N))
9655       return NewSRL;
9656 
9657   // Attempt to convert a srl of a load into a narrower zero-extending load.
9658   if (SDValue NarrowLoad = reduceLoadWidth(N))
9659     return NarrowLoad;
9660 
9661   // Here is a common situation. We want to optimize:
9662   //
9663   //   %a = ...
9664   //   %b = and i32 %a, 2
9665   //   %c = srl i32 %b, 1
9666   //   brcond i32 %c ...
9667   //
9668   // into
9669   //
9670   //   %a = ...
9671   //   %b = and %a, 2
9672   //   %c = setcc eq %b, 0
9673   //   brcond %c ...
9674   //
9675   // However when after the source operand of SRL is optimized into AND, the SRL
9676   // itself may not be optimized further. Look for it and add the BRCOND into
9677   // the worklist.
9678   if (N->hasOneUse()) {
9679     SDNode *Use = *N->use_begin();
9680     if (Use->getOpcode() == ISD::BRCOND)
9681       AddToWorklist(Use);
9682     else if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse()) {
9683       // Also look pass the truncate.
9684       Use = *Use->use_begin();
9685       if (Use->getOpcode() == ISD::BRCOND)
9686         AddToWorklist(Use);
9687     }
9688   }
9689 
9690   // Try to transform this shift into a multiply-high if
9691   // it matches the appropriate pattern detected in combineShiftToMULH.
9692   if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
9693     return MULH;
9694 
9695   return SDValue();
9696 }
9697 
visitFunnelShift(SDNode * N)9698 SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
9699   EVT VT = N->getValueType(0);
9700   SDValue N0 = N->getOperand(0);
9701   SDValue N1 = N->getOperand(1);
9702   SDValue N2 = N->getOperand(2);
9703   bool IsFSHL = N->getOpcode() == ISD::FSHL;
9704   unsigned BitWidth = VT.getScalarSizeInBits();
9705 
9706   // fold (fshl N0, N1, 0) -> N0
9707   // fold (fshr N0, N1, 0) -> N1
9708   if (isPowerOf2_32(BitWidth))
9709     if (DAG.MaskedValueIsZero(
9710             N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
9711       return IsFSHL ? N0 : N1;
9712 
9713   auto IsUndefOrZero = [](SDValue V) {
9714     return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
9715   };
9716 
9717   // TODO - support non-uniform vector shift amounts.
9718   if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
9719     EVT ShAmtTy = N2.getValueType();
9720 
9721     // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
9722     if (Cst->getAPIntValue().uge(BitWidth)) {
9723       uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
9724       return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1,
9725                          DAG.getConstant(RotAmt, SDLoc(N), ShAmtTy));
9726     }
9727 
9728     unsigned ShAmt = Cst->getZExtValue();
9729     if (ShAmt == 0)
9730       return IsFSHL ? N0 : N1;
9731 
9732     // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
9733     // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
9734     // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
9735     // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
9736     if (IsUndefOrZero(N0))
9737       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1,
9738                          DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt,
9739                                          SDLoc(N), ShAmtTy));
9740     if (IsUndefOrZero(N1))
9741       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
9742                          DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt,
9743                                          SDLoc(N), ShAmtTy));
9744 
9745     // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
9746     // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
9747     // TODO - bigendian support once we have test coverage.
9748     // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
9749     // TODO - permit LHS EXTLOAD if extensions are shifted out.
9750     if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
9751         !DAG.getDataLayout().isBigEndian()) {
9752       auto *LHS = dyn_cast<LoadSDNode>(N0);
9753       auto *RHS = dyn_cast<LoadSDNode>(N1);
9754       if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
9755           LHS->getAddressSpace() == RHS->getAddressSpace() &&
9756           (LHS->hasOneUse() || RHS->hasOneUse()) && ISD::isNON_EXTLoad(RHS) &&
9757           ISD::isNON_EXTLoad(LHS)) {
9758         if (DAG.areNonVolatileConsecutiveLoads(LHS, RHS, BitWidth / 8, 1)) {
9759           SDLoc DL(RHS);
9760           uint64_t PtrOff =
9761               IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
9762           Align NewAlign = commonAlignment(RHS->getAlign(), PtrOff);
9763           bool Fast = false;
9764           if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
9765                                      RHS->getAddressSpace(), NewAlign,
9766                                      RHS->getMemOperand()->getFlags(), &Fast) &&
9767               Fast) {
9768             SDValue NewPtr = DAG.getMemBasePlusOffset(
9769                 RHS->getBasePtr(), TypeSize::Fixed(PtrOff), DL);
9770             AddToWorklist(NewPtr.getNode());
9771             SDValue Load = DAG.getLoad(
9772                 VT, DL, RHS->getChain(), NewPtr,
9773                 RHS->getPointerInfo().getWithOffset(PtrOff), NewAlign,
9774                 RHS->getMemOperand()->getFlags(), RHS->getAAInfo());
9775             // Replace the old load's chain with the new load's chain.
9776             WorklistRemover DeadNodes(*this);
9777             DAG.ReplaceAllUsesOfValueWith(N1.getValue(1), Load.getValue(1));
9778             return Load;
9779           }
9780         }
9781       }
9782     }
9783   }
9784 
9785   // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
9786   // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
9787   // iff We know the shift amount is in range.
9788   // TODO: when is it worth doing SUB(BW, N2) as well?
9789   if (isPowerOf2_32(BitWidth)) {
9790     APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
9791     if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
9792       return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1, N2);
9793     if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
9794       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N2);
9795   }
9796 
9797   // fold (fshl N0, N0, N2) -> (rotl N0, N2)
9798   // fold (fshr N0, N0, N2) -> (rotr N0, N2)
9799   // TODO: Investigate flipping this rotate if only one is legal, if funnel shift
9800   // is legal as well we might be better off avoiding non-constant (BW - N2).
9801   unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
9802   if (N0 == N1 && hasOperation(RotOpc, VT))
9803     return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2);
9804 
9805   // Simplify, based on bits shifted out of N0/N1.
9806   if (SimplifyDemandedBits(SDValue(N, 0)))
9807     return SDValue(N, 0);
9808 
9809   return SDValue();
9810 }
9811 
visitSHLSAT(SDNode * N)9812 SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
9813   SDValue N0 = N->getOperand(0);
9814   SDValue N1 = N->getOperand(1);
9815   if (SDValue V = DAG.simplifyShift(N0, N1))
9816     return V;
9817 
9818   EVT VT = N0.getValueType();
9819 
9820   // fold (*shlsat c1, c2) -> c1<<c2
9821   if (SDValue C =
9822           DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, {N0, N1}))
9823     return C;
9824 
9825   ConstantSDNode *N1C = isConstOrConstSplat(N1);
9826 
9827   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) {
9828     // fold (sshlsat x, c) -> (shl x, c)
9829     if (N->getOpcode() == ISD::SSHLSAT && N1C &&
9830         N1C->getAPIntValue().ult(DAG.ComputeNumSignBits(N0)))
9831       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1);
9832 
9833     // fold (ushlsat x, c) -> (shl x, c)
9834     if (N->getOpcode() == ISD::USHLSAT && N1C &&
9835         N1C->getAPIntValue().ule(
9836             DAG.computeKnownBits(N0).countMinLeadingZeros()))
9837       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1);
9838   }
9839 
9840   return SDValue();
9841 }
9842 
9843 // Given a ABS node, detect the following pattern:
9844 // (ABS (SUB (EXTEND a), (EXTEND b))).
9845 // Generates UABD/SABD instruction.
combineABSToABD(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)9846 static SDValue combineABSToABD(SDNode *N, SelectionDAG &DAG,
9847                                const TargetLowering &TLI) {
9848   SDValue AbsOp1 = N->getOperand(0);
9849   SDValue Op0, Op1;
9850 
9851   if (AbsOp1.getOpcode() != ISD::SUB)
9852     return SDValue();
9853 
9854   Op0 = AbsOp1.getOperand(0);
9855   Op1 = AbsOp1.getOperand(1);
9856 
9857   unsigned Opc0 = Op0.getOpcode();
9858   // Check if the operands of the sub are (zero|sign)-extended.
9859   if (Opc0 != Op1.getOpcode() ||
9860       (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
9861     return SDValue();
9862 
9863   EVT VT = N->getValueType(0);
9864   EVT VT1 = Op0.getOperand(0).getValueType();
9865   EVT VT2 = Op1.getOperand(0).getValueType();
9866   unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU;
9867 
9868   // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
9869   // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
9870   // NOTE: Extensions must be equivalent.
9871   if (VT1 == VT2 && TLI.isOperationLegalOrCustom(ABDOpcode, VT1)) {
9872     Op0 = Op0.getOperand(0);
9873     Op1 = Op1.getOperand(0);
9874     SDValue ABD = DAG.getNode(ABDOpcode, SDLoc(N), VT1, Op0, Op1);
9875     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, ABD);
9876   }
9877 
9878   // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
9879   // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
9880   if (TLI.isOperationLegalOrCustom(ABDOpcode, VT))
9881     return DAG.getNode(ABDOpcode, SDLoc(N), VT, Op0, Op1);
9882 
9883   return SDValue();
9884 }
9885 
visitABS(SDNode * N)9886 SDValue DAGCombiner::visitABS(SDNode *N) {
9887   SDValue N0 = N->getOperand(0);
9888   EVT VT = N->getValueType(0);
9889 
9890   // fold (abs c1) -> c2
9891   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9892     return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0);
9893   // fold (abs (abs x)) -> (abs x)
9894   if (N0.getOpcode() == ISD::ABS)
9895     return N0;
9896   // fold (abs x) -> x iff not-negative
9897   if (DAG.SignBitIsZero(N0))
9898     return N0;
9899 
9900   if (SDValue ABD = combineABSToABD(N, DAG, TLI))
9901     return ABD;
9902 
9903   return SDValue();
9904 }
9905 
visitBSWAP(SDNode * N)9906 SDValue DAGCombiner::visitBSWAP(SDNode *N) {
9907   SDValue N0 = N->getOperand(0);
9908   EVT VT = N->getValueType(0);
9909   SDLoc DL(N);
9910 
9911   // fold (bswap c1) -> c2
9912   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9913     return DAG.getNode(ISD::BSWAP, DL, VT, N0);
9914   // fold (bswap (bswap x)) -> x
9915   if (N0.getOpcode() == ISD::BSWAP)
9916     return N0.getOperand(0);
9917 
9918   // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
9919   // isn't supported, it will be expanded to bswap followed by a manual reversal
9920   // of bits in each byte. By placing bswaps before bitreverse, we can remove
9921   // the two bswaps if the bitreverse gets expanded.
9922   if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
9923     SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
9924     return DAG.getNode(ISD::BITREVERSE, DL, VT, BSwap);
9925   }
9926 
9927   // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
9928   // iff x >= bw/2 (i.e. lower half is known zero)
9929   unsigned BW = VT.getScalarSizeInBits();
9930   if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
9931     auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
9932     EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), BW / 2);
9933     if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
9934         ShAmt->getZExtValue() >= (BW / 2) &&
9935         (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(HalfVT) &&
9936         TLI.isTruncateFree(VT, HalfVT) &&
9937         (!LegalOperations || hasOperation(ISD::BSWAP, HalfVT))) {
9938       SDValue Res = N0.getOperand(0);
9939       if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
9940         Res = DAG.getNode(ISD::SHL, DL, VT, Res,
9941                           DAG.getConstant(NewShAmt, DL, getShiftAmountTy(VT)));
9942       Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
9943       Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
9944       return DAG.getZExtOrTrunc(Res, DL, VT);
9945     }
9946   }
9947 
9948   // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
9949   // inverse-shift-of-bswap:
9950   // bswap (X u<< C) --> (bswap X) u>> C
9951   // bswap (X u>> C) --> (bswap X) u<< C
9952   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
9953       N0.hasOneUse()) {
9954     auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
9955     if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
9956         ShAmt->getZExtValue() % 8 == 0) {
9957       SDValue NewSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
9958       unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
9959       return DAG.getNode(InverseShift, DL, VT, NewSwap, N0.getOperand(1));
9960     }
9961   }
9962 
9963   return SDValue();
9964 }
9965 
visitBITREVERSE(SDNode * N)9966 SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
9967   SDValue N0 = N->getOperand(0);
9968   EVT VT = N->getValueType(0);
9969 
9970   // fold (bitreverse c1) -> c2
9971   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9972     return DAG.getNode(ISD::BITREVERSE, SDLoc(N), VT, N0);
9973   // fold (bitreverse (bitreverse x)) -> x
9974   if (N0.getOpcode() == ISD::BITREVERSE)
9975     return N0.getOperand(0);
9976   return SDValue();
9977 }
9978 
visitCTLZ(SDNode * N)9979 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
9980   SDValue N0 = N->getOperand(0);
9981   EVT VT = N->getValueType(0);
9982 
9983   // fold (ctlz c1) -> c2
9984   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9985     return DAG.getNode(ISD::CTLZ, SDLoc(N), VT, N0);
9986 
9987   // If the value is known never to be zero, switch to the undef version.
9988   if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT)) {
9989     if (DAG.isKnownNeverZero(N0))
9990       return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
9991   }
9992 
9993   return SDValue();
9994 }
9995 
visitCTLZ_ZERO_UNDEF(SDNode * N)9996 SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
9997   SDValue N0 = N->getOperand(0);
9998   EVT VT = N->getValueType(0);
9999 
10000   // fold (ctlz_zero_undef c1) -> c2
10001   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10002     return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10003   return SDValue();
10004 }
10005 
visitCTTZ(SDNode * N)10006 SDValue DAGCombiner::visitCTTZ(SDNode *N) {
10007   SDValue N0 = N->getOperand(0);
10008   EVT VT = N->getValueType(0);
10009 
10010   // fold (cttz c1) -> c2
10011   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10012     return DAG.getNode(ISD::CTTZ, SDLoc(N), VT, N0);
10013 
10014   // If the value is known never to be zero, switch to the undef version.
10015   if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT)) {
10016     if (DAG.isKnownNeverZero(N0))
10017       return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10018   }
10019 
10020   return SDValue();
10021 }
10022 
visitCTTZ_ZERO_UNDEF(SDNode * N)10023 SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
10024   SDValue N0 = N->getOperand(0);
10025   EVT VT = N->getValueType(0);
10026 
10027   // fold (cttz_zero_undef c1) -> c2
10028   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10029     return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10030   return SDValue();
10031 }
10032 
visitCTPOP(SDNode * N)10033 SDValue DAGCombiner::visitCTPOP(SDNode *N) {
10034   SDValue N0 = N->getOperand(0);
10035   EVT VT = N->getValueType(0);
10036 
10037   // fold (ctpop c1) -> c2
10038   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10039     return DAG.getNode(ISD::CTPOP, SDLoc(N), VT, N0);
10040   return SDValue();
10041 }
10042 
10043 // FIXME: This should be checking for no signed zeros on individual operands, as
10044 // well as no nans.
isLegalToCombineMinNumMaxNum(SelectionDAG & DAG,SDValue LHS,SDValue RHS,const TargetLowering & TLI)10045 static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
10046                                          SDValue RHS,
10047                                          const TargetLowering &TLI) {
10048   const TargetOptions &Options = DAG.getTarget().Options;
10049   EVT VT = LHS.getValueType();
10050 
10051   return Options.NoSignedZerosFPMath && VT.isFloatingPoint() &&
10052          TLI.isProfitableToCombineMinNumMaxNum(VT) &&
10053          DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS);
10054 }
10055 
10056 /// Generate Min/Max node
combineMinNumMaxNum(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const TargetLowering & TLI,SelectionDAG & DAG)10057 static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
10058                                    SDValue RHS, SDValue True, SDValue False,
10059                                    ISD::CondCode CC, const TargetLowering &TLI,
10060                                    SelectionDAG &DAG) {
10061   if (!(LHS == True && RHS == False) && !(LHS == False && RHS == True))
10062     return SDValue();
10063 
10064   EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
10065   switch (CC) {
10066   case ISD::SETOLT:
10067   case ISD::SETOLE:
10068   case ISD::SETLT:
10069   case ISD::SETLE:
10070   case ISD::SETULT:
10071   case ISD::SETULE: {
10072     // Since it's known never nan to get here already, either fminnum or
10073     // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
10074     // expanded in terms of it.
10075     unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
10076     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
10077       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
10078 
10079     unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
10080     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
10081       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
10082     return SDValue();
10083   }
10084   case ISD::SETOGT:
10085   case ISD::SETOGE:
10086   case ISD::SETGT:
10087   case ISD::SETGE:
10088   case ISD::SETUGT:
10089   case ISD::SETUGE: {
10090     unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
10091     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
10092       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
10093 
10094     unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
10095     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
10096       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
10097     return SDValue();
10098   }
10099   default:
10100     return SDValue();
10101   }
10102 }
10103 
10104 /// If a (v)select has a condition value that is a sign-bit test, try to smear
10105 /// the condition operand sign-bit across the value width and use it as a mask.
foldSelectOfConstantsUsingSra(SDNode * N,SelectionDAG & DAG)10106 static SDValue foldSelectOfConstantsUsingSra(SDNode *N, SelectionDAG &DAG) {
10107   SDValue Cond = N->getOperand(0);
10108   SDValue C1 = N->getOperand(1);
10109   SDValue C2 = N->getOperand(2);
10110   if (!isConstantOrConstantVector(C1) || !isConstantOrConstantVector(C2))
10111     return SDValue();
10112 
10113   EVT VT = N->getValueType(0);
10114   if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
10115       VT != Cond.getOperand(0).getValueType())
10116     return SDValue();
10117 
10118   // The inverted-condition + commuted-select variants of these patterns are
10119   // canonicalized to these forms in IR.
10120   SDValue X = Cond.getOperand(0);
10121   SDValue CondC = Cond.getOperand(1);
10122   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
10123   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
10124       isAllOnesOrAllOnesSplat(C2)) {
10125     // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
10126     SDLoc DL(N);
10127     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
10128     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
10129     return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
10130   }
10131   if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
10132     // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
10133     SDLoc DL(N);
10134     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
10135     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
10136     return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
10137   }
10138   return SDValue();
10139 }
10140 
foldSelectOfConstants(SDNode * N)10141 SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
10142   SDValue Cond = N->getOperand(0);
10143   SDValue N1 = N->getOperand(1);
10144   SDValue N2 = N->getOperand(2);
10145   EVT VT = N->getValueType(0);
10146   EVT CondVT = Cond.getValueType();
10147   SDLoc DL(N);
10148 
10149   if (!VT.isInteger())
10150     return SDValue();
10151 
10152   auto *C1 = dyn_cast<ConstantSDNode>(N1);
10153   auto *C2 = dyn_cast<ConstantSDNode>(N2);
10154   if (!C1 || !C2)
10155     return SDValue();
10156 
10157   // Only do this before legalization to avoid conflicting with target-specific
10158   // transforms in the other direction (create a select from a zext/sext). There
10159   // is also a target-independent combine here in DAGCombiner in the other
10160   // direction for (select Cond, -1, 0) when the condition is not i1.
10161   if (CondVT == MVT::i1 && !LegalOperations) {
10162     if (C1->isZero() && C2->isOne()) {
10163       // select Cond, 0, 1 --> zext (!Cond)
10164       SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
10165       if (VT != MVT::i1)
10166         NotCond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotCond);
10167       return NotCond;
10168     }
10169     if (C1->isZero() && C2->isAllOnes()) {
10170       // select Cond, 0, -1 --> sext (!Cond)
10171       SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
10172       if (VT != MVT::i1)
10173         NotCond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NotCond);
10174       return NotCond;
10175     }
10176     if (C1->isOne() && C2->isZero()) {
10177       // select Cond, 1, 0 --> zext (Cond)
10178       if (VT != MVT::i1)
10179         Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
10180       return Cond;
10181     }
10182     if (C1->isAllOnes() && C2->isZero()) {
10183       // select Cond, -1, 0 --> sext (Cond)
10184       if (VT != MVT::i1)
10185         Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
10186       return Cond;
10187     }
10188 
10189     // Use a target hook because some targets may prefer to transform in the
10190     // other direction.
10191     if (TLI.convertSelectOfConstantsToMath(VT)) {
10192       // For any constants that differ by 1, we can transform the select into an
10193       // extend and add.
10194       const APInt &C1Val = C1->getAPIntValue();
10195       const APInt &C2Val = C2->getAPIntValue();
10196       if (C1Val - 1 == C2Val) {
10197         // select Cond, C1, C1-1 --> add (zext Cond), C1-1
10198         if (VT != MVT::i1)
10199           Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
10200         return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
10201       }
10202       if (C1Val + 1 == C2Val) {
10203         // select Cond, C1, C1+1 --> add (sext Cond), C1+1
10204         if (VT != MVT::i1)
10205           Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
10206         return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
10207       }
10208 
10209       // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
10210       if (C1Val.isPowerOf2() && C2Val.isZero()) {
10211         if (VT != MVT::i1)
10212           Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
10213         SDValue ShAmtC =
10214             DAG.getShiftAmountConstant(C1Val.exactLogBase2(), VT, DL);
10215         return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
10216       }
10217 
10218       if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
10219         return V;
10220     }
10221 
10222     return SDValue();
10223   }
10224 
10225   // fold (select Cond, 0, 1) -> (xor Cond, 1)
10226   // We can't do this reliably if integer based booleans have different contents
10227   // to floating point based booleans. This is because we can't tell whether we
10228   // have an integer-based boolean or a floating-point-based boolean unless we
10229   // can find the SETCC that produced it and inspect its operands. This is
10230   // fairly easy if C is the SETCC node, but it can potentially be
10231   // undiscoverable (or not reasonably discoverable). For example, it could be
10232   // in another basic block or it could require searching a complicated
10233   // expression.
10234   if (CondVT.isInteger() &&
10235       TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
10236           TargetLowering::ZeroOrOneBooleanContent &&
10237       TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
10238           TargetLowering::ZeroOrOneBooleanContent &&
10239       C1->isZero() && C2->isOne()) {
10240     SDValue NotCond =
10241         DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
10242     if (VT.bitsEq(CondVT))
10243       return NotCond;
10244     return DAG.getZExtOrTrunc(NotCond, DL, VT);
10245   }
10246 
10247   return SDValue();
10248 }
10249 
foldBoolSelectToLogic(SDNode * N,SelectionDAG & DAG)10250 static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
10251   assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT) &&
10252          "Expected a (v)select");
10253   SDValue Cond = N->getOperand(0);
10254   SDValue T = N->getOperand(1), F = N->getOperand(2);
10255   EVT VT = N->getValueType(0);
10256   if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
10257     return SDValue();
10258 
10259   // select Cond, Cond, F --> or Cond, F
10260   // select Cond, 1, F    --> or Cond, F
10261   if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
10262     return DAG.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
10263 
10264   // select Cond, T, Cond --> and Cond, T
10265   // select Cond, T, 0    --> and Cond, T
10266   if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
10267     return DAG.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
10268 
10269   // select Cond, T, 1 --> or (not Cond), T
10270   if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
10271     SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
10272     return DAG.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
10273   }
10274 
10275   // select Cond, 0, F --> and (not Cond), F
10276   if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
10277     SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
10278     return DAG.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
10279   }
10280 
10281   return SDValue();
10282 }
10283 
foldVSelectToSignBitSplatMask(SDNode * N,SelectionDAG & DAG)10284 static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
10285   SDValue N0 = N->getOperand(0);
10286   SDValue N1 = N->getOperand(1);
10287   SDValue N2 = N->getOperand(2);
10288   EVT VT = N->getValueType(0);
10289   if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
10290     return SDValue();
10291 
10292   SDValue Cond0 = N0.getOperand(0);
10293   SDValue Cond1 = N0.getOperand(1);
10294   ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
10295   if (VT != Cond0.getValueType())
10296     return SDValue();
10297 
10298   // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
10299   // compare is inverted from that pattern ("Cond0 s> -1").
10300   if (CC == ISD::SETLT && isNullOrNullSplat(Cond1))
10301     ; // This is the pattern we are looking for.
10302   else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond1))
10303     std::swap(N1, N2);
10304   else
10305     return SDValue();
10306 
10307   // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & N1
10308   if (isNullOrNullSplat(N2)) {
10309     SDLoc DL(N);
10310     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
10311     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
10312     return DAG.getNode(ISD::AND, DL, VT, Sra, N1);
10313   }
10314 
10315   // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | N2
10316   if (isAllOnesOrAllOnesSplat(N1)) {
10317     SDLoc DL(N);
10318     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
10319     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
10320     return DAG.getNode(ISD::OR, DL, VT, Sra, N2);
10321   }
10322 
10323   // If we have to invert the sign bit mask, only do that transform if the
10324   // target has a bitwise 'and not' instruction (the invert is free).
10325   // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & N2
10326   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
10327   if (isNullOrNullSplat(N1) && TLI.hasAndNot(N1)) {
10328     SDLoc DL(N);
10329     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
10330     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
10331     SDValue Not = DAG.getNOT(DL, Sra, VT);
10332     return DAG.getNode(ISD::AND, DL, VT, Not, N2);
10333   }
10334 
10335   // TODO: There's another pattern in this family, but it may require
10336   //       implementing hasOrNot() to check for profitability:
10337   //       (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | N2
10338 
10339   return SDValue();
10340 }
10341 
visitSELECT(SDNode * N)10342 SDValue DAGCombiner::visitSELECT(SDNode *N) {
10343   SDValue N0 = N->getOperand(0);
10344   SDValue N1 = N->getOperand(1);
10345   SDValue N2 = N->getOperand(2);
10346   EVT VT = N->getValueType(0);
10347   EVT VT0 = N0.getValueType();
10348   SDLoc DL(N);
10349   SDNodeFlags Flags = N->getFlags();
10350 
10351   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
10352     return V;
10353 
10354   if (SDValue V = foldSelectOfConstants(N))
10355     return V;
10356 
10357   if (SDValue V = foldBoolSelectToLogic(N, DAG))
10358     return V;
10359 
10360   // If we can fold this based on the true/false value, do so.
10361   if (SimplifySelectOps(N, N1, N2))
10362     return SDValue(N, 0); // Don't revisit N.
10363 
10364   if (VT0 == MVT::i1) {
10365     // The code in this block deals with the following 2 equivalences:
10366     //    select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
10367     //    select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
10368     // The target can specify its preferred form with the
10369     // shouldNormalizeToSelectSequence() callback. However we always transform
10370     // to the right anyway if we find the inner select exists in the DAG anyway
10371     // and we always transform to the left side if we know that we can further
10372     // optimize the combination of the conditions.
10373     bool normalizeToSequence =
10374         TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
10375     // select (and Cond0, Cond1), X, Y
10376     //   -> select Cond0, (select Cond1, X, Y), Y
10377     if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
10378       SDValue Cond0 = N0->getOperand(0);
10379       SDValue Cond1 = N0->getOperand(1);
10380       SDValue InnerSelect =
10381           DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
10382       if (normalizeToSequence || !InnerSelect.use_empty())
10383         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
10384                            InnerSelect, N2, Flags);
10385       // Cleanup on failure.
10386       if (InnerSelect.use_empty())
10387         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
10388     }
10389     // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
10390     if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
10391       SDValue Cond0 = N0->getOperand(0);
10392       SDValue Cond1 = N0->getOperand(1);
10393       SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
10394                                         Cond1, N1, N2, Flags);
10395       if (normalizeToSequence || !InnerSelect.use_empty())
10396         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
10397                            InnerSelect, Flags);
10398       // Cleanup on failure.
10399       if (InnerSelect.use_empty())
10400         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
10401     }
10402 
10403     // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
10404     if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
10405       SDValue N1_0 = N1->getOperand(0);
10406       SDValue N1_1 = N1->getOperand(1);
10407       SDValue N1_2 = N1->getOperand(2);
10408       if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
10409         // Create the actual and node if we can generate good code for it.
10410         if (!normalizeToSequence) {
10411           SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
10412           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
10413                              N2, Flags);
10414         }
10415         // Otherwise see if we can optimize the "and" to a better pattern.
10416         if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
10417           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
10418                              N2, Flags);
10419         }
10420       }
10421     }
10422     // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
10423     if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
10424       SDValue N2_0 = N2->getOperand(0);
10425       SDValue N2_1 = N2->getOperand(1);
10426       SDValue N2_2 = N2->getOperand(2);
10427       if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
10428         // Create the actual or node if we can generate good code for it.
10429         if (!normalizeToSequence) {
10430           SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
10431           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
10432                              N2_2, Flags);
10433         }
10434         // Otherwise see if we can optimize to a better pattern.
10435         if (SDValue Combined = visitORLike(N0, N2_0, N))
10436           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
10437                              N2_2, Flags);
10438       }
10439     }
10440   }
10441 
10442   // select (not Cond), N1, N2 -> select Cond, N2, N1
10443   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
10444     SDValue SelectOp = DAG.getSelect(DL, VT, F, N2, N1);
10445     SelectOp->setFlags(Flags);
10446     return SelectOp;
10447   }
10448 
10449   // Fold selects based on a setcc into other things, such as min/max/abs.
10450   if (N0.getOpcode() == ISD::SETCC) {
10451     SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
10452     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
10453 
10454     // select (fcmp lt x, y), x, y -> fminnum x, y
10455     // select (fcmp gt x, y), x, y -> fmaxnum x, y
10456     //
10457     // This is OK if we don't care what happens if either operand is a NaN.
10458     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, TLI))
10459       if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2,
10460                                                 CC, TLI, DAG))
10461         return FMinMax;
10462 
10463     // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
10464     // This is conservatively limited to pre-legal-operations to give targets
10465     // a chance to reverse the transform if they want to do that. Also, it is
10466     // unlikely that the pattern would be formed late, so it's probably not
10467     // worth going through the other checks.
10468     if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
10469         CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
10470         N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
10471       auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
10472       auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
10473       if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
10474         // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
10475         // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
10476         //
10477         // The IR equivalent of this transform would have this form:
10478         //   %a = add %x, C
10479         //   %c = icmp ugt %x, ~C
10480         //   %r = select %c, -1, %a
10481         //   =>
10482         //   %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
10483         //   %u0 = extractvalue %u, 0
10484         //   %u1 = extractvalue %u, 1
10485         //   %r = select %u1, -1, %u0
10486         SDVTList VTs = DAG.getVTList(VT, VT0);
10487         SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
10488         return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
10489       }
10490     }
10491 
10492     if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
10493         (!LegalOperations &&
10494          TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) {
10495       // Any flags available in a select/setcc fold will be on the setcc as they
10496       // migrated from fcmp
10497       Flags = N0->getFlags();
10498       SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1,
10499                                        N2, N0.getOperand(2));
10500       SelectNode->setFlags(Flags);
10501       return SelectNode;
10502     }
10503 
10504     if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
10505       return NewSel;
10506   }
10507 
10508   if (!VT.isVector())
10509     if (SDValue BinOp = foldSelectOfBinops(N))
10510       return BinOp;
10511 
10512   return SDValue();
10513 }
10514 
10515 // This function assumes all the vselect's arguments are CONCAT_VECTOR
10516 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
ConvertSelectToConcatVector(SDNode * N,SelectionDAG & DAG)10517 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
10518   SDLoc DL(N);
10519   SDValue Cond = N->getOperand(0);
10520   SDValue LHS = N->getOperand(1);
10521   SDValue RHS = N->getOperand(2);
10522   EVT VT = N->getValueType(0);
10523   int NumElems = VT.getVectorNumElements();
10524   assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
10525          RHS.getOpcode() == ISD::CONCAT_VECTORS &&
10526          Cond.getOpcode() == ISD::BUILD_VECTOR);
10527 
10528   // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
10529   // binary ones here.
10530   if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
10531     return SDValue();
10532 
10533   // We're sure we have an even number of elements due to the
10534   // concat_vectors we have as arguments to vselect.
10535   // Skip BV elements until we find one that's not an UNDEF
10536   // After we find an UNDEF element, keep looping until we get to half the
10537   // length of the BV and see if all the non-undef nodes are the same.
10538   ConstantSDNode *BottomHalf = nullptr;
10539   for (int i = 0; i < NumElems / 2; ++i) {
10540     if (Cond->getOperand(i)->isUndef())
10541       continue;
10542 
10543     if (BottomHalf == nullptr)
10544       BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
10545     else if (Cond->getOperand(i).getNode() != BottomHalf)
10546       return SDValue();
10547   }
10548 
10549   // Do the same for the second half of the BuildVector
10550   ConstantSDNode *TopHalf = nullptr;
10551   for (int i = NumElems / 2; i < NumElems; ++i) {
10552     if (Cond->getOperand(i)->isUndef())
10553       continue;
10554 
10555     if (TopHalf == nullptr)
10556       TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
10557     else if (Cond->getOperand(i).getNode() != TopHalf)
10558       return SDValue();
10559   }
10560 
10561   assert(TopHalf && BottomHalf &&
10562          "One half of the selector was all UNDEFs and the other was all the "
10563          "same value. This should have been addressed before this function.");
10564   return DAG.getNode(
10565       ISD::CONCAT_VECTORS, DL, VT,
10566       BottomHalf->isZero() ? RHS->getOperand(0) : LHS->getOperand(0),
10567       TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
10568 }
10569 
refineUniformBase(SDValue & BasePtr,SDValue & Index,bool IndexIsScaled,SelectionDAG & DAG)10570 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
10571                        SelectionDAG &DAG) {
10572   if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD)
10573     return false;
10574 
10575   // Only perform the transformation when existing operands can be reused.
10576   if (IndexIsScaled)
10577     return false;
10578 
10579   // For now we check only the LHS of the add.
10580   SDValue LHS = Index.getOperand(0);
10581   SDValue SplatVal = DAG.getSplatValue(LHS);
10582   if (!SplatVal || SplatVal.getValueType() != BasePtr.getValueType())
10583     return false;
10584 
10585   BasePtr = SplatVal;
10586   Index = Index.getOperand(1);
10587   return true;
10588 }
10589 
10590 // Fold sext/zext of index into index type.
refineIndexType(SDValue & Index,ISD::MemIndexType & IndexType,EVT DataVT,SelectionDAG & DAG)10591 bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
10592                      SelectionDAG &DAG) {
10593   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
10594 
10595   // It's always safe to look through zero extends.
10596   if (Index.getOpcode() == ISD::ZERO_EXTEND) {
10597     SDValue Op = Index.getOperand(0);
10598     if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType(), DataVT)) {
10599       IndexType = ISD::UNSIGNED_SCALED;
10600       Index = Op;
10601       return true;
10602     }
10603     if (ISD::isIndexTypeSigned(IndexType)) {
10604       IndexType = ISD::UNSIGNED_SCALED;
10605       return true;
10606     }
10607   }
10608 
10609   // It's only safe to look through sign extends when Index is signed.
10610   if (Index.getOpcode() == ISD::SIGN_EXTEND &&
10611       ISD::isIndexTypeSigned(IndexType)) {
10612     SDValue Op = Index.getOperand(0);
10613     if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType(), DataVT)) {
10614       Index = Op;
10615       return true;
10616     }
10617   }
10618 
10619   return false;
10620 }
10621 
visitMSCATTER(SDNode * N)10622 SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
10623   MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
10624   SDValue Mask = MSC->getMask();
10625   SDValue Chain = MSC->getChain();
10626   SDValue Index = MSC->getIndex();
10627   SDValue Scale = MSC->getScale();
10628   SDValue StoreVal = MSC->getValue();
10629   SDValue BasePtr = MSC->getBasePtr();
10630   ISD::MemIndexType IndexType = MSC->getIndexType();
10631   SDLoc DL(N);
10632 
10633   // Zap scatters with a zero mask.
10634   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
10635     return Chain;
10636 
10637   if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) {
10638     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
10639     return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
10640                                 DL, Ops, MSC->getMemOperand(), IndexType,
10641                                 MSC->isTruncatingStore());
10642   }
10643 
10644   if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
10645     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
10646     return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
10647                                 DL, Ops, MSC->getMemOperand(), IndexType,
10648                                 MSC->isTruncatingStore());
10649   }
10650 
10651   return SDValue();
10652 }
10653 
visitMSTORE(SDNode * N)10654 SDValue DAGCombiner::visitMSTORE(SDNode *N) {
10655   MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
10656   SDValue Mask = MST->getMask();
10657   SDValue Chain = MST->getChain();
10658   SDValue Value = MST->getValue();
10659   SDValue Ptr = MST->getBasePtr();
10660   SDLoc DL(N);
10661 
10662   // Zap masked stores with a zero mask.
10663   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
10664     return Chain;
10665 
10666   // If this is a masked load with an all ones mask, we can use a unmasked load.
10667   // FIXME: Can we do this for indexed, compressing, or truncating stores?
10668   if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MST->isUnindexed() &&
10669       !MST->isCompressingStore() && !MST->isTruncatingStore())
10670     return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
10671                         MST->getBasePtr(), MST->getPointerInfo(),
10672                         MST->getOriginalAlign(), MachineMemOperand::MOStore,
10673                         MST->getAAInfo());
10674 
10675   // Try transforming N to an indexed store.
10676   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
10677     return SDValue(N, 0);
10678 
10679   if (MST->isTruncatingStore() && MST->isUnindexed() &&
10680       Value.getValueType().isInteger() &&
10681       (!isa<ConstantSDNode>(Value) ||
10682        !cast<ConstantSDNode>(Value)->isOpaque())) {
10683     APInt TruncDemandedBits =
10684         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
10685                              MST->getMemoryVT().getScalarSizeInBits());
10686 
10687     // See if we can simplify the operation with
10688     // SimplifyDemandedBits, which only works if the value has a single use.
10689     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
10690       // Re-visit the store if anything changed and the store hasn't been merged
10691       // with another node (N is deleted) SimplifyDemandedBits will add Value's
10692       // node back to the worklist if necessary, but we also need to re-visit
10693       // the Store node itself.
10694       if (N->getOpcode() != ISD::DELETED_NODE)
10695         AddToWorklist(N);
10696       return SDValue(N, 0);
10697     }
10698   }
10699 
10700   // If this is a TRUNC followed by a masked store, fold this into a masked
10701   // truncating store.  We can do this even if this is already a masked
10702   // truncstore.
10703   if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
10704       MST->isUnindexed() &&
10705       TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
10706                                MST->getMemoryVT(), LegalOperations)) {
10707     auto Mask = TLI.promoteTargetBoolean(DAG, MST->getMask(),
10708                                          Value.getOperand(0).getValueType());
10709     return DAG.getMaskedStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
10710                               MST->getOffset(), Mask, MST->getMemoryVT(),
10711                               MST->getMemOperand(), MST->getAddressingMode(),
10712                               /*IsTruncating=*/true);
10713   }
10714 
10715   return SDValue();
10716 }
10717 
visitMGATHER(SDNode * N)10718 SDValue DAGCombiner::visitMGATHER(SDNode *N) {
10719   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
10720   SDValue Mask = MGT->getMask();
10721   SDValue Chain = MGT->getChain();
10722   SDValue Index = MGT->getIndex();
10723   SDValue Scale = MGT->getScale();
10724   SDValue PassThru = MGT->getPassThru();
10725   SDValue BasePtr = MGT->getBasePtr();
10726   ISD::MemIndexType IndexType = MGT->getIndexType();
10727   SDLoc DL(N);
10728 
10729   // Zap gathers with a zero mask.
10730   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
10731     return CombineTo(N, PassThru, MGT->getChain());
10732 
10733   if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) {
10734     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
10735     return DAG.getMaskedGather(
10736         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
10737         Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
10738   }
10739 
10740   if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
10741     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
10742     return DAG.getMaskedGather(
10743         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
10744         Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
10745   }
10746 
10747   return SDValue();
10748 }
10749 
visitMLOAD(SDNode * N)10750 SDValue DAGCombiner::visitMLOAD(SDNode *N) {
10751   MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
10752   SDValue Mask = MLD->getMask();
10753   SDLoc DL(N);
10754 
10755   // Zap masked loads with a zero mask.
10756   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
10757     return CombineTo(N, MLD->getPassThru(), MLD->getChain());
10758 
10759   // If this is a masked load with an all ones mask, we can use a unmasked load.
10760   // FIXME: Can we do this for indexed, expanding, or extending loads?
10761   if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MLD->isUnindexed() &&
10762       !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
10763     SDValue NewLd = DAG.getLoad(
10764         N->getValueType(0), SDLoc(N), MLD->getChain(), MLD->getBasePtr(),
10765         MLD->getPointerInfo(), MLD->getOriginalAlign(),
10766         MachineMemOperand::MOLoad, MLD->getAAInfo(), MLD->getRanges());
10767     return CombineTo(N, NewLd, NewLd.getValue(1));
10768   }
10769 
10770   // Try transforming N to an indexed load.
10771   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
10772     return SDValue(N, 0);
10773 
10774   return SDValue();
10775 }
10776 
10777 /// A vector select of 2 constant vectors can be simplified to math/logic to
10778 /// avoid a variable select instruction and possibly avoid constant loads.
foldVSelectOfConstants(SDNode * N)10779 SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
10780   SDValue Cond = N->getOperand(0);
10781   SDValue N1 = N->getOperand(1);
10782   SDValue N2 = N->getOperand(2);
10783   EVT VT = N->getValueType(0);
10784   if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
10785       !TLI.convertSelectOfConstantsToMath(VT) ||
10786       !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
10787       !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
10788     return SDValue();
10789 
10790   // Check if we can use the condition value to increment/decrement a single
10791   // constant value. This simplifies a select to an add and removes a constant
10792   // load/materialization from the general case.
10793   bool AllAddOne = true;
10794   bool AllSubOne = true;
10795   unsigned Elts = VT.getVectorNumElements();
10796   for (unsigned i = 0; i != Elts; ++i) {
10797     SDValue N1Elt = N1.getOperand(i);
10798     SDValue N2Elt = N2.getOperand(i);
10799     if (N1Elt.isUndef() || N2Elt.isUndef())
10800       continue;
10801     if (N1Elt.getValueType() != N2Elt.getValueType())
10802       continue;
10803 
10804     const APInt &C1 = cast<ConstantSDNode>(N1Elt)->getAPIntValue();
10805     const APInt &C2 = cast<ConstantSDNode>(N2Elt)->getAPIntValue();
10806     if (C1 != C2 + 1)
10807       AllAddOne = false;
10808     if (C1 != C2 - 1)
10809       AllSubOne = false;
10810   }
10811 
10812   // Further simplifications for the extra-special cases where the constants are
10813   // all 0 or all -1 should be implemented as folds of these patterns.
10814   SDLoc DL(N);
10815   if (AllAddOne || AllSubOne) {
10816     // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
10817     // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
10818     auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
10819     SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
10820     return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
10821   }
10822 
10823   // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
10824   APInt Pow2C;
10825   if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
10826       isNullOrNullSplat(N2)) {
10827     SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
10828     SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
10829     return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
10830   }
10831 
10832   if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
10833     return V;
10834 
10835   // The general case for select-of-constants:
10836   // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
10837   // ...but that only makes sense if a vselect is slower than 2 logic ops, so
10838   // leave that to a machine-specific pass.
10839   return SDValue();
10840 }
10841 
visitVSELECT(SDNode * N)10842 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
10843   SDValue N0 = N->getOperand(0);
10844   SDValue N1 = N->getOperand(1);
10845   SDValue N2 = N->getOperand(2);
10846   EVT VT = N->getValueType(0);
10847   SDLoc DL(N);
10848 
10849   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
10850     return V;
10851 
10852   if (SDValue V = foldBoolSelectToLogic(N, DAG))
10853     return V;
10854 
10855   // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
10856   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
10857     return DAG.getSelect(DL, VT, F, N2, N1);
10858 
10859   // Canonicalize integer abs.
10860   // vselect (setg[te] X,  0),  X, -X ->
10861   // vselect (setgt    X, -1),  X, -X ->
10862   // vselect (setl[te] X,  0), -X,  X ->
10863   // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
10864   if (N0.getOpcode() == ISD::SETCC) {
10865     SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
10866     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
10867     bool isAbs = false;
10868     bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
10869 
10870     if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
10871          (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
10872         N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
10873       isAbs = ISD::isBuildVectorAllZeros(N2.getOperand(0).getNode());
10874     else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
10875              N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
10876       isAbs = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
10877 
10878     if (isAbs) {
10879       if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
10880         return DAG.getNode(ISD::ABS, DL, VT, LHS);
10881 
10882       SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
10883                                   DAG.getConstant(VT.getScalarSizeInBits() - 1,
10884                                                   DL, getShiftAmountTy(VT)));
10885       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
10886       AddToWorklist(Shift.getNode());
10887       AddToWorklist(Add.getNode());
10888       return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
10889     }
10890 
10891     // vselect x, y (fcmp lt x, y) -> fminnum x, y
10892     // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
10893     //
10894     // This is OK if we don't care about what happens if either operand is a
10895     // NaN.
10896     //
10897     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) {
10898       if (SDValue FMinMax =
10899               combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC, TLI, DAG))
10900         return FMinMax;
10901     }
10902 
10903     if (SDValue S = PerformMinMaxFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
10904       return S;
10905     if (SDValue S = PerformUMinFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
10906       return S;
10907 
10908     // If this select has a condition (setcc) with narrower operands than the
10909     // select, try to widen the compare to match the select width.
10910     // TODO: This should be extended to handle any constant.
10911     // TODO: This could be extended to handle non-loading patterns, but that
10912     //       requires thorough testing to avoid regressions.
10913     if (isNullOrNullSplat(RHS)) {
10914       EVT NarrowVT = LHS.getValueType();
10915       EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
10916       EVT SetCCVT = getSetCCResultType(LHS.getValueType());
10917       unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
10918       unsigned WideWidth = WideVT.getScalarSizeInBits();
10919       bool IsSigned = isSignedIntSetCC(CC);
10920       auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
10921       if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
10922           SetCCWidth != 1 && SetCCWidth < WideWidth &&
10923           TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
10924           TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
10925         // Both compare operands can be widened for free. The LHS can use an
10926         // extended load, and the RHS is a constant:
10927         //   vselect (ext (setcc load(X), C)), N1, N2 -->
10928         //   vselect (setcc extload(X), C'), N1, N2
10929         auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
10930         SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
10931         SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
10932         EVT WideSetCCVT = getSetCCResultType(WideVT);
10933         SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
10934         return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
10935       }
10936     }
10937 
10938     // Match VSELECTs into add with unsigned saturation.
10939     if (hasOperation(ISD::UADDSAT, VT)) {
10940       // Check if one of the arms of the VSELECT is vector with all bits set.
10941       // If it's on the left side invert the predicate to simplify logic below.
10942       SDValue Other;
10943       ISD::CondCode SatCC = CC;
10944       if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) {
10945         Other = N2;
10946         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
10947       } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) {
10948         Other = N1;
10949       }
10950 
10951       if (Other && Other.getOpcode() == ISD::ADD) {
10952         SDValue CondLHS = LHS, CondRHS = RHS;
10953         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
10954 
10955         // Canonicalize condition operands.
10956         if (SatCC == ISD::SETUGE) {
10957           std::swap(CondLHS, CondRHS);
10958           SatCC = ISD::SETULE;
10959         }
10960 
10961         // We can test against either of the addition operands.
10962         // x <= x+y ? x+y : ~0 --> uaddsat x, y
10963         // x+y >= x ? x+y : ~0 --> uaddsat x, y
10964         if (SatCC == ISD::SETULE && Other == CondRHS &&
10965             (OpLHS == CondLHS || OpRHS == CondLHS))
10966           return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
10967 
10968         if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
10969             (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
10970              OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
10971             CondLHS == OpLHS) {
10972           // If the RHS is a constant we have to reverse the const
10973           // canonicalization.
10974           // x >= ~C ? x+C : ~0 --> uaddsat x, C
10975           auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
10976             return Cond->getAPIntValue() == ~Op->getAPIntValue();
10977           };
10978           if (SatCC == ISD::SETULE &&
10979               ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT))
10980             return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
10981         }
10982       }
10983     }
10984 
10985     // Match VSELECTs into sub with unsigned saturation.
10986     if (hasOperation(ISD::USUBSAT, VT)) {
10987       // Check if one of the arms of the VSELECT is a zero vector. If it's on
10988       // the left side invert the predicate to simplify logic below.
10989       SDValue Other;
10990       ISD::CondCode SatCC = CC;
10991       if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
10992         Other = N2;
10993         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
10994       } else if (ISD::isConstantSplatVectorAllZeros(N2.getNode())) {
10995         Other = N1;
10996       }
10997 
10998       // zext(x) >= y ? trunc(zext(x) - y) : 0
10999       // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
11000       // zext(x) >  y ? trunc(zext(x) - y) : 0
11001       // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
11002       if (Other && Other.getOpcode() == ISD::TRUNCATE &&
11003           Other.getOperand(0).getOpcode() == ISD::SUB &&
11004           (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
11005         SDValue OpLHS = Other.getOperand(0).getOperand(0);
11006         SDValue OpRHS = Other.getOperand(0).getOperand(1);
11007         if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
11008           if (SDValue R = getTruncatedUSUBSAT(VT, LHS.getValueType(), LHS, RHS,
11009                                               DAG, DL))
11010             return R;
11011       }
11012 
11013       if (Other && Other.getNumOperands() == 2) {
11014         SDValue CondRHS = RHS;
11015         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
11016 
11017         if (OpLHS == LHS) {
11018           // Look for a general sub with unsigned saturation first.
11019           // x >= y ? x-y : 0 --> usubsat x, y
11020           // x >  y ? x-y : 0 --> usubsat x, y
11021           if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
11022               Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
11023             return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
11024 
11025           if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
11026               OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
11027             if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
11028                 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
11029               // If the RHS is a constant we have to reverse the const
11030               // canonicalization.
11031               // x > C-1 ? x+-C : 0 --> usubsat x, C
11032               auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
11033                 return (!Op && !Cond) ||
11034                        (Op && Cond &&
11035                         Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
11036               };
11037               if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
11038                   ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
11039                                             /*AllowUndefs*/ true)) {
11040                 OpRHS = DAG.getNode(ISD::SUB, DL, VT,
11041                                     DAG.getConstant(0, DL, VT), OpRHS);
11042                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
11043               }
11044 
11045               // Another special case: If C was a sign bit, the sub has been
11046               // canonicalized into a xor.
11047               // FIXME: Would it be better to use computeKnownBits to
11048               // determine whether it's safe to decanonicalize the xor?
11049               // x s< 0 ? x^C : 0 --> usubsat x, C
11050               APInt SplatValue;
11051               if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
11052                   ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) &&
11053                   ISD::isConstantSplatVectorAllZeros(CondRHS.getNode()) &&
11054                   SplatValue.isSignMask()) {
11055                 // Note that we have to rebuild the RHS constant here to
11056                 // ensure we don't rely on particular values of undef lanes.
11057                 OpRHS = DAG.getConstant(SplatValue, DL, VT);
11058                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
11059               }
11060             }
11061           }
11062         }
11063       }
11064     }
11065   }
11066 
11067   if (SimplifySelectOps(N, N1, N2))
11068     return SDValue(N, 0);  // Don't revisit N.
11069 
11070   // Fold (vselect all_ones, N1, N2) -> N1
11071   if (ISD::isConstantSplatVectorAllOnes(N0.getNode()))
11072     return N1;
11073   // Fold (vselect all_zeros, N1, N2) -> N2
11074   if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
11075     return N2;
11076 
11077   // The ConvertSelectToConcatVector function is assuming both the above
11078   // checks for (vselect (build_vector all{ones,zeros) ...) have been made
11079   // and addressed.
11080   if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
11081       N2.getOpcode() == ISD::CONCAT_VECTORS &&
11082       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
11083     if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
11084       return CV;
11085   }
11086 
11087   if (SDValue V = foldVSelectOfConstants(N))
11088     return V;
11089 
11090   if (hasOperation(ISD::SRA, VT))
11091     if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
11092       return V;
11093 
11094   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
11095     return SDValue(N, 0);
11096 
11097   return SDValue();
11098 }
11099 
visitSELECT_CC(SDNode * N)11100 SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
11101   SDValue N0 = N->getOperand(0);
11102   SDValue N1 = N->getOperand(1);
11103   SDValue N2 = N->getOperand(2);
11104   SDValue N3 = N->getOperand(3);
11105   SDValue N4 = N->getOperand(4);
11106   ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
11107 
11108   // fold select_cc lhs, rhs, x, x, cc -> x
11109   if (N2 == N3)
11110     return N2;
11111 
11112   // Determine if the condition we're dealing with is constant
11113   if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
11114                                   CC, SDLoc(N), false)) {
11115     AddToWorklist(SCC.getNode());
11116 
11117     // cond always true -> true val
11118     // cond always false -> false val
11119     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode()))
11120       return SCCC->isZero() ? N3 : N2;
11121 
11122     // When the condition is UNDEF, just return the first operand. This is
11123     // coherent the DAG creation, no setcc node is created in this case
11124     if (SCC->isUndef())
11125       return N2;
11126 
11127     // Fold to a simpler select_cc
11128     if (SCC.getOpcode() == ISD::SETCC) {
11129       SDValue SelectOp = DAG.getNode(
11130           ISD::SELECT_CC, SDLoc(N), N2.getValueType(), SCC.getOperand(0),
11131           SCC.getOperand(1), N2, N3, SCC.getOperand(2));
11132       SelectOp->setFlags(SCC->getFlags());
11133       return SelectOp;
11134     }
11135   }
11136 
11137   // If we can fold this based on the true/false value, do so.
11138   if (SimplifySelectOps(N, N2, N3))
11139     return SDValue(N, 0);  // Don't revisit N.
11140 
11141   // fold select_cc into other things, such as min/max/abs
11142   return SimplifySelectCC(SDLoc(N), N0, N1, N2, N3, CC);
11143 }
11144 
visitSETCC(SDNode * N)11145 SDValue DAGCombiner::visitSETCC(SDNode *N) {
11146   // setcc is very commonly used as an argument to brcond. This pattern
11147   // also lend itself to numerous combines and, as a result, it is desired
11148   // we keep the argument to a brcond as a setcc as much as possible.
11149   bool PreferSetCC =
11150       N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
11151 
11152   ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
11153   EVT VT = N->getValueType(0);
11154 
11155   //   SETCC(FREEZE(X), CONST, Cond)
11156   // =>
11157   //   FREEZE(SETCC(X, CONST, Cond))
11158   // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
11159   // isn't equivalent to true or false.
11160   // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
11161   // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
11162   //
11163   // This transformation is beneficial because visitBRCOND can fold
11164   // BRCOND(FREEZE(X)) to BRCOND(X).
11165 
11166   // Conservatively optimize integer comparisons only.
11167   if (PreferSetCC) {
11168     // Do this only when SETCC is going to be used by BRCOND.
11169 
11170     SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
11171     ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
11172     ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
11173     bool Updated = false;
11174 
11175     // Is 'X Cond C' always true or false?
11176     auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
11177       bool False = (Cond == ISD::SETULT && C->isZero()) ||
11178                    (Cond == ISD::SETLT  && C->isMinSignedValue()) ||
11179                    (Cond == ISD::SETUGT && C->isAllOnes()) ||
11180                    (Cond == ISD::SETGT  && C->isMaxSignedValue());
11181       bool True =  (Cond == ISD::SETULE && C->isAllOnes()) ||
11182                    (Cond == ISD::SETLE  && C->isMaxSignedValue()) ||
11183                    (Cond == ISD::SETUGE && C->isZero()) ||
11184                    (Cond == ISD::SETGE  && C->isMinSignedValue());
11185       return True || False;
11186     };
11187 
11188     if (N0->getOpcode() == ISD::FREEZE && N0.hasOneUse() && N1C) {
11189       if (!IsAlwaysTrueOrFalse(Cond, N1C)) {
11190         N0 = N0->getOperand(0);
11191         Updated = true;
11192       }
11193     }
11194     if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse() && N0C) {
11195       if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Cond),
11196                                N0C)) {
11197         N1 = N1->getOperand(0);
11198         Updated = true;
11199       }
11200     }
11201 
11202     if (Updated)
11203       return DAG.getFreeze(DAG.getSetCC(SDLoc(N), VT, N0, N1, Cond));
11204   }
11205 
11206   SDValue Combined = SimplifySetCC(VT, N->getOperand(0), N->getOperand(1), Cond,
11207                                    SDLoc(N), !PreferSetCC);
11208 
11209   if (!Combined)
11210     return SDValue();
11211 
11212   // If we prefer to have a setcc, and we don't, we'll try our best to
11213   // recreate one using rebuildSetCC.
11214   if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
11215     SDValue NewSetCC = rebuildSetCC(Combined);
11216 
11217     // We don't have anything interesting to combine to.
11218     if (NewSetCC.getNode() == N)
11219       return SDValue();
11220 
11221     if (NewSetCC)
11222       return NewSetCC;
11223   }
11224 
11225   return Combined;
11226 }
11227 
visitSETCCCARRY(SDNode * N)11228 SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
11229   SDValue LHS = N->getOperand(0);
11230   SDValue RHS = N->getOperand(1);
11231   SDValue Carry = N->getOperand(2);
11232   SDValue Cond = N->getOperand(3);
11233 
11234   // If Carry is false, fold to a regular SETCC.
11235   if (isNullConstant(Carry))
11236     return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
11237 
11238   return SDValue();
11239 }
11240 
11241 /// Check if N satisfies:
11242 ///   N is used once.
11243 ///   N is a Load.
11244 ///   The load is compatible with ExtOpcode. It means
11245 ///     If load has explicit zero/sign extension, ExpOpcode must have the same
11246 ///     extension.
11247 ///     Otherwise returns true.
isCompatibleLoad(SDValue N,unsigned ExtOpcode)11248 static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
11249   if (!N.hasOneUse())
11250     return false;
11251 
11252   if (!isa<LoadSDNode>(N))
11253     return false;
11254 
11255   LoadSDNode *Load = cast<LoadSDNode>(N);
11256   ISD::LoadExtType LoadExt = Load->getExtensionType();
11257   if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
11258     return true;
11259 
11260   // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
11261   // extension.
11262   if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
11263       (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
11264     return false;
11265 
11266   return true;
11267 }
11268 
11269 /// Fold
11270 ///   (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
11271 ///   (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
11272 ///   (aext (select c, load x, load y)) -> (select c, extload x, extload y)
11273 /// This function is called by the DAGCombiner when visiting sext/zext/aext
11274 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
tryToFoldExtendSelectLoad(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG)11275 static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
11276                                          SelectionDAG &DAG) {
11277   unsigned Opcode = N->getOpcode();
11278   SDValue N0 = N->getOperand(0);
11279   EVT VT = N->getValueType(0);
11280   SDLoc DL(N);
11281 
11282   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
11283           Opcode == ISD::ANY_EXTEND) &&
11284          "Expected EXTEND dag node in input!");
11285 
11286   if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
11287       !N0.hasOneUse())
11288     return SDValue();
11289 
11290   SDValue Op1 = N0->getOperand(1);
11291   SDValue Op2 = N0->getOperand(2);
11292   if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
11293     return SDValue();
11294 
11295   auto ExtLoadOpcode = ISD::EXTLOAD;
11296   if (Opcode == ISD::SIGN_EXTEND)
11297     ExtLoadOpcode = ISD::SEXTLOAD;
11298   else if (Opcode == ISD::ZERO_EXTEND)
11299     ExtLoadOpcode = ISD::ZEXTLOAD;
11300 
11301   LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
11302   LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
11303   if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
11304       !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()))
11305     return SDValue();
11306 
11307   SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
11308   SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
11309   return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
11310 }
11311 
11312 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
11313 /// a build_vector of constants.
11314 /// This function is called by the DAGCombiner when visiting sext/zext/aext
11315 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
11316 /// Vector extends are not folded if operations are legal; this is to
11317 /// avoid introducing illegal build_vector dag nodes.
tryToFoldExtendOfConstant(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalTypes)11318 static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI,
11319                                          SelectionDAG &DAG, bool LegalTypes) {
11320   unsigned Opcode = N->getOpcode();
11321   SDValue N0 = N->getOperand(0);
11322   EVT VT = N->getValueType(0);
11323   SDLoc DL(N);
11324 
11325   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
11326          Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
11327          Opcode == ISD::ZERO_EXTEND_VECTOR_INREG)
11328          && "Expected EXTEND dag node in input!");
11329 
11330   // fold (sext c1) -> c1
11331   // fold (zext c1) -> c1
11332   // fold (aext c1) -> c1
11333   if (isa<ConstantSDNode>(N0))
11334     return DAG.getNode(Opcode, DL, VT, N0);
11335 
11336   // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
11337   // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
11338   // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
11339   if (N0->getOpcode() == ISD::SELECT) {
11340     SDValue Op1 = N0->getOperand(1);
11341     SDValue Op2 = N0->getOperand(2);
11342     if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
11343         (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
11344       // For any_extend, choose sign extension of the constants to allow a
11345       // possible further transform to sign_extend_inreg.i.e.
11346       //
11347       // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
11348       // t2: i64 = any_extend t1
11349       // -->
11350       // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
11351       // -->
11352       // t4: i64 = sign_extend_inreg t3
11353       unsigned FoldOpc = Opcode;
11354       if (FoldOpc == ISD::ANY_EXTEND)
11355         FoldOpc = ISD::SIGN_EXTEND;
11356       return DAG.getSelect(DL, VT, N0->getOperand(0),
11357                            DAG.getNode(FoldOpc, DL, VT, Op1),
11358                            DAG.getNode(FoldOpc, DL, VT, Op2));
11359     }
11360   }
11361 
11362   // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
11363   // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
11364   // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
11365   EVT SVT = VT.getScalarType();
11366   if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
11367       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
11368     return SDValue();
11369 
11370   // We can fold this node into a build_vector.
11371   unsigned VTBits = SVT.getSizeInBits();
11372   unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
11373   SmallVector<SDValue, 8> Elts;
11374   unsigned NumElts = VT.getVectorNumElements();
11375 
11376   // For zero-extensions, UNDEF elements still guarantee to have the upper
11377   // bits set to zero.
11378   bool IsZext =
11379       Opcode == ISD::ZERO_EXTEND || Opcode == ISD::ZERO_EXTEND_VECTOR_INREG;
11380 
11381   for (unsigned i = 0; i != NumElts; ++i) {
11382     SDValue Op = N0.getOperand(i);
11383     if (Op.isUndef()) {
11384       Elts.push_back(IsZext ? DAG.getConstant(0, DL, SVT) : DAG.getUNDEF(SVT));
11385       continue;
11386     }
11387 
11388     SDLoc DL(Op);
11389     // Get the constant value and if needed trunc it to the size of the type.
11390     // Nodes like build_vector might have constants wider than the scalar type.
11391     APInt C = cast<ConstantSDNode>(Op)->getAPIntValue().zextOrTrunc(EVTBits);
11392     if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
11393       Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
11394     else
11395       Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
11396   }
11397 
11398   return DAG.getBuildVector(VT, DL, Elts);
11399 }
11400 
11401 // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
11402 // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
11403 // transformation. Returns true if extension are possible and the above
11404 // mentioned transformation is profitable.
ExtendUsesToFormExtLoad(EVT VT,SDNode * N,SDValue N0,unsigned ExtOpc,SmallVectorImpl<SDNode * > & ExtendNodes,const TargetLowering & TLI)11405 static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
11406                                     unsigned ExtOpc,
11407                                     SmallVectorImpl<SDNode *> &ExtendNodes,
11408                                     const TargetLowering &TLI) {
11409   bool HasCopyToRegUses = false;
11410   bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
11411   for (SDNode::use_iterator UI = N0->use_begin(), UE = N0->use_end(); UI != UE;
11412        ++UI) {
11413     SDNode *User = *UI;
11414     if (User == N)
11415       continue;
11416     if (UI.getUse().getResNo() != N0.getResNo())
11417       continue;
11418     // FIXME: Only extend SETCC N, N and SETCC N, c for now.
11419     if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
11420       ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
11421       if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
11422         // Sign bits will be lost after a zext.
11423         return false;
11424       bool Add = false;
11425       for (unsigned i = 0; i != 2; ++i) {
11426         SDValue UseOp = User->getOperand(i);
11427         if (UseOp == N0)
11428           continue;
11429         if (!isa<ConstantSDNode>(UseOp))
11430           return false;
11431         Add = true;
11432       }
11433       if (Add)
11434         ExtendNodes.push_back(User);
11435       continue;
11436     }
11437     // If truncates aren't free and there are users we can't
11438     // extend, it isn't worthwhile.
11439     if (!isTruncFree)
11440       return false;
11441     // Remember if this value is live-out.
11442     if (User->getOpcode() == ISD::CopyToReg)
11443       HasCopyToRegUses = true;
11444   }
11445 
11446   if (HasCopyToRegUses) {
11447     bool BothLiveOut = false;
11448     for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
11449          UI != UE; ++UI) {
11450       SDUse &Use = UI.getUse();
11451       if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
11452         BothLiveOut = true;
11453         break;
11454       }
11455     }
11456     if (BothLiveOut)
11457       // Both unextended and extended values are live out. There had better be
11458       // a good reason for the transformation.
11459       return ExtendNodes.size();
11460   }
11461   return true;
11462 }
11463 
ExtendSetCCUses(const SmallVectorImpl<SDNode * > & SetCCs,SDValue OrigLoad,SDValue ExtLoad,ISD::NodeType ExtType)11464 void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
11465                                   SDValue OrigLoad, SDValue ExtLoad,
11466                                   ISD::NodeType ExtType) {
11467   // Extend SetCC uses if necessary.
11468   SDLoc DL(ExtLoad);
11469   for (SDNode *SetCC : SetCCs) {
11470     SmallVector<SDValue, 4> Ops;
11471 
11472     for (unsigned j = 0; j != 2; ++j) {
11473       SDValue SOp = SetCC->getOperand(j);
11474       if (SOp == OrigLoad)
11475         Ops.push_back(ExtLoad);
11476       else
11477         Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
11478     }
11479 
11480     Ops.push_back(SetCC->getOperand(2));
11481     CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
11482   }
11483 }
11484 
11485 // FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
CombineExtLoad(SDNode * N)11486 SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
11487   SDValue N0 = N->getOperand(0);
11488   EVT DstVT = N->getValueType(0);
11489   EVT SrcVT = N0.getValueType();
11490 
11491   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
11492           N->getOpcode() == ISD::ZERO_EXTEND) &&
11493          "Unexpected node type (not an extend)!");
11494 
11495   // fold (sext (load x)) to multiple smaller sextloads; same for zext.
11496   // For example, on a target with legal v4i32, but illegal v8i32, turn:
11497   //   (v8i32 (sext (v8i16 (load x))))
11498   // into:
11499   //   (v8i32 (concat_vectors (v4i32 (sextload x)),
11500   //                          (v4i32 (sextload (x + 16)))))
11501   // Where uses of the original load, i.e.:
11502   //   (v8i16 (load x))
11503   // are replaced with:
11504   //   (v8i16 (truncate
11505   //     (v8i32 (concat_vectors (v4i32 (sextload x)),
11506   //                            (v4i32 (sextload (x + 16)))))))
11507   //
11508   // This combine is only applicable to illegal, but splittable, vectors.
11509   // All legal types, and illegal non-vector types, are handled elsewhere.
11510   // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
11511   //
11512   if (N0->getOpcode() != ISD::LOAD)
11513     return SDValue();
11514 
11515   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11516 
11517   if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
11518       !N0.hasOneUse() || !LN0->isSimple() ||
11519       !DstVT.isVector() || !DstVT.isPow2VectorType() ||
11520       !TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
11521     return SDValue();
11522 
11523   SmallVector<SDNode *, 4> SetCCs;
11524   if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
11525     return SDValue();
11526 
11527   ISD::LoadExtType ExtType =
11528       N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
11529 
11530   // Try to split the vector types to get down to legal types.
11531   EVT SplitSrcVT = SrcVT;
11532   EVT SplitDstVT = DstVT;
11533   while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
11534          SplitSrcVT.getVectorNumElements() > 1) {
11535     SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
11536     SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
11537   }
11538 
11539   if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
11540     return SDValue();
11541 
11542   assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
11543 
11544   SDLoc DL(N);
11545   const unsigned NumSplits =
11546       DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
11547   const unsigned Stride = SplitSrcVT.getStoreSize();
11548   SmallVector<SDValue, 4> Loads;
11549   SmallVector<SDValue, 4> Chains;
11550 
11551   SDValue BasePtr = LN0->getBasePtr();
11552   for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
11553     const unsigned Offset = Idx * Stride;
11554     const Align Align = commonAlignment(LN0->getAlign(), Offset);
11555 
11556     SDValue SplitLoad = DAG.getExtLoad(
11557         ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(), BasePtr,
11558         LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, Align,
11559         LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
11560 
11561     BasePtr = DAG.getMemBasePlusOffset(BasePtr, TypeSize::Fixed(Stride), DL);
11562 
11563     Loads.push_back(SplitLoad.getValue(0));
11564     Chains.push_back(SplitLoad.getValue(1));
11565   }
11566 
11567   SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
11568   SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
11569 
11570   // Simplify TF.
11571   AddToWorklist(NewChain.getNode());
11572 
11573   CombineTo(N, NewValue);
11574 
11575   // Replace uses of the original load (before extension)
11576   // with a truncate of the concatenated sextloaded vectors.
11577   SDValue Trunc =
11578       DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
11579   ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
11580   CombineTo(N0.getNode(), Trunc, NewChain);
11581   return SDValue(N, 0); // Return N so it doesn't get rechecked!
11582 }
11583 
11584 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
11585 //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
CombineZExtLogicopShiftLoad(SDNode * N)11586 SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
11587   assert(N->getOpcode() == ISD::ZERO_EXTEND);
11588   EVT VT = N->getValueType(0);
11589   EVT OrigVT = N->getOperand(0).getValueType();
11590   if (TLI.isZExtFree(OrigVT, VT))
11591     return SDValue();
11592 
11593   // and/or/xor
11594   SDValue N0 = N->getOperand(0);
11595   if (!(N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
11596         N0.getOpcode() == ISD::XOR) ||
11597       N0.getOperand(1).getOpcode() != ISD::Constant ||
11598       (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
11599     return SDValue();
11600 
11601   // shl/shr
11602   SDValue N1 = N0->getOperand(0);
11603   if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
11604       N1.getOperand(1).getOpcode() != ISD::Constant ||
11605       (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
11606     return SDValue();
11607 
11608   // load
11609   if (!isa<LoadSDNode>(N1.getOperand(0)))
11610     return SDValue();
11611   LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
11612   EVT MemVT = Load->getMemoryVT();
11613   if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
11614       Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
11615     return SDValue();
11616 
11617 
11618   // If the shift op is SHL, the logic op must be AND, otherwise the result
11619   // will be wrong.
11620   if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
11621     return SDValue();
11622 
11623   if (!N0.hasOneUse() || !N1.hasOneUse())
11624     return SDValue();
11625 
11626   SmallVector<SDNode*, 4> SetCCs;
11627   if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
11628                                ISD::ZERO_EXTEND, SetCCs, TLI))
11629     return SDValue();
11630 
11631   // Actually do the transformation.
11632   SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
11633                                    Load->getChain(), Load->getBasePtr(),
11634                                    Load->getMemoryVT(), Load->getMemOperand());
11635 
11636   SDLoc DL1(N1);
11637   SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
11638                               N1.getOperand(1));
11639 
11640   APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
11641   SDLoc DL0(N0);
11642   SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
11643                             DAG.getConstant(Mask, DL0, VT));
11644 
11645   ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
11646   CombineTo(N, And);
11647   if (SDValue(Load, 0).hasOneUse()) {
11648     DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
11649   } else {
11650     SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
11651                                 Load->getValueType(0), ExtLoad);
11652     CombineTo(Load, Trunc, ExtLoad.getValue(1));
11653   }
11654 
11655   // N0 is dead at this point.
11656   recursivelyDeleteUnusedNodes(N0.getNode());
11657 
11658   return SDValue(N,0); // Return N so it doesn't get rechecked!
11659 }
11660 
11661 /// If we're narrowing or widening the result of a vector select and the final
11662 /// size is the same size as a setcc (compare) feeding the select, then try to
11663 /// apply the cast operation to the select's operands because matching vector
11664 /// sizes for a select condition and other operands should be more efficient.
matchVSelectOpSizesWithSetCC(SDNode * Cast)11665 SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
11666   unsigned CastOpcode = Cast->getOpcode();
11667   assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
11668           CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
11669           CastOpcode == ISD::FP_ROUND) &&
11670          "Unexpected opcode for vector select narrowing/widening");
11671 
11672   // We only do this transform before legal ops because the pattern may be
11673   // obfuscated by target-specific operations after legalization. Do not create
11674   // an illegal select op, however, because that may be difficult to lower.
11675   EVT VT = Cast->getValueType(0);
11676   if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
11677     return SDValue();
11678 
11679   SDValue VSel = Cast->getOperand(0);
11680   if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
11681       VSel.getOperand(0).getOpcode() != ISD::SETCC)
11682     return SDValue();
11683 
11684   // Does the setcc have the same vector size as the casted select?
11685   SDValue SetCC = VSel.getOperand(0);
11686   EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
11687   if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
11688     return SDValue();
11689 
11690   // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
11691   SDValue A = VSel.getOperand(1);
11692   SDValue B = VSel.getOperand(2);
11693   SDValue CastA, CastB;
11694   SDLoc DL(Cast);
11695   if (CastOpcode == ISD::FP_ROUND) {
11696     // FP_ROUND (fptrunc) has an extra flag operand to pass along.
11697     CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
11698     CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
11699   } else {
11700     CastA = DAG.getNode(CastOpcode, DL, VT, A);
11701     CastB = DAG.getNode(CastOpcode, DL, VT, B);
11702   }
11703   return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
11704 }
11705 
11706 // fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
11707 // fold ([s|z]ext (     extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
tryToFoldExtOfExtload(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType)11708 static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
11709                                      const TargetLowering &TLI, EVT VT,
11710                                      bool LegalOperations, SDNode *N,
11711                                      SDValue N0, ISD::LoadExtType ExtLoadType) {
11712   SDNode *N0Node = N0.getNode();
11713   bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
11714                                                    : ISD::isZEXTLoad(N0Node);
11715   if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
11716       !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
11717     return SDValue();
11718 
11719   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11720   EVT MemVT = LN0->getMemoryVT();
11721   if ((LegalOperations || !LN0->isSimple() ||
11722        VT.isVector()) &&
11723       !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
11724     return SDValue();
11725 
11726   SDValue ExtLoad =
11727       DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
11728                      LN0->getBasePtr(), MemVT, LN0->getMemOperand());
11729   Combiner.CombineTo(N, ExtLoad);
11730   DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
11731   if (LN0->use_empty())
11732     Combiner.recursivelyDeleteUnusedNodes(LN0);
11733   return SDValue(N, 0); // Return N so it doesn't get rechecked!
11734 }
11735 
11736 // fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
11737 // Only generate vector extloads when 1) they're legal, and 2) they are
11738 // deemed desirable by the target.
tryToFoldExtOfLoad(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)11739 static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
11740                                   const TargetLowering &TLI, EVT VT,
11741                                   bool LegalOperations, SDNode *N, SDValue N0,
11742                                   ISD::LoadExtType ExtLoadType,
11743                                   ISD::NodeType ExtOpc) {
11744   // TODO: isFixedLengthVector() should be removed and any negative effects on
11745   // code generation being the result of that target's implementation of
11746   // isVectorLoadExtDesirable().
11747   if (!ISD::isNON_EXTLoad(N0.getNode()) ||
11748       !ISD::isUNINDEXEDLoad(N0.getNode()) ||
11749       ((LegalOperations || VT.isFixedLengthVector() ||
11750         !cast<LoadSDNode>(N0)->isSimple()) &&
11751        !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType())))
11752     return {};
11753 
11754   bool DoXform = true;
11755   SmallVector<SDNode *, 4> SetCCs;
11756   if (!N0.hasOneUse())
11757     DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
11758   if (VT.isVector())
11759     DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
11760   if (!DoXform)
11761     return {};
11762 
11763   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11764   SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
11765                                    LN0->getBasePtr(), N0.getValueType(),
11766                                    LN0->getMemOperand());
11767   Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
11768   // If the load value is used only by N, replace it via CombineTo N.
11769   bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
11770   Combiner.CombineTo(N, ExtLoad);
11771   if (NoReplaceTrunc) {
11772     DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
11773     Combiner.recursivelyDeleteUnusedNodes(LN0);
11774   } else {
11775     SDValue Trunc =
11776         DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
11777     Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
11778   }
11779   return SDValue(N, 0); // Return N so it doesn't get rechecked!
11780 }
11781 
tryToFoldExtOfMaskedLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)11782 static SDValue tryToFoldExtOfMaskedLoad(SelectionDAG &DAG,
11783                                         const TargetLowering &TLI, EVT VT,
11784                                         SDNode *N, SDValue N0,
11785                                         ISD::LoadExtType ExtLoadType,
11786                                         ISD::NodeType ExtOpc) {
11787   if (!N0.hasOneUse())
11788     return SDValue();
11789 
11790   MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0);
11791   if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
11792     return SDValue();
11793 
11794   if (!TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0)))
11795     return SDValue();
11796 
11797   if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
11798     return SDValue();
11799 
11800   SDLoc dl(Ld);
11801   SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
11802   SDValue NewLoad = DAG.getMaskedLoad(
11803       VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
11804       PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
11805       ExtLoadType, Ld->isExpandingLoad());
11806   DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
11807   return NewLoad;
11808 }
11809 
foldExtendedSignBitTest(SDNode * N,SelectionDAG & DAG,bool LegalOperations)11810 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
11811                                        bool LegalOperations) {
11812   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
11813           N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
11814 
11815   SDValue SetCC = N->getOperand(0);
11816   if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
11817       !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
11818     return SDValue();
11819 
11820   SDValue X = SetCC.getOperand(0);
11821   SDValue Ones = SetCC.getOperand(1);
11822   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
11823   EVT VT = N->getValueType(0);
11824   EVT XVT = X.getValueType();
11825   // setge X, C is canonicalized to setgt, so we do not need to match that
11826   // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
11827   // not require the 'not' op.
11828   if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
11829     // Invert and smear/shift the sign bit:
11830     // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
11831     // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
11832     SDLoc DL(N);
11833     unsigned ShCt = VT.getSizeInBits() - 1;
11834     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11835     if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
11836       SDValue NotX = DAG.getNOT(DL, X, VT);
11837       SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
11838       auto ShiftOpcode =
11839         N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
11840       return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
11841     }
11842   }
11843   return SDValue();
11844 }
11845 
foldSextSetcc(SDNode * N)11846 SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
11847   SDValue N0 = N->getOperand(0);
11848   if (N0.getOpcode() != ISD::SETCC)
11849     return SDValue();
11850 
11851   SDValue N00 = N0.getOperand(0);
11852   SDValue N01 = N0.getOperand(1);
11853   ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
11854   EVT VT = N->getValueType(0);
11855   EVT N00VT = N00.getValueType();
11856   SDLoc DL(N);
11857 
11858   // Propagate fast-math-flags.
11859   SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
11860 
11861   // On some architectures (such as SSE/NEON/etc) the SETCC result type is
11862   // the same size as the compared operands. Try to optimize sext(setcc())
11863   // if this is the case.
11864   if (VT.isVector() && !LegalOperations &&
11865       TLI.getBooleanContents(N00VT) ==
11866           TargetLowering::ZeroOrNegativeOneBooleanContent) {
11867     EVT SVT = getSetCCResultType(N00VT);
11868 
11869     // If we already have the desired type, don't change it.
11870     if (SVT != N0.getValueType()) {
11871       // We know that the # elements of the results is the same as the
11872       // # elements of the compare (and the # elements of the compare result
11873       // for that matter).  Check to see that they are the same size.  If so,
11874       // we know that the element size of the sext'd result matches the
11875       // element size of the compare operands.
11876       if (VT.getSizeInBits() == SVT.getSizeInBits())
11877         return DAG.getSetCC(DL, VT, N00, N01, CC);
11878 
11879       // If the desired elements are smaller or larger than the source
11880       // elements, we can use a matching integer vector type and then
11881       // truncate/sign extend.
11882       EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
11883       if (SVT == MatchingVecType) {
11884         SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
11885         return DAG.getSExtOrTrunc(VsetCC, DL, VT);
11886       }
11887     }
11888 
11889     // Try to eliminate the sext of a setcc by zexting the compare operands.
11890     if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(ISD::SETCC, VT) &&
11891         !TLI.isOperationLegalOrCustom(ISD::SETCC, SVT)) {
11892       bool IsSignedCmp = ISD::isSignedIntSetCC(CC);
11893       unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
11894       unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
11895 
11896       // We have an unsupported narrow vector compare op that would be legal
11897       // if extended to the destination type. See if the compare operands
11898       // can be freely extended to the destination type.
11899       auto IsFreeToExtend = [&](SDValue V) {
11900         if (isConstantOrConstantVector(V, /*NoOpaques*/ true))
11901           return true;
11902         // Match a simple, non-extended load that can be converted to a
11903         // legal {z/s}ext-load.
11904         // TODO: Allow widening of an existing {z/s}ext-load?
11905         if (!(ISD::isNON_EXTLoad(V.getNode()) &&
11906               ISD::isUNINDEXEDLoad(V.getNode()) &&
11907               cast<LoadSDNode>(V)->isSimple() &&
11908               TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType())))
11909           return false;
11910 
11911         // Non-chain users of this value must either be the setcc in this
11912         // sequence or extends that can be folded into the new {z/s}ext-load.
11913         for (SDNode::use_iterator UI = V->use_begin(), UE = V->use_end();
11914              UI != UE; ++UI) {
11915           // Skip uses of the chain and the setcc.
11916           SDNode *User = *UI;
11917           if (UI.getUse().getResNo() != 0 || User == N0.getNode())
11918             continue;
11919           // Extra users must have exactly the same cast we are about to create.
11920           // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
11921           //       is enhanced similarly.
11922           if (User->getOpcode() != ExtOpcode || User->getValueType(0) != VT)
11923             return false;
11924         }
11925         return true;
11926       };
11927 
11928       if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
11929         SDValue Ext0 = DAG.getNode(ExtOpcode, DL, VT, N00);
11930         SDValue Ext1 = DAG.getNode(ExtOpcode, DL, VT, N01);
11931         return DAG.getSetCC(DL, VT, Ext0, Ext1, CC);
11932       }
11933     }
11934   }
11935 
11936   // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
11937   // Here, T can be 1 or -1, depending on the type of the setcc and
11938   // getBooleanContents().
11939   unsigned SetCCWidth = N0.getScalarValueSizeInBits();
11940 
11941   // To determine the "true" side of the select, we need to know the high bit
11942   // of the value returned by the setcc if it evaluates to true.
11943   // If the type of the setcc is i1, then the true case of the select is just
11944   // sext(i1 1), that is, -1.
11945   // If the type of the setcc is larger (say, i8) then the value of the high
11946   // bit depends on getBooleanContents(), so ask TLI for a real "true" value
11947   // of the appropriate width.
11948   SDValue ExtTrueVal = (SetCCWidth == 1)
11949                            ? DAG.getAllOnesConstant(DL, VT)
11950                            : DAG.getBoolConstant(true, DL, VT, N00VT);
11951   SDValue Zero = DAG.getConstant(0, DL, VT);
11952   if (SDValue SCC = SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
11953     return SCC;
11954 
11955   if (!VT.isVector() && !TLI.convertSelectOfConstantsToMath(VT)) {
11956     EVT SetCCVT = getSetCCResultType(N00VT);
11957     // Don't do this transform for i1 because there's a select transform
11958     // that would reverse it.
11959     // TODO: We should not do this transform at all without a target hook
11960     // because a sext is likely cheaper than a select?
11961     if (SetCCVT.getScalarSizeInBits() != 1 &&
11962         (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
11963       SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
11964       return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
11965     }
11966   }
11967 
11968   return SDValue();
11969 }
11970 
visitSIGN_EXTEND(SDNode * N)11971 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
11972   SDValue N0 = N->getOperand(0);
11973   EVT VT = N->getValueType(0);
11974   SDLoc DL(N);
11975 
11976   // sext(undef) = 0 because the top bit will all be the same.
11977   if (N0.isUndef())
11978     return DAG.getConstant(0, DL, VT);
11979 
11980   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
11981     return Res;
11982 
11983   // fold (sext (sext x)) -> (sext x)
11984   // fold (sext (aext x)) -> (sext x)
11985   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
11986     return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
11987 
11988   if (N0.getOpcode() == ISD::TRUNCATE) {
11989     // fold (sext (truncate (load x))) -> (sext (smaller load x))
11990     // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
11991     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
11992       SDNode *oye = N0.getOperand(0).getNode();
11993       if (NarrowLoad.getNode() != N0.getNode()) {
11994         CombineTo(N0.getNode(), NarrowLoad);
11995         // CombineTo deleted the truncate, if needed, but not what's under it.
11996         AddToWorklist(oye);
11997       }
11998       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
11999     }
12000 
12001     // See if the value being truncated is already sign extended.  If so, just
12002     // eliminate the trunc/sext pair.
12003     SDValue Op = N0.getOperand(0);
12004     unsigned OpBits   = Op.getScalarValueSizeInBits();
12005     unsigned MidBits  = N0.getScalarValueSizeInBits();
12006     unsigned DestBits = VT.getScalarSizeInBits();
12007     unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
12008 
12009     if (OpBits == DestBits) {
12010       // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
12011       // bits, it is already ready.
12012       if (NumSignBits > DestBits-MidBits)
12013         return Op;
12014     } else if (OpBits < DestBits) {
12015       // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
12016       // bits, just sext from i32.
12017       if (NumSignBits > OpBits-MidBits)
12018         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
12019     } else {
12020       // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
12021       // bits, just truncate to i32.
12022       if (NumSignBits > OpBits-MidBits)
12023         return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
12024     }
12025 
12026     // fold (sext (truncate x)) -> (sextinreg x).
12027     if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
12028                                                  N0.getValueType())) {
12029       if (OpBits < DestBits)
12030         Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
12031       else if (OpBits > DestBits)
12032         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
12033       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
12034                          DAG.getValueType(N0.getValueType()));
12035     }
12036   }
12037 
12038   // Try to simplify (sext (load x)).
12039   if (SDValue foldedExt =
12040           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
12041                              ISD::SEXTLOAD, ISD::SIGN_EXTEND))
12042     return foldedExt;
12043 
12044   if (SDValue foldedExt =
12045       tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::SEXTLOAD,
12046                                ISD::SIGN_EXTEND))
12047     return foldedExt;
12048 
12049   // fold (sext (load x)) to multiple smaller sextloads.
12050   // Only on illegal but splittable vectors.
12051   if (SDValue ExtLoad = CombineExtLoad(N))
12052     return ExtLoad;
12053 
12054   // Try to simplify (sext (sextload x)).
12055   if (SDValue foldedExt = tryToFoldExtOfExtload(
12056           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
12057     return foldedExt;
12058 
12059   // fold (sext (and/or/xor (load x), cst)) ->
12060   //      (and/or/xor (sextload x), (sext cst))
12061   if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
12062        N0.getOpcode() == ISD::XOR) &&
12063       isa<LoadSDNode>(N0.getOperand(0)) &&
12064       N0.getOperand(1).getOpcode() == ISD::Constant &&
12065       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
12066     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
12067     EVT MemVT = LN00->getMemoryVT();
12068     if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
12069       LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
12070       SmallVector<SDNode*, 4> SetCCs;
12071       bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
12072                                              ISD::SIGN_EXTEND, SetCCs, TLI);
12073       if (DoXform) {
12074         SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
12075                                          LN00->getChain(), LN00->getBasePtr(),
12076                                          LN00->getMemoryVT(),
12077                                          LN00->getMemOperand());
12078         APInt Mask = N0.getConstantOperandAPInt(1).sext(VT.getSizeInBits());
12079         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
12080                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
12081         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
12082         bool NoReplaceTruncAnd = !N0.hasOneUse();
12083         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
12084         CombineTo(N, And);
12085         // If N0 has multiple uses, change other uses as well.
12086         if (NoReplaceTruncAnd) {
12087           SDValue TruncAnd =
12088               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
12089           CombineTo(N0.getNode(), TruncAnd);
12090         }
12091         if (NoReplaceTrunc) {
12092           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
12093         } else {
12094           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
12095                                       LN00->getValueType(0), ExtLoad);
12096           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
12097         }
12098         return SDValue(N,0); // Return N so it doesn't get rechecked!
12099       }
12100     }
12101   }
12102 
12103   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
12104     return V;
12105 
12106   if (SDValue V = foldSextSetcc(N))
12107     return V;
12108 
12109   // fold (sext x) -> (zext x) if the sign bit is known zero.
12110   if ((!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
12111       DAG.SignBitIsZero(N0))
12112     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0);
12113 
12114   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
12115     return NewVSel;
12116 
12117   // Eliminate this sign extend by doing a negation in the destination type:
12118   // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
12119   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
12120       isNullOrNullSplat(N0.getOperand(0)) &&
12121       N0.getOperand(1).getOpcode() == ISD::ZERO_EXTEND &&
12122       TLI.isOperationLegalOrCustom(ISD::SUB, VT)) {
12123     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
12124     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Zext);
12125   }
12126   // Eliminate this sign extend by doing a decrement in the destination type:
12127   // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
12128   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
12129       isAllOnesOrAllOnesSplat(N0.getOperand(1)) &&
12130       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
12131       TLI.isOperationLegalOrCustom(ISD::ADD, VT)) {
12132     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
12133     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
12134   }
12135 
12136   // fold sext (not i1 X) -> add (zext i1 X), -1
12137   // TODO: This could be extended to handle bool vectors.
12138   if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
12139       (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
12140                             TLI.isOperationLegal(ISD::ADD, VT)))) {
12141     // If we can eliminate the 'not', the sext form should be better
12142     if (SDValue NewXor = visitXOR(N0.getNode())) {
12143       // Returning N0 is a form of in-visit replacement that may have
12144       // invalidated N0.
12145       if (NewXor.getNode() == N0.getNode()) {
12146         // Return SDValue here as the xor should have already been replaced in
12147         // this sext.
12148         return SDValue();
12149       }
12150 
12151       // Return a new sext with the new xor.
12152       return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor);
12153     }
12154 
12155     SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
12156     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
12157   }
12158 
12159   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
12160     return Res;
12161 
12162   return SDValue();
12163 }
12164 
12165 // isTruncateOf - If N is a truncate of some other value, return true, record
12166 // the value being truncated in Op and which of Op's bits are zero/one in Known.
12167 // This function computes KnownBits to avoid a duplicated call to
12168 // computeKnownBits in the caller.
isTruncateOf(SelectionDAG & DAG,SDValue N,SDValue & Op,KnownBits & Known)12169 static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
12170                          KnownBits &Known) {
12171   if (N->getOpcode() == ISD::TRUNCATE) {
12172     Op = N->getOperand(0);
12173     Known = DAG.computeKnownBits(Op);
12174     return true;
12175   }
12176 
12177   if (N.getOpcode() != ISD::SETCC ||
12178       N.getValueType().getScalarType() != MVT::i1 ||
12179       cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
12180     return false;
12181 
12182   SDValue Op0 = N->getOperand(0);
12183   SDValue Op1 = N->getOperand(1);
12184   assert(Op0.getValueType() == Op1.getValueType());
12185 
12186   if (isNullOrNullSplat(Op0))
12187     Op = Op1;
12188   else if (isNullOrNullSplat(Op1))
12189     Op = Op0;
12190   else
12191     return false;
12192 
12193   Known = DAG.computeKnownBits(Op);
12194 
12195   return (Known.Zero | 1).isAllOnes();
12196 }
12197 
12198 /// Given an extending node with a pop-count operand, if the target does not
12199 /// support a pop-count in the narrow source type but does support it in the
12200 /// destination type, widen the pop-count to the destination type.
widenCtPop(SDNode * Extend,SelectionDAG & DAG)12201 static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG) {
12202   assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
12203           Extend->getOpcode() == ISD::ANY_EXTEND) && "Expected extend op");
12204 
12205   SDValue CtPop = Extend->getOperand(0);
12206   if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
12207     return SDValue();
12208 
12209   EVT VT = Extend->getValueType(0);
12210   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12211   if (TLI.isOperationLegalOrCustom(ISD::CTPOP, CtPop.getValueType()) ||
12212       !TLI.isOperationLegalOrCustom(ISD::CTPOP, VT))
12213     return SDValue();
12214 
12215   // zext (ctpop X) --> ctpop (zext X)
12216   SDLoc DL(Extend);
12217   SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
12218   return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
12219 }
12220 
visitZERO_EXTEND(SDNode * N)12221 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
12222   SDValue N0 = N->getOperand(0);
12223   EVT VT = N->getValueType(0);
12224 
12225   // zext(undef) = 0
12226   if (N0.isUndef())
12227     return DAG.getConstant(0, SDLoc(N), VT);
12228 
12229   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
12230     return Res;
12231 
12232   // fold (zext (zext x)) -> (zext x)
12233   // fold (zext (aext x)) -> (zext x)
12234   if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
12235     return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT,
12236                        N0.getOperand(0));
12237 
12238   // fold (zext (truncate x)) -> (zext x) or
12239   //      (zext (truncate x)) -> (truncate x)
12240   // This is valid when the truncated bits of x are already zero.
12241   SDValue Op;
12242   KnownBits Known;
12243   if (isTruncateOf(DAG, N0, Op, Known)) {
12244     APInt TruncatedBits =
12245       (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
12246       APInt(Op.getScalarValueSizeInBits(), 0) :
12247       APInt::getBitsSet(Op.getScalarValueSizeInBits(),
12248                         N0.getScalarValueSizeInBits(),
12249                         std::min(Op.getScalarValueSizeInBits(),
12250                                  VT.getScalarSizeInBits()));
12251     if (TruncatedBits.isSubsetOf(Known.Zero))
12252       return DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
12253   }
12254 
12255   // fold (zext (truncate x)) -> (and x, mask)
12256   if (N0.getOpcode() == ISD::TRUNCATE) {
12257     // fold (zext (truncate (load x))) -> (zext (smaller load x))
12258     // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
12259     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
12260       SDNode *oye = N0.getOperand(0).getNode();
12261       if (NarrowLoad.getNode() != N0.getNode()) {
12262         CombineTo(N0.getNode(), NarrowLoad);
12263         // CombineTo deleted the truncate, if needed, but not what's under it.
12264         AddToWorklist(oye);
12265       }
12266       return SDValue(N, 0); // Return N so it doesn't get rechecked!
12267     }
12268 
12269     EVT SrcVT = N0.getOperand(0).getValueType();
12270     EVT MinVT = N0.getValueType();
12271 
12272     // Try to mask before the extension to avoid having to generate a larger mask,
12273     // possibly over several sub-vectors.
12274     if (SrcVT.bitsLT(VT) && VT.isVector()) {
12275       if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
12276                                TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
12277         SDValue Op = N0.getOperand(0);
12278         Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
12279         AddToWorklist(Op.getNode());
12280         SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
12281         // Transfer the debug info; the new node is equivalent to N0.
12282         DAG.transferDbgValues(N0, ZExtOrTrunc);
12283         return ZExtOrTrunc;
12284       }
12285     }
12286 
12287     if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
12288       SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
12289       AddToWorklist(Op.getNode());
12290       SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
12291       // We may safely transfer the debug info describing the truncate node over
12292       // to the equivalent and operation.
12293       DAG.transferDbgValues(N0, And);
12294       return And;
12295     }
12296   }
12297 
12298   // Fold (zext (and (trunc x), cst)) -> (and x, cst),
12299   // if either of the casts is not free.
12300   if (N0.getOpcode() == ISD::AND &&
12301       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
12302       N0.getOperand(1).getOpcode() == ISD::Constant &&
12303       (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
12304                            N0.getValueType()) ||
12305        !TLI.isZExtFree(N0.getValueType(), VT))) {
12306     SDValue X = N0.getOperand(0).getOperand(0);
12307     X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
12308     APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
12309     SDLoc DL(N);
12310     return DAG.getNode(ISD::AND, DL, VT,
12311                        X, DAG.getConstant(Mask, DL, VT));
12312   }
12313 
12314   // Try to simplify (zext (load x)).
12315   if (SDValue foldedExt =
12316           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
12317                              ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
12318     return foldedExt;
12319 
12320   if (SDValue foldedExt =
12321       tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::ZEXTLOAD,
12322                                ISD::ZERO_EXTEND))
12323     return foldedExt;
12324 
12325   // fold (zext (load x)) to multiple smaller zextloads.
12326   // Only on illegal but splittable vectors.
12327   if (SDValue ExtLoad = CombineExtLoad(N))
12328     return ExtLoad;
12329 
12330   // fold (zext (and/or/xor (load x), cst)) ->
12331   //      (and/or/xor (zextload x), (zext cst))
12332   // Unless (and (load x) cst) will match as a zextload already and has
12333   // additional users.
12334   if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
12335        N0.getOpcode() == ISD::XOR) &&
12336       isa<LoadSDNode>(N0.getOperand(0)) &&
12337       N0.getOperand(1).getOpcode() == ISD::Constant &&
12338       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
12339     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
12340     EVT MemVT = LN00->getMemoryVT();
12341     if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
12342         LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
12343       bool DoXform = true;
12344       SmallVector<SDNode*, 4> SetCCs;
12345       if (!N0.hasOneUse()) {
12346         if (N0.getOpcode() == ISD::AND) {
12347           auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
12348           EVT LoadResultTy = AndC->getValueType(0);
12349           EVT ExtVT;
12350           if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
12351             DoXform = false;
12352         }
12353       }
12354       if (DoXform)
12355         DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
12356                                           ISD::ZERO_EXTEND, SetCCs, TLI);
12357       if (DoXform) {
12358         SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
12359                                          LN00->getChain(), LN00->getBasePtr(),
12360                                          LN00->getMemoryVT(),
12361                                          LN00->getMemOperand());
12362         APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
12363         SDLoc DL(N);
12364         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
12365                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
12366         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
12367         bool NoReplaceTruncAnd = !N0.hasOneUse();
12368         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
12369         CombineTo(N, And);
12370         // If N0 has multiple uses, change other uses as well.
12371         if (NoReplaceTruncAnd) {
12372           SDValue TruncAnd =
12373               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
12374           CombineTo(N0.getNode(), TruncAnd);
12375         }
12376         if (NoReplaceTrunc) {
12377           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
12378         } else {
12379           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
12380                                       LN00->getValueType(0), ExtLoad);
12381           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
12382         }
12383         return SDValue(N,0); // Return N so it doesn't get rechecked!
12384       }
12385     }
12386   }
12387 
12388   // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
12389   //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
12390   if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
12391     return ZExtLoad;
12392 
12393   // Try to simplify (zext (zextload x)).
12394   if (SDValue foldedExt = tryToFoldExtOfExtload(
12395           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
12396     return foldedExt;
12397 
12398   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
12399     return V;
12400 
12401   if (N0.getOpcode() == ISD::SETCC) {
12402     // Propagate fast-math-flags.
12403     SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
12404 
12405     // Only do this before legalize for now.
12406     if (!LegalOperations && VT.isVector() &&
12407         N0.getValueType().getVectorElementType() == MVT::i1) {
12408       EVT N00VT = N0.getOperand(0).getValueType();
12409       if (getSetCCResultType(N00VT) == N0.getValueType())
12410         return SDValue();
12411 
12412       // We know that the # elements of the results is the same as the #
12413       // elements of the compare (and the # elements of the compare result for
12414       // that matter). Check to see that they are the same size. If so, we know
12415       // that the element size of the sext'd result matches the element size of
12416       // the compare operands.
12417       SDLoc DL(N);
12418       if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
12419         // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
12420         SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
12421                                      N0.getOperand(1), N0.getOperand(2));
12422         return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType());
12423       }
12424 
12425       // If the desired elements are smaller or larger than the source
12426       // elements we can use a matching integer vector type and then
12427       // truncate/any extend followed by zext_in_reg.
12428       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
12429       SDValue VsetCC =
12430           DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
12431                       N0.getOperand(1), N0.getOperand(2));
12432       return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL,
12433                                     N0.getValueType());
12434     }
12435 
12436     // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
12437     SDLoc DL(N);
12438     EVT N0VT = N0.getValueType();
12439     EVT N00VT = N0.getOperand(0).getValueType();
12440     if (SDValue SCC = SimplifySelectCC(
12441             DL, N0.getOperand(0), N0.getOperand(1),
12442             DAG.getBoolConstant(true, DL, N0VT, N00VT),
12443             DAG.getBoolConstant(false, DL, N0VT, N00VT),
12444             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
12445       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, SCC);
12446   }
12447 
12448   // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
12449   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
12450       isa<ConstantSDNode>(N0.getOperand(1)) &&
12451       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
12452       N0.hasOneUse()) {
12453     SDValue ShAmt = N0.getOperand(1);
12454     if (N0.getOpcode() == ISD::SHL) {
12455       SDValue InnerZExt = N0.getOperand(0);
12456       // If the original shl may be shifting out bits, do not perform this
12457       // transformation.
12458       unsigned KnownZeroBits = InnerZExt.getValueSizeInBits() -
12459         InnerZExt.getOperand(0).getValueSizeInBits();
12460       if (cast<ConstantSDNode>(ShAmt)->getAPIntValue().ugt(KnownZeroBits))
12461         return SDValue();
12462     }
12463 
12464     SDLoc DL(N);
12465 
12466     // Ensure that the shift amount is wide enough for the shifted value.
12467     if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
12468       ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
12469 
12470     return DAG.getNode(N0.getOpcode(), DL, VT,
12471                        DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)),
12472                        ShAmt);
12473   }
12474 
12475   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
12476     return NewVSel;
12477 
12478   if (SDValue NewCtPop = widenCtPop(N, DAG))
12479     return NewCtPop;
12480 
12481   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
12482     return Res;
12483 
12484   return SDValue();
12485 }
12486 
visitANY_EXTEND(SDNode * N)12487 SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
12488   SDValue N0 = N->getOperand(0);
12489   EVT VT = N->getValueType(0);
12490 
12491   // aext(undef) = undef
12492   if (N0.isUndef())
12493     return DAG.getUNDEF(VT);
12494 
12495   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
12496     return Res;
12497 
12498   // fold (aext (aext x)) -> (aext x)
12499   // fold (aext (zext x)) -> (zext x)
12500   // fold (aext (sext x)) -> (sext x)
12501   if (N0.getOpcode() == ISD::ANY_EXTEND  ||
12502       N0.getOpcode() == ISD::ZERO_EXTEND ||
12503       N0.getOpcode() == ISD::SIGN_EXTEND)
12504     return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
12505 
12506   // fold (aext (truncate (load x))) -> (aext (smaller load x))
12507   // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
12508   if (N0.getOpcode() == ISD::TRUNCATE) {
12509     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
12510       SDNode *oye = N0.getOperand(0).getNode();
12511       if (NarrowLoad.getNode() != N0.getNode()) {
12512         CombineTo(N0.getNode(), NarrowLoad);
12513         // CombineTo deleted the truncate, if needed, but not what's under it.
12514         AddToWorklist(oye);
12515       }
12516       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
12517     }
12518   }
12519 
12520   // fold (aext (truncate x))
12521   if (N0.getOpcode() == ISD::TRUNCATE)
12522     return DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
12523 
12524   // Fold (aext (and (trunc x), cst)) -> (and x, cst)
12525   // if the trunc is not free.
12526   if (N0.getOpcode() == ISD::AND &&
12527       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
12528       N0.getOperand(1).getOpcode() == ISD::Constant &&
12529       !TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
12530                           N0.getValueType())) {
12531     SDLoc DL(N);
12532     SDValue X = DAG.getAnyExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
12533     SDValue Y = DAG.getNode(ISD::ANY_EXTEND, DL, VT, N0.getOperand(1));
12534     assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
12535     return DAG.getNode(ISD::AND, DL, VT, X, Y);
12536   }
12537 
12538   // fold (aext (load x)) -> (aext (truncate (extload x)))
12539   // None of the supported targets knows how to perform load and any_ext
12540   // on vectors in one instruction, so attempt to fold to zext instead.
12541   if (VT.isVector()) {
12542     // Try to simplify (zext (load x)).
12543     if (SDValue foldedExt =
12544             tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
12545                                ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
12546       return foldedExt;
12547   } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
12548              ISD::isUNINDEXEDLoad(N0.getNode()) &&
12549              TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
12550     bool DoXform = true;
12551     SmallVector<SDNode *, 4> SetCCs;
12552     if (!N0.hasOneUse())
12553       DoXform =
12554           ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, TLI);
12555     if (DoXform) {
12556       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12557       SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
12558                                        LN0->getChain(), LN0->getBasePtr(),
12559                                        N0.getValueType(), LN0->getMemOperand());
12560       ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
12561       // If the load value is used only by N, replace it via CombineTo N.
12562       bool NoReplaceTrunc = N0.hasOneUse();
12563       CombineTo(N, ExtLoad);
12564       if (NoReplaceTrunc) {
12565         DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
12566         recursivelyDeleteUnusedNodes(LN0);
12567       } else {
12568         SDValue Trunc =
12569             DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
12570         CombineTo(LN0, Trunc, ExtLoad.getValue(1));
12571       }
12572       return SDValue(N, 0); // Return N so it doesn't get rechecked!
12573     }
12574   }
12575 
12576   // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
12577   // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
12578   // fold (aext ( extload x)) -> (aext (truncate (extload  x)))
12579   if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
12580       ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
12581     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12582     ISD::LoadExtType ExtType = LN0->getExtensionType();
12583     EVT MemVT = LN0->getMemoryVT();
12584     if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
12585       SDValue ExtLoad = DAG.getExtLoad(ExtType, SDLoc(N),
12586                                        VT, LN0->getChain(), LN0->getBasePtr(),
12587                                        MemVT, LN0->getMemOperand());
12588       CombineTo(N, ExtLoad);
12589       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
12590       recursivelyDeleteUnusedNodes(LN0);
12591       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
12592     }
12593   }
12594 
12595   if (N0.getOpcode() == ISD::SETCC) {
12596     // Propagate fast-math-flags.
12597     SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
12598 
12599     // For vectors:
12600     // aext(setcc) -> vsetcc
12601     // aext(setcc) -> truncate(vsetcc)
12602     // aext(setcc) -> aext(vsetcc)
12603     // Only do this before legalize for now.
12604     if (VT.isVector() && !LegalOperations) {
12605       EVT N00VT = N0.getOperand(0).getValueType();
12606       if (getSetCCResultType(N00VT) == N0.getValueType())
12607         return SDValue();
12608 
12609       // We know that the # elements of the results is the same as the
12610       // # elements of the compare (and the # elements of the compare result
12611       // for that matter).  Check to see that they are the same size.  If so,
12612       // we know that the element size of the sext'd result matches the
12613       // element size of the compare operands.
12614       if (VT.getSizeInBits() == N00VT.getSizeInBits())
12615         return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0),
12616                              N0.getOperand(1),
12617                              cast<CondCodeSDNode>(N0.getOperand(2))->get());
12618 
12619       // If the desired elements are smaller or larger than the source
12620       // elements we can use a matching integer vector type and then
12621       // truncate/any extend
12622       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
12623       SDValue VsetCC =
12624         DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0),
12625                       N0.getOperand(1),
12626                       cast<CondCodeSDNode>(N0.getOperand(2))->get());
12627       return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT);
12628     }
12629 
12630     // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
12631     SDLoc DL(N);
12632     if (SDValue SCC = SimplifySelectCC(
12633             DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
12634             DAG.getConstant(0, DL, VT),
12635             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
12636       return SCC;
12637   }
12638 
12639   if (SDValue NewCtPop = widenCtPop(N, DAG))
12640     return NewCtPop;
12641 
12642   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
12643     return Res;
12644 
12645   return SDValue();
12646 }
12647 
visitAssertExt(SDNode * N)12648 SDValue DAGCombiner::visitAssertExt(SDNode *N) {
12649   unsigned Opcode = N->getOpcode();
12650   SDValue N0 = N->getOperand(0);
12651   SDValue N1 = N->getOperand(1);
12652   EVT AssertVT = cast<VTSDNode>(N1)->getVT();
12653 
12654   // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
12655   if (N0.getOpcode() == Opcode &&
12656       AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
12657     return N0;
12658 
12659   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
12660       N0.getOperand(0).getOpcode() == Opcode) {
12661     // We have an assert, truncate, assert sandwich. Make one stronger assert
12662     // by asserting on the smallest asserted type to the larger source type.
12663     // This eliminates the later assert:
12664     // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
12665     // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
12666     SDLoc DL(N);
12667     SDValue BigA = N0.getOperand(0);
12668     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
12669     EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
12670     SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
12671     SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
12672                                     BigA.getOperand(0), MinAssertVTVal);
12673     return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
12674   }
12675 
12676   // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
12677   // than X. Just move the AssertZext in front of the truncate and drop the
12678   // AssertSExt.
12679   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
12680       N0.getOperand(0).getOpcode() == ISD::AssertSext &&
12681       Opcode == ISD::AssertZext) {
12682     SDValue BigA = N0.getOperand(0);
12683     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
12684     if (AssertVT.bitsLT(BigA_AssertVT)) {
12685       SDLoc DL(N);
12686       SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
12687                                       BigA.getOperand(0), N1);
12688       return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
12689     }
12690   }
12691 
12692   return SDValue();
12693 }
12694 
visitAssertAlign(SDNode * N)12695 SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
12696   SDLoc DL(N);
12697 
12698   Align AL = cast<AssertAlignSDNode>(N)->getAlign();
12699   SDValue N0 = N->getOperand(0);
12700 
12701   // Fold (assertalign (assertalign x, AL0), AL1) ->
12702   // (assertalign x, max(AL0, AL1))
12703   if (auto *AAN = dyn_cast<AssertAlignSDNode>(N0))
12704     return DAG.getAssertAlign(DL, N0.getOperand(0),
12705                               std::max(AL, AAN->getAlign()));
12706 
12707   // In rare cases, there are trivial arithmetic ops in source operands. Sink
12708   // this assert down to source operands so that those arithmetic ops could be
12709   // exposed to the DAG combining.
12710   switch (N0.getOpcode()) {
12711   default:
12712     break;
12713   case ISD::ADD:
12714   case ISD::SUB: {
12715     unsigned AlignShift = Log2(AL);
12716     SDValue LHS = N0.getOperand(0);
12717     SDValue RHS = N0.getOperand(1);
12718     unsigned LHSAlignShift = DAG.computeKnownBits(LHS).countMinTrailingZeros();
12719     unsigned RHSAlignShift = DAG.computeKnownBits(RHS).countMinTrailingZeros();
12720     if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
12721       if (LHSAlignShift < AlignShift)
12722         LHS = DAG.getAssertAlign(DL, LHS, AL);
12723       if (RHSAlignShift < AlignShift)
12724         RHS = DAG.getAssertAlign(DL, RHS, AL);
12725       return DAG.getNode(N0.getOpcode(), DL, N0.getValueType(), LHS, RHS);
12726     }
12727     break;
12728   }
12729   }
12730 
12731   return SDValue();
12732 }
12733 
12734 /// If the result of a load is shifted/masked/truncated to an effectively
12735 /// narrower type, try to transform the load to a narrower type and/or
12736 /// use an extending load.
reduceLoadWidth(SDNode * N)12737 SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
12738   unsigned Opc = N->getOpcode();
12739 
12740   ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
12741   SDValue N0 = N->getOperand(0);
12742   EVT VT = N->getValueType(0);
12743   EVT ExtVT = VT;
12744 
12745   // This transformation isn't valid for vector loads.
12746   if (VT.isVector())
12747     return SDValue();
12748 
12749   // The ShAmt variable is used to indicate that we've consumed a right
12750   // shift. I.e. we want to narrow the width of the load by skipping to load the
12751   // ShAmt least significant bits.
12752   unsigned ShAmt = 0;
12753   // A special case is when the least significant bits from the load are masked
12754   // away, but using an AND rather than a right shift. HasShiftedOffset is used
12755   // to indicate that the narrowed load should be left-shifted ShAmt bits to get
12756   // the result.
12757   bool HasShiftedOffset = false;
12758   // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
12759   // extended to VT.
12760   if (Opc == ISD::SIGN_EXTEND_INREG) {
12761     ExtType = ISD::SEXTLOAD;
12762     ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
12763   } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
12764     // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
12765     // value, or it may be shifting a higher subword, half or byte into the
12766     // lowest bits.
12767 
12768     // Only handle shift with constant shift amount, and the shiftee must be a
12769     // load.
12770     auto *LN = dyn_cast<LoadSDNode>(N0);
12771     auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
12772     if (!N1C || !LN)
12773       return SDValue();
12774     // If the shift amount is larger than the memory type then we're not
12775     // accessing any of the loaded bytes.
12776     ShAmt = N1C->getZExtValue();
12777     uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
12778     if (MemoryWidth <= ShAmt)
12779       return SDValue();
12780     // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
12781     ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
12782     ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
12783     // If original load is a SEXTLOAD then we can't simply replace it by a
12784     // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
12785     // followed by a ZEXT, but that is not handled at the moment). Similarly if
12786     // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
12787     if ((LN->getExtensionType() == ISD::SEXTLOAD ||
12788          LN->getExtensionType() == ISD::ZEXTLOAD) &&
12789         LN->getExtensionType() != ExtType)
12790       return SDValue();
12791   } else if (Opc == ISD::AND) {
12792     // An AND with a constant mask is the same as a truncate + zero-extend.
12793     auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
12794     if (!AndC)
12795       return SDValue();
12796 
12797     const APInt &Mask = AndC->getAPIntValue();
12798     unsigned ActiveBits = 0;
12799     if (Mask.isMask()) {
12800       ActiveBits = Mask.countTrailingOnes();
12801     } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) {
12802       HasShiftedOffset = true;
12803     } else {
12804       return SDValue();
12805     }
12806 
12807     ExtType = ISD::ZEXTLOAD;
12808     ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
12809   }
12810 
12811   // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
12812   // a right shift. Here we redo some of those checks, to possibly adjust the
12813   // ExtVT even further based on "a masking AND". We could also end up here for
12814   // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
12815   // need to be done here as well.
12816   if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
12817     SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
12818     // Bail out when the SRL has more than one use. This is done for historical
12819     // (undocumented) reasons. Maybe intent was to guard the AND-masking below
12820     // check below? And maybe it could be non-profitable to do the transform in
12821     // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
12822     // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
12823     if (!SRL.hasOneUse())
12824       return SDValue();
12825 
12826     // Only handle shift with constant shift amount, and the shiftee must be a
12827     // load.
12828     auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
12829     auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
12830     if (!SRL1C || !LN)
12831       return SDValue();
12832 
12833     // If the shift amount is larger than the input type then we're not
12834     // accessing any of the loaded bytes.  If the load was a zextload/extload
12835     // then the result of the shift+trunc is zero/undef (handled elsewhere).
12836     ShAmt = SRL1C->getZExtValue();
12837     uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
12838     if (ShAmt >= MemoryWidth)
12839       return SDValue();
12840 
12841     // Because a SRL must be assumed to *need* to zero-extend the high bits
12842     // (as opposed to anyext the high bits), we can't combine the zextload
12843     // lowering of SRL and an sextload.
12844     if (LN->getExtensionType() == ISD::SEXTLOAD)
12845       return SDValue();
12846 
12847     // Avoid reading outside the memory accessed by the original load (could
12848     // happened if we only adjust the load base pointer by ShAmt). Instead we
12849     // try to narrow the load even further. The typical scenario here is:
12850     //   (i64 (truncate (i96 (srl (load x), 64)))) ->
12851     //     (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
12852     if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
12853       // Don't replace sextload by zextload.
12854       if (ExtType == ISD::SEXTLOAD)
12855         return SDValue();
12856       // Narrow the load.
12857       ExtType = ISD::ZEXTLOAD;
12858       ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
12859     }
12860 
12861     // If the SRL is only used by a masking AND, we may be able to adjust
12862     // the ExtVT to make the AND redundant.
12863     SDNode *Mask = *(SRL->use_begin());
12864     if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
12865         isa<ConstantSDNode>(Mask->getOperand(1))) {
12866       const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
12867       if (ShiftMask.isMask()) {
12868         EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
12869                                          ShiftMask.countTrailingOnes());
12870         // If the mask is smaller, recompute the type.
12871         if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
12872             TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
12873           ExtVT = MaskedVT;
12874       }
12875     }
12876 
12877     N0 = SRL.getOperand(0);
12878   }
12879 
12880   // If the load is shifted left (and the result isn't shifted back right), we
12881   // can fold a truncate through the shift. The typical scenario is that N
12882   // points at a TRUNCATE here so the attempted fold is:
12883   //   (truncate (shl (load x), c))) -> (shl (narrow load x), c)
12884   // ShLeftAmt will indicate how much a narrowed load should be shifted left.
12885   unsigned ShLeftAmt = 0;
12886   if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
12887       ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
12888     if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
12889       ShLeftAmt = N01->getZExtValue();
12890       N0 = N0.getOperand(0);
12891     }
12892   }
12893 
12894   // If we haven't found a load, we can't narrow it.
12895   if (!isa<LoadSDNode>(N0))
12896     return SDValue();
12897 
12898   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12899   // Reducing the width of a volatile load is illegal.  For atomics, we may be
12900   // able to reduce the width provided we never widen again. (see D66309)
12901   if (!LN0->isSimple() ||
12902       !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
12903     return SDValue();
12904 
12905   auto AdjustBigEndianShift = [&](unsigned ShAmt) {
12906     unsigned LVTStoreBits =
12907         LN0->getMemoryVT().getStoreSizeInBits().getFixedSize();
12908     unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedSize();
12909     return LVTStoreBits - EVTStoreBits - ShAmt;
12910   };
12911 
12912   // We need to adjust the pointer to the load by ShAmt bits in order to load
12913   // the correct bytes.
12914   unsigned PtrAdjustmentInBits =
12915       DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
12916 
12917   uint64_t PtrOff = PtrAdjustmentInBits / 8;
12918   Align NewAlign = commonAlignment(LN0->getAlign(), PtrOff);
12919   SDLoc DL(LN0);
12920   // The original load itself didn't wrap, so an offset within it doesn't.
12921   SDNodeFlags Flags;
12922   Flags.setNoUnsignedWrap(true);
12923   SDValue NewPtr = DAG.getMemBasePlusOffset(LN0->getBasePtr(),
12924                                             TypeSize::Fixed(PtrOff), DL, Flags);
12925   AddToWorklist(NewPtr.getNode());
12926 
12927   SDValue Load;
12928   if (ExtType == ISD::NON_EXTLOAD)
12929     Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
12930                        LN0->getPointerInfo().getWithOffset(PtrOff), NewAlign,
12931                        LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
12932   else
12933     Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
12934                           LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
12935                           NewAlign, LN0->getMemOperand()->getFlags(),
12936                           LN0->getAAInfo());
12937 
12938   // Replace the old load's chain with the new load's chain.
12939   WorklistRemover DeadNodes(*this);
12940   DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
12941 
12942   // Shift the result left, if we've swallowed a left shift.
12943   SDValue Result = Load;
12944   if (ShLeftAmt != 0) {
12945     EVT ShImmTy = getShiftAmountTy(Result.getValueType());
12946     if (!isUIntN(ShImmTy.getScalarSizeInBits(), ShLeftAmt))
12947       ShImmTy = VT;
12948     // If the shift amount is as large as the result size (but, presumably,
12949     // no larger than the source) then the useful bits of the result are
12950     // zero; we can't simply return the shortened shift, because the result
12951     // of that operation is undefined.
12952     if (ShLeftAmt >= VT.getScalarSizeInBits())
12953       Result = DAG.getConstant(0, DL, VT);
12954     else
12955       Result = DAG.getNode(ISD::SHL, DL, VT,
12956                           Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
12957   }
12958 
12959   if (HasShiftedOffset) {
12960     // We're using a shifted mask, so the load now has an offset. This means
12961     // that data has been loaded into the lower bytes than it would have been
12962     // before, so we need to shl the loaded data into the correct position in the
12963     // register.
12964     SDValue ShiftC = DAG.getConstant(ShAmt, DL, VT);
12965     Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
12966     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
12967   }
12968 
12969   // Return the new loaded value.
12970   return Result;
12971 }
12972 
visitSIGN_EXTEND_INREG(SDNode * N)12973 SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
12974   SDValue N0 = N->getOperand(0);
12975   SDValue N1 = N->getOperand(1);
12976   EVT VT = N->getValueType(0);
12977   EVT ExtVT = cast<VTSDNode>(N1)->getVT();
12978   unsigned VTBits = VT.getScalarSizeInBits();
12979   unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
12980 
12981   // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
12982   if (N0.isUndef())
12983     return DAG.getConstant(0, SDLoc(N), VT);
12984 
12985   // fold (sext_in_reg c1) -> c1
12986   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
12987     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1);
12988 
12989   // If the input is already sign extended, just drop the extension.
12990   if (ExtVTBits >= DAG.ComputeMaxSignificantBits(N0))
12991     return N0;
12992 
12993   // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
12994   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
12995       ExtVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
12996     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0.getOperand(0),
12997                        N1);
12998 
12999   // fold (sext_in_reg (sext x)) -> (sext x)
13000   // fold (sext_in_reg (aext x)) -> (sext x)
13001   // if x is small enough or if we know that x has more than 1 sign bit and the
13002   // sign_extend_inreg is extending from one of them.
13003   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
13004     SDValue N00 = N0.getOperand(0);
13005     unsigned N00Bits = N00.getScalarValueSizeInBits();
13006     if ((N00Bits <= ExtVTBits ||
13007          DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits) &&
13008         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
13009       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
13010   }
13011 
13012   // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
13013   // if x is small enough or if we know that x has more than 1 sign bit and the
13014   // sign_extend_inreg is extending from one of them.
13015   if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13016       N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ||
13017       N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) {
13018     SDValue N00 = N0.getOperand(0);
13019     unsigned N00Bits = N00.getScalarValueSizeInBits();
13020     unsigned DstElts = N0.getValueType().getVectorMinNumElements();
13021     unsigned SrcElts = N00.getValueType().getVectorMinNumElements();
13022     bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
13023     APInt DemandedSrcElts = APInt::getLowBitsSet(SrcElts, DstElts);
13024     if ((N00Bits == ExtVTBits ||
13025          (!IsZext && (N00Bits < ExtVTBits ||
13026                       DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits))) &&
13027         (!LegalOperations ||
13028          TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
13029       return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT, N00);
13030   }
13031 
13032   // fold (sext_in_reg (zext x)) -> (sext x)
13033   // iff we are extending the source sign bit.
13034   if (N0.getOpcode() == ISD::ZERO_EXTEND) {
13035     SDValue N00 = N0.getOperand(0);
13036     if (N00.getScalarValueSizeInBits() == ExtVTBits &&
13037         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
13038       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1);
13039   }
13040 
13041   // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
13042   if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, ExtVTBits - 1)))
13043     return DAG.getZeroExtendInReg(N0, SDLoc(N), ExtVT);
13044 
13045   // fold operands of sext_in_reg based on knowledge that the top bits are not
13046   // demanded.
13047   if (SimplifyDemandedBits(SDValue(N, 0)))
13048     return SDValue(N, 0);
13049 
13050   // fold (sext_in_reg (load x)) -> (smaller sextload x)
13051   // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
13052   if (SDValue NarrowLoad = reduceLoadWidth(N))
13053     return NarrowLoad;
13054 
13055   // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
13056   // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
13057   // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
13058   if (N0.getOpcode() == ISD::SRL) {
13059     if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
13060       if (ShAmt->getAPIntValue().ule(VTBits - ExtVTBits)) {
13061         // We can turn this into an SRA iff the input to the SRL is already sign
13062         // extended enough.
13063         unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
13064         if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
13065           return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0),
13066                              N0.getOperand(1));
13067       }
13068   }
13069 
13070   // fold (sext_inreg (extload x)) -> (sextload x)
13071   // If sextload is not supported by target, we can only do the combine when
13072   // load has one use. Doing otherwise can block folding the extload with other
13073   // extends that the target does support.
13074   if (ISD::isEXTLoad(N0.getNode()) &&
13075       ISD::isUNINDEXEDLoad(N0.getNode()) &&
13076       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
13077       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
13078         N0.hasOneUse()) ||
13079        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
13080     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13081     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
13082                                      LN0->getChain(),
13083                                      LN0->getBasePtr(), ExtVT,
13084                                      LN0->getMemOperand());
13085     CombineTo(N, ExtLoad);
13086     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
13087     AddToWorklist(ExtLoad.getNode());
13088     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
13089   }
13090 
13091   // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
13092   if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
13093       N0.hasOneUse() &&
13094       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
13095       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
13096        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
13097     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13098     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
13099                                      LN0->getChain(),
13100                                      LN0->getBasePtr(), ExtVT,
13101                                      LN0->getMemOperand());
13102     CombineTo(N, ExtLoad);
13103     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
13104     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
13105   }
13106 
13107   // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
13108   // ignore it if the masked load is already sign extended
13109   if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0)) {
13110     if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
13111         Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
13112         TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) {
13113       SDValue ExtMaskedLoad = DAG.getMaskedLoad(
13114           VT, SDLoc(N), Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(),
13115           Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(),
13116           Ld->getAddressingMode(), ISD::SEXTLOAD, Ld->isExpandingLoad());
13117       CombineTo(N, ExtMaskedLoad);
13118       CombineTo(N0.getNode(), ExtMaskedLoad, ExtMaskedLoad.getValue(1));
13119       return SDValue(N, 0); // Return N so it doesn't get rechecked!
13120     }
13121   }
13122 
13123   // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
13124   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
13125     if (SDValue(GN0, 0).hasOneUse() &&
13126         ExtVT == GN0->getMemoryVT() &&
13127         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
13128       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
13129                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
13130 
13131       SDValue ExtLoad = DAG.getMaskedGather(
13132           DAG.getVTList(VT, MVT::Other), ExtVT, SDLoc(N), Ops,
13133           GN0->getMemOperand(), GN0->getIndexType(), ISD::SEXTLOAD);
13134 
13135       CombineTo(N, ExtLoad);
13136       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
13137       AddToWorklist(ExtLoad.getNode());
13138       return SDValue(N, 0); // Return N so it doesn't get rechecked!
13139     }
13140   }
13141 
13142   // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
13143   if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
13144     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
13145                                            N0.getOperand(1), false))
13146       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, BSwap, N1);
13147   }
13148 
13149   return SDValue();
13150 }
13151 
visitEXTEND_VECTOR_INREG(SDNode * N)13152 SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
13153   SDValue N0 = N->getOperand(0);
13154   EVT VT = N->getValueType(0);
13155 
13156   // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
13157   if (N0.isUndef())
13158     return DAG.getConstant(0, SDLoc(N), VT);
13159 
13160   if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
13161     return Res;
13162 
13163   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
13164     return SDValue(N, 0);
13165 
13166   return SDValue();
13167 }
13168 
visitTRUNCATE(SDNode * N)13169 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
13170   SDValue N0 = N->getOperand(0);
13171   EVT VT = N->getValueType(0);
13172   EVT SrcVT = N0.getValueType();
13173   bool isLE = DAG.getDataLayout().isLittleEndian();
13174 
13175   // noop truncate
13176   if (SrcVT == VT)
13177     return N0;
13178 
13179   // fold (truncate (truncate x)) -> (truncate x)
13180   if (N0.getOpcode() == ISD::TRUNCATE)
13181     return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
13182 
13183   // fold (truncate c1) -> c1
13184   if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
13185     SDValue C = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0);
13186     if (C.getNode() != N)
13187       return C;
13188   }
13189 
13190   // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
13191   if (N0.getOpcode() == ISD::ZERO_EXTEND ||
13192       N0.getOpcode() == ISD::SIGN_EXTEND ||
13193       N0.getOpcode() == ISD::ANY_EXTEND) {
13194     // if the source is smaller than the dest, we still need an extend.
13195     if (N0.getOperand(0).getValueType().bitsLT(VT))
13196       return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
13197     // if the source is larger than the dest, than we just need the truncate.
13198     if (N0.getOperand(0).getValueType().bitsGT(VT))
13199       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
13200     // if the source and dest are the same type, we can drop both the extend
13201     // and the truncate.
13202     return N0.getOperand(0);
13203   }
13204 
13205   // Try to narrow a truncate-of-sext_in_reg to the destination type:
13206   // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
13207   if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
13208       N0.hasOneUse()) {
13209     SDValue X = N0.getOperand(0);
13210     SDValue ExtVal = N0.getOperand(1);
13211     EVT ExtVT = cast<VTSDNode>(ExtVal)->getVT();
13212     if (ExtVT.bitsLT(VT)) {
13213       SDValue TrX = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, X);
13214       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, TrX, ExtVal);
13215     }
13216   }
13217 
13218   // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
13219   if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
13220     return SDValue();
13221 
13222   // Fold extract-and-trunc into a narrow extract. For example:
13223   //   i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
13224   //   i32 y = TRUNCATE(i64 x)
13225   //        -- becomes --
13226   //   v16i8 b = BITCAST (v2i64 val)
13227   //   i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
13228   //
13229   // Note: We only run this optimization after type legalization (which often
13230   // creates this pattern) and before operation legalization after which
13231   // we need to be more careful about the vector instructions that we generate.
13232   if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
13233       LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
13234     EVT VecTy = N0.getOperand(0).getValueType();
13235     EVT ExTy = N0.getValueType();
13236     EVT TrTy = N->getValueType(0);
13237 
13238     auto EltCnt = VecTy.getVectorElementCount();
13239     unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
13240     auto NewEltCnt = EltCnt * SizeRatio;
13241 
13242     EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
13243     assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
13244 
13245     SDValue EltNo = N0->getOperand(1);
13246     if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
13247       int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
13248       int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
13249 
13250       SDLoc DL(N);
13251       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
13252                          DAG.getBitcast(NVT, N0.getOperand(0)),
13253                          DAG.getVectorIdxConstant(Index, DL));
13254     }
13255   }
13256 
13257   // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
13258   if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
13259     if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
13260         TLI.isTruncateFree(SrcVT, VT)) {
13261       SDLoc SL(N0);
13262       SDValue Cond = N0.getOperand(0);
13263       SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
13264       SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
13265       return DAG.getNode(ISD::SELECT, SDLoc(N), VT, Cond, TruncOp0, TruncOp1);
13266     }
13267   }
13268 
13269   // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
13270   if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
13271       (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
13272       TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
13273     SDValue Amt = N0.getOperand(1);
13274     KnownBits Known = DAG.computeKnownBits(Amt);
13275     unsigned Size = VT.getScalarSizeInBits();
13276     if (Known.countMaxActiveBits() <= Log2_32(Size)) {
13277       SDLoc SL(N);
13278       EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
13279 
13280       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
13281       if (AmtVT != Amt.getValueType()) {
13282         Amt = DAG.getZExtOrTrunc(Amt, SL, AmtVT);
13283         AddToWorklist(Amt.getNode());
13284       }
13285       return DAG.getNode(ISD::SHL, SL, VT, Trunc, Amt);
13286     }
13287   }
13288 
13289   if (SDValue V = foldSubToUSubSat(VT, N0.getNode()))
13290     return V;
13291 
13292   // Attempt to pre-truncate BUILD_VECTOR sources.
13293   if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
13294       TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()) &&
13295       // Avoid creating illegal types if running after type legalizer.
13296       (!LegalTypes || TLI.isTypeLegal(VT.getScalarType()))) {
13297     SDLoc DL(N);
13298     EVT SVT = VT.getScalarType();
13299     SmallVector<SDValue, 8> TruncOps;
13300     for (const SDValue &Op : N0->op_values()) {
13301       SDValue TruncOp = DAG.getNode(ISD::TRUNCATE, DL, SVT, Op);
13302       TruncOps.push_back(TruncOp);
13303     }
13304     return DAG.getBuildVector(VT, DL, TruncOps);
13305   }
13306 
13307   // Fold a series of buildvector, bitcast, and truncate if possible.
13308   // For example fold
13309   //   (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
13310   //   (2xi32 (buildvector x, y)).
13311   if (Level == AfterLegalizeVectorOps && VT.isVector() &&
13312       N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
13313       N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
13314       N0.getOperand(0).hasOneUse()) {
13315     SDValue BuildVect = N0.getOperand(0);
13316     EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
13317     EVT TruncVecEltTy = VT.getVectorElementType();
13318 
13319     // Check that the element types match.
13320     if (BuildVectEltTy == TruncVecEltTy) {
13321       // Now we only need to compute the offset of the truncated elements.
13322       unsigned BuildVecNumElts =  BuildVect.getNumOperands();
13323       unsigned TruncVecNumElts = VT.getVectorNumElements();
13324       unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
13325 
13326       assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
13327              "Invalid number of elements");
13328 
13329       SmallVector<SDValue, 8> Opnds;
13330       for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
13331         Opnds.push_back(BuildVect.getOperand(i));
13332 
13333       return DAG.getBuildVector(VT, SDLoc(N), Opnds);
13334     }
13335   }
13336 
13337   // fold (truncate (load x)) -> (smaller load x)
13338   // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
13339   if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
13340     if (SDValue Reduced = reduceLoadWidth(N))
13341       return Reduced;
13342 
13343     // Handle the case where the load remains an extending load even
13344     // after truncation.
13345     if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
13346       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13347       if (LN0->isSimple() && LN0->getMemoryVT().bitsLT(VT)) {
13348         SDValue NewLoad = DAG.getExtLoad(LN0->getExtensionType(), SDLoc(LN0),
13349                                          VT, LN0->getChain(), LN0->getBasePtr(),
13350                                          LN0->getMemoryVT(),
13351                                          LN0->getMemOperand());
13352         DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
13353         return NewLoad;
13354       }
13355     }
13356   }
13357 
13358   // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
13359   // where ... are all 'undef'.
13360   if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
13361     SmallVector<EVT, 8> VTs;
13362     SDValue V;
13363     unsigned Idx = 0;
13364     unsigned NumDefs = 0;
13365 
13366     for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
13367       SDValue X = N0.getOperand(i);
13368       if (!X.isUndef()) {
13369         V = X;
13370         Idx = i;
13371         NumDefs++;
13372       }
13373       // Stop if more than one members are non-undef.
13374       if (NumDefs > 1)
13375         break;
13376 
13377       VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
13378                                      VT.getVectorElementType(),
13379                                      X.getValueType().getVectorElementCount()));
13380     }
13381 
13382     if (NumDefs == 0)
13383       return DAG.getUNDEF(VT);
13384 
13385     if (NumDefs == 1) {
13386       assert(V.getNode() && "The single defined operand is empty!");
13387       SmallVector<SDValue, 8> Opnds;
13388       for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
13389         if (i != Idx) {
13390           Opnds.push_back(DAG.getUNDEF(VTs[i]));
13391           continue;
13392         }
13393         SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
13394         AddToWorklist(NV.getNode());
13395         Opnds.push_back(NV);
13396       }
13397       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Opnds);
13398     }
13399   }
13400 
13401   // Fold truncate of a bitcast of a vector to an extract of the low vector
13402   // element.
13403   //
13404   // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
13405   if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
13406     SDValue VecSrc = N0.getOperand(0);
13407     EVT VecSrcVT = VecSrc.getValueType();
13408     if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
13409         (!LegalOperations ||
13410          TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
13411       SDLoc SL(N);
13412 
13413       unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
13414       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, VT, VecSrc,
13415                          DAG.getVectorIdxConstant(Idx, SL));
13416     }
13417   }
13418 
13419   // Simplify the operands using demanded-bits information.
13420   if (SimplifyDemandedBits(SDValue(N, 0)))
13421     return SDValue(N, 0);
13422 
13423   // See if we can simplify the input to this truncate through knowledge that
13424   // only the low bits are being used.
13425   // For example "trunc (or (shl x, 8), y)" // -> trunc y
13426   // Currently we only perform this optimization on scalars because vectors
13427   // may have different active low bits.
13428   if (!VT.isVector()) {
13429     APInt Mask =
13430         APInt::getLowBitsSet(N0.getValueSizeInBits(), VT.getSizeInBits());
13431     if (SDValue Shorter = DAG.GetDemandedBits(N0, Mask))
13432       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Shorter);
13433   }
13434 
13435   // fold (truncate (extract_subvector(ext x))) ->
13436   //      (extract_subvector x)
13437   // TODO: This can be generalized to cover cases where the truncate and extract
13438   // do not fully cancel each other out.
13439   if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
13440     SDValue N00 = N0.getOperand(0);
13441     if (N00.getOpcode() == ISD::SIGN_EXTEND ||
13442         N00.getOpcode() == ISD::ZERO_EXTEND ||
13443         N00.getOpcode() == ISD::ANY_EXTEND) {
13444       if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
13445           VT.getVectorElementType())
13446         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
13447                            N00.getOperand(0), N0.getOperand(1));
13448     }
13449   }
13450 
13451   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
13452     return NewVSel;
13453 
13454   // Narrow a suitable binary operation with a non-opaque constant operand by
13455   // moving it ahead of the truncate. This is limited to pre-legalization
13456   // because targets may prefer a wider type during later combines and invert
13457   // this transform.
13458   switch (N0.getOpcode()) {
13459   case ISD::ADD:
13460   case ISD::SUB:
13461   case ISD::MUL:
13462   case ISD::AND:
13463   case ISD::OR:
13464   case ISD::XOR:
13465     if (!LegalOperations && N0.hasOneUse() &&
13466         (isConstantOrConstantVector(N0.getOperand(0), true) ||
13467          isConstantOrConstantVector(N0.getOperand(1), true))) {
13468       // TODO: We already restricted this to pre-legalization, but for vectors
13469       // we are extra cautious to not create an unsupported operation.
13470       // Target-specific changes are likely needed to avoid regressions here.
13471       if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
13472         SDLoc DL(N);
13473         SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
13474         SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
13475         return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
13476       }
13477     }
13478     break;
13479   case ISD::ADDE:
13480   case ISD::ADDCARRY:
13481     // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
13482     // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry)
13483     // When the adde's carry is not used.
13484     // We only do for addcarry before legalize operation
13485     if (((!LegalOperations && N0.getOpcode() == ISD::ADDCARRY) ||
13486          TLI.isOperationLegal(N0.getOpcode(), VT)) &&
13487         N0.hasOneUse() && !N0->hasAnyUseOfValue(1)) {
13488       SDLoc DL(N);
13489       SDValue X = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
13490       SDValue Y = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
13491       SDVTList VTs = DAG.getVTList(VT, N0->getValueType(1));
13492       return DAG.getNode(N0.getOpcode(), DL, VTs, X, Y, N0.getOperand(2));
13493     }
13494     break;
13495   case ISD::USUBSAT:
13496     // Truncate the USUBSAT only if LHS is a known zero-extension, its not
13497     // enough to know that the upper bits are zero we must ensure that we don't
13498     // introduce an extra truncate.
13499     if (!LegalOperations && N0.hasOneUse() &&
13500         N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
13501         N0.getOperand(0).getOperand(0).getScalarValueSizeInBits() <=
13502             VT.getScalarSizeInBits() &&
13503         hasOperation(N0.getOpcode(), VT)) {
13504       return getTruncatedUSUBSAT(VT, SrcVT, N0.getOperand(0), N0.getOperand(1),
13505                                  DAG, SDLoc(N));
13506     }
13507     break;
13508   }
13509 
13510   return SDValue();
13511 }
13512 
getBuildPairElt(SDNode * N,unsigned i)13513 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
13514   SDValue Elt = N->getOperand(i);
13515   if (Elt.getOpcode() != ISD::MERGE_VALUES)
13516     return Elt.getNode();
13517   return Elt.getOperand(Elt.getResNo()).getNode();
13518 }
13519 
13520 /// build_pair (load, load) -> load
13521 /// if load locations are consecutive.
CombineConsecutiveLoads(SDNode * N,EVT VT)13522 SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
13523   assert(N->getOpcode() == ISD::BUILD_PAIR);
13524 
13525   auto *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
13526   auto *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
13527 
13528   // A BUILD_PAIR is always having the least significant part in elt 0 and the
13529   // most significant part in elt 1. So when combining into one large load, we
13530   // need to consider the endianness.
13531   if (DAG.getDataLayout().isBigEndian())
13532     std::swap(LD1, LD2);
13533 
13534   if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !ISD::isNON_EXTLoad(LD2) ||
13535       !LD1->hasOneUse() || !LD2->hasOneUse() ||
13536       LD1->getAddressSpace() != LD2->getAddressSpace())
13537     return SDValue();
13538 
13539   bool LD1Fast = false;
13540   EVT LD1VT = LD1->getValueType(0);
13541   unsigned LD1Bytes = LD1VT.getStoreSize();
13542   if ((!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) &&
13543       DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1) &&
13544       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
13545                              *LD1->getMemOperand(), &LD1Fast) && LD1Fast)
13546     return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
13547                        LD1->getPointerInfo(), LD1->getAlign());
13548 
13549   return SDValue();
13550 }
13551 
getPPCf128HiElementSelector(const SelectionDAG & DAG)13552 static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
13553   // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
13554   // and Lo parts; on big-endian machines it doesn't.
13555   return DAG.getDataLayout().isBigEndian() ? 1 : 0;
13556 }
13557 
foldBitcastedFPLogic(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)13558 static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
13559                                     const TargetLowering &TLI) {
13560   // If this is not a bitcast to an FP type or if the target doesn't have
13561   // IEEE754-compliant FP logic, we're done.
13562   EVT VT = N->getValueType(0);
13563   if (!VT.isFloatingPoint() || !TLI.hasBitPreservingFPLogic(VT))
13564     return SDValue();
13565 
13566   // TODO: Handle cases where the integer constant is a different scalar
13567   // bitwidth to the FP.
13568   SDValue N0 = N->getOperand(0);
13569   EVT SourceVT = N0.getValueType();
13570   if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
13571     return SDValue();
13572 
13573   unsigned FPOpcode;
13574   APInt SignMask;
13575   switch (N0.getOpcode()) {
13576   case ISD::AND:
13577     FPOpcode = ISD::FABS;
13578     SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
13579     break;
13580   case ISD::XOR:
13581     FPOpcode = ISD::FNEG;
13582     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
13583     break;
13584   case ISD::OR:
13585     FPOpcode = ISD::FABS;
13586     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
13587     break;
13588   default:
13589     return SDValue();
13590   }
13591 
13592   // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
13593   // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
13594   // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
13595   //   fneg (fabs X)
13596   SDValue LogicOp0 = N0.getOperand(0);
13597   ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
13598   if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
13599       LogicOp0.getOpcode() == ISD::BITCAST &&
13600       LogicOp0.getOperand(0).getValueType() == VT) {
13601     SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0.getOperand(0));
13602     NumFPLogicOpsConv++;
13603     if (N0.getOpcode() == ISD::OR)
13604       return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
13605     return FPOp;
13606   }
13607 
13608   return SDValue();
13609 }
13610 
visitBITCAST(SDNode * N)13611 SDValue DAGCombiner::visitBITCAST(SDNode *N) {
13612   SDValue N0 = N->getOperand(0);
13613   EVT VT = N->getValueType(0);
13614 
13615   if (N0.isUndef())
13616     return DAG.getUNDEF(VT);
13617 
13618   // If the input is a BUILD_VECTOR with all constant elements, fold this now.
13619   // Only do this before legalize types, unless both types are integer and the
13620   // scalar type is legal. Only do this before legalize ops, since the target
13621   // maybe depending on the bitcast.
13622   // First check to see if this is all constant.
13623   // TODO: Support FP bitcasts after legalize types.
13624   if (VT.isVector() &&
13625       (!LegalTypes ||
13626        (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
13627         TLI.isTypeLegal(VT.getVectorElementType()))) &&
13628       N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
13629       cast<BuildVectorSDNode>(N0)->isConstant())
13630     return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
13631                                              VT.getVectorElementType());
13632 
13633   // If the input is a constant, let getNode fold it.
13634   if (isIntOrFPConstant(N0)) {
13635     // If we can't allow illegal operations, we need to check that this is just
13636     // a fp -> int or int -> conversion and that the resulting operation will
13637     // be legal.
13638     if (!LegalOperations ||
13639         (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
13640          TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
13641         (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
13642          TLI.isOperationLegal(ISD::Constant, VT))) {
13643       SDValue C = DAG.getBitcast(VT, N0);
13644       if (C.getNode() != N)
13645         return C;
13646     }
13647   }
13648 
13649   // (conv (conv x, t1), t2) -> (conv x, t2)
13650   if (N0.getOpcode() == ISD::BITCAST)
13651     return DAG.getBitcast(VT, N0.getOperand(0));
13652 
13653   // fold (conv (load x)) -> (load (conv*)x)
13654   // If the resultant load doesn't need a higher alignment than the original!
13655   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
13656       // Do not remove the cast if the types differ in endian layout.
13657       TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
13658           TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
13659       // If the load is volatile, we only want to change the load type if the
13660       // resulting load is legal. Otherwise we might increase the number of
13661       // memory accesses. We don't care if the original type was legal or not
13662       // as we assume software couldn't rely on the number of accesses of an
13663       // illegal type.
13664       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
13665        TLI.isOperationLegal(ISD::LOAD, VT))) {
13666     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13667 
13668     if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
13669                                     *LN0->getMemOperand())) {
13670       SDValue Load =
13671           DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
13672                       LN0->getPointerInfo(), LN0->getAlign(),
13673                       LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
13674       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
13675       return Load;
13676     }
13677   }
13678 
13679   if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
13680     return V;
13681 
13682   // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
13683   // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
13684   //
13685   // For ppc_fp128:
13686   // fold (bitcast (fneg x)) ->
13687   //     flipbit = signbit
13688   //     (xor (bitcast x) (build_pair flipbit, flipbit))
13689   //
13690   // fold (bitcast (fabs x)) ->
13691   //     flipbit = (and (extract_element (bitcast x), 0), signbit)
13692   //     (xor (bitcast x) (build_pair flipbit, flipbit))
13693   // This often reduces constant pool loads.
13694   if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
13695        (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
13696       N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
13697       !N0.getValueType().isVector()) {
13698     SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
13699     AddToWorklist(NewConv.getNode());
13700 
13701     SDLoc DL(N);
13702     if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
13703       assert(VT.getSizeInBits() == 128);
13704       SDValue SignBit = DAG.getConstant(
13705           APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
13706       SDValue FlipBit;
13707       if (N0.getOpcode() == ISD::FNEG) {
13708         FlipBit = SignBit;
13709         AddToWorklist(FlipBit.getNode());
13710       } else {
13711         assert(N0.getOpcode() == ISD::FABS);
13712         SDValue Hi =
13713             DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
13714                         DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
13715                                               SDLoc(NewConv)));
13716         AddToWorklist(Hi.getNode());
13717         FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
13718         AddToWorklist(FlipBit.getNode());
13719       }
13720       SDValue FlipBits =
13721           DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
13722       AddToWorklist(FlipBits.getNode());
13723       return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
13724     }
13725     APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
13726     if (N0.getOpcode() == ISD::FNEG)
13727       return DAG.getNode(ISD::XOR, DL, VT,
13728                          NewConv, DAG.getConstant(SignBit, DL, VT));
13729     assert(N0.getOpcode() == ISD::FABS);
13730     return DAG.getNode(ISD::AND, DL, VT,
13731                        NewConv, DAG.getConstant(~SignBit, DL, VT));
13732   }
13733 
13734   // fold (bitconvert (fcopysign cst, x)) ->
13735   //         (or (and (bitconvert x), sign), (and cst, (not sign)))
13736   // Note that we don't handle (copysign x, cst) because this can always be
13737   // folded to an fneg or fabs.
13738   //
13739   // For ppc_fp128:
13740   // fold (bitcast (fcopysign cst, x)) ->
13741   //     flipbit = (and (extract_element
13742   //                     (xor (bitcast cst), (bitcast x)), 0),
13743   //                    signbit)
13744   //     (xor (bitcast cst) (build_pair flipbit, flipbit))
13745   if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
13746       isa<ConstantFPSDNode>(N0.getOperand(0)) && VT.isInteger() &&
13747       !VT.isVector()) {
13748     unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
13749     EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
13750     if (isTypeLegal(IntXVT)) {
13751       SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
13752       AddToWorklist(X.getNode());
13753 
13754       // If X has a different width than the result/lhs, sext it or truncate it.
13755       unsigned VTWidth = VT.getSizeInBits();
13756       if (OrigXWidth < VTWidth) {
13757         X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
13758         AddToWorklist(X.getNode());
13759       } else if (OrigXWidth > VTWidth) {
13760         // To get the sign bit in the right place, we have to shift it right
13761         // before truncating.
13762         SDLoc DL(X);
13763         X = DAG.getNode(ISD::SRL, DL,
13764                         X.getValueType(), X,
13765                         DAG.getConstant(OrigXWidth-VTWidth, DL,
13766                                         X.getValueType()));
13767         AddToWorklist(X.getNode());
13768         X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
13769         AddToWorklist(X.getNode());
13770       }
13771 
13772       if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
13773         APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
13774         SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
13775         AddToWorklist(Cst.getNode());
13776         SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
13777         AddToWorklist(X.getNode());
13778         SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
13779         AddToWorklist(XorResult.getNode());
13780         SDValue XorResult64 = DAG.getNode(
13781             ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
13782             DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
13783                                   SDLoc(XorResult)));
13784         AddToWorklist(XorResult64.getNode());
13785         SDValue FlipBit =
13786             DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
13787                         DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
13788         AddToWorklist(FlipBit.getNode());
13789         SDValue FlipBits =
13790             DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
13791         AddToWorklist(FlipBits.getNode());
13792         return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
13793       }
13794       APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
13795       X = DAG.getNode(ISD::AND, SDLoc(X), VT,
13796                       X, DAG.getConstant(SignBit, SDLoc(X), VT));
13797       AddToWorklist(X.getNode());
13798 
13799       SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
13800       Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
13801                         Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
13802       AddToWorklist(Cst.getNode());
13803 
13804       return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
13805     }
13806   }
13807 
13808   // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
13809   if (N0.getOpcode() == ISD::BUILD_PAIR)
13810     if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
13811       return CombineLD;
13812 
13813   // Remove double bitcasts from shuffles - this is often a legacy of
13814   // XformToShuffleWithZero being used to combine bitmaskings (of
13815   // float vectors bitcast to integer vectors) into shuffles.
13816   // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
13817   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
13818       N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
13819       VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
13820       !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
13821     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
13822 
13823     // If operands are a bitcast, peek through if it casts the original VT.
13824     // If operands are a constant, just bitcast back to original VT.
13825     auto PeekThroughBitcast = [&](SDValue Op) {
13826       if (Op.getOpcode() == ISD::BITCAST &&
13827           Op.getOperand(0).getValueType() == VT)
13828         return SDValue(Op.getOperand(0));
13829       if (Op.isUndef() || isAnyConstantBuildVector(Op))
13830         return DAG.getBitcast(VT, Op);
13831       return SDValue();
13832     };
13833 
13834     // FIXME: If either input vector is bitcast, try to convert the shuffle to
13835     // the result type of this bitcast. This would eliminate at least one
13836     // bitcast. See the transform in InstCombine.
13837     SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
13838     SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
13839     if (!(SV0 && SV1))
13840       return SDValue();
13841 
13842     int MaskScale =
13843         VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
13844     SmallVector<int, 8> NewMask;
13845     for (int M : SVN->getMask())
13846       for (int i = 0; i != MaskScale; ++i)
13847         NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
13848 
13849     SDValue LegalShuffle =
13850         TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
13851     if (LegalShuffle)
13852       return LegalShuffle;
13853   }
13854 
13855   return SDValue();
13856 }
13857 
visitBUILD_PAIR(SDNode * N)13858 SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
13859   EVT VT = N->getValueType(0);
13860   return CombineConsecutiveLoads(N, VT);
13861 }
13862 
visitFREEZE(SDNode * N)13863 SDValue DAGCombiner::visitFREEZE(SDNode *N) {
13864   SDValue N0 = N->getOperand(0);
13865 
13866   if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, /*PoisonOnly*/ false))
13867     return N0;
13868 
13869   // Fold freeze(bitcast(x)) -> bitcast(freeze(x)).
13870   // TODO: Replace with pushFreezeToPreventPoisonFromPropagating fold.
13871   if (N0.getOpcode() == ISD::BITCAST)
13872     return DAG.getBitcast(N->getValueType(0),
13873                           DAG.getNode(ISD::FREEZE, SDLoc(N0),
13874                                       N0.getOperand(0).getValueType(),
13875                                       N0.getOperand(0)));
13876 
13877   return SDValue();
13878 }
13879 
13880 /// We know that BV is a build_vector node with Constant, ConstantFP or Undef
13881 /// operands. DstEltVT indicates the destination element value type.
13882 SDValue DAGCombiner::
ConstantFoldBITCASTofBUILD_VECTOR(SDNode * BV,EVT DstEltVT)13883 ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
13884   EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
13885 
13886   // If this is already the right type, we're done.
13887   if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
13888 
13889   unsigned SrcBitSize = SrcEltVT.getSizeInBits();
13890   unsigned DstBitSize = DstEltVT.getSizeInBits();
13891 
13892   // If this is a conversion of N elements of one type to N elements of another
13893   // type, convert each element.  This handles FP<->INT cases.
13894   if (SrcBitSize == DstBitSize) {
13895     SmallVector<SDValue, 8> Ops;
13896     for (SDValue Op : BV->op_values()) {
13897       // If the vector element type is not legal, the BUILD_VECTOR operands
13898       // are promoted and implicitly truncated.  Make that explicit here.
13899       if (Op.getValueType() != SrcEltVT)
13900         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
13901       Ops.push_back(DAG.getBitcast(DstEltVT, Op));
13902       AddToWorklist(Ops.back().getNode());
13903     }
13904     EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
13905                               BV->getValueType(0).getVectorNumElements());
13906     return DAG.getBuildVector(VT, SDLoc(BV), Ops);
13907   }
13908 
13909   // Otherwise, we're growing or shrinking the elements.  To avoid having to
13910   // handle annoying details of growing/shrinking FP values, we convert them to
13911   // int first.
13912   if (SrcEltVT.isFloatingPoint()) {
13913     // Convert the input float vector to a int vector where the elements are the
13914     // same sizes.
13915     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits());
13916     BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode();
13917     SrcEltVT = IntVT;
13918   }
13919 
13920   // Now we know the input is an integer vector.  If the output is a FP type,
13921   // convert to integer first, then to FP of the right size.
13922   if (DstEltVT.isFloatingPoint()) {
13923     EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits());
13924     SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode();
13925 
13926     // Next, convert to FP elements of the same size.
13927     return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
13928   }
13929 
13930   // Okay, we know the src/dst types are both integers of differing types.
13931   assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
13932 
13933   // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
13934   // BuildVectorSDNode?
13935   auto *BVN = cast<BuildVectorSDNode>(BV);
13936 
13937   // Extract the constant raw bit data.
13938   BitVector UndefElements;
13939   SmallVector<APInt> RawBits;
13940   bool IsLE = DAG.getDataLayout().isLittleEndian();
13941   if (!BVN->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements))
13942     return SDValue();
13943 
13944   SDLoc DL(BV);
13945   SmallVector<SDValue, 8> Ops;
13946   for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
13947     if (UndefElements[I])
13948       Ops.push_back(DAG.getUNDEF(DstEltVT));
13949     else
13950       Ops.push_back(DAG.getConstant(RawBits[I], DL, DstEltVT));
13951   }
13952 
13953   EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
13954   return DAG.getBuildVector(VT, DL, Ops);
13955 }
13956 
13957 // Returns true if floating point contraction is allowed on the FMUL-SDValue
13958 // `N`
isContractableFMUL(const TargetOptions & Options,SDValue N)13959 static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
13960   assert(N.getOpcode() == ISD::FMUL);
13961 
13962   return Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath ||
13963          N->getFlags().hasAllowContract();
13964 }
13965 
13966 // Returns true if `N` can assume no infinities involved in its computation.
hasNoInfs(const TargetOptions & Options,SDValue N)13967 static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
13968   return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
13969 }
13970 
13971 /// Try to perform FMA combining on a given FADD node.
visitFADDForFMACombine(SDNode * N)13972 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
13973   SDValue N0 = N->getOperand(0);
13974   SDValue N1 = N->getOperand(1);
13975   EVT VT = N->getValueType(0);
13976   SDLoc SL(N);
13977 
13978   const TargetOptions &Options = DAG.getTarget().Options;
13979 
13980   // Floating-point multiply-add with intermediate rounding.
13981   bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
13982 
13983   // Floating-point multiply-add without intermediate rounding.
13984   bool HasFMA =
13985       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
13986       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
13987 
13988   // No valid opcode, do not combine.
13989   if (!HasFMAD && !HasFMA)
13990     return SDValue();
13991 
13992   bool CanReassociate =
13993       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
13994   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
13995                               Options.UnsafeFPMath || HasFMAD);
13996   // If the addition is not contractable, do not combine.
13997   if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
13998     return SDValue();
13999 
14000   if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
14001     return SDValue();
14002 
14003   // Always prefer FMAD to FMA for precision.
14004   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
14005   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
14006 
14007   auto isFusedOp = [&](SDValue N) {
14008     unsigned Opcode = N.getOpcode();
14009     return Opcode == ISD::FMA || Opcode == ISD::FMAD;
14010   };
14011 
14012   // Is the node an FMUL and contractable either due to global flags or
14013   // SDNodeFlags.
14014   auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
14015     if (N.getOpcode() != ISD::FMUL)
14016       return false;
14017     return AllowFusionGlobally || N->getFlags().hasAllowContract();
14018   };
14019   // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
14020   // prefer to fold the multiply with fewer uses.
14021   if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
14022     if (N0->use_size() > N1->use_size())
14023       std::swap(N0, N1);
14024   }
14025 
14026   // fold (fadd (fmul x, y), z) -> (fma x, y, z)
14027   if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
14028     return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
14029                        N0.getOperand(1), N1);
14030   }
14031 
14032   // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
14033   // Note: Commutes FADD operands.
14034   if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
14035     return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
14036                        N1.getOperand(1), N0);
14037   }
14038 
14039   // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
14040   // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
14041   // This requires reassociation because it changes the order of operations.
14042   SDValue FMA, E;
14043   if (CanReassociate && isFusedOp(N0) &&
14044       N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() &&
14045       N0.getOperand(2).hasOneUse()) {
14046     FMA = N0;
14047     E = N1;
14048   } else if (CanReassociate && isFusedOp(N1) &&
14049              N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() &&
14050              N1.getOperand(2).hasOneUse()) {
14051     FMA = N1;
14052     E = N0;
14053   }
14054   if (FMA && E) {
14055     SDValue A = FMA.getOperand(0);
14056     SDValue B = FMA.getOperand(1);
14057     SDValue C = FMA.getOperand(2).getOperand(0);
14058     SDValue D = FMA.getOperand(2).getOperand(1);
14059     SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
14060     return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE);
14061   }
14062 
14063   // Look through FP_EXTEND nodes to do more combining.
14064 
14065   // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
14066   if (N0.getOpcode() == ISD::FP_EXTEND) {
14067     SDValue N00 = N0.getOperand(0);
14068     if (isContractableFMUL(N00) &&
14069         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14070                             N00.getValueType())) {
14071       return DAG.getNode(PreferredFusedOpcode, SL, VT,
14072                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
14073                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
14074                          N1);
14075     }
14076   }
14077 
14078   // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
14079   // Note: Commutes FADD operands.
14080   if (N1.getOpcode() == ISD::FP_EXTEND) {
14081     SDValue N10 = N1.getOperand(0);
14082     if (isContractableFMUL(N10) &&
14083         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14084                             N10.getValueType())) {
14085       return DAG.getNode(PreferredFusedOpcode, SL, VT,
14086                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
14087                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)),
14088                          N0);
14089     }
14090   }
14091 
14092   // More folding opportunities when target permits.
14093   if (Aggressive) {
14094     // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
14095     //   -> (fma x, y, (fma (fpext u), (fpext v), z))
14096     auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
14097                                     SDValue Z) {
14098       return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y,
14099                          DAG.getNode(PreferredFusedOpcode, SL, VT,
14100                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
14101                                      DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
14102                                      Z));
14103     };
14104     if (isFusedOp(N0)) {
14105       SDValue N02 = N0.getOperand(2);
14106       if (N02.getOpcode() == ISD::FP_EXTEND) {
14107         SDValue N020 = N02.getOperand(0);
14108         if (isContractableFMUL(N020) &&
14109             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14110                                 N020.getValueType())) {
14111           return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
14112                                       N020.getOperand(0), N020.getOperand(1),
14113                                       N1);
14114         }
14115       }
14116     }
14117 
14118     // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
14119     //   -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
14120     // FIXME: This turns two single-precision and one double-precision
14121     // operation into two double-precision operations, which might not be
14122     // interesting for all targets, especially GPUs.
14123     auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
14124                                     SDValue Z) {
14125       return DAG.getNode(
14126           PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, X),
14127           DAG.getNode(ISD::FP_EXTEND, SL, VT, Y),
14128           DAG.getNode(PreferredFusedOpcode, SL, VT,
14129                       DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
14130                       DAG.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
14131     };
14132     if (N0.getOpcode() == ISD::FP_EXTEND) {
14133       SDValue N00 = N0.getOperand(0);
14134       if (isFusedOp(N00)) {
14135         SDValue N002 = N00.getOperand(2);
14136         if (isContractableFMUL(N002) &&
14137             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14138                                 N00.getValueType())) {
14139           return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
14140                                       N002.getOperand(0), N002.getOperand(1),
14141                                       N1);
14142         }
14143       }
14144     }
14145 
14146     // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
14147     //   -> (fma y, z, (fma (fpext u), (fpext v), x))
14148     if (isFusedOp(N1)) {
14149       SDValue N12 = N1.getOperand(2);
14150       if (N12.getOpcode() == ISD::FP_EXTEND) {
14151         SDValue N120 = N12.getOperand(0);
14152         if (isContractableFMUL(N120) &&
14153             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14154                                 N120.getValueType())) {
14155           return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
14156                                       N120.getOperand(0), N120.getOperand(1),
14157                                       N0);
14158         }
14159       }
14160     }
14161 
14162     // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
14163     //   -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
14164     // FIXME: This turns two single-precision and one double-precision
14165     // operation into two double-precision operations, which might not be
14166     // interesting for all targets, especially GPUs.
14167     if (N1.getOpcode() == ISD::FP_EXTEND) {
14168       SDValue N10 = N1.getOperand(0);
14169       if (isFusedOp(N10)) {
14170         SDValue N102 = N10.getOperand(2);
14171         if (isContractableFMUL(N102) &&
14172             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14173                                 N10.getValueType())) {
14174           return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
14175                                       N102.getOperand(0), N102.getOperand(1),
14176                                       N0);
14177         }
14178       }
14179     }
14180   }
14181 
14182   return SDValue();
14183 }
14184 
14185 /// Try to perform FMA combining on a given FSUB node.
visitFSUBForFMACombine(SDNode * N)14186 SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
14187   SDValue N0 = N->getOperand(0);
14188   SDValue N1 = N->getOperand(1);
14189   EVT VT = N->getValueType(0);
14190   SDLoc SL(N);
14191 
14192   const TargetOptions &Options = DAG.getTarget().Options;
14193   // Floating-point multiply-add with intermediate rounding.
14194   bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
14195 
14196   // Floating-point multiply-add without intermediate rounding.
14197   bool HasFMA =
14198       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
14199       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
14200 
14201   // No valid opcode, do not combine.
14202   if (!HasFMAD && !HasFMA)
14203     return SDValue();
14204 
14205   const SDNodeFlags Flags = N->getFlags();
14206   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
14207                               Options.UnsafeFPMath || HasFMAD);
14208 
14209   // If the subtraction is not contractable, do not combine.
14210   if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
14211     return SDValue();
14212 
14213   if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
14214     return SDValue();
14215 
14216   // Always prefer FMAD to FMA for precision.
14217   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
14218   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
14219   bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
14220 
14221   // Is the node an FMUL and contractable either due to global flags or
14222   // SDNodeFlags.
14223   auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
14224     if (N.getOpcode() != ISD::FMUL)
14225       return false;
14226     return AllowFusionGlobally || N->getFlags().hasAllowContract();
14227   };
14228 
14229   // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
14230   auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
14231     if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
14232       return DAG.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0),
14233                          XY.getOperand(1), DAG.getNode(ISD::FNEG, SL, VT, Z));
14234     }
14235     return SDValue();
14236   };
14237 
14238   // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
14239   // Note: Commutes FSUB operands.
14240   auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
14241     if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
14242       return DAG.getNode(PreferredFusedOpcode, SL, VT,
14243                          DAG.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)),
14244                          YZ.getOperand(1), X);
14245     }
14246     return SDValue();
14247   };
14248 
14249   // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
14250   // prefer to fold the multiply with fewer uses.
14251   if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
14252       (N0->use_size() > N1->use_size())) {
14253     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
14254     if (SDValue V = tryToFoldXSubYZ(N0, N1))
14255       return V;
14256     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
14257     if (SDValue V = tryToFoldXYSubZ(N0, N1))
14258       return V;
14259   } else {
14260     // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
14261     if (SDValue V = tryToFoldXYSubZ(N0, N1))
14262       return V;
14263     // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
14264     if (SDValue V = tryToFoldXSubYZ(N0, N1))
14265       return V;
14266   }
14267 
14268   // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
14269   if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) &&
14270       (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
14271     SDValue N00 = N0.getOperand(0).getOperand(0);
14272     SDValue N01 = N0.getOperand(0).getOperand(1);
14273     return DAG.getNode(PreferredFusedOpcode, SL, VT,
14274                        DAG.getNode(ISD::FNEG, SL, VT, N00), N01,
14275                        DAG.getNode(ISD::FNEG, SL, VT, N1));
14276   }
14277 
14278   // Look through FP_EXTEND nodes to do more combining.
14279 
14280   // fold (fsub (fpext (fmul x, y)), z)
14281   //   -> (fma (fpext x), (fpext y), (fneg z))
14282   if (N0.getOpcode() == ISD::FP_EXTEND) {
14283     SDValue N00 = N0.getOperand(0);
14284     if (isContractableFMUL(N00) &&
14285         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14286                             N00.getValueType())) {
14287       return DAG.getNode(PreferredFusedOpcode, SL, VT,
14288                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
14289                          DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
14290                          DAG.getNode(ISD::FNEG, SL, VT, N1));
14291     }
14292   }
14293 
14294   // fold (fsub x, (fpext (fmul y, z)))
14295   //   -> (fma (fneg (fpext y)), (fpext z), x)
14296   // Note: Commutes FSUB operands.
14297   if (N1.getOpcode() == ISD::FP_EXTEND) {
14298     SDValue N10 = N1.getOperand(0);
14299     if (isContractableFMUL(N10) &&
14300         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14301                             N10.getValueType())) {
14302       return DAG.getNode(
14303           PreferredFusedOpcode, SL, VT,
14304           DAG.getNode(ISD::FNEG, SL, VT,
14305                       DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))),
14306           DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
14307     }
14308   }
14309 
14310   // fold (fsub (fpext (fneg (fmul, x, y))), z)
14311   //   -> (fneg (fma (fpext x), (fpext y), z))
14312   // Note: This could be removed with appropriate canonicalization of the
14313   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
14314   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
14315   // from implementing the canonicalization in visitFSUB.
14316   if (N0.getOpcode() == ISD::FP_EXTEND) {
14317     SDValue N00 = N0.getOperand(0);
14318     if (N00.getOpcode() == ISD::FNEG) {
14319       SDValue N000 = N00.getOperand(0);
14320       if (isContractableFMUL(N000) &&
14321           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14322                               N00.getValueType())) {
14323         return DAG.getNode(
14324             ISD::FNEG, SL, VT,
14325             DAG.getNode(PreferredFusedOpcode, SL, VT,
14326                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
14327                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
14328                         N1));
14329       }
14330     }
14331   }
14332 
14333   // fold (fsub (fneg (fpext (fmul, x, y))), z)
14334   //   -> (fneg (fma (fpext x)), (fpext y), z)
14335   // Note: This could be removed with appropriate canonicalization of the
14336   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
14337   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
14338   // from implementing the canonicalization in visitFSUB.
14339   if (N0.getOpcode() == ISD::FNEG) {
14340     SDValue N00 = N0.getOperand(0);
14341     if (N00.getOpcode() == ISD::FP_EXTEND) {
14342       SDValue N000 = N00.getOperand(0);
14343       if (isContractableFMUL(N000) &&
14344           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14345                               N000.getValueType())) {
14346         return DAG.getNode(
14347             ISD::FNEG, SL, VT,
14348             DAG.getNode(PreferredFusedOpcode, SL, VT,
14349                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
14350                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
14351                         N1));
14352       }
14353     }
14354   }
14355 
14356   auto isReassociable = [Options](SDNode *N) {
14357     return Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
14358   };
14359 
14360   auto isContractableAndReassociableFMUL = [isContractableFMUL,
14361                                             isReassociable](SDValue N) {
14362     return isContractableFMUL(N) && isReassociable(N.getNode());
14363   };
14364 
14365   auto isFusedOp = [&](SDValue N) {
14366     unsigned Opcode = N.getOpcode();
14367     return Opcode == ISD::FMA || Opcode == ISD::FMAD;
14368   };
14369 
14370   // More folding opportunities when target permits.
14371   if (Aggressive && isReassociable(N)) {
14372     bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract();
14373     // fold (fsub (fma x, y, (fmul u, v)), z)
14374     //   -> (fma x, y (fma u, v, (fneg z)))
14375     if (CanFuse && isFusedOp(N0) &&
14376         isContractableAndReassociableFMUL(N0.getOperand(2)) &&
14377         N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
14378       return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
14379                          N0.getOperand(1),
14380                          DAG.getNode(PreferredFusedOpcode, SL, VT,
14381                                      N0.getOperand(2).getOperand(0),
14382                                      N0.getOperand(2).getOperand(1),
14383                                      DAG.getNode(ISD::FNEG, SL, VT, N1)));
14384     }
14385 
14386     // fold (fsub x, (fma y, z, (fmul u, v)))
14387     //   -> (fma (fneg y), z, (fma (fneg u), v, x))
14388     if (CanFuse && isFusedOp(N1) &&
14389         isContractableAndReassociableFMUL(N1.getOperand(2)) &&
14390         N1->hasOneUse() && NoSignedZero) {
14391       SDValue N20 = N1.getOperand(2).getOperand(0);
14392       SDValue N21 = N1.getOperand(2).getOperand(1);
14393       return DAG.getNode(
14394           PreferredFusedOpcode, SL, VT,
14395           DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1),
14396           DAG.getNode(PreferredFusedOpcode, SL, VT,
14397                       DAG.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
14398     }
14399 
14400     // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
14401     //   -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
14402     if (isFusedOp(N0) && N0->hasOneUse()) {
14403       SDValue N02 = N0.getOperand(2);
14404       if (N02.getOpcode() == ISD::FP_EXTEND) {
14405         SDValue N020 = N02.getOperand(0);
14406         if (isContractableAndReassociableFMUL(N020) &&
14407             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14408                                 N020.getValueType())) {
14409           return DAG.getNode(
14410               PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
14411               DAG.getNode(
14412                   PreferredFusedOpcode, SL, VT,
14413                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)),
14414                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)),
14415                   DAG.getNode(ISD::FNEG, SL, VT, N1)));
14416         }
14417       }
14418     }
14419 
14420     // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
14421     //   -> (fma (fpext x), (fpext y),
14422     //           (fma (fpext u), (fpext v), (fneg z)))
14423     // FIXME: This turns two single-precision and one double-precision
14424     // operation into two double-precision operations, which might not be
14425     // interesting for all targets, especially GPUs.
14426     if (N0.getOpcode() == ISD::FP_EXTEND) {
14427       SDValue N00 = N0.getOperand(0);
14428       if (isFusedOp(N00)) {
14429         SDValue N002 = N00.getOperand(2);
14430         if (isContractableAndReassociableFMUL(N002) &&
14431             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14432                                 N00.getValueType())) {
14433           return DAG.getNode(
14434               PreferredFusedOpcode, SL, VT,
14435               DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
14436               DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
14437               DAG.getNode(
14438                   PreferredFusedOpcode, SL, VT,
14439                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)),
14440                   DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)),
14441                   DAG.getNode(ISD::FNEG, SL, VT, N1)));
14442         }
14443       }
14444     }
14445 
14446     // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
14447     //   -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
14448     if (isFusedOp(N1) && N1.getOperand(2).getOpcode() == ISD::FP_EXTEND &&
14449         N1->hasOneUse()) {
14450       SDValue N120 = N1.getOperand(2).getOperand(0);
14451       if (isContractableAndReassociableFMUL(N120) &&
14452           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14453                               N120.getValueType())) {
14454         SDValue N1200 = N120.getOperand(0);
14455         SDValue N1201 = N120.getOperand(1);
14456         return DAG.getNode(
14457             PreferredFusedOpcode, SL, VT,
14458             DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1),
14459             DAG.getNode(PreferredFusedOpcode, SL, VT,
14460                         DAG.getNode(ISD::FNEG, SL, VT,
14461                                     DAG.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
14462                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
14463       }
14464     }
14465 
14466     // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
14467     //   -> (fma (fneg (fpext y)), (fpext z),
14468     //           (fma (fneg (fpext u)), (fpext v), x))
14469     // FIXME: This turns two single-precision and one double-precision
14470     // operation into two double-precision operations, which might not be
14471     // interesting for all targets, especially GPUs.
14472     if (N1.getOpcode() == ISD::FP_EXTEND && isFusedOp(N1.getOperand(0))) {
14473       SDValue CvtSrc = N1.getOperand(0);
14474       SDValue N100 = CvtSrc.getOperand(0);
14475       SDValue N101 = CvtSrc.getOperand(1);
14476       SDValue N102 = CvtSrc.getOperand(2);
14477       if (isContractableAndReassociableFMUL(N102) &&
14478           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14479                               CvtSrc.getValueType())) {
14480         SDValue N1020 = N102.getOperand(0);
14481         SDValue N1021 = N102.getOperand(1);
14482         return DAG.getNode(
14483             PreferredFusedOpcode, SL, VT,
14484             DAG.getNode(ISD::FNEG, SL, VT,
14485                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N100)),
14486             DAG.getNode(ISD::FP_EXTEND, SL, VT, N101),
14487             DAG.getNode(PreferredFusedOpcode, SL, VT,
14488                         DAG.getNode(ISD::FNEG, SL, VT,
14489                                     DAG.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
14490                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
14491       }
14492     }
14493   }
14494 
14495   return SDValue();
14496 }
14497 
14498 /// Try to perform FMA combining on a given FMUL node based on the distributive
14499 /// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
14500 /// subtraction instead of addition).
visitFMULForFMADistributiveCombine(SDNode * N)14501 SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
14502   SDValue N0 = N->getOperand(0);
14503   SDValue N1 = N->getOperand(1);
14504   EVT VT = N->getValueType(0);
14505   SDLoc SL(N);
14506 
14507   assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
14508 
14509   const TargetOptions &Options = DAG.getTarget().Options;
14510 
14511   // The transforms below are incorrect when x == 0 and y == inf, because the
14512   // intermediate multiplication produces a nan.
14513   SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
14514   if (!hasNoInfs(Options, FAdd))
14515     return SDValue();
14516 
14517   // Floating-point multiply-add without intermediate rounding.
14518   bool HasFMA =
14519       isContractableFMUL(Options, SDValue(N, 0)) &&
14520       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
14521       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
14522 
14523   // Floating-point multiply-add with intermediate rounding. This can result
14524   // in a less precise result due to the changed rounding order.
14525   bool HasFMAD = Options.UnsafeFPMath &&
14526                  (LegalOperations && TLI.isFMADLegal(DAG, N));
14527 
14528   // No valid opcode, do not combine.
14529   if (!HasFMAD && !HasFMA)
14530     return SDValue();
14531 
14532   // Always prefer FMAD to FMA for precision.
14533   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
14534   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
14535 
14536   // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
14537   // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
14538   auto FuseFADD = [&](SDValue X, SDValue Y) {
14539     if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
14540       if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
14541         if (C->isExactlyValue(+1.0))
14542           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
14543                              Y);
14544         if (C->isExactlyValue(-1.0))
14545           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
14546                              DAG.getNode(ISD::FNEG, SL, VT, Y));
14547       }
14548     }
14549     return SDValue();
14550   };
14551 
14552   if (SDValue FMA = FuseFADD(N0, N1))
14553     return FMA;
14554   if (SDValue FMA = FuseFADD(N1, N0))
14555     return FMA;
14556 
14557   // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
14558   // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
14559   // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
14560   // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
14561   auto FuseFSUB = [&](SDValue X, SDValue Y) {
14562     if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
14563       if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
14564         if (C0->isExactlyValue(+1.0))
14565           return DAG.getNode(PreferredFusedOpcode, SL, VT,
14566                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
14567                              Y);
14568         if (C0->isExactlyValue(-1.0))
14569           return DAG.getNode(PreferredFusedOpcode, SL, VT,
14570                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
14571                              DAG.getNode(ISD::FNEG, SL, VT, Y));
14572       }
14573       if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
14574         if (C1->isExactlyValue(+1.0))
14575           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
14576                              DAG.getNode(ISD::FNEG, SL, VT, Y));
14577         if (C1->isExactlyValue(-1.0))
14578           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
14579                              Y);
14580       }
14581     }
14582     return SDValue();
14583   };
14584 
14585   if (SDValue FMA = FuseFSUB(N0, N1))
14586     return FMA;
14587   if (SDValue FMA = FuseFSUB(N1, N0))
14588     return FMA;
14589 
14590   return SDValue();
14591 }
14592 
visitFADD(SDNode * N)14593 SDValue DAGCombiner::visitFADD(SDNode *N) {
14594   SDValue N0 = N->getOperand(0);
14595   SDValue N1 = N->getOperand(1);
14596   bool N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
14597   bool N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
14598   EVT VT = N->getValueType(0);
14599   SDLoc DL(N);
14600   const TargetOptions &Options = DAG.getTarget().Options;
14601   SDNodeFlags Flags = N->getFlags();
14602   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14603 
14604   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
14605     return R;
14606 
14607   // fold (fadd c1, c2) -> c1 + c2
14608   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FADD, DL, VT, {N0, N1}))
14609     return C;
14610 
14611   // canonicalize constant to RHS
14612   if (N0CFP && !N1CFP)
14613     return DAG.getNode(ISD::FADD, DL, VT, N1, N0);
14614 
14615   // fold vector ops
14616   if (VT.isVector())
14617     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
14618       return FoldedVOp;
14619 
14620   // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
14621   ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
14622   if (N1C && N1C->isZero())
14623     if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
14624       return N0;
14625 
14626   if (SDValue NewSel = foldBinOpIntoSelect(N))
14627     return NewSel;
14628 
14629   // fold (fadd A, (fneg B)) -> (fsub A, B)
14630   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
14631     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
14632             N1, DAG, LegalOperations, ForCodeSize))
14633       return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1);
14634 
14635   // fold (fadd (fneg A), B) -> (fsub B, A)
14636   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
14637     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
14638             N0, DAG, LegalOperations, ForCodeSize))
14639       return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0);
14640 
14641   auto isFMulNegTwo = [](SDValue FMul) {
14642     if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
14643       return false;
14644     auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
14645     return C && C->isExactlyValue(-2.0);
14646   };
14647 
14648   // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
14649   if (isFMulNegTwo(N0)) {
14650     SDValue B = N0.getOperand(0);
14651     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
14652     return DAG.getNode(ISD::FSUB, DL, VT, N1, Add);
14653   }
14654   // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
14655   if (isFMulNegTwo(N1)) {
14656     SDValue B = N1.getOperand(0);
14657     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
14658     return DAG.getNode(ISD::FSUB, DL, VT, N0, Add);
14659   }
14660 
14661   // No FP constant should be created after legalization as Instruction
14662   // Selection pass has a hard time dealing with FP constants.
14663   bool AllowNewConst = (Level < AfterLegalizeDAG);
14664 
14665   // If nnan is enabled, fold lots of things.
14666   if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
14667     // If allowed, fold (fadd (fneg x), x) -> 0.0
14668     if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
14669       return DAG.getConstantFP(0.0, DL, VT);
14670 
14671     // If allowed, fold (fadd x, (fneg x)) -> 0.0
14672     if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
14673       return DAG.getConstantFP(0.0, DL, VT);
14674   }
14675 
14676   // If 'unsafe math' or reassoc and nsz, fold lots of things.
14677   // TODO: break out portions of the transformations below for which Unsafe is
14678   //       considered and which do not require both nsz and reassoc
14679   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
14680        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
14681       AllowNewConst) {
14682     // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
14683     if (N1CFP && N0.getOpcode() == ISD::FADD &&
14684         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
14685       SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1);
14686       return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC);
14687     }
14688 
14689     // We can fold chains of FADD's of the same value into multiplications.
14690     // This transform is not safe in general because we are reducing the number
14691     // of rounding steps.
14692     if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
14693       if (N0.getOpcode() == ISD::FMUL) {
14694         bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
14695         bool CFP01 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
14696 
14697         // (fadd (fmul x, c), x) -> (fmul x, c+1)
14698         if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
14699           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
14700                                        DAG.getConstantFP(1.0, DL, VT));
14701           return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP);
14702         }
14703 
14704         // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
14705         if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
14706             N1.getOperand(0) == N1.getOperand(1) &&
14707             N0.getOperand(0) == N1.getOperand(0)) {
14708           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
14709                                        DAG.getConstantFP(2.0, DL, VT));
14710           return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP);
14711         }
14712       }
14713 
14714       if (N1.getOpcode() == ISD::FMUL) {
14715         bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
14716         bool CFP11 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
14717 
14718         // (fadd x, (fmul x, c)) -> (fmul x, c+1)
14719         if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
14720           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
14721                                        DAG.getConstantFP(1.0, DL, VT));
14722           return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP);
14723         }
14724 
14725         // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
14726         if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
14727             N0.getOperand(0) == N0.getOperand(1) &&
14728             N1.getOperand(0) == N0.getOperand(0)) {
14729           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
14730                                        DAG.getConstantFP(2.0, DL, VT));
14731           return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP);
14732         }
14733       }
14734 
14735       if (N0.getOpcode() == ISD::FADD) {
14736         bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
14737         // (fadd (fadd x, x), x) -> (fmul x, 3.0)
14738         if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
14739             (N0.getOperand(0) == N1)) {
14740           return DAG.getNode(ISD::FMUL, DL, VT, N1,
14741                              DAG.getConstantFP(3.0, DL, VT));
14742         }
14743       }
14744 
14745       if (N1.getOpcode() == ISD::FADD) {
14746         bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
14747         // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
14748         if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
14749             N1.getOperand(0) == N0) {
14750           return DAG.getNode(ISD::FMUL, DL, VT, N0,
14751                              DAG.getConstantFP(3.0, DL, VT));
14752         }
14753       }
14754 
14755       // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
14756       if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
14757           N0.getOperand(0) == N0.getOperand(1) &&
14758           N1.getOperand(0) == N1.getOperand(1) &&
14759           N0.getOperand(0) == N1.getOperand(0)) {
14760         return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
14761                            DAG.getConstantFP(4.0, DL, VT));
14762       }
14763     }
14764   } // enable-unsafe-fp-math
14765 
14766   // FADD -> FMA combines:
14767   if (SDValue Fused = visitFADDForFMACombine(N)) {
14768     AddToWorklist(Fused.getNode());
14769     return Fused;
14770   }
14771   return SDValue();
14772 }
14773 
visitSTRICT_FADD(SDNode * N)14774 SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
14775   SDValue Chain = N->getOperand(0);
14776   SDValue N0 = N->getOperand(1);
14777   SDValue N1 = N->getOperand(2);
14778   EVT VT = N->getValueType(0);
14779   EVT ChainVT = N->getValueType(1);
14780   SDLoc DL(N);
14781   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14782 
14783   // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
14784   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
14785     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
14786             N1, DAG, LegalOperations, ForCodeSize)) {
14787       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
14788                          {Chain, N0, NegN1});
14789     }
14790 
14791   // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
14792   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
14793     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
14794             N0, DAG, LegalOperations, ForCodeSize)) {
14795       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
14796                          {Chain, N1, NegN0});
14797     }
14798   return SDValue();
14799 }
14800 
visitFSUB(SDNode * N)14801 SDValue DAGCombiner::visitFSUB(SDNode *N) {
14802   SDValue N0 = N->getOperand(0);
14803   SDValue N1 = N->getOperand(1);
14804   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
14805   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
14806   EVT VT = N->getValueType(0);
14807   SDLoc DL(N);
14808   const TargetOptions &Options = DAG.getTarget().Options;
14809   const SDNodeFlags Flags = N->getFlags();
14810   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14811 
14812   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
14813     return R;
14814 
14815   // fold (fsub c1, c2) -> c1-c2
14816   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FSUB, DL, VT, {N0, N1}))
14817     return C;
14818 
14819   // fold vector ops
14820   if (VT.isVector())
14821     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
14822       return FoldedVOp;
14823 
14824   if (SDValue NewSel = foldBinOpIntoSelect(N))
14825     return NewSel;
14826 
14827   // (fsub A, 0) -> A
14828   if (N1CFP && N1CFP->isZero()) {
14829     if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
14830         Flags.hasNoSignedZeros()) {
14831       return N0;
14832     }
14833   }
14834 
14835   if (N0 == N1) {
14836     // (fsub x, x) -> 0.0
14837     if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
14838       return DAG.getConstantFP(0.0f, DL, VT);
14839   }
14840 
14841   // (fsub -0.0, N1) -> -N1
14842   if (N0CFP && N0CFP->isZero()) {
14843     if (N0CFP->isNegative() ||
14844         (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
14845       // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
14846       // flushed to zero, unless all users treat denorms as zero (DAZ).
14847       // FIXME: This transform will change the sign of a NaN and the behavior
14848       // of a signaling NaN. It is only valid when a NoNaN flag is present.
14849       DenormalMode DenormMode = DAG.getDenormalMode(VT);
14850       if (DenormMode == DenormalMode::getIEEE()) {
14851         if (SDValue NegN1 =
14852                 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
14853           return NegN1;
14854         if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
14855           return DAG.getNode(ISD::FNEG, DL, VT, N1);
14856       }
14857     }
14858   }
14859 
14860   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
14861        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
14862       N1.getOpcode() == ISD::FADD) {
14863     // X - (X + Y) -> -Y
14864     if (N0 == N1->getOperand(0))
14865       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1));
14866     // X - (Y + X) -> -Y
14867     if (N0 == N1->getOperand(1))
14868       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0));
14869   }
14870 
14871   // fold (fsub A, (fneg B)) -> (fadd A, B)
14872   if (SDValue NegN1 =
14873           TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
14874     return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1);
14875 
14876   // FSUB -> FMA combines:
14877   if (SDValue Fused = visitFSUBForFMACombine(N)) {
14878     AddToWorklist(Fused.getNode());
14879     return Fused;
14880   }
14881 
14882   return SDValue();
14883 }
14884 
visitFMUL(SDNode * N)14885 SDValue DAGCombiner::visitFMUL(SDNode *N) {
14886   SDValue N0 = N->getOperand(0);
14887   SDValue N1 = N->getOperand(1);
14888   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
14889   EVT VT = N->getValueType(0);
14890   SDLoc DL(N);
14891   const TargetOptions &Options = DAG.getTarget().Options;
14892   const SDNodeFlags Flags = N->getFlags();
14893   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14894 
14895   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
14896     return R;
14897 
14898   // fold (fmul c1, c2) -> c1*c2
14899   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FMUL, DL, VT, {N0, N1}))
14900     return C;
14901 
14902   // canonicalize constant to RHS
14903   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
14904      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
14905     return DAG.getNode(ISD::FMUL, DL, VT, N1, N0);
14906 
14907   // fold vector ops
14908   if (VT.isVector())
14909     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
14910       return FoldedVOp;
14911 
14912   if (SDValue NewSel = foldBinOpIntoSelect(N))
14913     return NewSel;
14914 
14915   if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
14916     // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
14917     if (DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
14918         N0.getOpcode() == ISD::FMUL) {
14919       SDValue N00 = N0.getOperand(0);
14920       SDValue N01 = N0.getOperand(1);
14921       // Avoid an infinite loop by making sure that N00 is not a constant
14922       // (the inner multiply has not been constant folded yet).
14923       if (DAG.isConstantFPBuildVectorOrConstantFP(N01) &&
14924           !DAG.isConstantFPBuildVectorOrConstantFP(N00)) {
14925         SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1);
14926         return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts);
14927       }
14928     }
14929 
14930     // Match a special-case: we convert X * 2.0 into fadd.
14931     // fmul (fadd X, X), C -> fmul X, 2.0 * C
14932     if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
14933         N0.getOperand(0) == N0.getOperand(1)) {
14934       const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
14935       SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
14936       return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
14937     }
14938   }
14939 
14940   // fold (fmul X, 2.0) -> (fadd X, X)
14941   if (N1CFP && N1CFP->isExactlyValue(+2.0))
14942     return DAG.getNode(ISD::FADD, DL, VT, N0, N0);
14943 
14944   // fold (fmul X, -1.0) -> (fsub -0.0, X)
14945   if (N1CFP && N1CFP->isExactlyValue(-1.0)) {
14946     if (!LegalOperations || TLI.isOperationLegal(ISD::FSUB, VT)) {
14947       return DAG.getNode(ISD::FSUB, DL, VT,
14948                          DAG.getConstantFP(-0.0, DL, VT), N0, Flags);
14949     }
14950   }
14951 
14952   // -N0 * -N1 --> N0 * N1
14953   TargetLowering::NegatibleCost CostN0 =
14954       TargetLowering::NegatibleCost::Expensive;
14955   TargetLowering::NegatibleCost CostN1 =
14956       TargetLowering::NegatibleCost::Expensive;
14957   SDValue NegN0 =
14958       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
14959   SDValue NegN1 =
14960       TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
14961   if (NegN0 && NegN1 &&
14962       (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
14963        CostN1 == TargetLowering::NegatibleCost::Cheaper))
14964     return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1);
14965 
14966   // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
14967   // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
14968   if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
14969       (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
14970       TLI.isOperationLegal(ISD::FABS, VT)) {
14971     SDValue Select = N0, X = N1;
14972     if (Select.getOpcode() != ISD::SELECT)
14973       std::swap(Select, X);
14974 
14975     SDValue Cond = Select.getOperand(0);
14976     auto TrueOpnd  = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
14977     auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
14978 
14979     if (TrueOpnd && FalseOpnd &&
14980         Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
14981         isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
14982         cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
14983       ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
14984       switch (CC) {
14985       default: break;
14986       case ISD::SETOLT:
14987       case ISD::SETULT:
14988       case ISD::SETOLE:
14989       case ISD::SETULE:
14990       case ISD::SETLT:
14991       case ISD::SETLE:
14992         std::swap(TrueOpnd, FalseOpnd);
14993         LLVM_FALLTHROUGH;
14994       case ISD::SETOGT:
14995       case ISD::SETUGT:
14996       case ISD::SETOGE:
14997       case ISD::SETUGE:
14998       case ISD::SETGT:
14999       case ISD::SETGE:
15000         if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
15001             TLI.isOperationLegal(ISD::FNEG, VT))
15002           return DAG.getNode(ISD::FNEG, DL, VT,
15003                    DAG.getNode(ISD::FABS, DL, VT, X));
15004         if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
15005           return DAG.getNode(ISD::FABS, DL, VT, X);
15006 
15007         break;
15008       }
15009     }
15010   }
15011 
15012   // FMUL -> FMA combines:
15013   if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
15014     AddToWorklist(Fused.getNode());
15015     return Fused;
15016   }
15017 
15018   return SDValue();
15019 }
15020 
visitFMA(SDNode * N)15021 SDValue DAGCombiner::visitFMA(SDNode *N) {
15022   SDValue N0 = N->getOperand(0);
15023   SDValue N1 = N->getOperand(1);
15024   SDValue N2 = N->getOperand(2);
15025   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
15026   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
15027   EVT VT = N->getValueType(0);
15028   SDLoc DL(N);
15029   const TargetOptions &Options = DAG.getTarget().Options;
15030   // FMA nodes have flags that propagate to the created nodes.
15031   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15032 
15033   bool CanReassociate =
15034       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
15035 
15036   // Constant fold FMA.
15037   if (isa<ConstantFPSDNode>(N0) &&
15038       isa<ConstantFPSDNode>(N1) &&
15039       isa<ConstantFPSDNode>(N2)) {
15040     return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2);
15041   }
15042 
15043   // (-N0 * -N1) + N2 --> (N0 * N1) + N2
15044   TargetLowering::NegatibleCost CostN0 =
15045       TargetLowering::NegatibleCost::Expensive;
15046   TargetLowering::NegatibleCost CostN1 =
15047       TargetLowering::NegatibleCost::Expensive;
15048   SDValue NegN0 =
15049       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
15050   SDValue NegN1 =
15051       TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
15052   if (NegN0 && NegN1 &&
15053       (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
15054        CostN1 == TargetLowering::NegatibleCost::Cheaper))
15055     return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
15056 
15057   // FIXME: use fast math flags instead of Options.UnsafeFPMath
15058   if (Options.UnsafeFPMath) {
15059     if (N0CFP && N0CFP->isZero())
15060       return N2;
15061     if (N1CFP && N1CFP->isZero())
15062       return N2;
15063   }
15064 
15065   if (N0CFP && N0CFP->isExactlyValue(1.0))
15066     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
15067   if (N1CFP && N1CFP->isExactlyValue(1.0))
15068     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
15069 
15070   // Canonicalize (fma c, x, y) -> (fma x, c, y)
15071   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
15072      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
15073     return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
15074 
15075   if (CanReassociate) {
15076     // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
15077     if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) &&
15078         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
15079         DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
15080       return DAG.getNode(ISD::FMUL, DL, VT, N0,
15081                          DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
15082     }
15083 
15084     // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
15085     if (N0.getOpcode() == ISD::FMUL &&
15086         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
15087         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
15088       return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
15089                          DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)),
15090                          N2);
15091     }
15092   }
15093 
15094   // (fma x, -1, y) -> (fadd (fneg x), y)
15095   if (N1CFP) {
15096     if (N1CFP->isExactlyValue(1.0))
15097       return DAG.getNode(ISD::FADD, DL, VT, N0, N2);
15098 
15099     if (N1CFP->isExactlyValue(-1.0) &&
15100         (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
15101       SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0);
15102       AddToWorklist(RHSNeg.getNode());
15103       return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
15104     }
15105 
15106     // fma (fneg x), K, y -> fma x -K, y
15107     if (N0.getOpcode() == ISD::FNEG &&
15108         (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
15109          (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT,
15110                                               ForCodeSize)))) {
15111       return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
15112                          DAG.getNode(ISD::FNEG, DL, VT, N1), N2);
15113     }
15114   }
15115 
15116   if (CanReassociate) {
15117     // (fma x, c, x) -> (fmul x, (c+1))
15118     if (N1CFP && N0 == N2) {
15119       return DAG.getNode(
15120           ISD::FMUL, DL, VT, N0,
15121           DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(1.0, DL, VT)));
15122     }
15123 
15124     // (fma x, c, (fneg x)) -> (fmul x, (c-1))
15125     if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
15126       return DAG.getNode(
15127           ISD::FMUL, DL, VT, N0,
15128           DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(-1.0, DL, VT)));
15129     }
15130   }
15131 
15132   // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
15133   // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
15134   if (!TLI.isFNegFree(VT))
15135     if (SDValue Neg = TLI.getCheaperNegatedExpression(
15136             SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
15137       return DAG.getNode(ISD::FNEG, DL, VT, Neg);
15138   return SDValue();
15139 }
15140 
15141 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
15142 // reciprocal.
15143 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
15144 // Notice that this is not always beneficial. One reason is different targets
15145 // may have different costs for FDIV and FMUL, so sometimes the cost of two
15146 // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
15147 // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
combineRepeatedFPDivisors(SDNode * N)15148 SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
15149   // TODO: Limit this transform based on optsize/minsize - it always creates at
15150   //       least 1 extra instruction. But the perf win may be substantial enough
15151   //       that only minsize should restrict this.
15152   bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
15153   const SDNodeFlags Flags = N->getFlags();
15154   if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
15155     return SDValue();
15156 
15157   // Skip if current node is a reciprocal/fneg-reciprocal.
15158   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
15159   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
15160   if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
15161     return SDValue();
15162 
15163   // Exit early if the target does not want this transform or if there can't
15164   // possibly be enough uses of the divisor to make the transform worthwhile.
15165   unsigned MinUses = TLI.combineRepeatedFPDivisors();
15166 
15167   // For splat vectors, scale the number of uses by the splat factor. If we can
15168   // convert the division into a scalar op, that will likely be much faster.
15169   unsigned NumElts = 1;
15170   EVT VT = N->getValueType(0);
15171   if (VT.isVector() && DAG.isSplatValue(N1))
15172     NumElts = VT.getVectorMinNumElements();
15173 
15174   if (!MinUses || (N1->use_size() * NumElts) < MinUses)
15175     return SDValue();
15176 
15177   // Find all FDIV users of the same divisor.
15178   // Use a set because duplicates may be present in the user list.
15179   SetVector<SDNode *> Users;
15180   for (auto *U : N1->uses()) {
15181     if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
15182       // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
15183       if (U->getOperand(1).getOpcode() == ISD::FSQRT &&
15184           U->getOperand(0) == U->getOperand(1).getOperand(0) &&
15185           U->getFlags().hasAllowReassociation() &&
15186           U->getFlags().hasNoSignedZeros())
15187         continue;
15188 
15189       // This division is eligible for optimization only if global unsafe math
15190       // is enabled or if this division allows reciprocal formation.
15191       if (UnsafeMath || U->getFlags().hasAllowReciprocal())
15192         Users.insert(U);
15193     }
15194   }
15195 
15196   // Now that we have the actual number of divisor uses, make sure it meets
15197   // the minimum threshold specified by the target.
15198   if ((Users.size() * NumElts) < MinUses)
15199     return SDValue();
15200 
15201   SDLoc DL(N);
15202   SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
15203   SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
15204 
15205   // Dividend / Divisor -> Dividend * Reciprocal
15206   for (auto *U : Users) {
15207     SDValue Dividend = U->getOperand(0);
15208     if (Dividend != FPOne) {
15209       SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
15210                                     Reciprocal, Flags);
15211       CombineTo(U, NewNode);
15212     } else if (U != Reciprocal.getNode()) {
15213       // In the absence of fast-math-flags, this user node is always the
15214       // same node as Reciprocal, but with FMF they may be different nodes.
15215       CombineTo(U, Reciprocal);
15216     }
15217   }
15218   return SDValue(N, 0);  // N was replaced.
15219 }
15220 
visitFDIV(SDNode * N)15221 SDValue DAGCombiner::visitFDIV(SDNode *N) {
15222   SDValue N0 = N->getOperand(0);
15223   SDValue N1 = N->getOperand(1);
15224   EVT VT = N->getValueType(0);
15225   SDLoc DL(N);
15226   const TargetOptions &Options = DAG.getTarget().Options;
15227   SDNodeFlags Flags = N->getFlags();
15228   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15229 
15230   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
15231     return R;
15232 
15233   // fold (fdiv c1, c2) -> c1/c2
15234   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FDIV, DL, VT, {N0, N1}))
15235     return C;
15236 
15237   // fold vector ops
15238   if (VT.isVector())
15239     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
15240       return FoldedVOp;
15241 
15242   if (SDValue NewSel = foldBinOpIntoSelect(N))
15243     return NewSel;
15244 
15245   if (SDValue V = combineRepeatedFPDivisors(N))
15246     return V;
15247 
15248   if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
15249     // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
15250     if (auto *N1CFP = dyn_cast<ConstantFPSDNode>(N1)) {
15251       // Compute the reciprocal 1.0 / c2.
15252       const APFloat &N1APF = N1CFP->getValueAPF();
15253       APFloat Recip(N1APF.getSemantics(), 1); // 1.0
15254       APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
15255       // Only do the transform if the reciprocal is a legal fp immediate that
15256       // isn't too nasty (eg NaN, denormal, ...).
15257       if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
15258           (!LegalOperations ||
15259            // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
15260            // backend)... we should handle this gracefully after Legalize.
15261            // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
15262            TLI.isOperationLegal(ISD::ConstantFP, VT) ||
15263            TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
15264         return DAG.getNode(ISD::FMUL, DL, VT, N0,
15265                            DAG.getConstantFP(Recip, DL, VT));
15266     }
15267 
15268     // If this FDIV is part of a reciprocal square root, it may be folded
15269     // into a target-specific square root estimate instruction.
15270     if (N1.getOpcode() == ISD::FSQRT) {
15271       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags))
15272         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
15273     } else if (N1.getOpcode() == ISD::FP_EXTEND &&
15274                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
15275       if (SDValue RV =
15276               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
15277         RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
15278         AddToWorklist(RV.getNode());
15279         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
15280       }
15281     } else if (N1.getOpcode() == ISD::FP_ROUND &&
15282                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
15283       if (SDValue RV =
15284               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
15285         RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
15286         AddToWorklist(RV.getNode());
15287         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
15288       }
15289     } else if (N1.getOpcode() == ISD::FMUL) {
15290       // Look through an FMUL. Even though this won't remove the FDIV directly,
15291       // it's still worthwhile to get rid of the FSQRT if possible.
15292       SDValue Sqrt, Y;
15293       if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
15294         Sqrt = N1.getOperand(0);
15295         Y = N1.getOperand(1);
15296       } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
15297         Sqrt = N1.getOperand(1);
15298         Y = N1.getOperand(0);
15299       }
15300       if (Sqrt.getNode()) {
15301         // If the other multiply operand is known positive, pull it into the
15302         // sqrt. That will eliminate the division if we convert to an estimate.
15303         if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
15304             N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
15305           SDValue A;
15306           if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
15307             A = Y.getOperand(0);
15308           else if (Y == Sqrt.getOperand(0))
15309             A = Y;
15310           if (A) {
15311             // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
15312             // X / (A * sqrt(A))       --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
15313             SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A);
15314             SDValue AAZ =
15315                 DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0));
15316             if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
15317               return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt);
15318 
15319             // Estimate creation failed. Clean up speculatively created nodes.
15320             recursivelyDeleteUnusedNodes(AAZ.getNode());
15321           }
15322         }
15323 
15324         // We found a FSQRT, so try to make this fold:
15325         // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
15326         if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {
15327           SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
15328           AddToWorklist(Div.getNode());
15329           return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
15330         }
15331       }
15332     }
15333 
15334     // Fold into a reciprocal estimate and multiply instead of a real divide.
15335     if (Options.NoInfsFPMath || Flags.hasNoInfs())
15336       if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
15337         return RV;
15338   }
15339 
15340   // Fold X/Sqrt(X) -> Sqrt(X)
15341   if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
15342       (Options.UnsafeFPMath || Flags.hasAllowReassociation()))
15343     if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(0))
15344       return N1;
15345 
15346   // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
15347   TargetLowering::NegatibleCost CostN0 =
15348       TargetLowering::NegatibleCost::Expensive;
15349   TargetLowering::NegatibleCost CostN1 =
15350       TargetLowering::NegatibleCost::Expensive;
15351   SDValue NegN0 =
15352       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
15353   SDValue NegN1 =
15354       TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
15355   if (NegN0 && NegN1 &&
15356       (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
15357        CostN1 == TargetLowering::NegatibleCost::Cheaper))
15358     return DAG.getNode(ISD::FDIV, SDLoc(N), VT, NegN0, NegN1);
15359 
15360   return SDValue();
15361 }
15362 
visitFREM(SDNode * N)15363 SDValue DAGCombiner::visitFREM(SDNode *N) {
15364   SDValue N0 = N->getOperand(0);
15365   SDValue N1 = N->getOperand(1);
15366   EVT VT = N->getValueType(0);
15367   SDNodeFlags Flags = N->getFlags();
15368   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15369 
15370   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
15371     return R;
15372 
15373   // fold (frem c1, c2) -> fmod(c1,c2)
15374   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FREM, SDLoc(N), VT, {N0, N1}))
15375     return C;
15376 
15377   if (SDValue NewSel = foldBinOpIntoSelect(N))
15378     return NewSel;
15379 
15380   return SDValue();
15381 }
15382 
visitFSQRT(SDNode * N)15383 SDValue DAGCombiner::visitFSQRT(SDNode *N) {
15384   SDNodeFlags Flags = N->getFlags();
15385   const TargetOptions &Options = DAG.getTarget().Options;
15386 
15387   // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
15388   // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
15389   if (!Flags.hasApproximateFuncs() ||
15390       (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
15391     return SDValue();
15392 
15393   SDValue N0 = N->getOperand(0);
15394   if (TLI.isFsqrtCheap(N0, DAG))
15395     return SDValue();
15396 
15397   // FSQRT nodes have flags that propagate to the created nodes.
15398   // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
15399   //       transform the fdiv, we may produce a sub-optimal estimate sequence
15400   //       because the reciprocal calculation may not have to filter out a
15401   //       0.0 input.
15402   return buildSqrtEstimate(N0, Flags);
15403 }
15404 
15405 /// copysign(x, fp_extend(y)) -> copysign(x, y)
15406 /// copysign(x, fp_round(y)) -> copysign(x, y)
CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode * N)15407 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
15408   SDValue N1 = N->getOperand(1);
15409   if ((N1.getOpcode() == ISD::FP_EXTEND ||
15410        N1.getOpcode() == ISD::FP_ROUND)) {
15411     EVT N1VT = N1->getValueType(0);
15412     EVT N1Op0VT = N1->getOperand(0).getValueType();
15413 
15414     // Always fold no-op FP casts.
15415     if (N1VT == N1Op0VT)
15416       return true;
15417 
15418     // Do not optimize out type conversion of f128 type yet.
15419     // For some targets like x86_64, configuration is changed to keep one f128
15420     // value in one SSE register, but instruction selection cannot handle
15421     // FCOPYSIGN on SSE registers yet.
15422     if (N1Op0VT == MVT::f128)
15423       return false;
15424 
15425     // Avoid mismatched vector operand types, for better instruction selection.
15426     if (N1Op0VT.isVector())
15427       return false;
15428 
15429     return true;
15430   }
15431   return false;
15432 }
15433 
visitFCOPYSIGN(SDNode * N)15434 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
15435   SDValue N0 = N->getOperand(0);
15436   SDValue N1 = N->getOperand(1);
15437   EVT VT = N->getValueType(0);
15438 
15439   // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
15440   if (SDValue C =
15441           DAG.FoldConstantArithmetic(ISD::FCOPYSIGN, SDLoc(N), VT, {N0, N1}))
15442     return C;
15443 
15444   if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) {
15445     const APFloat &V = N1C->getValueAPF();
15446     // copysign(x, c1) -> fabs(x)       iff ispos(c1)
15447     // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
15448     if (!V.isNegative()) {
15449       if (!LegalOperations || TLI.isOperationLegal(ISD::FABS, VT))
15450         return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
15451     } else {
15452       if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
15453         return DAG.getNode(ISD::FNEG, SDLoc(N), VT,
15454                            DAG.getNode(ISD::FABS, SDLoc(N0), VT, N0));
15455     }
15456   }
15457 
15458   // copysign(fabs(x), y) -> copysign(x, y)
15459   // copysign(fneg(x), y) -> copysign(x, y)
15460   // copysign(copysign(x,z), y) -> copysign(x, y)
15461   if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
15462       N0.getOpcode() == ISD::FCOPYSIGN)
15463     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0.getOperand(0), N1);
15464 
15465   // copysign(x, abs(y)) -> abs(x)
15466   if (N1.getOpcode() == ISD::FABS)
15467     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
15468 
15469   // copysign(x, copysign(y,z)) -> copysign(x, z)
15470   if (N1.getOpcode() == ISD::FCOPYSIGN)
15471     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(1));
15472 
15473   // copysign(x, fp_extend(y)) -> copysign(x, y)
15474   // copysign(x, fp_round(y)) -> copysign(x, y)
15475   if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
15476     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(0));
15477 
15478   return SDValue();
15479 }
15480 
visitFPOW(SDNode * N)15481 SDValue DAGCombiner::visitFPOW(SDNode *N) {
15482   ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
15483   if (!ExponentC)
15484     return SDValue();
15485   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15486 
15487   // Try to convert x ** (1/3) into cube root.
15488   // TODO: Handle the various flavors of long double.
15489   // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
15490   //       Some range near 1/3 should be fine.
15491   EVT VT = N->getValueType(0);
15492   if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
15493       (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
15494     // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
15495     // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
15496     // pow(-val, 1/3) =  nan; cbrt(-val) = -num.
15497     // For regular numbers, rounding may cause the results to differ.
15498     // Therefore, we require { nsz ninf nnan afn } for this transform.
15499     // TODO: We could select out the special cases if we don't have nsz/ninf.
15500     SDNodeFlags Flags = N->getFlags();
15501     if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
15502         !Flags.hasApproximateFuncs())
15503       return SDValue();
15504 
15505     // Do not create a cbrt() libcall if the target does not have it, and do not
15506     // turn a pow that has lowering support into a cbrt() libcall.
15507     if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
15508         (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
15509          DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
15510       return SDValue();
15511 
15512     return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0));
15513   }
15514 
15515   // Try to convert x ** (1/4) and x ** (3/4) into square roots.
15516   // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
15517   // TODO: This could be extended (using a target hook) to handle smaller
15518   // power-of-2 fractional exponents.
15519   bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
15520   bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
15521   if (ExponentIs025 || ExponentIs075) {
15522     // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
15523     // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) =  NaN.
15524     // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
15525     // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) =  NaN.
15526     // For regular numbers, rounding may cause the results to differ.
15527     // Therefore, we require { nsz ninf afn } for this transform.
15528     // TODO: We could select out the special cases if we don't have nsz/ninf.
15529     SDNodeFlags Flags = N->getFlags();
15530 
15531     // We only need no signed zeros for the 0.25 case.
15532     if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
15533         !Flags.hasApproximateFuncs())
15534       return SDValue();
15535 
15536     // Don't double the number of libcalls. We are trying to inline fast code.
15537     if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
15538       return SDValue();
15539 
15540     // Assume that libcalls are the smallest code.
15541     // TODO: This restriction should probably be lifted for vectors.
15542     if (ForCodeSize)
15543       return SDValue();
15544 
15545     // pow(X, 0.25) --> sqrt(sqrt(X))
15546     SDLoc DL(N);
15547     SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0));
15548     SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt);
15549     if (ExponentIs025)
15550       return SqrtSqrt;
15551     // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
15552     return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt);
15553   }
15554 
15555   return SDValue();
15556 }
15557 
foldFPToIntToFP(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)15558 static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
15559                                const TargetLowering &TLI) {
15560   // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
15561   // replacing casts with a libcall. We also must be allowed to ignore -0.0
15562   // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
15563   // conversions would return +0.0.
15564   // FIXME: We should be able to use node-level FMF here.
15565   // TODO: If strict math, should we use FABS (+ range check for signed cast)?
15566   EVT VT = N->getValueType(0);
15567   if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
15568       !DAG.getTarget().Options.NoSignedZerosFPMath)
15569     return SDValue();
15570 
15571   // fptosi/fptoui round towards zero, so converting from FP to integer and
15572   // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
15573   SDValue N0 = N->getOperand(0);
15574   if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
15575       N0.getOperand(0).getValueType() == VT)
15576     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
15577 
15578   if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
15579       N0.getOperand(0).getValueType() == VT)
15580     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
15581 
15582   return SDValue();
15583 }
15584 
visitSINT_TO_FP(SDNode * N)15585 SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
15586   SDValue N0 = N->getOperand(0);
15587   EVT VT = N->getValueType(0);
15588   EVT OpVT = N0.getValueType();
15589 
15590   // [us]itofp(undef) = 0, because the result value is bounded.
15591   if (N0.isUndef())
15592     return DAG.getConstantFP(0.0, SDLoc(N), VT);
15593 
15594   // fold (sint_to_fp c1) -> c1fp
15595   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
15596       // ...but only if the target supports immediate floating-point values
15597       (!LegalOperations ||
15598        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
15599     return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
15600 
15601   // If the input is a legal type, and SINT_TO_FP is not legal on this target,
15602   // but UINT_TO_FP is legal on this target, try to convert.
15603   if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
15604       hasOperation(ISD::UINT_TO_FP, OpVT)) {
15605     // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
15606     if (DAG.SignBitIsZero(N0))
15607       return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
15608   }
15609 
15610   // The next optimizations are desirable only if SELECT_CC can be lowered.
15611   // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
15612   if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
15613       !VT.isVector() &&
15614       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
15615     SDLoc DL(N);
15616     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(-1.0, DL, VT),
15617                          DAG.getConstantFP(0.0, DL, VT));
15618   }
15619 
15620   // fold (sint_to_fp (zext (setcc x, y, cc))) ->
15621   //      (select (setcc x, y, cc), 1.0, 0.0)
15622   if (N0.getOpcode() == ISD::ZERO_EXTEND &&
15623       N0.getOperand(0).getOpcode() == ISD::SETCC && !VT.isVector() &&
15624       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
15625     SDLoc DL(N);
15626     return DAG.getSelect(DL, VT, N0.getOperand(0),
15627                          DAG.getConstantFP(1.0, DL, VT),
15628                          DAG.getConstantFP(0.0, DL, VT));
15629   }
15630 
15631   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
15632     return FTrunc;
15633 
15634   return SDValue();
15635 }
15636 
visitUINT_TO_FP(SDNode * N)15637 SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
15638   SDValue N0 = N->getOperand(0);
15639   EVT VT = N->getValueType(0);
15640   EVT OpVT = N0.getValueType();
15641 
15642   // [us]itofp(undef) = 0, because the result value is bounded.
15643   if (N0.isUndef())
15644     return DAG.getConstantFP(0.0, SDLoc(N), VT);
15645 
15646   // fold (uint_to_fp c1) -> c1fp
15647   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
15648       // ...but only if the target supports immediate floating-point values
15649       (!LegalOperations ||
15650        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
15651     return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
15652 
15653   // If the input is a legal type, and UINT_TO_FP is not legal on this target,
15654   // but SINT_TO_FP is legal on this target, try to convert.
15655   if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
15656       hasOperation(ISD::SINT_TO_FP, OpVT)) {
15657     // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
15658     if (DAG.SignBitIsZero(N0))
15659       return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
15660   }
15661 
15662   // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
15663   if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
15664       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
15665     SDLoc DL(N);
15666     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(1.0, DL, VT),
15667                          DAG.getConstantFP(0.0, DL, VT));
15668   }
15669 
15670   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
15671     return FTrunc;
15672 
15673   return SDValue();
15674 }
15675 
15676 // Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
FoldIntToFPToInt(SDNode * N,SelectionDAG & DAG)15677 static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
15678   SDValue N0 = N->getOperand(0);
15679   EVT VT = N->getValueType(0);
15680 
15681   if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
15682     return SDValue();
15683 
15684   SDValue Src = N0.getOperand(0);
15685   EVT SrcVT = Src.getValueType();
15686   bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
15687   bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
15688 
15689   // We can safely assume the conversion won't overflow the output range,
15690   // because (for example) (uint8_t)18293.f is undefined behavior.
15691 
15692   // Since we can assume the conversion won't overflow, our decision as to
15693   // whether the input will fit in the float should depend on the minimum
15694   // of the input range and output range.
15695 
15696   // This means this is also safe for a signed input and unsigned output, since
15697   // a negative input would lead to undefined behavior.
15698   unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
15699   unsigned OutputSize = (int)VT.getScalarSizeInBits();
15700   unsigned ActualSize = std::min(InputSize, OutputSize);
15701   const fltSemantics &sem = DAG.EVTToAPFloatSemantics(N0.getValueType());
15702 
15703   // We can only fold away the float conversion if the input range can be
15704   // represented exactly in the float range.
15705   if (APFloat::semanticsPrecision(sem) >= ActualSize) {
15706     if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
15707       unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
15708                                                        : ISD::ZERO_EXTEND;
15709       return DAG.getNode(ExtOp, SDLoc(N), VT, Src);
15710     }
15711     if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
15712       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Src);
15713     return DAG.getBitcast(VT, Src);
15714   }
15715   return SDValue();
15716 }
15717 
visitFP_TO_SINT(SDNode * N)15718 SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
15719   SDValue N0 = N->getOperand(0);
15720   EVT VT = N->getValueType(0);
15721 
15722   // fold (fp_to_sint undef) -> undef
15723   if (N0.isUndef())
15724     return DAG.getUNDEF(VT);
15725 
15726   // fold (fp_to_sint c1fp) -> c1
15727   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15728     return DAG.getNode(ISD::FP_TO_SINT, SDLoc(N), VT, N0);
15729 
15730   return FoldIntToFPToInt(N, DAG);
15731 }
15732 
visitFP_TO_UINT(SDNode * N)15733 SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
15734   SDValue N0 = N->getOperand(0);
15735   EVT VT = N->getValueType(0);
15736 
15737   // fold (fp_to_uint undef) -> undef
15738   if (N0.isUndef())
15739     return DAG.getUNDEF(VT);
15740 
15741   // fold (fp_to_uint c1fp) -> c1
15742   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15743     return DAG.getNode(ISD::FP_TO_UINT, SDLoc(N), VT, N0);
15744 
15745   return FoldIntToFPToInt(N, DAG);
15746 }
15747 
visitFP_ROUND(SDNode * N)15748 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
15749   SDValue N0 = N->getOperand(0);
15750   SDValue N1 = N->getOperand(1);
15751   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
15752   EVT VT = N->getValueType(0);
15753 
15754   // fold (fp_round c1fp) -> c1fp
15755   if (N0CFP)
15756     return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT, N0, N1);
15757 
15758   // fold (fp_round (fp_extend x)) -> x
15759   if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
15760     return N0.getOperand(0);
15761 
15762   // fold (fp_round (fp_round x)) -> (fp_round x)
15763   if (N0.getOpcode() == ISD::FP_ROUND) {
15764     const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
15765     const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
15766 
15767     // Skip this folding if it results in an fp_round from f80 to f16.
15768     //
15769     // f80 to f16 always generates an expensive (and as yet, unimplemented)
15770     // libcall to __truncxfhf2 instead of selecting native f16 conversion
15771     // instructions from f32 or f64.  Moreover, the first (value-preserving)
15772     // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
15773     // x86.
15774     if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
15775       return SDValue();
15776 
15777     // If the first fp_round isn't a value preserving truncation, it might
15778     // introduce a tie in the second fp_round, that wouldn't occur in the
15779     // single-step fp_round we want to fold to.
15780     // In other words, double rounding isn't the same as rounding.
15781     // Also, this is a value preserving truncation iff both fp_round's are.
15782     if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
15783       SDLoc DL(N);
15784       return DAG.getNode(ISD::FP_ROUND, DL, VT, N0.getOperand(0),
15785                          DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL));
15786     }
15787   }
15788 
15789   // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
15790   if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse()) {
15791     SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
15792                               N0.getOperand(0), N1);
15793     AddToWorklist(Tmp.getNode());
15794     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
15795                        Tmp, N0.getOperand(1));
15796   }
15797 
15798   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
15799     return NewVSel;
15800 
15801   return SDValue();
15802 }
15803 
visitFP_EXTEND(SDNode * N)15804 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
15805   SDValue N0 = N->getOperand(0);
15806   EVT VT = N->getValueType(0);
15807 
15808   // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
15809   if (N->hasOneUse() &&
15810       N->use_begin()->getOpcode() == ISD::FP_ROUND)
15811     return SDValue();
15812 
15813   // fold (fp_extend c1fp) -> c1fp
15814   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15815     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, N0);
15816 
15817   // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
15818   if (N0.getOpcode() == ISD::FP16_TO_FP &&
15819       TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
15820     return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), VT, N0.getOperand(0));
15821 
15822   // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
15823   // value of X.
15824   if (N0.getOpcode() == ISD::FP_ROUND
15825       && N0.getConstantOperandVal(1) == 1) {
15826     SDValue In = N0.getOperand(0);
15827     if (In.getValueType() == VT) return In;
15828     if (VT.bitsLT(In.getValueType()))
15829       return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT,
15830                          In, N0.getOperand(1));
15831     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, In);
15832   }
15833 
15834   // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
15835   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
15836       TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) {
15837     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
15838     SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
15839                                      LN0->getChain(),
15840                                      LN0->getBasePtr(), N0.getValueType(),
15841                                      LN0->getMemOperand());
15842     CombineTo(N, ExtLoad);
15843     CombineTo(N0.getNode(),
15844               DAG.getNode(ISD::FP_ROUND, SDLoc(N0),
15845                           N0.getValueType(), ExtLoad,
15846                           DAG.getIntPtrConstant(1, SDLoc(N0))),
15847               ExtLoad.getValue(1));
15848     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
15849   }
15850 
15851   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
15852     return NewVSel;
15853 
15854   return SDValue();
15855 }
15856 
visitFCEIL(SDNode * N)15857 SDValue DAGCombiner::visitFCEIL(SDNode *N) {
15858   SDValue N0 = N->getOperand(0);
15859   EVT VT = N->getValueType(0);
15860 
15861   // fold (fceil c1) -> fceil(c1)
15862   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15863     return DAG.getNode(ISD::FCEIL, SDLoc(N), VT, N0);
15864 
15865   return SDValue();
15866 }
15867 
visitFTRUNC(SDNode * N)15868 SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
15869   SDValue N0 = N->getOperand(0);
15870   EVT VT = N->getValueType(0);
15871 
15872   // fold (ftrunc c1) -> ftrunc(c1)
15873   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15874     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0);
15875 
15876   // fold ftrunc (known rounded int x) -> x
15877   // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
15878   // likely to be generated to extract integer from a rounded floating value.
15879   switch (N0.getOpcode()) {
15880   default: break;
15881   case ISD::FRINT:
15882   case ISD::FTRUNC:
15883   case ISD::FNEARBYINT:
15884   case ISD::FFLOOR:
15885   case ISD::FCEIL:
15886     return N0;
15887   }
15888 
15889   return SDValue();
15890 }
15891 
visitFFLOOR(SDNode * N)15892 SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
15893   SDValue N0 = N->getOperand(0);
15894   EVT VT = N->getValueType(0);
15895 
15896   // fold (ffloor c1) -> ffloor(c1)
15897   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15898     return DAG.getNode(ISD::FFLOOR, SDLoc(N), VT, N0);
15899 
15900   return SDValue();
15901 }
15902 
visitFNEG(SDNode * N)15903 SDValue DAGCombiner::visitFNEG(SDNode *N) {
15904   SDValue N0 = N->getOperand(0);
15905   EVT VT = N->getValueType(0);
15906   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15907 
15908   // Constant fold FNEG.
15909   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15910     return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0);
15911 
15912   if (SDValue NegN0 =
15913           TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize))
15914     return NegN0;
15915 
15916   // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
15917   // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
15918   // know it was called from a context with a nsz flag if the input fsub does
15919   // not.
15920   if (N0.getOpcode() == ISD::FSUB &&
15921       (DAG.getTarget().Options.NoSignedZerosFPMath ||
15922        N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
15923     return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
15924                        N0.getOperand(0));
15925   }
15926 
15927   if (SDValue Cast = foldSignChangeInBitcast(N))
15928     return Cast;
15929 
15930   return SDValue();
15931 }
15932 
visitFMinMax(SDNode * N)15933 SDValue DAGCombiner::visitFMinMax(SDNode *N) {
15934   SDValue N0 = N->getOperand(0);
15935   SDValue N1 = N->getOperand(1);
15936   EVT VT = N->getValueType(0);
15937   const SDNodeFlags Flags = N->getFlags();
15938   unsigned Opc = N->getOpcode();
15939   bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
15940   bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
15941   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15942 
15943   // Constant fold.
15944   if (SDValue C = DAG.FoldConstantArithmetic(Opc, SDLoc(N), VT, {N0, N1}))
15945     return C;
15946 
15947   // Canonicalize to constant on RHS.
15948   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
15949       !DAG.isConstantFPBuildVectorOrConstantFP(N1))
15950     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
15951 
15952   if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1)) {
15953     const APFloat &AF = N1CFP->getValueAPF();
15954 
15955     // minnum(X, nan) -> X
15956     // maxnum(X, nan) -> X
15957     // minimum(X, nan) -> nan
15958     // maximum(X, nan) -> nan
15959     if (AF.isNaN())
15960       return PropagatesNaN ? N->getOperand(1) : N->getOperand(0);
15961 
15962     // In the following folds, inf can be replaced with the largest finite
15963     // float, if the ninf flag is set.
15964     if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
15965       // minnum(X, -inf) -> -inf
15966       // maxnum(X, +inf) -> +inf
15967       // minimum(X, -inf) -> -inf if nnan
15968       // maximum(X, +inf) -> +inf if nnan
15969       if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
15970         return N->getOperand(1);
15971 
15972       // minnum(X, +inf) -> X if nnan
15973       // maxnum(X, -inf) -> X if nnan
15974       // minimum(X, +inf) -> X
15975       // maximum(X, -inf) -> X
15976       if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
15977         return N->getOperand(0);
15978     }
15979   }
15980 
15981   return SDValue();
15982 }
15983 
visitFABS(SDNode * N)15984 SDValue DAGCombiner::visitFABS(SDNode *N) {
15985   SDValue N0 = N->getOperand(0);
15986   EVT VT = N->getValueType(0);
15987 
15988   // fold (fabs c1) -> fabs(c1)
15989   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15990     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
15991 
15992   // fold (fabs (fabs x)) -> (fabs x)
15993   if (N0.getOpcode() == ISD::FABS)
15994     return N->getOperand(0);
15995 
15996   // fold (fabs (fneg x)) -> (fabs x)
15997   // fold (fabs (fcopysign x, y)) -> (fabs x)
15998   if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
15999     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0));
16000 
16001   if (SDValue Cast = foldSignChangeInBitcast(N))
16002     return Cast;
16003 
16004   return SDValue();
16005 }
16006 
visitBRCOND(SDNode * N)16007 SDValue DAGCombiner::visitBRCOND(SDNode *N) {
16008   SDValue Chain = N->getOperand(0);
16009   SDValue N1 = N->getOperand(1);
16010   SDValue N2 = N->getOperand(2);
16011 
16012   // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
16013   // nondeterministic jumps).
16014   if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
16015     return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
16016                        N1->getOperand(0), N2);
16017   }
16018 
16019   // If N is a constant we could fold this into a fallthrough or unconditional
16020   // branch. However that doesn't happen very often in normal code, because
16021   // Instcombine/SimplifyCFG should have handled the available opportunities.
16022   // If we did this folding here, it would be necessary to update the
16023   // MachineBasicBlock CFG, which is awkward.
16024 
16025   // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
16026   // on the target.
16027   if (N1.getOpcode() == ISD::SETCC &&
16028       TLI.isOperationLegalOrCustom(ISD::BR_CC,
16029                                    N1.getOperand(0).getValueType())) {
16030     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
16031                        Chain, N1.getOperand(2),
16032                        N1.getOperand(0), N1.getOperand(1), N2);
16033   }
16034 
16035   if (N1.hasOneUse()) {
16036     // rebuildSetCC calls visitXor which may change the Chain when there is a
16037     // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
16038     HandleSDNode ChainHandle(Chain);
16039     if (SDValue NewN1 = rebuildSetCC(N1))
16040       return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
16041                          ChainHandle.getValue(), NewN1, N2);
16042   }
16043 
16044   return SDValue();
16045 }
16046 
rebuildSetCC(SDValue N)16047 SDValue DAGCombiner::rebuildSetCC(SDValue N) {
16048   if (N.getOpcode() == ISD::SRL ||
16049       (N.getOpcode() == ISD::TRUNCATE &&
16050        (N.getOperand(0).hasOneUse() &&
16051         N.getOperand(0).getOpcode() == ISD::SRL))) {
16052     // Look pass the truncate.
16053     if (N.getOpcode() == ISD::TRUNCATE)
16054       N = N.getOperand(0);
16055 
16056     // Match this pattern so that we can generate simpler code:
16057     //
16058     //   %a = ...
16059     //   %b = and i32 %a, 2
16060     //   %c = srl i32 %b, 1
16061     //   brcond i32 %c ...
16062     //
16063     // into
16064     //
16065     //   %a = ...
16066     //   %b = and i32 %a, 2
16067     //   %c = setcc eq %b, 0
16068     //   brcond %c ...
16069     //
16070     // This applies only when the AND constant value has one bit set and the
16071     // SRL constant is equal to the log2 of the AND constant. The back-end is
16072     // smart enough to convert the result into a TEST/JMP sequence.
16073     SDValue Op0 = N.getOperand(0);
16074     SDValue Op1 = N.getOperand(1);
16075 
16076     if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
16077       SDValue AndOp1 = Op0.getOperand(1);
16078 
16079       if (AndOp1.getOpcode() == ISD::Constant) {
16080         const APInt &AndConst = cast<ConstantSDNode>(AndOp1)->getAPIntValue();
16081 
16082         if (AndConst.isPowerOf2() &&
16083             cast<ConstantSDNode>(Op1)->getAPIntValue() == AndConst.logBase2()) {
16084           SDLoc DL(N);
16085           return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
16086                               Op0, DAG.getConstant(0, DL, Op0.getValueType()),
16087                               ISD::SETNE);
16088         }
16089       }
16090     }
16091   }
16092 
16093   // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
16094   // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
16095   if (N.getOpcode() == ISD::XOR) {
16096     // Because we may call this on a speculatively constructed
16097     // SimplifiedSetCC Node, we need to simplify this node first.
16098     // Ideally this should be folded into SimplifySetCC and not
16099     // here. For now, grab a handle to N so we don't lose it from
16100     // replacements interal to the visit.
16101     HandleSDNode XORHandle(N);
16102     while (N.getOpcode() == ISD::XOR) {
16103       SDValue Tmp = visitXOR(N.getNode());
16104       // No simplification done.
16105       if (!Tmp.getNode())
16106         break;
16107       // Returning N is form in-visit replacement that may invalidated
16108       // N. Grab value from Handle.
16109       if (Tmp.getNode() == N.getNode())
16110         N = XORHandle.getValue();
16111       else // Node simplified. Try simplifying again.
16112         N = Tmp;
16113     }
16114 
16115     if (N.getOpcode() != ISD::XOR)
16116       return N;
16117 
16118     SDValue Op0 = N->getOperand(0);
16119     SDValue Op1 = N->getOperand(1);
16120 
16121     if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
16122       bool Equal = false;
16123       // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
16124       if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
16125           Op0.getValueType() == MVT::i1) {
16126         N = Op0;
16127         Op0 = N->getOperand(0);
16128         Op1 = N->getOperand(1);
16129         Equal = true;
16130       }
16131 
16132       EVT SetCCVT = N.getValueType();
16133       if (LegalTypes)
16134         SetCCVT = getSetCCResultType(SetCCVT);
16135       // Replace the uses of XOR with SETCC
16136       return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1,
16137                           Equal ? ISD::SETEQ : ISD::SETNE);
16138     }
16139   }
16140 
16141   return SDValue();
16142 }
16143 
16144 // Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
16145 //
visitBR_CC(SDNode * N)16146 SDValue DAGCombiner::visitBR_CC(SDNode *N) {
16147   CondCodeSDNode *CC = cast<CondCodeSDNode>(N->getOperand(1));
16148   SDValue CondLHS = N->getOperand(2), CondRHS = N->getOperand(3);
16149 
16150   // If N is a constant we could fold this into a fallthrough or unconditional
16151   // branch. However that doesn't happen very often in normal code, because
16152   // Instcombine/SimplifyCFG should have handled the available opportunities.
16153   // If we did this folding here, it would be necessary to update the
16154   // MachineBasicBlock CFG, which is awkward.
16155 
16156   // Use SimplifySetCC to simplify SETCC's.
16157   SDValue Simp = SimplifySetCC(getSetCCResultType(CondLHS.getValueType()),
16158                                CondLHS, CondRHS, CC->get(), SDLoc(N),
16159                                false);
16160   if (Simp.getNode()) AddToWorklist(Simp.getNode());
16161 
16162   // fold to a simpler setcc
16163   if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
16164     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
16165                        N->getOperand(0), Simp.getOperand(2),
16166                        Simp.getOperand(0), Simp.getOperand(1),
16167                        N->getOperand(4));
16168 
16169   return SDValue();
16170 }
16171 
getCombineLoadStoreParts(SDNode * N,unsigned Inc,unsigned Dec,bool & IsLoad,bool & IsMasked,SDValue & Ptr,const TargetLowering & TLI)16172 static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
16173                                      bool &IsLoad, bool &IsMasked, SDValue &Ptr,
16174                                      const TargetLowering &TLI) {
16175   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
16176     if (LD->isIndexed())
16177       return false;
16178     EVT VT = LD->getMemoryVT();
16179     if (!TLI.isIndexedLoadLegal(Inc, VT) && !TLI.isIndexedLoadLegal(Dec, VT))
16180       return false;
16181     Ptr = LD->getBasePtr();
16182   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
16183     if (ST->isIndexed())
16184       return false;
16185     EVT VT = ST->getMemoryVT();
16186     if (!TLI.isIndexedStoreLegal(Inc, VT) && !TLI.isIndexedStoreLegal(Dec, VT))
16187       return false;
16188     Ptr = ST->getBasePtr();
16189     IsLoad = false;
16190   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) {
16191     if (LD->isIndexed())
16192       return false;
16193     EVT VT = LD->getMemoryVT();
16194     if (!TLI.isIndexedMaskedLoadLegal(Inc, VT) &&
16195         !TLI.isIndexedMaskedLoadLegal(Dec, VT))
16196       return false;
16197     Ptr = LD->getBasePtr();
16198     IsMasked = true;
16199   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) {
16200     if (ST->isIndexed())
16201       return false;
16202     EVT VT = ST->getMemoryVT();
16203     if (!TLI.isIndexedMaskedStoreLegal(Inc, VT) &&
16204         !TLI.isIndexedMaskedStoreLegal(Dec, VT))
16205       return false;
16206     Ptr = ST->getBasePtr();
16207     IsLoad = false;
16208     IsMasked = true;
16209   } else {
16210     return false;
16211   }
16212   return true;
16213 }
16214 
16215 /// Try turning a load/store into a pre-indexed load/store when the base
16216 /// pointer is an add or subtract and it has other uses besides the load/store.
16217 /// After the transformation, the new indexed load/store has effectively folded
16218 /// the add/subtract in and all of its other uses are redirected to the
16219 /// new load/store.
CombineToPreIndexedLoadStore(SDNode * N)16220 bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
16221   if (Level < AfterLegalizeDAG)
16222     return false;
16223 
16224   bool IsLoad = true;
16225   bool IsMasked = false;
16226   SDValue Ptr;
16227   if (!getCombineLoadStoreParts(N, ISD::PRE_INC, ISD::PRE_DEC, IsLoad, IsMasked,
16228                                 Ptr, TLI))
16229     return false;
16230 
16231   // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
16232   // out.  There is no reason to make this a preinc/predec.
16233   if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
16234       Ptr->hasOneUse())
16235     return false;
16236 
16237   // Ask the target to do addressing mode selection.
16238   SDValue BasePtr;
16239   SDValue Offset;
16240   ISD::MemIndexedMode AM = ISD::UNINDEXED;
16241   if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
16242     return false;
16243 
16244   // Backends without true r+i pre-indexed forms may need to pass a
16245   // constant base with a variable offset so that constant coercion
16246   // will work with the patterns in canonical form.
16247   bool Swapped = false;
16248   if (isa<ConstantSDNode>(BasePtr)) {
16249     std::swap(BasePtr, Offset);
16250     Swapped = true;
16251   }
16252 
16253   // Don't create a indexed load / store with zero offset.
16254   if (isNullConstant(Offset))
16255     return false;
16256 
16257   // Try turning it into a pre-indexed load / store except when:
16258   // 1) The new base ptr is a frame index.
16259   // 2) If N is a store and the new base ptr is either the same as or is a
16260   //    predecessor of the value being stored.
16261   // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
16262   //    that would create a cycle.
16263   // 4) All uses are load / store ops that use it as old base ptr.
16264 
16265   // Check #1.  Preinc'ing a frame index would require copying the stack pointer
16266   // (plus the implicit offset) to a register to preinc anyway.
16267   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
16268     return false;
16269 
16270   // Check #2.
16271   if (!IsLoad) {
16272     SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(N)->getValue()
16273                            : cast<StoreSDNode>(N)->getValue();
16274 
16275     // Would require a copy.
16276     if (Val == BasePtr)
16277       return false;
16278 
16279     // Would create a cycle.
16280     if (Val == Ptr || Ptr->isPredecessorOf(Val.getNode()))
16281       return false;
16282   }
16283 
16284   // Caches for hasPredecessorHelper.
16285   SmallPtrSet<const SDNode *, 32> Visited;
16286   SmallVector<const SDNode *, 16> Worklist;
16287   Worklist.push_back(N);
16288 
16289   // If the offset is a constant, there may be other adds of constants that
16290   // can be folded with this one. We should do this to avoid having to keep
16291   // a copy of the original base pointer.
16292   SmallVector<SDNode *, 16> OtherUses;
16293   if (isa<ConstantSDNode>(Offset))
16294     for (SDNode::use_iterator UI = BasePtr->use_begin(),
16295                               UE = BasePtr->use_end();
16296          UI != UE; ++UI) {
16297       SDUse &Use = UI.getUse();
16298       // Skip the use that is Ptr and uses of other results from BasePtr's
16299       // node (important for nodes that return multiple results).
16300       if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
16301         continue;
16302 
16303       if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist))
16304         continue;
16305 
16306       if (Use.getUser()->getOpcode() != ISD::ADD &&
16307           Use.getUser()->getOpcode() != ISD::SUB) {
16308         OtherUses.clear();
16309         break;
16310       }
16311 
16312       SDValue Op1 = Use.getUser()->getOperand((UI.getOperandNo() + 1) & 1);
16313       if (!isa<ConstantSDNode>(Op1)) {
16314         OtherUses.clear();
16315         break;
16316       }
16317 
16318       // FIXME: In some cases, we can be smarter about this.
16319       if (Op1.getValueType() != Offset.getValueType()) {
16320         OtherUses.clear();
16321         break;
16322       }
16323 
16324       OtherUses.push_back(Use.getUser());
16325     }
16326 
16327   if (Swapped)
16328     std::swap(BasePtr, Offset);
16329 
16330   // Now check for #3 and #4.
16331   bool RealUse = false;
16332 
16333   for (SDNode *Use : Ptr->uses()) {
16334     if (Use == N)
16335       continue;
16336     if (SDNode::hasPredecessorHelper(Use, Visited, Worklist))
16337       return false;
16338 
16339     // If Ptr may be folded in addressing mode of other use, then it's
16340     // not profitable to do this transformation.
16341     if (!canFoldInAddressingMode(Ptr.getNode(), Use, DAG, TLI))
16342       RealUse = true;
16343   }
16344 
16345   if (!RealUse)
16346     return false;
16347 
16348   SDValue Result;
16349   if (!IsMasked) {
16350     if (IsLoad)
16351       Result = DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
16352     else
16353       Result =
16354           DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
16355   } else {
16356     if (IsLoad)
16357       Result = DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
16358                                         Offset, AM);
16359     else
16360       Result = DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr,
16361                                          Offset, AM);
16362   }
16363   ++PreIndexedNodes;
16364   ++NodesCombined;
16365   LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
16366              Result.dump(&DAG); dbgs() << '\n');
16367   WorklistRemover DeadNodes(*this);
16368   if (IsLoad) {
16369     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
16370     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
16371   } else {
16372     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
16373   }
16374 
16375   // Finally, since the node is now dead, remove it from the graph.
16376   deleteAndRecombine(N);
16377 
16378   if (Swapped)
16379     std::swap(BasePtr, Offset);
16380 
16381   // Replace other uses of BasePtr that can be updated to use Ptr
16382   for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
16383     unsigned OffsetIdx = 1;
16384     if (OtherUses[i]->getOperand(OffsetIdx).getNode() == BasePtr.getNode())
16385       OffsetIdx = 0;
16386     assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
16387            BasePtr.getNode() && "Expected BasePtr operand");
16388 
16389     // We need to replace ptr0 in the following expression:
16390     //   x0 * offset0 + y0 * ptr0 = t0
16391     // knowing that
16392     //   x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
16393     //
16394     // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
16395     // indexed load/store and the expression that needs to be re-written.
16396     //
16397     // Therefore, we have:
16398     //   t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
16399 
16400     auto *CN = cast<ConstantSDNode>(OtherUses[i]->getOperand(OffsetIdx));
16401     const APInt &Offset0 = CN->getAPIntValue();
16402     const APInt &Offset1 = cast<ConstantSDNode>(Offset)->getAPIntValue();
16403     int X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
16404     int Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
16405     int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
16406     int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
16407 
16408     unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
16409 
16410     APInt CNV = Offset0;
16411     if (X0 < 0) CNV = -CNV;
16412     if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
16413     else CNV = CNV - Offset1;
16414 
16415     SDLoc DL(OtherUses[i]);
16416 
16417     // We can now generate the new expression.
16418     SDValue NewOp1 = DAG.getConstant(CNV, DL, CN->getValueType(0));
16419     SDValue NewOp2 = Result.getValue(IsLoad ? 1 : 0);
16420 
16421     SDValue NewUse = DAG.getNode(Opcode,
16422                                  DL,
16423                                  OtherUses[i]->getValueType(0), NewOp1, NewOp2);
16424     DAG.ReplaceAllUsesOfValueWith(SDValue(OtherUses[i], 0), NewUse);
16425     deleteAndRecombine(OtherUses[i]);
16426   }
16427 
16428   // Replace the uses of Ptr with uses of the updated base value.
16429   DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(IsLoad ? 1 : 0));
16430   deleteAndRecombine(Ptr.getNode());
16431   AddToWorklist(Result.getNode());
16432 
16433   return true;
16434 }
16435 
shouldCombineToPostInc(SDNode * N,SDValue Ptr,SDNode * PtrUse,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)16436 static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
16437                                    SDValue &BasePtr, SDValue &Offset,
16438                                    ISD::MemIndexedMode &AM,
16439                                    SelectionDAG &DAG,
16440                                    const TargetLowering &TLI) {
16441   if (PtrUse == N ||
16442       (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
16443     return false;
16444 
16445   if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
16446     return false;
16447 
16448   // Don't create a indexed load / store with zero offset.
16449   if (isNullConstant(Offset))
16450     return false;
16451 
16452   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
16453     return false;
16454 
16455   SmallPtrSet<const SDNode *, 32> Visited;
16456   for (SDNode *Use : BasePtr->uses()) {
16457     if (Use == Ptr.getNode())
16458       continue;
16459 
16460     // No if there's a later user which could perform the index instead.
16461     if (isa<MemSDNode>(Use)) {
16462       bool IsLoad = true;
16463       bool IsMasked = false;
16464       SDValue OtherPtr;
16465       if (getCombineLoadStoreParts(Use, ISD::POST_INC, ISD::POST_DEC, IsLoad,
16466                                    IsMasked, OtherPtr, TLI)) {
16467         SmallVector<const SDNode *, 2> Worklist;
16468         Worklist.push_back(Use);
16469         if (SDNode::hasPredecessorHelper(N, Visited, Worklist))
16470           return false;
16471       }
16472     }
16473 
16474     // If all the uses are load / store addresses, then don't do the
16475     // transformation.
16476     if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
16477       for (SDNode *UseUse : Use->uses())
16478         if (canFoldInAddressingMode(Use, UseUse, DAG, TLI))
16479           return false;
16480     }
16481   }
16482   return true;
16483 }
16484 
getPostIndexedLoadStoreOp(SDNode * N,bool & IsLoad,bool & IsMasked,SDValue & Ptr,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)16485 static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
16486                                          bool &IsMasked, SDValue &Ptr,
16487                                          SDValue &BasePtr, SDValue &Offset,
16488                                          ISD::MemIndexedMode &AM,
16489                                          SelectionDAG &DAG,
16490                                          const TargetLowering &TLI) {
16491   if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad,
16492                                 IsMasked, Ptr, TLI) ||
16493       Ptr->hasOneUse())
16494     return nullptr;
16495 
16496   // Try turning it into a post-indexed load / store except when
16497   // 1) All uses are load / store ops that use it as base ptr (and
16498   //    it may be folded as addressing mmode).
16499   // 2) Op must be independent of N, i.e. Op is neither a predecessor
16500   //    nor a successor of N. Otherwise, if Op is folded that would
16501   //    create a cycle.
16502   for (SDNode *Op : Ptr->uses()) {
16503     // Check for #1.
16504     if (!shouldCombineToPostInc(N, Ptr, Op, BasePtr, Offset, AM, DAG, TLI))
16505       continue;
16506 
16507     // Check for #2.
16508     SmallPtrSet<const SDNode *, 32> Visited;
16509     SmallVector<const SDNode *, 8> Worklist;
16510     // Ptr is predecessor to both N and Op.
16511     Visited.insert(Ptr.getNode());
16512     Worklist.push_back(N);
16513     Worklist.push_back(Op);
16514     if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) &&
16515         !SDNode::hasPredecessorHelper(Op, Visited, Worklist))
16516       return Op;
16517   }
16518   return nullptr;
16519 }
16520 
16521 /// Try to combine a load/store with a add/sub of the base pointer node into a
16522 /// post-indexed load/store. The transformation folded the add/subtract into the
16523 /// new indexed load/store effectively and all of its uses are redirected to the
16524 /// new load/store.
CombineToPostIndexedLoadStore(SDNode * N)16525 bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
16526   if (Level < AfterLegalizeDAG)
16527     return false;
16528 
16529   bool IsLoad = true;
16530   bool IsMasked = false;
16531   SDValue Ptr;
16532   SDValue BasePtr;
16533   SDValue Offset;
16534   ISD::MemIndexedMode AM = ISD::UNINDEXED;
16535   SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
16536                                          Offset, AM, DAG, TLI);
16537   if (!Op)
16538     return false;
16539 
16540   SDValue Result;
16541   if (!IsMasked)
16542     Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
16543                                          Offset, AM)
16544                     : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N),
16545                                           BasePtr, Offset, AM);
16546   else
16547     Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N),
16548                                                BasePtr, Offset, AM)
16549                     : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N),
16550                                                 BasePtr, Offset, AM);
16551   ++PostIndexedNodes;
16552   ++NodesCombined;
16553   LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
16554              Result.dump(&DAG); dbgs() << '\n');
16555   WorklistRemover DeadNodes(*this);
16556   if (IsLoad) {
16557     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
16558     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
16559   } else {
16560     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
16561   }
16562 
16563   // Finally, since the node is now dead, remove it from the graph.
16564   deleteAndRecombine(N);
16565 
16566   // Replace the uses of Use with uses of the updated base value.
16567   DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0),
16568                                 Result.getValue(IsLoad ? 1 : 0));
16569   deleteAndRecombine(Op);
16570   return true;
16571 }
16572 
16573 /// Return the base-pointer arithmetic from an indexed \p LD.
SplitIndexingFromLoad(LoadSDNode * LD)16574 SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
16575   ISD::MemIndexedMode AM = LD->getAddressingMode();
16576   assert(AM != ISD::UNINDEXED);
16577   SDValue BP = LD->getOperand(1);
16578   SDValue Inc = LD->getOperand(2);
16579 
16580   // Some backends use TargetConstants for load offsets, but don't expect
16581   // TargetConstants in general ADD nodes. We can convert these constants into
16582   // regular Constants (if the constant is not opaque).
16583   assert((Inc.getOpcode() != ISD::TargetConstant ||
16584           !cast<ConstantSDNode>(Inc)->isOpaque()) &&
16585          "Cannot split out indexing using opaque target constants");
16586   if (Inc.getOpcode() == ISD::TargetConstant) {
16587     ConstantSDNode *ConstInc = cast<ConstantSDNode>(Inc);
16588     Inc = DAG.getConstant(*ConstInc->getConstantIntValue(), SDLoc(Inc),
16589                           ConstInc->getValueType(0));
16590   }
16591 
16592   unsigned Opc =
16593       (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
16594   return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
16595 }
16596 
numVectorEltsOrZero(EVT T)16597 static inline ElementCount numVectorEltsOrZero(EVT T) {
16598   return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(0);
16599 }
16600 
getTruncatedStoreValue(StoreSDNode * ST,SDValue & Val)16601 bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
16602   Val = ST->getValue();
16603   EVT STType = Val.getValueType();
16604   EVT STMemType = ST->getMemoryVT();
16605   if (STType == STMemType)
16606     return true;
16607   if (isTypeLegal(STMemType))
16608     return false; // fail.
16609   if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
16610       TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
16611     Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
16612     return true;
16613   }
16614   if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
16615       STType.isInteger() && STMemType.isInteger()) {
16616     Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
16617     return true;
16618   }
16619   if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
16620     Val = DAG.getBitcast(STMemType, Val);
16621     return true;
16622   }
16623   return false; // fail.
16624 }
16625 
extendLoadedValueToExtension(LoadSDNode * LD,SDValue & Val)16626 bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
16627   EVT LDMemType = LD->getMemoryVT();
16628   EVT LDType = LD->getValueType(0);
16629   assert(Val.getValueType() == LDMemType &&
16630          "Attempting to extend value of non-matching type");
16631   if (LDType == LDMemType)
16632     return true;
16633   if (LDMemType.isInteger() && LDType.isInteger()) {
16634     switch (LD->getExtensionType()) {
16635     case ISD::NON_EXTLOAD:
16636       Val = DAG.getBitcast(LDType, Val);
16637       return true;
16638     case ISD::EXTLOAD:
16639       Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
16640       return true;
16641     case ISD::SEXTLOAD:
16642       Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
16643       return true;
16644     case ISD::ZEXTLOAD:
16645       Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
16646       return true;
16647     }
16648   }
16649   return false;
16650 }
16651 
ForwardStoreValueToDirectLoad(LoadSDNode * LD)16652 SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
16653   if (OptLevel == CodeGenOpt::None || !LD->isSimple())
16654     return SDValue();
16655   SDValue Chain = LD->getOperand(0);
16656   StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode());
16657   // TODO: Relax this restriction for unordered atomics (see D66309)
16658   if (!ST || !ST->isSimple())
16659     return SDValue();
16660 
16661   EVT LDType = LD->getValueType(0);
16662   EVT LDMemType = LD->getMemoryVT();
16663   EVT STMemType = ST->getMemoryVT();
16664   EVT STType = ST->getValue().getValueType();
16665 
16666   // There are two cases to consider here:
16667   //  1. The store is fixed width and the load is scalable. In this case we
16668   //     don't know at compile time if the store completely envelops the load
16669   //     so we abandon the optimisation.
16670   //  2. The store is scalable and the load is fixed width. We could
16671   //     potentially support a limited number of cases here, but there has been
16672   //     no cost-benefit analysis to prove it's worth it.
16673   bool LdStScalable = LDMemType.isScalableVector();
16674   if (LdStScalable != STMemType.isScalableVector())
16675     return SDValue();
16676 
16677   // If we are dealing with scalable vectors on a big endian platform the
16678   // calculation of offsets below becomes trickier, since we do not know at
16679   // compile time the absolute size of the vector. Until we've done more
16680   // analysis on big-endian platforms it seems better to bail out for now.
16681   if (LdStScalable && DAG.getDataLayout().isBigEndian())
16682     return SDValue();
16683 
16684   BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
16685   BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG);
16686   int64_t Offset;
16687   if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
16688     return SDValue();
16689 
16690   // Normalize for Endianness. After this Offset=0 will denote that the least
16691   // significant bit in the loaded value maps to the least significant bit in
16692   // the stored value). With Offset=n (for n > 0) the loaded value starts at the
16693   // n:th least significant byte of the stored value.
16694   if (DAG.getDataLayout().isBigEndian())
16695     Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedSize() -
16696               (int64_t)LDMemType.getStoreSizeInBits().getFixedSize()) /
16697                  8 -
16698              Offset;
16699 
16700   // Check that the stored value cover all bits that are loaded.
16701   bool STCoversLD;
16702 
16703   TypeSize LdMemSize = LDMemType.getSizeInBits();
16704   TypeSize StMemSize = STMemType.getSizeInBits();
16705   if (LdStScalable)
16706     STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
16707   else
16708     STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedSize() <=
16709                                    StMemSize.getFixedSize());
16710 
16711   auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
16712     if (LD->isIndexed()) {
16713       // Cannot handle opaque target constants and we must respect the user's
16714       // request not to split indexes from loads.
16715       if (!canSplitIdx(LD))
16716         return SDValue();
16717       SDValue Idx = SplitIndexingFromLoad(LD);
16718       SDValue Ops[] = {Val, Idx, Chain};
16719       return CombineTo(LD, Ops, 3);
16720     }
16721     return CombineTo(LD, Val, Chain);
16722   };
16723 
16724   if (!STCoversLD)
16725     return SDValue();
16726 
16727   // Memory as copy space (potentially masked).
16728   if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
16729     // Simple case: Direct non-truncating forwarding
16730     if (LDType.getSizeInBits() == LdMemSize)
16731       return ReplaceLd(LD, ST->getValue(), Chain);
16732     // Can we model the truncate and extension with an and mask?
16733     if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
16734         !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
16735       // Mask to size of LDMemType
16736       auto Mask =
16737           DAG.getConstant(APInt::getLowBitsSet(STType.getFixedSizeInBits(),
16738                                                StMemSize.getFixedSize()),
16739                           SDLoc(ST), STType);
16740       auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
16741       return ReplaceLd(LD, Val, Chain);
16742     }
16743   }
16744 
16745   // TODO: Deal with nonzero offset.
16746   if (LD->getBasePtr().isUndef() || Offset != 0)
16747     return SDValue();
16748   // Model necessary truncations / extenstions.
16749   SDValue Val;
16750   // Truncate Value To Stored Memory Size.
16751   do {
16752     if (!getTruncatedStoreValue(ST, Val))
16753       continue;
16754     if (!isTypeLegal(LDMemType))
16755       continue;
16756     if (STMemType != LDMemType) {
16757       // TODO: Support vectors? This requires extract_subvector/bitcast.
16758       if (!STMemType.isVector() && !LDMemType.isVector() &&
16759           STMemType.isInteger() && LDMemType.isInteger())
16760         Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
16761       else
16762         continue;
16763     }
16764     if (!extendLoadedValueToExtension(LD, Val))
16765       continue;
16766     return ReplaceLd(LD, Val, Chain);
16767   } while (false);
16768 
16769   // On failure, cleanup dead nodes we may have created.
16770   if (Val->use_empty())
16771     deleteAndRecombine(Val.getNode());
16772   return SDValue();
16773 }
16774 
visitLOAD(SDNode * N)16775 SDValue DAGCombiner::visitLOAD(SDNode *N) {
16776   LoadSDNode *LD  = cast<LoadSDNode>(N);
16777   SDValue Chain = LD->getChain();
16778   SDValue Ptr   = LD->getBasePtr();
16779 
16780   // If load is not volatile and there are no uses of the loaded value (and
16781   // the updated indexed value in case of indexed loads), change uses of the
16782   // chain value into uses of the chain input (i.e. delete the dead load).
16783   // TODO: Allow this for unordered atomics (see D66309)
16784   if (LD->isSimple()) {
16785     if (N->getValueType(1) == MVT::Other) {
16786       // Unindexed loads.
16787       if (!N->hasAnyUseOfValue(0)) {
16788         // It's not safe to use the two value CombineTo variant here. e.g.
16789         // v1, chain2 = load chain1, loc
16790         // v2, chain3 = load chain2, loc
16791         // v3         = add v2, c
16792         // Now we replace use of chain2 with chain1.  This makes the second load
16793         // isomorphic to the one we are deleting, and thus makes this load live.
16794         LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
16795                    dbgs() << "\nWith chain: "; Chain.dump(&DAG);
16796                    dbgs() << "\n");
16797         WorklistRemover DeadNodes(*this);
16798         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
16799         AddUsersToWorklist(Chain.getNode());
16800         if (N->use_empty())
16801           deleteAndRecombine(N);
16802 
16803         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
16804       }
16805     } else {
16806       // Indexed loads.
16807       assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
16808 
16809       // If this load has an opaque TargetConstant offset, then we cannot split
16810       // the indexing into an add/sub directly (that TargetConstant may not be
16811       // valid for a different type of node, and we cannot convert an opaque
16812       // target constant into a regular constant).
16813       bool CanSplitIdx = canSplitIdx(LD);
16814 
16815       if (!N->hasAnyUseOfValue(0) && (CanSplitIdx || !N->hasAnyUseOfValue(1))) {
16816         SDValue Undef = DAG.getUNDEF(N->getValueType(0));
16817         SDValue Index;
16818         if (N->hasAnyUseOfValue(1) && CanSplitIdx) {
16819           Index = SplitIndexingFromLoad(LD);
16820           // Try to fold the base pointer arithmetic into subsequent loads and
16821           // stores.
16822           AddUsersToWorklist(N);
16823         } else
16824           Index = DAG.getUNDEF(N->getValueType(1));
16825         LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
16826                    dbgs() << "\nWith: "; Undef.dump(&DAG);
16827                    dbgs() << " and 2 other values\n");
16828         WorklistRemover DeadNodes(*this);
16829         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
16830         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
16831         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 2), Chain);
16832         deleteAndRecombine(N);
16833         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
16834       }
16835     }
16836   }
16837 
16838   // If this load is directly stored, replace the load value with the stored
16839   // value.
16840   if (auto V = ForwardStoreValueToDirectLoad(LD))
16841     return V;
16842 
16843   // Try to infer better alignment information than the load already has.
16844   if (OptLevel != CodeGenOpt::None && LD->isUnindexed() && !LD->isAtomic()) {
16845     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
16846       if (*Alignment > LD->getAlign() &&
16847           isAligned(*Alignment, LD->getSrcValueOffset())) {
16848         SDValue NewLoad = DAG.getExtLoad(
16849             LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
16850             LD->getPointerInfo(), LD->getMemoryVT(), *Alignment,
16851             LD->getMemOperand()->getFlags(), LD->getAAInfo());
16852         // NewLoad will always be N as we are only refining the alignment
16853         assert(NewLoad.getNode() == N);
16854         (void)NewLoad;
16855       }
16856     }
16857   }
16858 
16859   if (LD->isUnindexed()) {
16860     // Walk up chain skipping non-aliasing memory nodes.
16861     SDValue BetterChain = FindBetterChain(LD, Chain);
16862 
16863     // If there is a better chain.
16864     if (Chain != BetterChain) {
16865       SDValue ReplLoad;
16866 
16867       // Replace the chain to void dependency.
16868       if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
16869         ReplLoad = DAG.getLoad(N->getValueType(0), SDLoc(LD),
16870                                BetterChain, Ptr, LD->getMemOperand());
16871       } else {
16872         ReplLoad = DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD),
16873                                   LD->getValueType(0),
16874                                   BetterChain, Ptr, LD->getMemoryVT(),
16875                                   LD->getMemOperand());
16876       }
16877 
16878       // Create token factor to keep old chain connected.
16879       SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
16880                                   MVT::Other, Chain, ReplLoad.getValue(1));
16881 
16882       // Replace uses with load result and token factor
16883       return CombineTo(N, ReplLoad.getValue(0), Token);
16884     }
16885   }
16886 
16887   // Try transforming N to an indexed load.
16888   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
16889     return SDValue(N, 0);
16890 
16891   // Try to slice up N to more direct loads if the slices are mapped to
16892   // different register banks or pairing can take place.
16893   if (SliceUpLoad(N))
16894     return SDValue(N, 0);
16895 
16896   return SDValue();
16897 }
16898 
16899 namespace {
16900 
16901 /// Helper structure used to slice a load in smaller loads.
16902 /// Basically a slice is obtained from the following sequence:
16903 /// Origin = load Ty1, Base
16904 /// Shift = srl Ty1 Origin, CstTy Amount
16905 /// Inst = trunc Shift to Ty2
16906 ///
16907 /// Then, it will be rewritten into:
16908 /// Slice = load SliceTy, Base + SliceOffset
16909 /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
16910 ///
16911 /// SliceTy is deduced from the number of bits that are actually used to
16912 /// build Inst.
16913 struct LoadedSlice {
16914   /// Helper structure used to compute the cost of a slice.
16915   struct Cost {
16916     /// Are we optimizing for code size.
16917     bool ForCodeSize = false;
16918 
16919     /// Various cost.
16920     unsigned Loads = 0;
16921     unsigned Truncates = 0;
16922     unsigned CrossRegisterBanksCopies = 0;
16923     unsigned ZExts = 0;
16924     unsigned Shift = 0;
16925 
Cost__anon54f00e403b11::LoadedSlice::Cost16926     explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
16927 
16928     /// Get the cost of one isolated slice.
Cost__anon54f00e403b11::LoadedSlice::Cost16929     Cost(const LoadedSlice &LS, bool ForCodeSize)
16930         : ForCodeSize(ForCodeSize), Loads(1) {
16931       EVT TruncType = LS.Inst->getValueType(0);
16932       EVT LoadedType = LS.getLoadedType();
16933       if (TruncType != LoadedType &&
16934           !LS.DAG->getTargetLoweringInfo().isZExtFree(LoadedType, TruncType))
16935         ZExts = 1;
16936     }
16937 
16938     /// Account for slicing gain in the current cost.
16939     /// Slicing provide a few gains like removing a shift or a
16940     /// truncate. This method allows to grow the cost of the original
16941     /// load with the gain from this slice.
addSliceGain__anon54f00e403b11::LoadedSlice::Cost16942     void addSliceGain(const LoadedSlice &LS) {
16943       // Each slice saves a truncate.
16944       const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
16945       if (!TLI.isTruncateFree(LS.Inst->getOperand(0).getValueType(),
16946                               LS.Inst->getValueType(0)))
16947         ++Truncates;
16948       // If there is a shift amount, this slice gets rid of it.
16949       if (LS.Shift)
16950         ++Shift;
16951       // If this slice can merge a cross register bank copy, account for it.
16952       if (LS.canMergeExpensiveCrossRegisterBankCopy())
16953         ++CrossRegisterBanksCopies;
16954     }
16955 
operator +=__anon54f00e403b11::LoadedSlice::Cost16956     Cost &operator+=(const Cost &RHS) {
16957       Loads += RHS.Loads;
16958       Truncates += RHS.Truncates;
16959       CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
16960       ZExts += RHS.ZExts;
16961       Shift += RHS.Shift;
16962       return *this;
16963     }
16964 
operator ==__anon54f00e403b11::LoadedSlice::Cost16965     bool operator==(const Cost &RHS) const {
16966       return Loads == RHS.Loads && Truncates == RHS.Truncates &&
16967              CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
16968              ZExts == RHS.ZExts && Shift == RHS.Shift;
16969     }
16970 
operator !=__anon54f00e403b11::LoadedSlice::Cost16971     bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
16972 
operator <__anon54f00e403b11::LoadedSlice::Cost16973     bool operator<(const Cost &RHS) const {
16974       // Assume cross register banks copies are as expensive as loads.
16975       // FIXME: Do we want some more target hooks?
16976       unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
16977       unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
16978       // Unless we are optimizing for code size, consider the
16979       // expensive operation first.
16980       if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
16981         return ExpensiveOpsLHS < ExpensiveOpsRHS;
16982       return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
16983              (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
16984     }
16985 
operator >__anon54f00e403b11::LoadedSlice::Cost16986     bool operator>(const Cost &RHS) const { return RHS < *this; }
16987 
operator <=__anon54f00e403b11::LoadedSlice::Cost16988     bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
16989 
operator >=__anon54f00e403b11::LoadedSlice::Cost16990     bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
16991   };
16992 
16993   // The last instruction that represent the slice. This should be a
16994   // truncate instruction.
16995   SDNode *Inst;
16996 
16997   // The original load instruction.
16998   LoadSDNode *Origin;
16999 
17000   // The right shift amount in bits from the original load.
17001   unsigned Shift;
17002 
17003   // The DAG from which Origin came from.
17004   // This is used to get some contextual information about legal types, etc.
17005   SelectionDAG *DAG;
17006 
LoadedSlice__anon54f00e403b11::LoadedSlice17007   LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
17008               unsigned Shift = 0, SelectionDAG *DAG = nullptr)
17009       : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
17010 
17011   /// Get the bits used in a chunk of bits \p BitWidth large.
17012   /// \return Result is \p BitWidth and has used bits set to 1 and
17013   ///         not used bits set to 0.
getUsedBits__anon54f00e403b11::LoadedSlice17014   APInt getUsedBits() const {
17015     // Reproduce the trunc(lshr) sequence:
17016     // - Start from the truncated value.
17017     // - Zero extend to the desired bit width.
17018     // - Shift left.
17019     assert(Origin && "No original load to compare against.");
17020     unsigned BitWidth = Origin->getValueSizeInBits(0);
17021     assert(Inst && "This slice is not bound to an instruction");
17022     assert(Inst->getValueSizeInBits(0) <= BitWidth &&
17023            "Extracted slice is bigger than the whole type!");
17024     APInt UsedBits(Inst->getValueSizeInBits(0), 0);
17025     UsedBits.setAllBits();
17026     UsedBits = UsedBits.zext(BitWidth);
17027     UsedBits <<= Shift;
17028     return UsedBits;
17029   }
17030 
17031   /// Get the size of the slice to be loaded in bytes.
getLoadedSize__anon54f00e403b11::LoadedSlice17032   unsigned getLoadedSize() const {
17033     unsigned SliceSize = getUsedBits().countPopulation();
17034     assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
17035     return SliceSize / 8;
17036   }
17037 
17038   /// Get the type that will be loaded for this slice.
17039   /// Note: This may not be the final type for the slice.
getLoadedType__anon54f00e403b11::LoadedSlice17040   EVT getLoadedType() const {
17041     assert(DAG && "Missing context");
17042     LLVMContext &Ctxt = *DAG->getContext();
17043     return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
17044   }
17045 
17046   /// Get the alignment of the load used for this slice.
getAlign__anon54f00e403b11::LoadedSlice17047   Align getAlign() const {
17048     Align Alignment = Origin->getAlign();
17049     uint64_t Offset = getOffsetFromBase();
17050     if (Offset != 0)
17051       Alignment = commonAlignment(Alignment, Alignment.value() + Offset);
17052     return Alignment;
17053   }
17054 
17055   /// Check if this slice can be rewritten with legal operations.
isLegal__anon54f00e403b11::LoadedSlice17056   bool isLegal() const {
17057     // An invalid slice is not legal.
17058     if (!Origin || !Inst || !DAG)
17059       return false;
17060 
17061     // Offsets are for indexed load only, we do not handle that.
17062     if (!Origin->getOffset().isUndef())
17063       return false;
17064 
17065     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
17066 
17067     // Check that the type is legal.
17068     EVT SliceType = getLoadedType();
17069     if (!TLI.isTypeLegal(SliceType))
17070       return false;
17071 
17072     // Check that the load is legal for this type.
17073     if (!TLI.isOperationLegal(ISD::LOAD, SliceType))
17074       return false;
17075 
17076     // Check that the offset can be computed.
17077     // 1. Check its type.
17078     EVT PtrType = Origin->getBasePtr().getValueType();
17079     if (PtrType == MVT::Untyped || PtrType.isExtended())
17080       return false;
17081 
17082     // 2. Check that it fits in the immediate.
17083     if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
17084       return false;
17085 
17086     // 3. Check that the computation is legal.
17087     if (!TLI.isOperationLegal(ISD::ADD, PtrType))
17088       return false;
17089 
17090     // Check that the zext is legal if it needs one.
17091     EVT TruncateType = Inst->getValueType(0);
17092     if (TruncateType != SliceType &&
17093         !TLI.isOperationLegal(ISD::ZERO_EXTEND, TruncateType))
17094       return false;
17095 
17096     return true;
17097   }
17098 
17099   /// Get the offset in bytes of this slice in the original chunk of
17100   /// bits.
17101   /// \pre DAG != nullptr.
getOffsetFromBase__anon54f00e403b11::LoadedSlice17102   uint64_t getOffsetFromBase() const {
17103     assert(DAG && "Missing context.");
17104     bool IsBigEndian = DAG->getDataLayout().isBigEndian();
17105     assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
17106     uint64_t Offset = Shift / 8;
17107     unsigned TySizeInBytes = Origin->getValueSizeInBits(0) / 8;
17108     assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
17109            "The size of the original loaded type is not a multiple of a"
17110            " byte.");
17111     // If Offset is bigger than TySizeInBytes, it means we are loading all
17112     // zeros. This should have been optimized before in the process.
17113     assert(TySizeInBytes > Offset &&
17114            "Invalid shift amount for given loaded size");
17115     if (IsBigEndian)
17116       Offset = TySizeInBytes - Offset - getLoadedSize();
17117     return Offset;
17118   }
17119 
17120   /// Generate the sequence of instructions to load the slice
17121   /// represented by this object and redirect the uses of this slice to
17122   /// this new sequence of instructions.
17123   /// \pre this->Inst && this->Origin are valid Instructions and this
17124   /// object passed the legal check: LoadedSlice::isLegal returned true.
17125   /// \return The last instruction of the sequence used to load the slice.
loadSlice__anon54f00e403b11::LoadedSlice17126   SDValue loadSlice() const {
17127     assert(Inst && Origin && "Unable to replace a non-existing slice.");
17128     const SDValue &OldBaseAddr = Origin->getBasePtr();
17129     SDValue BaseAddr = OldBaseAddr;
17130     // Get the offset in that chunk of bytes w.r.t. the endianness.
17131     int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
17132     assert(Offset >= 0 && "Offset too big to fit in int64_t!");
17133     if (Offset) {
17134       // BaseAddr = BaseAddr + Offset.
17135       EVT ArithType = BaseAddr.getValueType();
17136       SDLoc DL(Origin);
17137       BaseAddr = DAG->getNode(ISD::ADD, DL, ArithType, BaseAddr,
17138                               DAG->getConstant(Offset, DL, ArithType));
17139     }
17140 
17141     // Create the type of the loaded slice according to its size.
17142     EVT SliceType = getLoadedType();
17143 
17144     // Create the load for the slice.
17145     SDValue LastInst =
17146         DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
17147                      Origin->getPointerInfo().getWithOffset(Offset), getAlign(),
17148                      Origin->getMemOperand()->getFlags());
17149     // If the final type is not the same as the loaded type, this means that
17150     // we have to pad with zero. Create a zero extend for that.
17151     EVT FinalType = Inst->getValueType(0);
17152     if (SliceType != FinalType)
17153       LastInst =
17154           DAG->getNode(ISD::ZERO_EXTEND, SDLoc(LastInst), FinalType, LastInst);
17155     return LastInst;
17156   }
17157 
17158   /// Check if this slice can be merged with an expensive cross register
17159   /// bank copy. E.g.,
17160   /// i = load i32
17161   /// f = bitcast i32 i to float
canMergeExpensiveCrossRegisterBankCopy__anon54f00e403b11::LoadedSlice17162   bool canMergeExpensiveCrossRegisterBankCopy() const {
17163     if (!Inst || !Inst->hasOneUse())
17164       return false;
17165     SDNode *Use = *Inst->use_begin();
17166     if (Use->getOpcode() != ISD::BITCAST)
17167       return false;
17168     assert(DAG && "Missing context");
17169     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
17170     EVT ResVT = Use->getValueType(0);
17171     const TargetRegisterClass *ResRC =
17172         TLI.getRegClassFor(ResVT.getSimpleVT(), Use->isDivergent());
17173     const TargetRegisterClass *ArgRC =
17174         TLI.getRegClassFor(Use->getOperand(0).getValueType().getSimpleVT(),
17175                            Use->getOperand(0)->isDivergent());
17176     if (ArgRC == ResRC || !TLI.isOperationLegal(ISD::LOAD, ResVT))
17177       return false;
17178 
17179     // At this point, we know that we perform a cross-register-bank copy.
17180     // Check if it is expensive.
17181     const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
17182     // Assume bitcasts are cheap, unless both register classes do not
17183     // explicitly share a common sub class.
17184     if (!TRI || TRI->getCommonSubClass(ArgRC, ResRC))
17185       return false;
17186 
17187     // Check if it will be merged with the load.
17188     // 1. Check the alignment / fast memory access constraint.
17189     bool IsFast = false;
17190     if (!TLI.allowsMemoryAccess(*DAG->getContext(), DAG->getDataLayout(), ResVT,
17191                                 Origin->getAddressSpace(), getAlign(),
17192                                 Origin->getMemOperand()->getFlags(), &IsFast) ||
17193         !IsFast)
17194       return false;
17195 
17196     // 2. Check that the load is a legal operation for that type.
17197     if (!TLI.isOperationLegal(ISD::LOAD, ResVT))
17198       return false;
17199 
17200     // 3. Check that we do not have a zext in the way.
17201     if (Inst->getValueType(0) != getLoadedType())
17202       return false;
17203 
17204     return true;
17205   }
17206 };
17207 
17208 } // end anonymous namespace
17209 
17210 /// Check that all bits set in \p UsedBits form a dense region, i.e.,
17211 /// \p UsedBits looks like 0..0 1..1 0..0.
areUsedBitsDense(const APInt & UsedBits)17212 static bool areUsedBitsDense(const APInt &UsedBits) {
17213   // If all the bits are one, this is dense!
17214   if (UsedBits.isAllOnes())
17215     return true;
17216 
17217   // Get rid of the unused bits on the right.
17218   APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countTrailingZeros());
17219   // Get rid of the unused bits on the left.
17220   if (NarrowedUsedBits.countLeadingZeros())
17221     NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits());
17222   // Check that the chunk of bits is completely used.
17223   return NarrowedUsedBits.isAllOnes();
17224 }
17225 
17226 /// Check whether or not \p First and \p Second are next to each other
17227 /// in memory. This means that there is no hole between the bits loaded
17228 /// by \p First and the bits loaded by \p Second.
areSlicesNextToEachOther(const LoadedSlice & First,const LoadedSlice & Second)17229 static bool areSlicesNextToEachOther(const LoadedSlice &First,
17230                                      const LoadedSlice &Second) {
17231   assert(First.Origin == Second.Origin && First.Origin &&
17232          "Unable to match different memory origins.");
17233   APInt UsedBits = First.getUsedBits();
17234   assert((UsedBits & Second.getUsedBits()) == 0 &&
17235          "Slices are not supposed to overlap.");
17236   UsedBits |= Second.getUsedBits();
17237   return areUsedBitsDense(UsedBits);
17238 }
17239 
17240 /// Adjust the \p GlobalLSCost according to the target
17241 /// paring capabilities and the layout of the slices.
17242 /// \pre \p GlobalLSCost should account for at least as many loads as
17243 /// there is in the slices in \p LoadedSlices.
adjustCostForPairing(SmallVectorImpl<LoadedSlice> & LoadedSlices,LoadedSlice::Cost & GlobalLSCost)17244 static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
17245                                  LoadedSlice::Cost &GlobalLSCost) {
17246   unsigned NumberOfSlices = LoadedSlices.size();
17247   // If there is less than 2 elements, no pairing is possible.
17248   if (NumberOfSlices < 2)
17249     return;
17250 
17251   // Sort the slices so that elements that are likely to be next to each
17252   // other in memory are next to each other in the list.
17253   llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
17254     assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
17255     return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
17256   });
17257   const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
17258   // First (resp. Second) is the first (resp. Second) potentially candidate
17259   // to be placed in a paired load.
17260   const LoadedSlice *First = nullptr;
17261   const LoadedSlice *Second = nullptr;
17262   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
17263                 // Set the beginning of the pair.
17264                                                            First = Second) {
17265     Second = &LoadedSlices[CurrSlice];
17266 
17267     // If First is NULL, it means we start a new pair.
17268     // Get to the next slice.
17269     if (!First)
17270       continue;
17271 
17272     EVT LoadedType = First->getLoadedType();
17273 
17274     // If the types of the slices are different, we cannot pair them.
17275     if (LoadedType != Second->getLoadedType())
17276       continue;
17277 
17278     // Check if the target supplies paired loads for this type.
17279     Align RequiredAlignment;
17280     if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
17281       // move to the next pair, this type is hopeless.
17282       Second = nullptr;
17283       continue;
17284     }
17285     // Check if we meet the alignment requirement.
17286     if (First->getAlign() < RequiredAlignment)
17287       continue;
17288 
17289     // Check that both loads are next to each other in memory.
17290     if (!areSlicesNextToEachOther(*First, *Second))
17291       continue;
17292 
17293     assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
17294     --GlobalLSCost.Loads;
17295     // Move to the next pair.
17296     Second = nullptr;
17297   }
17298 }
17299 
17300 /// Check the profitability of all involved LoadedSlice.
17301 /// Currently, it is considered profitable if there is exactly two
17302 /// involved slices (1) which are (2) next to each other in memory, and
17303 /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
17304 ///
17305 /// Note: The order of the elements in \p LoadedSlices may be modified, but not
17306 /// the elements themselves.
17307 ///
17308 /// FIXME: When the cost model will be mature enough, we can relax
17309 /// constraints (1) and (2).
isSlicingProfitable(SmallVectorImpl<LoadedSlice> & LoadedSlices,const APInt & UsedBits,bool ForCodeSize)17310 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
17311                                 const APInt &UsedBits, bool ForCodeSize) {
17312   unsigned NumberOfSlices = LoadedSlices.size();
17313   if (StressLoadSlicing)
17314     return NumberOfSlices > 1;
17315 
17316   // Check (1).
17317   if (NumberOfSlices != 2)
17318     return false;
17319 
17320   // Check (2).
17321   if (!areUsedBitsDense(UsedBits))
17322     return false;
17323 
17324   // Check (3).
17325   LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
17326   // The original code has one big load.
17327   OrigCost.Loads = 1;
17328   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
17329     const LoadedSlice &LS = LoadedSlices[CurrSlice];
17330     // Accumulate the cost of all the slices.
17331     LoadedSlice::Cost SliceCost(LS, ForCodeSize);
17332     GlobalSlicingCost += SliceCost;
17333 
17334     // Account as cost in the original configuration the gain obtained
17335     // with the current slices.
17336     OrigCost.addSliceGain(LS);
17337   }
17338 
17339   // If the target supports paired load, adjust the cost accordingly.
17340   adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
17341   return OrigCost > GlobalSlicingCost;
17342 }
17343 
17344 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
17345 /// operations, split it in the various pieces being extracted.
17346 ///
17347 /// This sort of thing is introduced by SROA.
17348 /// This slicing takes care not to insert overlapping loads.
17349 /// \pre LI is a simple load (i.e., not an atomic or volatile load).
SliceUpLoad(SDNode * N)17350 bool DAGCombiner::SliceUpLoad(SDNode *N) {
17351   if (Level < AfterLegalizeDAG)
17352     return false;
17353 
17354   LoadSDNode *LD = cast<LoadSDNode>(N);
17355   if (!LD->isSimple() || !ISD::isNormalLoad(LD) ||
17356       !LD->getValueType(0).isInteger())
17357     return false;
17358 
17359   // The algorithm to split up a load of a scalable vector into individual
17360   // elements currently requires knowing the length of the loaded type,
17361   // so will need adjusting to work on scalable vectors.
17362   if (LD->getValueType(0).isScalableVector())
17363     return false;
17364 
17365   // Keep track of already used bits to detect overlapping values.
17366   // In that case, we will just abort the transformation.
17367   APInt UsedBits(LD->getValueSizeInBits(0), 0);
17368 
17369   SmallVector<LoadedSlice, 4> LoadedSlices;
17370 
17371   // Check if this load is used as several smaller chunks of bits.
17372   // Basically, look for uses in trunc or trunc(lshr) and record a new chain
17373   // of computation for each trunc.
17374   for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
17375        UI != UIEnd; ++UI) {
17376     // Skip the uses of the chain.
17377     if (UI.getUse().getResNo() != 0)
17378       continue;
17379 
17380     SDNode *User = *UI;
17381     unsigned Shift = 0;
17382 
17383     // Check if this is a trunc(lshr).
17384     if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
17385         isa<ConstantSDNode>(User->getOperand(1))) {
17386       Shift = User->getConstantOperandVal(1);
17387       User = *User->use_begin();
17388     }
17389 
17390     // At this point, User is a Truncate, iff we encountered, trunc or
17391     // trunc(lshr).
17392     if (User->getOpcode() != ISD::TRUNCATE)
17393       return false;
17394 
17395     // The width of the type must be a power of 2 and greater than 8-bits.
17396     // Otherwise the load cannot be represented in LLVM IR.
17397     // Moreover, if we shifted with a non-8-bits multiple, the slice
17398     // will be across several bytes. We do not support that.
17399     unsigned Width = User->getValueSizeInBits(0);
17400     if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
17401       return false;
17402 
17403     // Build the slice for this chain of computations.
17404     LoadedSlice LS(User, LD, Shift, &DAG);
17405     APInt CurrentUsedBits = LS.getUsedBits();
17406 
17407     // Check if this slice overlaps with another.
17408     if ((CurrentUsedBits & UsedBits) != 0)
17409       return false;
17410     // Update the bits used globally.
17411     UsedBits |= CurrentUsedBits;
17412 
17413     // Check if the new slice would be legal.
17414     if (!LS.isLegal())
17415       return false;
17416 
17417     // Record the slice.
17418     LoadedSlices.push_back(LS);
17419   }
17420 
17421   // Abort slicing if it does not seem to be profitable.
17422   if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
17423     return false;
17424 
17425   ++SlicedLoads;
17426 
17427   // Rewrite each chain to use an independent load.
17428   // By construction, each chain can be represented by a unique load.
17429 
17430   // Prepare the argument for the new token factor for all the slices.
17431   SmallVector<SDValue, 8> ArgChains;
17432   for (const LoadedSlice &LS : LoadedSlices) {
17433     SDValue SliceInst = LS.loadSlice();
17434     CombineTo(LS.Inst, SliceInst, true);
17435     if (SliceInst.getOpcode() != ISD::LOAD)
17436       SliceInst = SliceInst.getOperand(0);
17437     assert(SliceInst->getOpcode() == ISD::LOAD &&
17438            "It takes more than a zext to get to the loaded slice!!");
17439     ArgChains.push_back(SliceInst.getValue(1));
17440   }
17441 
17442   SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
17443                               ArgChains);
17444   DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
17445   AddToWorklist(Chain.getNode());
17446   return true;
17447 }
17448 
17449 /// Check to see if V is (and load (ptr), imm), where the load is having
17450 /// specific bytes cleared out.  If so, return the byte size being masked out
17451 /// and the shift amount.
17452 static std::pair<unsigned, unsigned>
CheckForMaskedLoad(SDValue V,SDValue Ptr,SDValue Chain)17453 CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
17454   std::pair<unsigned, unsigned> Result(0, 0);
17455 
17456   // Check for the structure we're looking for.
17457   if (V->getOpcode() != ISD::AND ||
17458       !isa<ConstantSDNode>(V->getOperand(1)) ||
17459       !ISD::isNormalLoad(V->getOperand(0).getNode()))
17460     return Result;
17461 
17462   // Check the chain and pointer.
17463   LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
17464   if (LD->getBasePtr() != Ptr) return Result;  // Not from same pointer.
17465 
17466   // This only handles simple types.
17467   if (V.getValueType() != MVT::i16 &&
17468       V.getValueType() != MVT::i32 &&
17469       V.getValueType() != MVT::i64)
17470     return Result;
17471 
17472   // Check the constant mask.  Invert it so that the bits being masked out are
17473   // 0 and the bits being kept are 1.  Use getSExtValue so that leading bits
17474   // follow the sign bit for uniformity.
17475   uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue();
17476   unsigned NotMaskLZ = countLeadingZeros(NotMask);
17477   if (NotMaskLZ & 7) return Result;  // Must be multiple of a byte.
17478   unsigned NotMaskTZ = countTrailingZeros(NotMask);
17479   if (NotMaskTZ & 7) return Result;  // Must be multiple of a byte.
17480   if (NotMaskLZ == 64) return Result;  // All zero mask.
17481 
17482   // See if we have a continuous run of bits.  If so, we have 0*1+0*
17483   if (countTrailingOnes(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
17484     return Result;
17485 
17486   // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
17487   if (V.getValueType() != MVT::i64 && NotMaskLZ)
17488     NotMaskLZ -= 64-V.getValueSizeInBits();
17489 
17490   unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
17491   switch (MaskedBytes) {
17492   case 1:
17493   case 2:
17494   case 4: break;
17495   default: return Result; // All one mask, or 5-byte mask.
17496   }
17497 
17498   // Verify that the first bit starts at a multiple of mask so that the access
17499   // is aligned the same as the access width.
17500   if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
17501 
17502   // For narrowing to be valid, it must be the case that the load the
17503   // immediately preceding memory operation before the store.
17504   if (LD == Chain.getNode())
17505     ; // ok.
17506   else if (Chain->getOpcode() == ISD::TokenFactor &&
17507            SDValue(LD, 1).hasOneUse()) {
17508     // LD has only 1 chain use so they are no indirect dependencies.
17509     if (!LD->isOperandOf(Chain.getNode()))
17510       return Result;
17511   } else
17512     return Result; // Fail.
17513 
17514   Result.first = MaskedBytes;
17515   Result.second = NotMaskTZ/8;
17516   return Result;
17517 }
17518 
17519 /// Check to see if IVal is something that provides a value as specified by
17520 /// MaskInfo. If so, replace the specified store with a narrower store of
17521 /// truncated IVal.
17522 static SDValue
ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned,unsigned> & MaskInfo,SDValue IVal,StoreSDNode * St,DAGCombiner * DC)17523 ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
17524                                 SDValue IVal, StoreSDNode *St,
17525                                 DAGCombiner *DC) {
17526   unsigned NumBytes = MaskInfo.first;
17527   unsigned ByteShift = MaskInfo.second;
17528   SelectionDAG &DAG = DC->getDAG();
17529 
17530   // Check to see if IVal is all zeros in the part being masked in by the 'or'
17531   // that uses this.  If not, this is not a replacement.
17532   APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(),
17533                                   ByteShift*8, (ByteShift+NumBytes)*8);
17534   if (!DAG.MaskedValueIsZero(IVal, Mask)) return SDValue();
17535 
17536   // Check that it is legal on the target to do this.  It is legal if the new
17537   // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
17538   // legalization. If the source type is legal, but the store type isn't, see
17539   // if we can use a truncating store.
17540   MVT VT = MVT::getIntegerVT(NumBytes * 8);
17541   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
17542   bool UseTruncStore;
17543   if (DC->isTypeLegal(VT))
17544     UseTruncStore = false;
17545   else if (TLI.isTypeLegal(IVal.getValueType()) &&
17546            TLI.isTruncStoreLegal(IVal.getValueType(), VT))
17547     UseTruncStore = true;
17548   else
17549     return SDValue();
17550   // Check that the target doesn't think this is a bad idea.
17551   if (St->getMemOperand() &&
17552       !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
17553                               *St->getMemOperand()))
17554     return SDValue();
17555 
17556   // Okay, we can do this!  Replace the 'St' store with a store of IVal that is
17557   // shifted by ByteShift and truncated down to NumBytes.
17558   if (ByteShift) {
17559     SDLoc DL(IVal);
17560     IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
17561                        DAG.getConstant(ByteShift*8, DL,
17562                                     DC->getShiftAmountTy(IVal.getValueType())));
17563   }
17564 
17565   // Figure out the offset for the store and the alignment of the access.
17566   unsigned StOffset;
17567   if (DAG.getDataLayout().isLittleEndian())
17568     StOffset = ByteShift;
17569   else
17570     StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
17571 
17572   SDValue Ptr = St->getBasePtr();
17573   if (StOffset) {
17574     SDLoc DL(IVal);
17575     Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(StOffset), DL);
17576   }
17577 
17578   ++OpsNarrowed;
17579   if (UseTruncStore)
17580     return DAG.getTruncStore(St->getChain(), SDLoc(St), IVal, Ptr,
17581                              St->getPointerInfo().getWithOffset(StOffset),
17582                              VT, St->getOriginalAlign());
17583 
17584   // Truncate down to the new size.
17585   IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
17586 
17587   return DAG
17588       .getStore(St->getChain(), SDLoc(St), IVal, Ptr,
17589                 St->getPointerInfo().getWithOffset(StOffset),
17590                 St->getOriginalAlign());
17591 }
17592 
17593 /// Look for sequence of load / op / store where op is one of 'or', 'xor', and
17594 /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
17595 /// narrowing the load and store if it would end up being a win for performance
17596 /// or code size.
ReduceLoadOpStoreWidth(SDNode * N)17597 SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
17598   StoreSDNode *ST  = cast<StoreSDNode>(N);
17599   if (!ST->isSimple())
17600     return SDValue();
17601 
17602   SDValue Chain = ST->getChain();
17603   SDValue Value = ST->getValue();
17604   SDValue Ptr   = ST->getBasePtr();
17605   EVT VT = Value.getValueType();
17606 
17607   if (ST->isTruncatingStore() || VT.isVector())
17608     return SDValue();
17609 
17610   unsigned Opc = Value.getOpcode();
17611 
17612   if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
17613       !Value.hasOneUse())
17614     return SDValue();
17615 
17616   // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
17617   // is a byte mask indicating a consecutive number of bytes, check to see if
17618   // Y is known to provide just those bytes.  If so, we try to replace the
17619   // load + replace + store sequence with a single (narrower) store, which makes
17620   // the load dead.
17621   if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
17622     std::pair<unsigned, unsigned> MaskedLoad;
17623     MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain);
17624     if (MaskedLoad.first)
17625       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
17626                                                   Value.getOperand(1), ST,this))
17627         return NewST;
17628 
17629     // Or is commutative, so try swapping X and Y.
17630     MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain);
17631     if (MaskedLoad.first)
17632       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
17633                                                   Value.getOperand(0), ST,this))
17634         return NewST;
17635   }
17636 
17637   if (!EnableReduceLoadOpStoreWidth)
17638     return SDValue();
17639 
17640   if (Value.getOperand(1).getOpcode() != ISD::Constant)
17641     return SDValue();
17642 
17643   SDValue N0 = Value.getOperand(0);
17644   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
17645       Chain == SDValue(N0.getNode(), 1)) {
17646     LoadSDNode *LD = cast<LoadSDNode>(N0);
17647     if (LD->getBasePtr() != Ptr ||
17648         LD->getPointerInfo().getAddrSpace() !=
17649         ST->getPointerInfo().getAddrSpace())
17650       return SDValue();
17651 
17652     // Find the type to narrow it the load / op / store to.
17653     SDValue N1 = Value.getOperand(1);
17654     unsigned BitWidth = N1.getValueSizeInBits();
17655     APInt Imm = cast<ConstantSDNode>(N1)->getAPIntValue();
17656     if (Opc == ISD::AND)
17657       Imm ^= APInt::getAllOnes(BitWidth);
17658     if (Imm == 0 || Imm.isAllOnes())
17659       return SDValue();
17660     unsigned ShAmt = Imm.countTrailingZeros();
17661     unsigned MSB = BitWidth - Imm.countLeadingZeros() - 1;
17662     unsigned NewBW = NextPowerOf2(MSB - ShAmt);
17663     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
17664     // The narrowing should be profitable, the load/store operation should be
17665     // legal (or custom) and the store size should be equal to the NewVT width.
17666     while (NewBW < BitWidth &&
17667            (NewVT.getStoreSizeInBits() != NewBW ||
17668             !TLI.isOperationLegalOrCustom(Opc, NewVT) ||
17669             !TLI.isNarrowingProfitable(VT, NewVT))) {
17670       NewBW = NextPowerOf2(NewBW);
17671       NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
17672     }
17673     if (NewBW >= BitWidth)
17674       return SDValue();
17675 
17676     // If the lsb changed does not start at the type bitwidth boundary,
17677     // start at the previous one.
17678     if (ShAmt % NewBW)
17679       ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
17680     APInt Mask = APInt::getBitsSet(BitWidth, ShAmt,
17681                                    std::min(BitWidth, ShAmt + NewBW));
17682     if ((Imm & Mask) == Imm) {
17683       APInt NewImm = (Imm & Mask).lshr(ShAmt).trunc(NewBW);
17684       if (Opc == ISD::AND)
17685         NewImm ^= APInt::getAllOnes(NewBW);
17686       uint64_t PtrOff = ShAmt / 8;
17687       // For big endian targets, we need to adjust the offset to the pointer to
17688       // load the correct bytes.
17689       if (DAG.getDataLayout().isBigEndian())
17690         PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
17691 
17692       bool IsFast = false;
17693       Align NewAlign = commonAlignment(LD->getAlign(), PtrOff);
17694       if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), NewVT,
17695                                   LD->getAddressSpace(), NewAlign,
17696                                   LD->getMemOperand()->getFlags(), &IsFast) ||
17697           !IsFast)
17698         return SDValue();
17699 
17700       SDValue NewPtr =
17701           DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(PtrOff), SDLoc(LD));
17702       SDValue NewLD =
17703           DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
17704                       LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
17705                       LD->getMemOperand()->getFlags(), LD->getAAInfo());
17706       SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
17707                                    DAG.getConstant(NewImm, SDLoc(Value),
17708                                                    NewVT));
17709       SDValue NewST =
17710           DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
17711                        ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
17712 
17713       AddToWorklist(NewPtr.getNode());
17714       AddToWorklist(NewLD.getNode());
17715       AddToWorklist(NewVal.getNode());
17716       WorklistRemover DeadNodes(*this);
17717       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLD.getValue(1));
17718       ++OpsNarrowed;
17719       return NewST;
17720     }
17721   }
17722 
17723   return SDValue();
17724 }
17725 
17726 /// For a given floating point load / store pair, if the load value isn't used
17727 /// by any other operations, then consider transforming the pair to integer
17728 /// load / store operations if the target deems the transformation profitable.
TransformFPLoadStorePair(SDNode * N)17729 SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
17730   StoreSDNode *ST  = cast<StoreSDNode>(N);
17731   SDValue Value = ST->getValue();
17732   if (ISD::isNormalStore(ST) && ISD::isNormalLoad(Value.getNode()) &&
17733       Value.hasOneUse()) {
17734     LoadSDNode *LD = cast<LoadSDNode>(Value);
17735     EVT VT = LD->getMemoryVT();
17736     if (!VT.isFloatingPoint() ||
17737         VT != ST->getMemoryVT() ||
17738         LD->isNonTemporal() ||
17739         ST->isNonTemporal() ||
17740         LD->getPointerInfo().getAddrSpace() != 0 ||
17741         ST->getPointerInfo().getAddrSpace() != 0)
17742       return SDValue();
17743 
17744     TypeSize VTSize = VT.getSizeInBits();
17745 
17746     // We don't know the size of scalable types at compile time so we cannot
17747     // create an integer of the equivalent size.
17748     if (VTSize.isScalable())
17749       return SDValue();
17750 
17751     bool FastLD = false, FastST = false;
17752     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VTSize.getFixedSize());
17753     if (!TLI.isOperationLegal(ISD::LOAD, IntVT) ||
17754         !TLI.isOperationLegal(ISD::STORE, IntVT) ||
17755         !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
17756         !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
17757         !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
17758                                 *LD->getMemOperand(), &FastLD) ||
17759         !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
17760                                 *ST->getMemOperand(), &FastST) ||
17761         !FastLD || !FastST)
17762       return SDValue();
17763 
17764     SDValue NewLD =
17765         DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(), LD->getBasePtr(),
17766                     LD->getPointerInfo(), LD->getAlign());
17767 
17768     SDValue NewST =
17769         DAG.getStore(ST->getChain(), SDLoc(N), NewLD, ST->getBasePtr(),
17770                      ST->getPointerInfo(), ST->getAlign());
17771 
17772     AddToWorklist(NewLD.getNode());
17773     AddToWorklist(NewST.getNode());
17774     WorklistRemover DeadNodes(*this);
17775     DAG.ReplaceAllUsesOfValueWith(Value.getValue(1), NewLD.getValue(1));
17776     ++LdStFP2Int;
17777     return NewST;
17778   }
17779 
17780   return SDValue();
17781 }
17782 
17783 // This is a helper function for visitMUL to check the profitability
17784 // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
17785 // MulNode is the original multiply, AddNode is (add x, c1),
17786 // and ConstNode is c2.
17787 //
17788 // If the (add x, c1) has multiple uses, we could increase
17789 // the number of adds if we make this transformation.
17790 // It would only be worth doing this if we can remove a
17791 // multiply in the process. Check for that here.
17792 // To illustrate:
17793 //     (A + c1) * c3
17794 //     (A + c2) * c3
17795 // We're checking for cases where we have common "c3 * A" expressions.
isMulAddWithConstProfitable(SDNode * MulNode,SDValue AddNode,SDValue ConstNode)17796 bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
17797                                               SDValue ConstNode) {
17798   APInt Val;
17799 
17800   // If the add only has one use, and the target thinks the folding is
17801   // profitable or does not lead to worse code, this would be OK to do.
17802   if (AddNode->hasOneUse() &&
17803       TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
17804     return true;
17805 
17806   // Walk all the users of the constant with which we're multiplying.
17807   for (SDNode *Use : ConstNode->uses()) {
17808     if (Use == MulNode) // This use is the one we're on right now. Skip it.
17809       continue;
17810 
17811     if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
17812       SDNode *OtherOp;
17813       SDNode *MulVar = AddNode.getOperand(0).getNode();
17814 
17815       // OtherOp is what we're multiplying against the constant.
17816       if (Use->getOperand(0) == ConstNode)
17817         OtherOp = Use->getOperand(1).getNode();
17818       else
17819         OtherOp = Use->getOperand(0).getNode();
17820 
17821       // Check to see if multiply is with the same operand of our "add".
17822       //
17823       //     ConstNode  = CONST
17824       //     Use = ConstNode * A  <-- visiting Use. OtherOp is A.
17825       //     ...
17826       //     AddNode  = (A + c1)  <-- MulVar is A.
17827       //         = AddNode * ConstNode   <-- current visiting instruction.
17828       //
17829       // If we make this transformation, we will have a common
17830       // multiply (ConstNode * A) that we can save.
17831       if (OtherOp == MulVar)
17832         return true;
17833 
17834       // Now check to see if a future expansion will give us a common
17835       // multiply.
17836       //
17837       //     ConstNode  = CONST
17838       //     AddNode    = (A + c1)
17839       //     ...   = AddNode * ConstNode <-- current visiting instruction.
17840       //     ...
17841       //     OtherOp = (A + c2)
17842       //     Use     = OtherOp * ConstNode <-- visiting Use.
17843       //
17844       // If we make this transformation, we will have a common
17845       // multiply (CONST * A) after we also do the same transformation
17846       // to the "t2" instruction.
17847       if (OtherOp->getOpcode() == ISD::ADD &&
17848           DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
17849           OtherOp->getOperand(0).getNode() == MulVar)
17850         return true;
17851     }
17852   }
17853 
17854   // Didn't find a case where this would be profitable.
17855   return false;
17856 }
17857 
getMergeStoreChains(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores)17858 SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
17859                                          unsigned NumStores) {
17860   SmallVector<SDValue, 8> Chains;
17861   SmallPtrSet<const SDNode *, 8> Visited;
17862   SDLoc StoreDL(StoreNodes[0].MemNode);
17863 
17864   for (unsigned i = 0; i < NumStores; ++i) {
17865     Visited.insert(StoreNodes[i].MemNode);
17866   }
17867 
17868   // don't include nodes that are children or repeated nodes.
17869   for (unsigned i = 0; i < NumStores; ++i) {
17870     if (Visited.insert(StoreNodes[i].MemNode->getChain().getNode()).second)
17871       Chains.push_back(StoreNodes[i].MemNode->getChain());
17872   }
17873 
17874   assert(Chains.size() > 0 && "Chain should have generated a chain");
17875   return DAG.getTokenFactor(StoreDL, Chains);
17876 }
17877 
mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> & StoreNodes,EVT MemVT,unsigned NumStores,bool IsConstantSrc,bool UseVector,bool UseTrunc)17878 bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
17879     SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
17880     bool IsConstantSrc, bool UseVector, bool UseTrunc) {
17881   // Make sure we have something to merge.
17882   if (NumStores < 2)
17883     return false;
17884 
17885   assert((!UseTrunc || !UseVector) &&
17886          "This optimization cannot emit a vector truncating store");
17887 
17888   // The latest Node in the DAG.
17889   SDLoc DL(StoreNodes[0].MemNode);
17890 
17891   TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
17892   unsigned SizeInBits = NumStores * ElementSizeBits;
17893   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
17894 
17895   Optional<MachineMemOperand::Flags> Flags;
17896   AAMDNodes AAInfo;
17897   for (unsigned I = 0; I != NumStores; ++I) {
17898     StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
17899     if (!Flags) {
17900       Flags = St->getMemOperand()->getFlags();
17901       AAInfo = St->getAAInfo();
17902       continue;
17903     }
17904     // Skip merging if there's an inconsistent flag.
17905     if (Flags != St->getMemOperand()->getFlags())
17906       return false;
17907     // Concatenate AA metadata.
17908     AAInfo = AAInfo.concat(St->getAAInfo());
17909   }
17910 
17911   EVT StoreTy;
17912   if (UseVector) {
17913     unsigned Elts = NumStores * NumMemElts;
17914     // Get the type for the merged vector store.
17915     StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
17916   } else
17917     StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
17918 
17919   SDValue StoredVal;
17920   if (UseVector) {
17921     if (IsConstantSrc) {
17922       SmallVector<SDValue, 8> BuildVector;
17923       for (unsigned I = 0; I != NumStores; ++I) {
17924         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
17925         SDValue Val = St->getValue();
17926         // If constant is of the wrong type, convert it now.
17927         if (MemVT != Val.getValueType()) {
17928           Val = peekThroughBitcasts(Val);
17929           // Deal with constants of wrong size.
17930           if (ElementSizeBits != Val.getValueSizeInBits()) {
17931             EVT IntMemVT =
17932                 EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
17933             if (isa<ConstantFPSDNode>(Val)) {
17934               // Not clear how to truncate FP values.
17935               return false;
17936             }
17937 
17938             if (auto *C = dyn_cast<ConstantSDNode>(Val))
17939               Val = DAG.getConstant(C->getAPIntValue()
17940                                         .zextOrTrunc(Val.getValueSizeInBits())
17941                                         .zextOrTrunc(ElementSizeBits),
17942                                     SDLoc(C), IntMemVT);
17943           }
17944           // Make sure correctly size type is the correct type.
17945           Val = DAG.getBitcast(MemVT, Val);
17946         }
17947         BuildVector.push_back(Val);
17948       }
17949       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
17950                                                : ISD::BUILD_VECTOR,
17951                               DL, StoreTy, BuildVector);
17952     } else {
17953       SmallVector<SDValue, 8> Ops;
17954       for (unsigned i = 0; i < NumStores; ++i) {
17955         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
17956         SDValue Val = peekThroughBitcasts(St->getValue());
17957         // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
17958         // type MemVT. If the underlying value is not the correct
17959         // type, but it is an extraction of an appropriate vector we
17960         // can recast Val to be of the correct type. This may require
17961         // converting between EXTRACT_VECTOR_ELT and
17962         // EXTRACT_SUBVECTOR.
17963         if ((MemVT != Val.getValueType()) &&
17964             (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
17965              Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
17966           EVT MemVTScalarTy = MemVT.getScalarType();
17967           // We may need to add a bitcast here to get types to line up.
17968           if (MemVTScalarTy != Val.getValueType().getScalarType()) {
17969             Val = DAG.getBitcast(MemVT, Val);
17970           } else {
17971             unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
17972                                             : ISD::EXTRACT_VECTOR_ELT;
17973             SDValue Vec = Val.getOperand(0);
17974             SDValue Idx = Val.getOperand(1);
17975             Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
17976           }
17977         }
17978         Ops.push_back(Val);
17979       }
17980 
17981       // Build the extracted vector elements back into a vector.
17982       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
17983                                                : ISD::BUILD_VECTOR,
17984                               DL, StoreTy, Ops);
17985     }
17986   } else {
17987     // We should always use a vector store when merging extracted vector
17988     // elements, so this path implies a store of constants.
17989     assert(IsConstantSrc && "Merged vector elements should use vector store");
17990 
17991     APInt StoreInt(SizeInBits, 0);
17992 
17993     // Construct a single integer constant which is made of the smaller
17994     // constant inputs.
17995     bool IsLE = DAG.getDataLayout().isLittleEndian();
17996     for (unsigned i = 0; i < NumStores; ++i) {
17997       unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
17998       StoreSDNode *St  = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
17999 
18000       SDValue Val = St->getValue();
18001       Val = peekThroughBitcasts(Val);
18002       StoreInt <<= ElementSizeBits;
18003       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
18004         StoreInt |= C->getAPIntValue()
18005                         .zextOrTrunc(ElementSizeBits)
18006                         .zextOrTrunc(SizeInBits);
18007       } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
18008         StoreInt |= C->getValueAPF()
18009                         .bitcastToAPInt()
18010                         .zextOrTrunc(ElementSizeBits)
18011                         .zextOrTrunc(SizeInBits);
18012         // If fp truncation is necessary give up for now.
18013         if (MemVT.getSizeInBits() != ElementSizeBits)
18014           return false;
18015       } else {
18016         llvm_unreachable("Invalid constant element type");
18017       }
18018     }
18019 
18020     // Create the new Load and Store operations.
18021     StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
18022   }
18023 
18024   LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
18025   SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
18026 
18027   // make sure we use trunc store if it's necessary to be legal.
18028   SDValue NewStore;
18029   if (!UseTrunc) {
18030     NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
18031                             FirstInChain->getPointerInfo(),
18032                             FirstInChain->getAlign(), *Flags, AAInfo);
18033   } else { // Must be realized as a trunc store
18034     EVT LegalizedStoredValTy =
18035         TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
18036     unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
18037     ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
18038     SDValue ExtendedStoreVal =
18039         DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
18040                         LegalizedStoredValTy);
18041     NewStore = DAG.getTruncStore(
18042         NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
18043         FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/,
18044         FirstInChain->getAlign(), *Flags, AAInfo);
18045   }
18046 
18047   // Replace all merged stores with the new store.
18048   for (unsigned i = 0; i < NumStores; ++i)
18049     CombineTo(StoreNodes[i].MemNode, NewStore);
18050 
18051   AddToWorklist(NewChain.getNode());
18052   return true;
18053 }
18054 
getStoreMergeCandidates(StoreSDNode * St,SmallVectorImpl<MemOpLink> & StoreNodes,SDNode * & RootNode)18055 void DAGCombiner::getStoreMergeCandidates(
18056     StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes,
18057     SDNode *&RootNode) {
18058   // This holds the base pointer, index, and the offset in bytes from the base
18059   // pointer. We must have a base and an offset. Do not handle stores to undef
18060   // base pointers.
18061   BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
18062   if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
18063     return;
18064 
18065   SDValue Val = peekThroughBitcasts(St->getValue());
18066   StoreSource StoreSrc = getStoreSource(Val);
18067   assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
18068 
18069   // Match on loadbaseptr if relevant.
18070   EVT MemVT = St->getMemoryVT();
18071   BaseIndexOffset LBasePtr;
18072   EVT LoadVT;
18073   if (StoreSrc == StoreSource::Load) {
18074     auto *Ld = cast<LoadSDNode>(Val);
18075     LBasePtr = BaseIndexOffset::match(Ld, DAG);
18076     LoadVT = Ld->getMemoryVT();
18077     // Load and store should be the same type.
18078     if (MemVT != LoadVT)
18079       return;
18080     // Loads must only have one use.
18081     if (!Ld->hasNUsesOfValue(1, 0))
18082       return;
18083     // The memory operands must not be volatile/indexed/atomic.
18084     // TODO: May be able to relax for unordered atomics (see D66309)
18085     if (!Ld->isSimple() || Ld->isIndexed())
18086       return;
18087   }
18088   auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
18089                             int64_t &Offset) -> bool {
18090     // The memory operands must not be volatile/indexed/atomic.
18091     // TODO: May be able to relax for unordered atomics (see D66309)
18092     if (!Other->isSimple() || Other->isIndexed())
18093       return false;
18094     // Don't mix temporal stores with non-temporal stores.
18095     if (St->isNonTemporal() != Other->isNonTemporal())
18096       return false;
18097     SDValue OtherBC = peekThroughBitcasts(Other->getValue());
18098     // Allow merging constants of different types as integers.
18099     bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
18100                                            : Other->getMemoryVT() != MemVT;
18101     switch (StoreSrc) {
18102     case StoreSource::Load: {
18103       if (NoTypeMatch)
18104         return false;
18105       // The Load's Base Ptr must also match.
18106       auto *OtherLd = dyn_cast<LoadSDNode>(OtherBC);
18107       if (!OtherLd)
18108         return false;
18109       BaseIndexOffset LPtr = BaseIndexOffset::match(OtherLd, DAG);
18110       if (LoadVT != OtherLd->getMemoryVT())
18111         return false;
18112       // Loads must only have one use.
18113       if (!OtherLd->hasNUsesOfValue(1, 0))
18114         return false;
18115       // The memory operands must not be volatile/indexed/atomic.
18116       // TODO: May be able to relax for unordered atomics (see D66309)
18117       if (!OtherLd->isSimple() || OtherLd->isIndexed())
18118         return false;
18119       // Don't mix temporal loads with non-temporal loads.
18120       if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
18121         return false;
18122       if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
18123         return false;
18124       break;
18125     }
18126     case StoreSource::Constant:
18127       if (NoTypeMatch)
18128         return false;
18129       if (!isIntOrFPConstant(OtherBC))
18130         return false;
18131       break;
18132     case StoreSource::Extract:
18133       // Do not merge truncated stores here.
18134       if (Other->isTruncatingStore())
18135         return false;
18136       if (!MemVT.bitsEq(OtherBC.getValueType()))
18137         return false;
18138       if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
18139           OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
18140         return false;
18141       break;
18142     default:
18143       llvm_unreachable("Unhandled store source for merging");
18144     }
18145     Ptr = BaseIndexOffset::match(Other, DAG);
18146     return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
18147   };
18148 
18149   // Check if the pair of StoreNode and the RootNode already bail out many
18150   // times which is over the limit in dependence check.
18151   auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
18152                                         SDNode *RootNode) -> bool {
18153     auto RootCount = StoreRootCountMap.find(StoreNode);
18154     return RootCount != StoreRootCountMap.end() &&
18155            RootCount->second.first == RootNode &&
18156            RootCount->second.second > StoreMergeDependenceLimit;
18157   };
18158 
18159   auto TryToAddCandidate = [&](SDNode::use_iterator UseIter) {
18160     // This must be a chain use.
18161     if (UseIter.getOperandNo() != 0)
18162       return;
18163     if (auto *OtherStore = dyn_cast<StoreSDNode>(*UseIter)) {
18164       BaseIndexOffset Ptr;
18165       int64_t PtrDiff;
18166       if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
18167           !OverLimitInDependenceCheck(OtherStore, RootNode))
18168         StoreNodes.push_back(MemOpLink(OtherStore, PtrDiff));
18169     }
18170   };
18171 
18172   // We looking for a root node which is an ancestor to all mergable
18173   // stores. We search up through a load, to our root and then down
18174   // through all children. For instance we will find Store{1,2,3} if
18175   // St is Store1, Store2. or Store3 where the root is not a load
18176   // which always true for nonvolatile ops. TODO: Expand
18177   // the search to find all valid candidates through multiple layers of loads.
18178   //
18179   // Root
18180   // |-------|-------|
18181   // Load    Load    Store3
18182   // |       |
18183   // Store1   Store2
18184   //
18185   // FIXME: We should be able to climb and
18186   // descend TokenFactors to find candidates as well.
18187 
18188   RootNode = St->getChain().getNode();
18189 
18190   unsigned NumNodesExplored = 0;
18191   const unsigned MaxSearchNodes = 1024;
18192   if (auto *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
18193     RootNode = Ldn->getChain().getNode();
18194     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
18195          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
18196       if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) { // walk down chain
18197         for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
18198           TryToAddCandidate(I2);
18199       }
18200       // Check stores that depend on the root (e.g. Store 3 in the chart above).
18201       if (I.getOperandNo() == 0 && isa<StoreSDNode>(*I)) {
18202         TryToAddCandidate(I);
18203       }
18204     }
18205   } else {
18206     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
18207          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
18208       TryToAddCandidate(I);
18209   }
18210 }
18211 
18212 // We need to check that merging these stores does not cause a loop in the
18213 // DAG. Any store candidate may depend on another candidate indirectly through
18214 // its operands. Check in parallel by searching up from operands of candidates.
checkMergeStoreCandidatesForDependencies(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores,SDNode * RootNode)18215 bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
18216     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
18217     SDNode *RootNode) {
18218   // FIXME: We should be able to truncate a full search of
18219   // predecessors by doing a BFS and keeping tabs the originating
18220   // stores from which worklist nodes come from in a similar way to
18221   // TokenFactor simplfication.
18222 
18223   SmallPtrSet<const SDNode *, 32> Visited;
18224   SmallVector<const SDNode *, 8> Worklist;
18225 
18226   // RootNode is a predecessor to all candidates so we need not search
18227   // past it. Add RootNode (peeking through TokenFactors). Do not count
18228   // these towards size check.
18229 
18230   Worklist.push_back(RootNode);
18231   while (!Worklist.empty()) {
18232     auto N = Worklist.pop_back_val();
18233     if (!Visited.insert(N).second)
18234       continue; // Already present in Visited.
18235     if (N->getOpcode() == ISD::TokenFactor) {
18236       for (SDValue Op : N->ops())
18237         Worklist.push_back(Op.getNode());
18238     }
18239   }
18240 
18241   // Don't count pruning nodes towards max.
18242   unsigned int Max = 1024 + Visited.size();
18243   // Search Ops of store candidates.
18244   for (unsigned i = 0; i < NumStores; ++i) {
18245     SDNode *N = StoreNodes[i].MemNode;
18246     // Of the 4 Store Operands:
18247     //   * Chain (Op 0) -> We have already considered these
18248     //                     in candidate selection, but only by following the
18249     //                     chain dependencies. We could still have a chain
18250     //                     dependency to a load, that has a non-chain dep to
18251     //                     another load, that depends on a store, etc. So it is
18252     //                     possible to have dependencies that consist of a mix
18253     //                     of chain and non-chain deps, and we need to include
18254     //                     chain operands in the analysis here..
18255     //   * Value (Op 1) -> Cycles may happen (e.g. through load chains)
18256     //   * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
18257     //                       but aren't necessarily fromt the same base node, so
18258     //                       cycles possible (e.g. via indexed store).
18259     //   * (Op 3) -> Represents the pre or post-indexing offset (or undef for
18260     //               non-indexed stores). Not constant on all targets (e.g. ARM)
18261     //               and so can participate in a cycle.
18262     for (unsigned j = 0; j < N->getNumOperands(); ++j)
18263       Worklist.push_back(N->getOperand(j).getNode());
18264   }
18265   // Search through DAG. We can stop early if we find a store node.
18266   for (unsigned i = 0; i < NumStores; ++i)
18267     if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
18268                                      Max)) {
18269       // If the searching bail out, record the StoreNode and RootNode in the
18270       // StoreRootCountMap. If we have seen the pair many times over a limit,
18271       // we won't add the StoreNode into StoreNodes set again.
18272       if (Visited.size() >= Max) {
18273         auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
18274         if (RootCount.first == RootNode)
18275           RootCount.second++;
18276         else
18277           RootCount = {RootNode, 1};
18278       }
18279       return false;
18280     }
18281   return true;
18282 }
18283 
18284 unsigned
getConsecutiveStores(SmallVectorImpl<MemOpLink> & StoreNodes,int64_t ElementSizeBytes) const18285 DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
18286                                   int64_t ElementSizeBytes) const {
18287   while (true) {
18288     // Find a store past the width of the first store.
18289     size_t StartIdx = 0;
18290     while ((StartIdx + 1 < StoreNodes.size()) &&
18291            StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
18292               StoreNodes[StartIdx + 1].OffsetFromBase)
18293       ++StartIdx;
18294 
18295     // Bail if we don't have enough candidates to merge.
18296     if (StartIdx + 1 >= StoreNodes.size())
18297       return 0;
18298 
18299     // Trim stores that overlapped with the first store.
18300     if (StartIdx)
18301       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
18302 
18303     // Scan the memory operations on the chain and find the first
18304     // non-consecutive store memory address.
18305     unsigned NumConsecutiveStores = 1;
18306     int64_t StartAddress = StoreNodes[0].OffsetFromBase;
18307     // Check that the addresses are consecutive starting from the second
18308     // element in the list of stores.
18309     for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
18310       int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
18311       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
18312         break;
18313       NumConsecutiveStores = i + 1;
18314     }
18315     if (NumConsecutiveStores > 1)
18316       return NumConsecutiveStores;
18317 
18318     // There are no consecutive stores at the start of the list.
18319     // Remove the first store and try again.
18320     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
18321   }
18322 }
18323 
tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors)18324 bool DAGCombiner::tryStoreMergeOfConstants(
18325     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
18326     EVT MemVT, SDNode *RootNode, bool AllowVectors) {
18327   LLVMContext &Context = *DAG.getContext();
18328   const DataLayout &DL = DAG.getDataLayout();
18329   int64_t ElementSizeBytes = MemVT.getStoreSize();
18330   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
18331   bool MadeChange = false;
18332 
18333   // Store the constants into memory as one consecutive store.
18334   while (NumConsecutiveStores >= 2) {
18335     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
18336     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
18337     Align FirstStoreAlign = FirstInChain->getAlign();
18338     unsigned LastLegalType = 1;
18339     unsigned LastLegalVectorType = 1;
18340     bool LastIntegerTrunc = false;
18341     bool NonZero = false;
18342     unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
18343     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
18344       StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
18345       SDValue StoredVal = ST->getValue();
18346       bool IsElementZero = false;
18347       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
18348         IsElementZero = C->isZero();
18349       else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
18350         IsElementZero = C->getConstantFPValue()->isNullValue();
18351       if (IsElementZero) {
18352         if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
18353           FirstZeroAfterNonZero = i;
18354       }
18355       NonZero |= !IsElementZero;
18356 
18357       // Find a legal type for the constant store.
18358       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
18359       EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
18360       bool IsFast = false;
18361 
18362       // Break early when size is too large to be legal.
18363       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
18364         break;
18365 
18366       if (TLI.isTypeLegal(StoreTy) &&
18367           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
18368                                DAG.getMachineFunction()) &&
18369           TLI.allowsMemoryAccess(Context, DL, StoreTy,
18370                                  *FirstInChain->getMemOperand(), &IsFast) &&
18371           IsFast) {
18372         LastIntegerTrunc = false;
18373         LastLegalType = i + 1;
18374         // Or check whether a truncstore is legal.
18375       } else if (TLI.getTypeAction(Context, StoreTy) ==
18376                  TargetLowering::TypePromoteInteger) {
18377         EVT LegalizedStoredValTy =
18378             TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
18379         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
18380             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
18381                                  DAG.getMachineFunction()) &&
18382             TLI.allowsMemoryAccess(Context, DL, StoreTy,
18383                                    *FirstInChain->getMemOperand(), &IsFast) &&
18384             IsFast) {
18385           LastIntegerTrunc = true;
18386           LastLegalType = i + 1;
18387         }
18388       }
18389 
18390       // We only use vectors if the constant is known to be zero or the
18391       // target allows it and the function is not marked with the
18392       // noimplicitfloat attribute.
18393       if ((!NonZero ||
18394            TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) &&
18395           AllowVectors) {
18396         // Find a legal type for the vector store.
18397         unsigned Elts = (i + 1) * NumMemElts;
18398         EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
18399         if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
18400             TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
18401             TLI.allowsMemoryAccess(Context, DL, Ty,
18402                                    *FirstInChain->getMemOperand(), &IsFast) &&
18403             IsFast)
18404           LastLegalVectorType = i + 1;
18405       }
18406     }
18407 
18408     bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
18409     unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
18410     bool UseTrunc = LastIntegerTrunc && !UseVector;
18411 
18412     // Check if we found a legal integer type that creates a meaningful
18413     // merge.
18414     if (NumElem < 2) {
18415       // We know that candidate stores are in order and of correct
18416       // shape. While there is no mergeable sequence from the
18417       // beginning one may start later in the sequence. The only
18418       // reason a merge of size N could have failed where another of
18419       // the same size would not have, is if the alignment has
18420       // improved or we've dropped a non-zero value. Drop as many
18421       // candidates as we can here.
18422       unsigned NumSkip = 1;
18423       while ((NumSkip < NumConsecutiveStores) &&
18424              (NumSkip < FirstZeroAfterNonZero) &&
18425              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
18426         NumSkip++;
18427 
18428       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
18429       NumConsecutiveStores -= NumSkip;
18430       continue;
18431     }
18432 
18433     // Check that we can merge these candidates without causing a cycle.
18434     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
18435                                                   RootNode)) {
18436       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
18437       NumConsecutiveStores -= NumElem;
18438       continue;
18439     }
18440 
18441     MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem,
18442                                                   /*IsConstantSrc*/ true,
18443                                                   UseVector, UseTrunc);
18444 
18445     // Remove merged stores for next iteration.
18446     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
18447     NumConsecutiveStores -= NumElem;
18448   }
18449   return MadeChange;
18450 }
18451 
tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode)18452 bool DAGCombiner::tryStoreMergeOfExtracts(
18453     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
18454     EVT MemVT, SDNode *RootNode) {
18455   LLVMContext &Context = *DAG.getContext();
18456   const DataLayout &DL = DAG.getDataLayout();
18457   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
18458   bool MadeChange = false;
18459 
18460   // Loop on Consecutive Stores on success.
18461   while (NumConsecutiveStores >= 2) {
18462     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
18463     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
18464     Align FirstStoreAlign = FirstInChain->getAlign();
18465     unsigned NumStoresToMerge = 1;
18466     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
18467       // Find a legal type for the vector store.
18468       unsigned Elts = (i + 1) * NumMemElts;
18469       EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
18470       bool IsFast = false;
18471 
18472       // Break early when size is too large to be legal.
18473       if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
18474         break;
18475 
18476       if (TLI.isTypeLegal(Ty) &&
18477           TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
18478           TLI.allowsMemoryAccess(Context, DL, Ty,
18479                                  *FirstInChain->getMemOperand(), &IsFast) &&
18480           IsFast)
18481         NumStoresToMerge = i + 1;
18482     }
18483 
18484     // Check if we found a legal integer type creating a meaningful
18485     // merge.
18486     if (NumStoresToMerge < 2) {
18487       // We know that candidate stores are in order and of correct
18488       // shape. While there is no mergeable sequence from the
18489       // beginning one may start later in the sequence. The only
18490       // reason a merge of size N could have failed where another of
18491       // the same size would not have, is if the alignment has
18492       // improved. Drop as many candidates as we can here.
18493       unsigned NumSkip = 1;
18494       while ((NumSkip < NumConsecutiveStores) &&
18495              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
18496         NumSkip++;
18497 
18498       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
18499       NumConsecutiveStores -= NumSkip;
18500       continue;
18501     }
18502 
18503     // Check that we can merge these candidates without causing a cycle.
18504     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStoresToMerge,
18505                                                   RootNode)) {
18506       StoreNodes.erase(StoreNodes.begin(),
18507                        StoreNodes.begin() + NumStoresToMerge);
18508       NumConsecutiveStores -= NumStoresToMerge;
18509       continue;
18510     }
18511 
18512     MadeChange |= mergeStoresOfConstantsOrVecElts(
18513         StoreNodes, MemVT, NumStoresToMerge, /*IsConstantSrc*/ false,
18514         /*UseVector*/ true, /*UseTrunc*/ false);
18515 
18516     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumStoresToMerge);
18517     NumConsecutiveStores -= NumStoresToMerge;
18518   }
18519   return MadeChange;
18520 }
18521 
tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors,bool IsNonTemporalStore,bool IsNonTemporalLoad)18522 bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
18523                                        unsigned NumConsecutiveStores, EVT MemVT,
18524                                        SDNode *RootNode, bool AllowVectors,
18525                                        bool IsNonTemporalStore,
18526                                        bool IsNonTemporalLoad) {
18527   LLVMContext &Context = *DAG.getContext();
18528   const DataLayout &DL = DAG.getDataLayout();
18529   int64_t ElementSizeBytes = MemVT.getStoreSize();
18530   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
18531   bool MadeChange = false;
18532 
18533   // Look for load nodes which are used by the stored values.
18534   SmallVector<MemOpLink, 8> LoadNodes;
18535 
18536   // Find acceptable loads. Loads need to have the same chain (token factor),
18537   // must not be zext, volatile, indexed, and they must be consecutive.
18538   BaseIndexOffset LdBasePtr;
18539 
18540   for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
18541     StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
18542     SDValue Val = peekThroughBitcasts(St->getValue());
18543     LoadSDNode *Ld = cast<LoadSDNode>(Val);
18544 
18545     BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
18546     // If this is not the first ptr that we check.
18547     int64_t LdOffset = 0;
18548     if (LdBasePtr.getBase().getNode()) {
18549       // The base ptr must be the same.
18550       if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
18551         break;
18552     } else {
18553       // Check that all other base pointers are the same as this one.
18554       LdBasePtr = LdPtr;
18555     }
18556 
18557     // We found a potential memory operand to merge.
18558     LoadNodes.push_back(MemOpLink(Ld, LdOffset));
18559   }
18560 
18561   while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
18562     Align RequiredAlignment;
18563     bool NeedRotate = false;
18564     if (LoadNodes.size() == 2) {
18565       // If we have load/store pair instructions and we only have two values,
18566       // don't bother merging.
18567       if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
18568           StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
18569         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
18570         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
18571         break;
18572       }
18573       // If the loads are reversed, see if we can rotate the halves into place.
18574       int64_t Offset0 = LoadNodes[0].OffsetFromBase;
18575       int64_t Offset1 = LoadNodes[1].OffsetFromBase;
18576       EVT PairVT = EVT::getIntegerVT(Context, ElementSizeBytes * 8 * 2);
18577       if (Offset0 - Offset1 == ElementSizeBytes &&
18578           (hasOperation(ISD::ROTL, PairVT) ||
18579            hasOperation(ISD::ROTR, PairVT))) {
18580         std::swap(LoadNodes[0], LoadNodes[1]);
18581         NeedRotate = true;
18582       }
18583     }
18584     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
18585     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
18586     Align FirstStoreAlign = FirstInChain->getAlign();
18587     LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
18588 
18589     // Scan the memory operations on the chain and find the first
18590     // non-consecutive load memory address. These variables hold the index in
18591     // the store node array.
18592 
18593     unsigned LastConsecutiveLoad = 1;
18594 
18595     // This variable refers to the size and not index in the array.
18596     unsigned LastLegalVectorType = 1;
18597     unsigned LastLegalIntegerType = 1;
18598     bool isDereferenceable = true;
18599     bool DoIntegerTruncate = false;
18600     int64_t StartAddress = LoadNodes[0].OffsetFromBase;
18601     SDValue LoadChain = FirstLoad->getChain();
18602     for (unsigned i = 1; i < LoadNodes.size(); ++i) {
18603       // All loads must share the same chain.
18604       if (LoadNodes[i].MemNode->getChain() != LoadChain)
18605         break;
18606 
18607       int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
18608       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
18609         break;
18610       LastConsecutiveLoad = i;
18611 
18612       if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
18613         isDereferenceable = false;
18614 
18615       // Find a legal type for the vector store.
18616       unsigned Elts = (i + 1) * NumMemElts;
18617       EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
18618 
18619       // Break early when size is too large to be legal.
18620       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
18621         break;
18622 
18623       bool IsFastSt = false;
18624       bool IsFastLd = false;
18625       // Don't try vector types if we need a rotate. We may still fail the
18626       // legality checks for the integer type, but we can't handle the rotate
18627       // case with vectors.
18628       // FIXME: We could use a shuffle in place of the rotate.
18629       if (!NeedRotate && TLI.isTypeLegal(StoreTy) &&
18630           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
18631                                DAG.getMachineFunction()) &&
18632           TLI.allowsMemoryAccess(Context, DL, StoreTy,
18633                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
18634           IsFastSt &&
18635           TLI.allowsMemoryAccess(Context, DL, StoreTy,
18636                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
18637           IsFastLd) {
18638         LastLegalVectorType = i + 1;
18639       }
18640 
18641       // Find a legal type for the integer store.
18642       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
18643       StoreTy = EVT::getIntegerVT(Context, SizeInBits);
18644       if (TLI.isTypeLegal(StoreTy) &&
18645           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
18646                                DAG.getMachineFunction()) &&
18647           TLI.allowsMemoryAccess(Context, DL, StoreTy,
18648                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
18649           IsFastSt &&
18650           TLI.allowsMemoryAccess(Context, DL, StoreTy,
18651                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
18652           IsFastLd) {
18653         LastLegalIntegerType = i + 1;
18654         DoIntegerTruncate = false;
18655         // Or check whether a truncstore and extload is legal.
18656       } else if (TLI.getTypeAction(Context, StoreTy) ==
18657                  TargetLowering::TypePromoteInteger) {
18658         EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
18659         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
18660             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
18661                                  DAG.getMachineFunction()) &&
18662             TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) &&
18663             TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) &&
18664             TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
18665             TLI.allowsMemoryAccess(Context, DL, StoreTy,
18666                                    *FirstInChain->getMemOperand(), &IsFastSt) &&
18667             IsFastSt &&
18668             TLI.allowsMemoryAccess(Context, DL, StoreTy,
18669                                    *FirstLoad->getMemOperand(), &IsFastLd) &&
18670             IsFastLd) {
18671           LastLegalIntegerType = i + 1;
18672           DoIntegerTruncate = true;
18673         }
18674       }
18675     }
18676 
18677     // Only use vector types if the vector type is larger than the integer
18678     // type. If they are the same, use integers.
18679     bool UseVectorTy =
18680         LastLegalVectorType > LastLegalIntegerType && AllowVectors;
18681     unsigned LastLegalType =
18682         std::max(LastLegalVectorType, LastLegalIntegerType);
18683 
18684     // We add +1 here because the LastXXX variables refer to location while
18685     // the NumElem refers to array/index size.
18686     unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
18687     NumElem = std::min(LastLegalType, NumElem);
18688     Align FirstLoadAlign = FirstLoad->getAlign();
18689 
18690     if (NumElem < 2) {
18691       // We know that candidate stores are in order and of correct
18692       // shape. While there is no mergeable sequence from the
18693       // beginning one may start later in the sequence. The only
18694       // reason a merge of size N could have failed where another of
18695       // the same size would not have is if the alignment or either
18696       // the load or store has improved. Drop as many candidates as we
18697       // can here.
18698       unsigned NumSkip = 1;
18699       while ((NumSkip < LoadNodes.size()) &&
18700              (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
18701              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
18702         NumSkip++;
18703       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
18704       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
18705       NumConsecutiveStores -= NumSkip;
18706       continue;
18707     }
18708 
18709     // Check that we can merge these candidates without causing a cycle.
18710     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
18711                                                   RootNode)) {
18712       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
18713       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
18714       NumConsecutiveStores -= NumElem;
18715       continue;
18716     }
18717 
18718     // Find if it is better to use vectors or integers to load and store
18719     // to memory.
18720     EVT JointMemOpVT;
18721     if (UseVectorTy) {
18722       // Find a legal type for the vector store.
18723       unsigned Elts = NumElem * NumMemElts;
18724       JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
18725     } else {
18726       unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
18727       JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
18728     }
18729 
18730     SDLoc LoadDL(LoadNodes[0].MemNode);
18731     SDLoc StoreDL(StoreNodes[0].MemNode);
18732 
18733     // The merged loads are required to have the same incoming chain, so
18734     // using the first's chain is acceptable.
18735 
18736     SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
18737     AddToWorklist(NewStoreChain.getNode());
18738 
18739     MachineMemOperand::Flags LdMMOFlags =
18740         isDereferenceable ? MachineMemOperand::MODereferenceable
18741                           : MachineMemOperand::MONone;
18742     if (IsNonTemporalLoad)
18743       LdMMOFlags |= MachineMemOperand::MONonTemporal;
18744 
18745     MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
18746                                               ? MachineMemOperand::MONonTemporal
18747                                               : MachineMemOperand::MONone;
18748 
18749     SDValue NewLoad, NewStore;
18750     if (UseVectorTy || !DoIntegerTruncate) {
18751       NewLoad = DAG.getLoad(
18752           JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(),
18753           FirstLoad->getPointerInfo(), FirstLoadAlign, LdMMOFlags);
18754       SDValue StoreOp = NewLoad;
18755       if (NeedRotate) {
18756         unsigned LoadWidth = ElementSizeBytes * 8 * 2;
18757         assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
18758                "Unexpected type for rotate-able load pair");
18759         SDValue RotAmt =
18760             DAG.getShiftAmountConstant(LoadWidth / 2, JointMemOpVT, LoadDL);
18761         // Target can convert to the identical ROTR if it does not have ROTL.
18762         StoreOp = DAG.getNode(ISD::ROTL, LoadDL, JointMemOpVT, NewLoad, RotAmt);
18763       }
18764       NewStore = DAG.getStore(
18765           NewStoreChain, StoreDL, StoreOp, FirstInChain->getBasePtr(),
18766           FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags);
18767     } else { // This must be the truncstore/extload case
18768       EVT ExtendedTy =
18769           TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
18770       NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
18771                                FirstLoad->getChain(), FirstLoad->getBasePtr(),
18772                                FirstLoad->getPointerInfo(), JointMemOpVT,
18773                                FirstLoadAlign, LdMMOFlags);
18774       NewStore = DAG.getTruncStore(
18775           NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
18776           FirstInChain->getPointerInfo(), JointMemOpVT,
18777           FirstInChain->getAlign(), FirstInChain->getMemOperand()->getFlags());
18778     }
18779 
18780     // Transfer chain users from old loads to the new load.
18781     for (unsigned i = 0; i < NumElem; ++i) {
18782       LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
18783       DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
18784                                     SDValue(NewLoad.getNode(), 1));
18785     }
18786 
18787     // Replace all stores with the new store. Recursively remove corresponding
18788     // values if they are no longer used.
18789     for (unsigned i = 0; i < NumElem; ++i) {
18790       SDValue Val = StoreNodes[i].MemNode->getOperand(1);
18791       CombineTo(StoreNodes[i].MemNode, NewStore);
18792       if (Val->use_empty())
18793         recursivelyDeleteUnusedNodes(Val.getNode());
18794     }
18795 
18796     MadeChange = true;
18797     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
18798     LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
18799     NumConsecutiveStores -= NumElem;
18800   }
18801   return MadeChange;
18802 }
18803 
mergeConsecutiveStores(StoreSDNode * St)18804 bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
18805   if (OptLevel == CodeGenOpt::None || !EnableStoreMerging)
18806     return false;
18807 
18808   // TODO: Extend this function to merge stores of scalable vectors.
18809   // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
18810   // store since we know <vscale x 16 x i8> is exactly twice as large as
18811   // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
18812   EVT MemVT = St->getMemoryVT();
18813   if (MemVT.isScalableVector())
18814     return false;
18815   if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
18816     return false;
18817 
18818   // This function cannot currently deal with non-byte-sized memory sizes.
18819   int64_t ElementSizeBytes = MemVT.getStoreSize();
18820   if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
18821     return false;
18822 
18823   // Do not bother looking at stored values that are not constants, loads, or
18824   // extracted vector elements.
18825   SDValue StoredVal = peekThroughBitcasts(St->getValue());
18826   const StoreSource StoreSrc = getStoreSource(StoredVal);
18827   if (StoreSrc == StoreSource::Unknown)
18828     return false;
18829 
18830   SmallVector<MemOpLink, 8> StoreNodes;
18831   SDNode *RootNode;
18832   // Find potential store merge candidates by searching through chain sub-DAG
18833   getStoreMergeCandidates(St, StoreNodes, RootNode);
18834 
18835   // Check if there is anything to merge.
18836   if (StoreNodes.size() < 2)
18837     return false;
18838 
18839   // Sort the memory operands according to their distance from the
18840   // base pointer.
18841   llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
18842     return LHS.OffsetFromBase < RHS.OffsetFromBase;
18843   });
18844 
18845   bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
18846       Attribute::NoImplicitFloat);
18847   bool IsNonTemporalStore = St->isNonTemporal();
18848   bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
18849                            cast<LoadSDNode>(StoredVal)->isNonTemporal();
18850 
18851   // Store Merge attempts to merge the lowest stores. This generally
18852   // works out as if successful, as the remaining stores are checked
18853   // after the first collection of stores is merged. However, in the
18854   // case that a non-mergeable store is found first, e.g., {p[-2],
18855   // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
18856   // mergeable cases. To prevent this, we prune such stores from the
18857   // front of StoreNodes here.
18858   bool MadeChange = false;
18859   while (StoreNodes.size() > 1) {
18860     unsigned NumConsecutiveStores =
18861         getConsecutiveStores(StoreNodes, ElementSizeBytes);
18862     // There are no more stores in the list to examine.
18863     if (NumConsecutiveStores == 0)
18864       return MadeChange;
18865 
18866     // We have at least 2 consecutive stores. Try to merge them.
18867     assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
18868     switch (StoreSrc) {
18869     case StoreSource::Constant:
18870       MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
18871                                              MemVT, RootNode, AllowVectors);
18872       break;
18873 
18874     case StoreSource::Extract:
18875       MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
18876                                             MemVT, RootNode);
18877       break;
18878 
18879     case StoreSource::Load:
18880       MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
18881                                          MemVT, RootNode, AllowVectors,
18882                                          IsNonTemporalStore, IsNonTemporalLoad);
18883       break;
18884 
18885     default:
18886       llvm_unreachable("Unhandled store source type");
18887     }
18888   }
18889   return MadeChange;
18890 }
18891 
replaceStoreChain(StoreSDNode * ST,SDValue BetterChain)18892 SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
18893   SDLoc SL(ST);
18894   SDValue ReplStore;
18895 
18896   // Replace the chain to avoid dependency.
18897   if (ST->isTruncatingStore()) {
18898     ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(),
18899                                   ST->getBasePtr(), ST->getMemoryVT(),
18900                                   ST->getMemOperand());
18901   } else {
18902     ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(),
18903                              ST->getMemOperand());
18904   }
18905 
18906   // Create token to keep both nodes around.
18907   SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
18908                               MVT::Other, ST->getChain(), ReplStore);
18909 
18910   // Make sure the new and old chains are cleaned up.
18911   AddToWorklist(Token.getNode());
18912 
18913   // Don't add users to work list.
18914   return CombineTo(ST, Token, false);
18915 }
18916 
replaceStoreOfFPConstant(StoreSDNode * ST)18917 SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
18918   SDValue Value = ST->getValue();
18919   if (Value.getOpcode() == ISD::TargetConstantFP)
18920     return SDValue();
18921 
18922   if (!ISD::isNormalStore(ST))
18923     return SDValue();
18924 
18925   SDLoc DL(ST);
18926 
18927   SDValue Chain = ST->getChain();
18928   SDValue Ptr = ST->getBasePtr();
18929 
18930   const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
18931 
18932   // NOTE: If the original store is volatile, this transform must not increase
18933   // the number of stores.  For example, on x86-32 an f64 can be stored in one
18934   // processor operation but an i64 (which is not legal) requires two.  So the
18935   // transform should not be done in this case.
18936 
18937   SDValue Tmp;
18938   switch (CFP->getSimpleValueType(0).SimpleTy) {
18939   default:
18940     llvm_unreachable("Unknown FP type");
18941   case MVT::f16:    // We don't do this for these yet.
18942   case MVT::bf16:
18943   case MVT::f80:
18944   case MVT::f128:
18945   case MVT::ppcf128:
18946     return SDValue();
18947   case MVT::f32:
18948     if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
18949         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
18950       Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
18951                             bitcastToAPInt().getZExtValue(), SDLoc(CFP),
18952                             MVT::i32);
18953       return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
18954     }
18955 
18956     return SDValue();
18957   case MVT::f64:
18958     if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
18959          ST->isSimple()) ||
18960         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
18961       Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
18962                             getZExtValue(), SDLoc(CFP), MVT::i64);
18963       return DAG.getStore(Chain, DL, Tmp,
18964                           Ptr, ST->getMemOperand());
18965     }
18966 
18967     if (ST->isSimple() &&
18968         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
18969       // Many FP stores are not made apparent until after legalize, e.g. for
18970       // argument passing.  Since this is so common, custom legalize the
18971       // 64-bit integer store into two 32-bit stores.
18972       uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
18973       SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
18974       SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
18975       if (DAG.getDataLayout().isBigEndian())
18976         std::swap(Lo, Hi);
18977 
18978       MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
18979       AAMDNodes AAInfo = ST->getAAInfo();
18980 
18981       SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
18982                                  ST->getOriginalAlign(), MMOFlags, AAInfo);
18983       Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(4), DL);
18984       SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
18985                                  ST->getPointerInfo().getWithOffset(4),
18986                                  ST->getOriginalAlign(), MMOFlags, AAInfo);
18987       return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
18988                          St0, St1);
18989     }
18990 
18991     return SDValue();
18992   }
18993 }
18994 
visitSTORE(SDNode * N)18995 SDValue DAGCombiner::visitSTORE(SDNode *N) {
18996   StoreSDNode *ST  = cast<StoreSDNode>(N);
18997   SDValue Chain = ST->getChain();
18998   SDValue Value = ST->getValue();
18999   SDValue Ptr   = ST->getBasePtr();
19000 
19001   // If this is a store of a bit convert, store the input value if the
19002   // resultant store does not need a higher alignment than the original.
19003   if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
19004       ST->isUnindexed()) {
19005     EVT SVT = Value.getOperand(0).getValueType();
19006     // If the store is volatile, we only want to change the store type if the
19007     // resulting store is legal. Otherwise we might increase the number of
19008     // memory accesses. We don't care if the original type was legal or not
19009     // as we assume software couldn't rely on the number of accesses of an
19010     // illegal type.
19011     // TODO: May be able to relax for unordered atomics (see D66309)
19012     if (((!LegalOperations && ST->isSimple()) ||
19013          TLI.isOperationLegal(ISD::STORE, SVT)) &&
19014         TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT,
19015                                      DAG, *ST->getMemOperand())) {
19016       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
19017                           ST->getMemOperand());
19018     }
19019   }
19020 
19021   // Turn 'store undef, Ptr' -> nothing.
19022   if (Value.isUndef() && ST->isUnindexed())
19023     return Chain;
19024 
19025   // Try to infer better alignment information than the store already has.
19026   if (OptLevel != CodeGenOpt::None && ST->isUnindexed() && !ST->isAtomic()) {
19027     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
19028       if (*Alignment > ST->getAlign() &&
19029           isAligned(*Alignment, ST->getSrcValueOffset())) {
19030         SDValue NewStore =
19031             DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
19032                               ST->getMemoryVT(), *Alignment,
19033                               ST->getMemOperand()->getFlags(), ST->getAAInfo());
19034         // NewStore will always be N as we are only refining the alignment
19035         assert(NewStore.getNode() == N);
19036         (void)NewStore;
19037       }
19038     }
19039   }
19040 
19041   // Try transforming a pair floating point load / store ops to integer
19042   // load / store ops.
19043   if (SDValue NewST = TransformFPLoadStorePair(N))
19044     return NewST;
19045 
19046   // Try transforming several stores into STORE (BSWAP).
19047   if (SDValue Store = mergeTruncStores(ST))
19048     return Store;
19049 
19050   if (ST->isUnindexed()) {
19051     // Walk up chain skipping non-aliasing memory nodes, on this store and any
19052     // adjacent stores.
19053     if (findBetterNeighborChains(ST)) {
19054       // replaceStoreChain uses CombineTo, which handled all of the worklist
19055       // manipulation. Return the original node to not do anything else.
19056       return SDValue(ST, 0);
19057     }
19058     Chain = ST->getChain();
19059   }
19060 
19061   // FIXME: is there such a thing as a truncating indexed store?
19062   if (ST->isTruncatingStore() && ST->isUnindexed() &&
19063       Value.getValueType().isInteger() &&
19064       (!isa<ConstantSDNode>(Value) ||
19065        !cast<ConstantSDNode>(Value)->isOpaque())) {
19066     // Convert a truncating store of a extension into a standard store.
19067     if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
19068          Value.getOpcode() == ISD::SIGN_EXTEND ||
19069          Value.getOpcode() == ISD::ANY_EXTEND) &&
19070         Value.getOperand(0).getValueType() == ST->getMemoryVT() &&
19071         TLI.isOperationLegalOrCustom(ISD::STORE, ST->getMemoryVT()))
19072       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
19073                           ST->getMemOperand());
19074 
19075     APInt TruncDemandedBits =
19076         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
19077                              ST->getMemoryVT().getScalarSizeInBits());
19078 
19079     // See if we can simplify the input to this truncstore with knowledge that
19080     // only the low bits are being used.  For example:
19081     // "truncstore (or (shl x, 8), y), i8"  -> "truncstore y, i8"
19082     AddToWorklist(Value.getNode());
19083     if (SDValue Shorter = DAG.GetDemandedBits(Value, TruncDemandedBits))
19084       return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr, ST->getMemoryVT(),
19085                                ST->getMemOperand());
19086 
19087     // Otherwise, see if we can simplify the operation with
19088     // SimplifyDemandedBits, which only works if the value has a single use.
19089     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
19090       // Re-visit the store if anything changed and the store hasn't been merged
19091       // with another node (N is deleted) SimplifyDemandedBits will add Value's
19092       // node back to the worklist if necessary, but we also need to re-visit
19093       // the Store node itself.
19094       if (N->getOpcode() != ISD::DELETED_NODE)
19095         AddToWorklist(N);
19096       return SDValue(N, 0);
19097     }
19098   }
19099 
19100   // If this is a load followed by a store to the same location, then the store
19101   // is dead/noop.
19102   // TODO: Can relax for unordered atomics (see D66309)
19103   if (LoadSDNode *Ld = dyn_cast<LoadSDNode>(Value)) {
19104     if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
19105         ST->isUnindexed() && ST->isSimple() &&
19106         Ld->getAddressSpace() == ST->getAddressSpace() &&
19107         // There can't be any side effects between the load and store, such as
19108         // a call or store.
19109         Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) {
19110       // The store is dead, remove it.
19111       return Chain;
19112     }
19113   }
19114 
19115   // TODO: Can relax for unordered atomics (see D66309)
19116   if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
19117     if (ST->isUnindexed() && ST->isSimple() &&
19118         ST1->isUnindexed() && ST1->isSimple()) {
19119       if (OptLevel != CodeGenOpt::None && ST1->getBasePtr() == Ptr &&
19120           ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
19121           ST->getAddressSpace() == ST1->getAddressSpace()) {
19122         // If this is a store followed by a store with the same value to the
19123         // same location, then the store is dead/noop.
19124         return Chain;
19125       }
19126 
19127       if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() &&
19128           !ST1->getBasePtr().isUndef() &&
19129           // BaseIndexOffset and the code below requires knowing the size
19130           // of a vector, so bail out if MemoryVT is scalable.
19131           !ST->getMemoryVT().isScalableVector() &&
19132           !ST1->getMemoryVT().isScalableVector() &&
19133           ST->getAddressSpace() == ST1->getAddressSpace()) {
19134         const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG);
19135         const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG);
19136         unsigned STBitSize = ST->getMemoryVT().getFixedSizeInBits();
19137         unsigned ChainBitSize = ST1->getMemoryVT().getFixedSizeInBits();
19138         // If this is a store who's preceding store to a subset of the current
19139         // location and no one other node is chained to that store we can
19140         // effectively drop the store. Do not remove stores to undef as they may
19141         // be used as data sinks.
19142         if (STBase.contains(DAG, STBitSize, ChainBase, ChainBitSize)) {
19143           CombineTo(ST1, ST1->getChain());
19144           return SDValue();
19145         }
19146       }
19147     }
19148   }
19149 
19150   // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
19151   // truncating store.  We can do this even if this is already a truncstore.
19152   if ((Value.getOpcode() == ISD::FP_ROUND ||
19153        Value.getOpcode() == ISD::TRUNCATE) &&
19154       Value->hasOneUse() && ST->isUnindexed() &&
19155       TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
19156                                ST->getMemoryVT(), LegalOperations)) {
19157     return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0),
19158                              Ptr, ST->getMemoryVT(), ST->getMemOperand());
19159   }
19160 
19161   // Always perform this optimization before types are legal. If the target
19162   // prefers, also try this after legalization to catch stores that were created
19163   // by intrinsics or other nodes.
19164   if (!LegalTypes || (TLI.mergeStoresAfterLegalization(ST->getMemoryVT()))) {
19165     while (true) {
19166       // There can be multiple store sequences on the same chain.
19167       // Keep trying to merge store sequences until we are unable to do so
19168       // or until we merge the last store on the chain.
19169       bool Changed = mergeConsecutiveStores(ST);
19170       if (!Changed) break;
19171       // Return N as merge only uses CombineTo and no worklist clean
19172       // up is necessary.
19173       if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
19174         return SDValue(N, 0);
19175     }
19176   }
19177 
19178   // Try transforming N to an indexed store.
19179   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
19180     return SDValue(N, 0);
19181 
19182   // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
19183   //
19184   // Make sure to do this only after attempting to merge stores in order to
19185   //  avoid changing the types of some subset of stores due to visit order,
19186   //  preventing their merging.
19187   if (isa<ConstantFPSDNode>(ST->getValue())) {
19188     if (SDValue NewSt = replaceStoreOfFPConstant(ST))
19189       return NewSt;
19190   }
19191 
19192   if (SDValue NewSt = splitMergedValStore(ST))
19193     return NewSt;
19194 
19195   return ReduceLoadOpStoreWidth(N);
19196 }
19197 
visitLIFETIME_END(SDNode * N)19198 SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
19199   const auto *LifetimeEnd = cast<LifetimeSDNode>(N);
19200   if (!LifetimeEnd->hasOffset())
19201     return SDValue();
19202 
19203   const BaseIndexOffset LifetimeEndBase(N->getOperand(1), SDValue(),
19204                                         LifetimeEnd->getOffset(), false);
19205 
19206   // We walk up the chains to find stores.
19207   SmallVector<SDValue, 8> Chains = {N->getOperand(0)};
19208   while (!Chains.empty()) {
19209     SDValue Chain = Chains.pop_back_val();
19210     if (!Chain.hasOneUse())
19211       continue;
19212     switch (Chain.getOpcode()) {
19213     case ISD::TokenFactor:
19214       for (unsigned Nops = Chain.getNumOperands(); Nops;)
19215         Chains.push_back(Chain.getOperand(--Nops));
19216       break;
19217     case ISD::LIFETIME_START:
19218     case ISD::LIFETIME_END:
19219       // We can forward past any lifetime start/end that can be proven not to
19220       // alias the node.
19221       if (!mayAlias(Chain.getNode(), N))
19222         Chains.push_back(Chain.getOperand(0));
19223       break;
19224     case ISD::STORE: {
19225       StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain);
19226       // TODO: Can relax for unordered atomics (see D66309)
19227       if (!ST->isSimple() || ST->isIndexed())
19228         continue;
19229       const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
19230       // The bounds of a scalable store are not known until runtime, so this
19231       // store cannot be elided.
19232       if (StoreSize.isScalable())
19233         continue;
19234       const BaseIndexOffset StoreBase = BaseIndexOffset::match(ST, DAG);
19235       // If we store purely within object bounds just before its lifetime ends,
19236       // we can remove the store.
19237       if (LifetimeEndBase.contains(DAG, LifetimeEnd->getSize() * 8, StoreBase,
19238                                    StoreSize.getFixedSize() * 8)) {
19239         LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
19240                    dbgs() << "\nwithin LIFETIME_END of : ";
19241                    LifetimeEndBase.dump(); dbgs() << "\n");
19242         CombineTo(ST, ST->getChain());
19243         return SDValue(N, 0);
19244       }
19245     }
19246     }
19247   }
19248   return SDValue();
19249 }
19250 
19251 /// For the instruction sequence of store below, F and I values
19252 /// are bundled together as an i64 value before being stored into memory.
19253 /// Sometimes it is more efficent to generate separate stores for F and I,
19254 /// which can remove the bitwise instructions or sink them to colder places.
19255 ///
19256 ///   (store (or (zext (bitcast F to i32) to i64),
19257 ///              (shl (zext I to i64), 32)), addr)  -->
19258 ///   (store F, addr) and (store I, addr+4)
19259 ///
19260 /// Similarly, splitting for other merged store can also be beneficial, like:
19261 /// For pair of {i32, i32}, i64 store --> two i32 stores.
19262 /// For pair of {i32, i16}, i64 store --> two i32 stores.
19263 /// For pair of {i16, i16}, i32 store --> two i16 stores.
19264 /// For pair of {i16, i8},  i32 store --> two i16 stores.
19265 /// For pair of {i8, i8},   i16 store --> two i8 stores.
19266 ///
19267 /// We allow each target to determine specifically which kind of splitting is
19268 /// supported.
19269 ///
19270 /// The store patterns are commonly seen from the simple code snippet below
19271 /// if only std::make_pair(...) is sroa transformed before inlined into hoo.
19272 ///   void goo(const std::pair<int, float> &);
19273 ///   hoo() {
19274 ///     ...
19275 ///     goo(std::make_pair(tmp, ftmp));
19276 ///     ...
19277 ///   }
19278 ///
splitMergedValStore(StoreSDNode * ST)19279 SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
19280   if (OptLevel == CodeGenOpt::None)
19281     return SDValue();
19282 
19283   // Can't change the number of memory accesses for a volatile store or break
19284   // atomicity for an atomic one.
19285   if (!ST->isSimple())
19286     return SDValue();
19287 
19288   SDValue Val = ST->getValue();
19289   SDLoc DL(ST);
19290 
19291   // Match OR operand.
19292   if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
19293     return SDValue();
19294 
19295   // Match SHL operand and get Lower and Higher parts of Val.
19296   SDValue Op1 = Val.getOperand(0);
19297   SDValue Op2 = Val.getOperand(1);
19298   SDValue Lo, Hi;
19299   if (Op1.getOpcode() != ISD::SHL) {
19300     std::swap(Op1, Op2);
19301     if (Op1.getOpcode() != ISD::SHL)
19302       return SDValue();
19303   }
19304   Lo = Op2;
19305   Hi = Op1.getOperand(0);
19306   if (!Op1.hasOneUse())
19307     return SDValue();
19308 
19309   // Match shift amount to HalfValBitSize.
19310   unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
19311   ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
19312   if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
19313     return SDValue();
19314 
19315   // Lo and Hi are zero-extended from int with size less equal than 32
19316   // to i64.
19317   if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
19318       !Lo.getOperand(0).getValueType().isScalarInteger() ||
19319       Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
19320       Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
19321       !Hi.getOperand(0).getValueType().isScalarInteger() ||
19322       Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
19323     return SDValue();
19324 
19325   // Use the EVT of low and high parts before bitcast as the input
19326   // of target query.
19327   EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
19328                   ? Lo.getOperand(0).getValueType()
19329                   : Lo.getValueType();
19330   EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
19331                    ? Hi.getOperand(0).getValueType()
19332                    : Hi.getValueType();
19333   if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
19334     return SDValue();
19335 
19336   // Start to split store.
19337   MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
19338   AAMDNodes AAInfo = ST->getAAInfo();
19339 
19340   // Change the sizes of Lo and Hi's value types to HalfValBitSize.
19341   EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
19342   Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
19343   Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
19344 
19345   SDValue Chain = ST->getChain();
19346   SDValue Ptr = ST->getBasePtr();
19347   // Lower value store.
19348   SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
19349                              ST->getOriginalAlign(), MMOFlags, AAInfo);
19350   Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(HalfValBitSize / 8), DL);
19351   // Higher value store.
19352   SDValue St1 = DAG.getStore(
19353       St0, DL, Hi, Ptr, ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
19354       ST->getOriginalAlign(), MMOFlags, AAInfo);
19355   return St1;
19356 }
19357 
19358 /// Convert a disguised subvector insertion into a shuffle:
combineInsertEltToShuffle(SDNode * N,unsigned InsIndex)19359 SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
19360   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
19361          "Expected extract_vector_elt");
19362   SDValue InsertVal = N->getOperand(1);
19363   SDValue Vec = N->getOperand(0);
19364 
19365   // (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
19366   // InsIndex)
19367   //   --> (vector_shuffle X, Y) and variations where shuffle operands may be
19368   //   CONCAT_VECTORS.
19369   if (Vec.getOpcode() == ISD::VECTOR_SHUFFLE && Vec.hasOneUse() &&
19370       InsertVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
19371       isa<ConstantSDNode>(InsertVal.getOperand(1))) {
19372     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Vec.getNode());
19373     ArrayRef<int> Mask = SVN->getMask();
19374 
19375     SDValue X = Vec.getOperand(0);
19376     SDValue Y = Vec.getOperand(1);
19377 
19378     // Vec's operand 0 is using indices from 0 to N-1 and
19379     // operand 1 from N to 2N - 1, where N is the number of
19380     // elements in the vectors.
19381     SDValue InsertVal0 = InsertVal.getOperand(0);
19382     int ElementOffset = -1;
19383 
19384     // We explore the inputs of the shuffle in order to see if we find the
19385     // source of the extract_vector_elt. If so, we can use it to modify the
19386     // shuffle rather than perform an insert_vector_elt.
19387     SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
19388     ArgWorkList.emplace_back(Mask.size(), Y);
19389     ArgWorkList.emplace_back(0, X);
19390 
19391     while (!ArgWorkList.empty()) {
19392       int ArgOffset;
19393       SDValue ArgVal;
19394       std::tie(ArgOffset, ArgVal) = ArgWorkList.pop_back_val();
19395 
19396       if (ArgVal == InsertVal0) {
19397         ElementOffset = ArgOffset;
19398         break;
19399       }
19400 
19401       // Peek through concat_vector.
19402       if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
19403         int CurrentArgOffset =
19404             ArgOffset + ArgVal.getValueType().getVectorNumElements();
19405         int Step = ArgVal.getOperand(0).getValueType().getVectorNumElements();
19406         for (SDValue Op : reverse(ArgVal->ops())) {
19407           CurrentArgOffset -= Step;
19408           ArgWorkList.emplace_back(CurrentArgOffset, Op);
19409         }
19410 
19411         // Make sure we went through all the elements and did not screw up index
19412         // computation.
19413         assert(CurrentArgOffset == ArgOffset);
19414       }
19415     }
19416 
19417     // If we failed to find a match, see if we can replace an UNDEF shuffle
19418     // operand.
19419     if (ElementOffset == -1 && Y.isUndef() &&
19420         InsertVal0.getValueType() == Y.getValueType()) {
19421       ElementOffset = Mask.size();
19422       Y = InsertVal0;
19423     }
19424 
19425     if (ElementOffset != -1) {
19426       SmallVector<int, 16> NewMask(Mask.begin(), Mask.end());
19427 
19428       auto *ExtrIndex = cast<ConstantSDNode>(InsertVal.getOperand(1));
19429       NewMask[InsIndex] = ElementOffset + ExtrIndex->getZExtValue();
19430       assert(NewMask[InsIndex] <
19431                  (int)(2 * Vec.getValueType().getVectorNumElements()) &&
19432              NewMask[InsIndex] >= 0 && "NewMask[InsIndex] is out of bound");
19433 
19434       SDValue LegalShuffle =
19435               TLI.buildLegalVectorShuffle(Vec.getValueType(), SDLoc(N), X,
19436                                           Y, NewMask, DAG);
19437       if (LegalShuffle)
19438         return LegalShuffle;
19439     }
19440   }
19441 
19442   // insert_vector_elt V, (bitcast X from vector type), IdxC -->
19443   // bitcast(shuffle (bitcast V), (extended X), Mask)
19444   // Note: We do not use an insert_subvector node because that requires a
19445   // legal subvector type.
19446   if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
19447       !InsertVal.getOperand(0).getValueType().isVector())
19448     return SDValue();
19449 
19450   SDValue SubVec = InsertVal.getOperand(0);
19451   SDValue DestVec = N->getOperand(0);
19452   EVT SubVecVT = SubVec.getValueType();
19453   EVT VT = DestVec.getValueType();
19454   unsigned NumSrcElts = SubVecVT.getVectorNumElements();
19455   // If the source only has a single vector element, the cost of creating adding
19456   // it to a vector is likely to exceed the cost of a insert_vector_elt.
19457   if (NumSrcElts == 1)
19458     return SDValue();
19459   unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
19460   unsigned NumMaskVals = ExtendRatio * NumSrcElts;
19461 
19462   // Step 1: Create a shuffle mask that implements this insert operation. The
19463   // vector that we are inserting into will be operand 0 of the shuffle, so
19464   // those elements are just 'i'. The inserted subvector is in the first
19465   // positions of operand 1 of the shuffle. Example:
19466   // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
19467   SmallVector<int, 16> Mask(NumMaskVals);
19468   for (unsigned i = 0; i != NumMaskVals; ++i) {
19469     if (i / NumSrcElts == InsIndex)
19470       Mask[i] = (i % NumSrcElts) + NumMaskVals;
19471     else
19472       Mask[i] = i;
19473   }
19474 
19475   // Bail out if the target can not handle the shuffle we want to create.
19476   EVT SubVecEltVT = SubVecVT.getVectorElementType();
19477   EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
19478   if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
19479     return SDValue();
19480 
19481   // Step 2: Create a wide vector from the inserted source vector by appending
19482   // undefined elements. This is the same size as our destination vector.
19483   SDLoc DL(N);
19484   SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
19485   ConcatOps[0] = SubVec;
19486   SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
19487 
19488   // Step 3: Shuffle in the padded subvector.
19489   SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
19490   SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
19491   AddToWorklist(PaddedSubV.getNode());
19492   AddToWorklist(DestVecBC.getNode());
19493   AddToWorklist(Shuf.getNode());
19494   return DAG.getBitcast(VT, Shuf);
19495 }
19496 
visitINSERT_VECTOR_ELT(SDNode * N)19497 SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
19498   SDValue InVec = N->getOperand(0);
19499   SDValue InVal = N->getOperand(1);
19500   SDValue EltNo = N->getOperand(2);
19501   SDLoc DL(N);
19502 
19503   EVT VT = InVec.getValueType();
19504   auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
19505 
19506   // Insert into out-of-bounds element is undefined.
19507   if (IndexC && VT.isFixedLengthVector() &&
19508       IndexC->getZExtValue() >= VT.getVectorNumElements())
19509     return DAG.getUNDEF(VT);
19510 
19511   // Remove redundant insertions:
19512   // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
19513   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
19514       InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
19515     return InVec;
19516 
19517   if (!IndexC) {
19518     // If this is variable insert to undef vector, it might be better to splat:
19519     // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
19520     if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) {
19521       if (VT.isScalableVector())
19522         return DAG.getSplatVector(VT, DL, InVal);
19523 
19524       SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), InVal);
19525       return DAG.getBuildVector(VT, DL, Ops);
19526     }
19527     return SDValue();
19528   }
19529 
19530   if (VT.isScalableVector())
19531     return SDValue();
19532 
19533   unsigned NumElts = VT.getVectorNumElements();
19534 
19535   // We must know which element is being inserted for folds below here.
19536   unsigned Elt = IndexC->getZExtValue();
19537 
19538   if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
19539     return Shuf;
19540 
19541   // Handle <1 x ???> vector insertion special cases.
19542   if (NumElts == 1) {
19543     // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
19544     if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
19545         InVal.getOperand(0).getValueType() == VT &&
19546         isNullConstant(InVal.getOperand(1)))
19547       return InVal.getOperand(0);
19548   }
19549 
19550   // Canonicalize insert_vector_elt dag nodes.
19551   // Example:
19552   // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
19553   // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
19554   //
19555   // Do this only if the child insert_vector node has one use; also
19556   // do this only if indices are both constants and Idx1 < Idx0.
19557   if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
19558       && isa<ConstantSDNode>(InVec.getOperand(2))) {
19559     unsigned OtherElt = InVec.getConstantOperandVal(2);
19560     if (Elt < OtherElt) {
19561       // Swap nodes.
19562       SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
19563                                   InVec.getOperand(0), InVal, EltNo);
19564       AddToWorklist(NewOp.getNode());
19565       return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
19566                          VT, NewOp, InVec.getOperand(1), InVec.getOperand(2));
19567     }
19568   }
19569 
19570   // Attempt to convert an insert_vector_elt chain into a legal build_vector.
19571   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) {
19572     // vXi1 vector - we don't need to recurse.
19573     if (NumElts == 1)
19574       return DAG.getBuildVector(VT, DL, {InVal});
19575 
19576     // If we haven't already collected the element, insert into the op list.
19577     EVT MaxEltVT = InVal.getValueType();
19578     auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
19579                                 unsigned Idx) {
19580       if (!Ops[Idx]) {
19581         Ops[Idx] = Elt;
19582         if (VT.isInteger()) {
19583           EVT EltVT = Elt.getValueType();
19584           MaxEltVT = MaxEltVT.bitsGE(EltVT) ? MaxEltVT : EltVT;
19585         }
19586       }
19587     };
19588 
19589     // Ensure all the operands are the same value type, fill any missing
19590     // operands with UNDEF and create the BUILD_VECTOR.
19591     auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops) {
19592       assert(Ops.size() == NumElts && "Unexpected vector size");
19593       for (SDValue &Op : Ops) {
19594         if (Op)
19595           Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, MaxEltVT) : Op;
19596         else
19597           Op = DAG.getUNDEF(MaxEltVT);
19598       }
19599       return DAG.getBuildVector(VT, DL, Ops);
19600     };
19601 
19602     SmallVector<SDValue, 8> Ops(NumElts, SDValue());
19603     Ops[Elt] = InVal;
19604 
19605     // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
19606     for (SDValue CurVec = InVec; CurVec;) {
19607       // UNDEF - build new BUILD_VECTOR from already inserted operands.
19608       if (CurVec.isUndef())
19609         return CanonicalizeBuildVector(Ops);
19610 
19611       // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
19612       if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
19613         for (unsigned I = 0; I != NumElts; ++I)
19614           AddBuildVectorOp(Ops, CurVec.getOperand(I), I);
19615         return CanonicalizeBuildVector(Ops);
19616       }
19617 
19618       // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
19619       if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
19620         AddBuildVectorOp(Ops, CurVec.getOperand(0), 0);
19621         return CanonicalizeBuildVector(Ops);
19622       }
19623 
19624       // INSERT_VECTOR_ELT - insert operand and continue up the chain.
19625       if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
19626         if (auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2)))
19627           if (CurIdx->getAPIntValue().ult(NumElts)) {
19628             unsigned Idx = CurIdx->getZExtValue();
19629             AddBuildVectorOp(Ops, CurVec.getOperand(1), Idx);
19630 
19631             // Found entire BUILD_VECTOR.
19632             if (all_of(Ops, [](SDValue Op) { return !!Op; }))
19633               return CanonicalizeBuildVector(Ops);
19634 
19635             CurVec = CurVec->getOperand(0);
19636             continue;
19637           }
19638 
19639       // Failed to find a match in the chain - bail.
19640       break;
19641     }
19642   }
19643 
19644   return SDValue();
19645 }
19646 
scalarizeExtractedVectorLoad(SDNode * EVE,EVT InVecVT,SDValue EltNo,LoadSDNode * OriginalLoad)19647 SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
19648                                                   SDValue EltNo,
19649                                                   LoadSDNode *OriginalLoad) {
19650   assert(OriginalLoad->isSimple());
19651 
19652   EVT ResultVT = EVE->getValueType(0);
19653   EVT VecEltVT = InVecVT.getVectorElementType();
19654 
19655   // If the vector element type is not a multiple of a byte then we are unable
19656   // to correctly compute an address to load only the extracted element as a
19657   // scalar.
19658   if (!VecEltVT.isByteSized())
19659     return SDValue();
19660 
19661   ISD::LoadExtType ExtTy =
19662       ResultVT.bitsGT(VecEltVT) ? ISD::NON_EXTLOAD : ISD::EXTLOAD;
19663   if (!TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
19664       !TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
19665     return SDValue();
19666 
19667   Align Alignment = OriginalLoad->getAlign();
19668   MachinePointerInfo MPI;
19669   SDLoc DL(EVE);
19670   if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
19671     int Elt = ConstEltNo->getZExtValue();
19672     unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
19673     MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
19674     Alignment = commonAlignment(Alignment, PtrOff);
19675   } else {
19676     // Discard the pointer info except the address space because the memory
19677     // operand can't represent this new access since the offset is variable.
19678     MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
19679     Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
19680   }
19681 
19682   bool IsFast = false;
19683   if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
19684                               OriginalLoad->getAddressSpace(), Alignment,
19685                               OriginalLoad->getMemOperand()->getFlags(),
19686                               &IsFast) ||
19687       !IsFast)
19688     return SDValue();
19689 
19690   SDValue NewPtr = TLI.getVectorElementPointer(DAG, OriginalLoad->getBasePtr(),
19691                                                InVecVT, EltNo);
19692 
19693   // We are replacing a vector load with a scalar load. The new load must have
19694   // identical memory op ordering to the original.
19695   SDValue Load;
19696   if (ResultVT.bitsGT(VecEltVT)) {
19697     // If the result type of vextract is wider than the load, then issue an
19698     // extending load instead.
19699     ISD::LoadExtType ExtType =
19700         TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) ? ISD::ZEXTLOAD
19701                                                               : ISD::EXTLOAD;
19702     Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
19703                           NewPtr, MPI, VecEltVT, Alignment,
19704                           OriginalLoad->getMemOperand()->getFlags(),
19705                           OriginalLoad->getAAInfo());
19706     DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
19707   } else {
19708     // The result type is narrower or the same width as the vector element
19709     Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
19710                        Alignment, OriginalLoad->getMemOperand()->getFlags(),
19711                        OriginalLoad->getAAInfo());
19712     DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
19713     if (ResultVT.bitsLT(VecEltVT))
19714       Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
19715     else
19716       Load = DAG.getBitcast(ResultVT, Load);
19717   }
19718   ++OpsNarrowed;
19719   return Load;
19720 }
19721 
19722 /// Transform a vector binary operation into a scalar binary operation by moving
19723 /// the math/logic after an extract element of a vector.
scalarizeExtractedBinop(SDNode * ExtElt,SelectionDAG & DAG,bool LegalOperations)19724 static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
19725                                        bool LegalOperations) {
19726   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19727   SDValue Vec = ExtElt->getOperand(0);
19728   SDValue Index = ExtElt->getOperand(1);
19729   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
19730   if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
19731       Vec->getNumValues() != 1)
19732     return SDValue();
19733 
19734   // Targets may want to avoid this to prevent an expensive register transfer.
19735   if (!TLI.shouldScalarizeBinop(Vec))
19736     return SDValue();
19737 
19738   // Extracting an element of a vector constant is constant-folded, so this
19739   // transform is just replacing a vector op with a scalar op while moving the
19740   // extract.
19741   SDValue Op0 = Vec.getOperand(0);
19742   SDValue Op1 = Vec.getOperand(1);
19743   APInt SplatVal;
19744   if (isAnyConstantBuildVector(Op0, true) ||
19745       ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
19746       isAnyConstantBuildVector(Op1, true) ||
19747       ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
19748     // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
19749     // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
19750     SDLoc DL(ExtElt);
19751     EVT VT = ExtElt->getValueType(0);
19752     SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
19753     SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
19754     return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
19755   }
19756 
19757   return SDValue();
19758 }
19759 
visitEXTRACT_VECTOR_ELT(SDNode * N)19760 SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
19761   SDValue VecOp = N->getOperand(0);
19762   SDValue Index = N->getOperand(1);
19763   EVT ScalarVT = N->getValueType(0);
19764   EVT VecVT = VecOp.getValueType();
19765   if (VecOp.isUndef())
19766     return DAG.getUNDEF(ScalarVT);
19767 
19768   // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
19769   //
19770   // This only really matters if the index is non-constant since other combines
19771   // on the constant elements already work.
19772   SDLoc DL(N);
19773   if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
19774       Index == VecOp.getOperand(2)) {
19775     SDValue Elt = VecOp.getOperand(1);
19776     return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
19777   }
19778 
19779   // (vextract (scalar_to_vector val, 0) -> val
19780   if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
19781     // Only 0'th element of SCALAR_TO_VECTOR is defined.
19782     if (DAG.isKnownNeverZero(Index))
19783       return DAG.getUNDEF(ScalarVT);
19784 
19785     // Check if the result type doesn't match the inserted element type. A
19786     // SCALAR_TO_VECTOR may truncate the inserted element and the
19787     // EXTRACT_VECTOR_ELT may widen the extracted vector.
19788     SDValue InOp = VecOp.getOperand(0);
19789     if (InOp.getValueType() != ScalarVT) {
19790       assert(InOp.getValueType().isInteger() && ScalarVT.isInteger() &&
19791              InOp.getValueType().bitsGT(ScalarVT));
19792       return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, InOp);
19793     }
19794     return InOp;
19795   }
19796 
19797   // extract_vector_elt of out-of-bounds element -> UNDEF
19798   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
19799   if (IndexC && VecVT.isFixedLengthVector() &&
19800       IndexC->getAPIntValue().uge(VecVT.getVectorNumElements()))
19801     return DAG.getUNDEF(ScalarVT);
19802 
19803   // extract_vector_elt (build_vector x, y), 1 -> y
19804   if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
19805        VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
19806       TLI.isTypeLegal(VecVT) &&
19807       (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) {
19808     assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
19809             VecVT.isFixedLengthVector()) &&
19810            "BUILD_VECTOR used for scalable vectors");
19811     unsigned IndexVal =
19812         VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
19813     SDValue Elt = VecOp.getOperand(IndexVal);
19814     EVT InEltVT = Elt.getValueType();
19815 
19816     // Sometimes build_vector's scalar input types do not match result type.
19817     if (ScalarVT == InEltVT)
19818       return Elt;
19819 
19820     // TODO: It may be useful to truncate if free if the build_vector implicitly
19821     // converts.
19822   }
19823 
19824   if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations))
19825     return BO;
19826 
19827   if (VecVT.isScalableVector())
19828     return SDValue();
19829 
19830   // All the code from this point onwards assumes fixed width vectors, but it's
19831   // possible that some of the combinations could be made to work for scalable
19832   // vectors too.
19833   unsigned NumElts = VecVT.getVectorNumElements();
19834   unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
19835 
19836   // TODO: These transforms should not require the 'hasOneUse' restriction, but
19837   // there are regressions on multiple targets without it. We can end up with a
19838   // mess of scalar and vector code if we reduce only part of the DAG to scalar.
19839   if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
19840       VecOp.hasOneUse()) {
19841     // The vector index of the LSBs of the source depend on the endian-ness.
19842     bool IsLE = DAG.getDataLayout().isLittleEndian();
19843     unsigned ExtractIndex = IndexC->getZExtValue();
19844     // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
19845     unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
19846     SDValue BCSrc = VecOp.getOperand(0);
19847     if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
19848       return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, BCSrc);
19849 
19850     if (LegalTypes && BCSrc.getValueType().isInteger() &&
19851         BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
19852       // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
19853       // trunc i64 X to i32
19854       SDValue X = BCSrc.getOperand(0);
19855       assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
19856              "Extract element and scalar to vector can't change element type "
19857              "from FP to integer.");
19858       unsigned XBitWidth = X.getValueSizeInBits();
19859       BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
19860 
19861       // An extract element return value type can be wider than its vector
19862       // operand element type. In that case, the high bits are undefined, so
19863       // it's possible that we may need to extend rather than truncate.
19864       if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
19865         assert(XBitWidth % VecEltBitWidth == 0 &&
19866                "Scalar bitwidth must be a multiple of vector element bitwidth");
19867         return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
19868       }
19869     }
19870   }
19871 
19872   // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
19873   // We only perform this optimization before the op legalization phase because
19874   // we may introduce new vector instructions which are not backed by TD
19875   // patterns. For example on AVX, extracting elements from a wide vector
19876   // without using extract_subvector. However, if we can find an underlying
19877   // scalar value, then we can always use that.
19878   if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
19879     auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
19880     // Find the new index to extract from.
19881     int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
19882 
19883     // Extracting an undef index is undef.
19884     if (OrigElt == -1)
19885       return DAG.getUNDEF(ScalarVT);
19886 
19887     // Select the right vector half to extract from.
19888     SDValue SVInVec;
19889     if (OrigElt < (int)NumElts) {
19890       SVInVec = VecOp.getOperand(0);
19891     } else {
19892       SVInVec = VecOp.getOperand(1);
19893       OrigElt -= NumElts;
19894     }
19895 
19896     if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
19897       SDValue InOp = SVInVec.getOperand(OrigElt);
19898       if (InOp.getValueType() != ScalarVT) {
19899         assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
19900         InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
19901       }
19902 
19903       return InOp;
19904     }
19905 
19906     // FIXME: We should handle recursing on other vector shuffles and
19907     // scalar_to_vector here as well.
19908 
19909     if (!LegalOperations ||
19910         // FIXME: Should really be just isOperationLegalOrCustom.
19911         TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
19912         TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
19913       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
19914                          DAG.getVectorIdxConstant(OrigElt, DL));
19915     }
19916   }
19917 
19918   // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
19919   // simplify it based on the (valid) extraction indices.
19920   if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) {
19921         return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
19922                Use->getOperand(0) == VecOp &&
19923                isa<ConstantSDNode>(Use->getOperand(1));
19924       })) {
19925     APInt DemandedElts = APInt::getZero(NumElts);
19926     for (SDNode *Use : VecOp->uses()) {
19927       auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1));
19928       if (CstElt->getAPIntValue().ult(NumElts))
19929         DemandedElts.setBit(CstElt->getZExtValue());
19930     }
19931     if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
19932       // We simplified the vector operand of this extract element. If this
19933       // extract is not dead, visit it again so it is folded properly.
19934       if (N->getOpcode() != ISD::DELETED_NODE)
19935         AddToWorklist(N);
19936       return SDValue(N, 0);
19937     }
19938     APInt DemandedBits = APInt::getAllOnes(VecEltBitWidth);
19939     if (SimplifyDemandedBits(VecOp, DemandedBits, DemandedElts, true)) {
19940       // We simplified the vector operand of this extract element. If this
19941       // extract is not dead, visit it again so it is folded properly.
19942       if (N->getOpcode() != ISD::DELETED_NODE)
19943         AddToWorklist(N);
19944       return SDValue(N, 0);
19945     }
19946   }
19947 
19948   // Everything under here is trying to match an extract of a loaded value.
19949   // If the result of load has to be truncated, then it's not necessarily
19950   // profitable.
19951   bool BCNumEltsChanged = false;
19952   EVT ExtVT = VecVT.getVectorElementType();
19953   EVT LVT = ExtVT;
19954   if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
19955     return SDValue();
19956 
19957   if (VecOp.getOpcode() == ISD::BITCAST) {
19958     // Don't duplicate a load with other uses.
19959     if (!VecOp.hasOneUse())
19960       return SDValue();
19961 
19962     EVT BCVT = VecOp.getOperand(0).getValueType();
19963     if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
19964       return SDValue();
19965     if (NumElts != BCVT.getVectorNumElements())
19966       BCNumEltsChanged = true;
19967     VecOp = VecOp.getOperand(0);
19968     ExtVT = BCVT.getVectorElementType();
19969   }
19970 
19971   // extract (vector load $addr), i --> load $addr + i * size
19972   if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
19973       ISD::isNormalLoad(VecOp.getNode()) &&
19974       !Index->hasPredecessor(VecOp.getNode())) {
19975     auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
19976     if (VecLoad && VecLoad->isSimple())
19977       return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
19978   }
19979 
19980   // Perform only after legalization to ensure build_vector / vector_shuffle
19981   // optimizations have already been done.
19982   if (!LegalOperations || !IndexC)
19983     return SDValue();
19984 
19985   // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
19986   // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
19987   // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
19988   int Elt = IndexC->getZExtValue();
19989   LoadSDNode *LN0 = nullptr;
19990   if (ISD::isNormalLoad(VecOp.getNode())) {
19991     LN0 = cast<LoadSDNode>(VecOp);
19992   } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
19993              VecOp.getOperand(0).getValueType() == ExtVT &&
19994              ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
19995     // Don't duplicate a load with other uses.
19996     if (!VecOp.hasOneUse())
19997       return SDValue();
19998 
19999     LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
20000   }
20001   if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
20002     // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
20003     // =>
20004     // (load $addr+1*size)
20005 
20006     // Don't duplicate a load with other uses.
20007     if (!VecOp.hasOneUse())
20008       return SDValue();
20009 
20010     // If the bit convert changed the number of elements, it is unsafe
20011     // to examine the mask.
20012     if (BCNumEltsChanged)
20013       return SDValue();
20014 
20015     // Select the input vector, guarding against out of range extract vector.
20016     int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
20017     VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
20018 
20019     if (VecOp.getOpcode() == ISD::BITCAST) {
20020       // Don't duplicate a load with other uses.
20021       if (!VecOp.hasOneUse())
20022         return SDValue();
20023 
20024       VecOp = VecOp.getOperand(0);
20025     }
20026     if (ISD::isNormalLoad(VecOp.getNode())) {
20027       LN0 = cast<LoadSDNode>(VecOp);
20028       Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
20029       Index = DAG.getConstant(Elt, DL, Index.getValueType());
20030     }
20031   } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
20032              VecVT.getVectorElementType() == ScalarVT &&
20033              (!LegalTypes ||
20034               TLI.isTypeLegal(
20035                   VecOp.getOperand(0).getValueType().getVectorElementType()))) {
20036     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
20037     //      -> extract_vector_elt a, 0
20038     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
20039     //      -> extract_vector_elt a, 1
20040     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
20041     //      -> extract_vector_elt b, 0
20042     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
20043     //      -> extract_vector_elt b, 1
20044     SDLoc SL(N);
20045     EVT ConcatVT = VecOp.getOperand(0).getValueType();
20046     unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
20047     SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, SL,
20048                                      Index.getValueType());
20049 
20050     SDValue ConcatOp = VecOp.getOperand(Elt / ConcatNumElts);
20051     SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL,
20052                               ConcatVT.getVectorElementType(),
20053                               ConcatOp, NewIdx);
20054     return DAG.getNode(ISD::BITCAST, SL, ScalarVT, Elt);
20055   }
20056 
20057   // Make sure we found a non-volatile load and the extractelement is
20058   // the only use.
20059   if (!LN0 || !LN0->hasNUsesOfValue(1,0) || !LN0->isSimple())
20060     return SDValue();
20061 
20062   // If Idx was -1 above, Elt is going to be -1, so just return undef.
20063   if (Elt == -1)
20064     return DAG.getUNDEF(LVT);
20065 
20066   return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
20067 }
20068 
20069 // Simplify (build_vec (ext )) to (bitcast (build_vec ))
reduceBuildVecExtToExtBuildVec(SDNode * N)20070 SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
20071   // We perform this optimization post type-legalization because
20072   // the type-legalizer often scalarizes integer-promoted vectors.
20073   // Performing this optimization before may create bit-casts which
20074   // will be type-legalized to complex code sequences.
20075   // We perform this optimization only before the operation legalizer because we
20076   // may introduce illegal operations.
20077   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
20078     return SDValue();
20079 
20080   unsigned NumInScalars = N->getNumOperands();
20081   SDLoc DL(N);
20082   EVT VT = N->getValueType(0);
20083 
20084   // Check to see if this is a BUILD_VECTOR of a bunch of values
20085   // which come from any_extend or zero_extend nodes. If so, we can create
20086   // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
20087   // optimizations. We do not handle sign-extend because we can't fill the sign
20088   // using shuffles.
20089   EVT SourceType = MVT::Other;
20090   bool AllAnyExt = true;
20091 
20092   for (unsigned i = 0; i != NumInScalars; ++i) {
20093     SDValue In = N->getOperand(i);
20094     // Ignore undef inputs.
20095     if (In.isUndef()) continue;
20096 
20097     bool AnyExt  = In.getOpcode() == ISD::ANY_EXTEND;
20098     bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
20099 
20100     // Abort if the element is not an extension.
20101     if (!ZeroExt && !AnyExt) {
20102       SourceType = MVT::Other;
20103       break;
20104     }
20105 
20106     // The input is a ZeroExt or AnyExt. Check the original type.
20107     EVT InTy = In.getOperand(0).getValueType();
20108 
20109     // Check that all of the widened source types are the same.
20110     if (SourceType == MVT::Other)
20111       // First time.
20112       SourceType = InTy;
20113     else if (InTy != SourceType) {
20114       // Multiple income types. Abort.
20115       SourceType = MVT::Other;
20116       break;
20117     }
20118 
20119     // Check if all of the extends are ANY_EXTENDs.
20120     AllAnyExt &= AnyExt;
20121   }
20122 
20123   // In order to have valid types, all of the inputs must be extended from the
20124   // same source type and all of the inputs must be any or zero extend.
20125   // Scalar sizes must be a power of two.
20126   EVT OutScalarTy = VT.getScalarType();
20127   bool ValidTypes = SourceType != MVT::Other &&
20128                  isPowerOf2_32(OutScalarTy.getSizeInBits()) &&
20129                  isPowerOf2_32(SourceType.getSizeInBits());
20130 
20131   // Create a new simpler BUILD_VECTOR sequence which other optimizations can
20132   // turn into a single shuffle instruction.
20133   if (!ValidTypes)
20134     return SDValue();
20135 
20136   // If we already have a splat buildvector, then don't fold it if it means
20137   // introducing zeros.
20138   if (!AllAnyExt && DAG.isSplatValue(SDValue(N, 0), /*AllowUndefs*/ true))
20139     return SDValue();
20140 
20141   bool isLE = DAG.getDataLayout().isLittleEndian();
20142   unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
20143   assert(ElemRatio > 1 && "Invalid element size ratio");
20144   SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
20145                                DAG.getConstant(0, DL, SourceType);
20146 
20147   unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
20148   SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
20149 
20150   // Populate the new build_vector
20151   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
20152     SDValue Cast = N->getOperand(i);
20153     assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
20154             Cast.getOpcode() == ISD::ZERO_EXTEND ||
20155             Cast.isUndef()) && "Invalid cast opcode");
20156     SDValue In;
20157     if (Cast.isUndef())
20158       In = DAG.getUNDEF(SourceType);
20159     else
20160       In = Cast->getOperand(0);
20161     unsigned Index = isLE ? (i * ElemRatio) :
20162                             (i * ElemRatio + (ElemRatio - 1));
20163 
20164     assert(Index < Ops.size() && "Invalid index");
20165     Ops[Index] = In;
20166   }
20167 
20168   // The type of the new BUILD_VECTOR node.
20169   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SourceType, NewBVElems);
20170   assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
20171          "Invalid vector size");
20172   // Check if the new vector type is legal.
20173   if (!isTypeLegal(VecVT) ||
20174       (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
20175        TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
20176     return SDValue();
20177 
20178   // Make the new BUILD_VECTOR.
20179   SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
20180 
20181   // The new BUILD_VECTOR node has the potential to be further optimized.
20182   AddToWorklist(BV.getNode());
20183   // Bitcast to the desired type.
20184   return DAG.getBitcast(VT, BV);
20185 }
20186 
20187 // Simplify (build_vec (trunc $1)
20188 //                     (trunc (srl $1 half-width))
20189 //                     (trunc (srl $1 (2 * half-width))) …)
20190 // to (bitcast $1)
reduceBuildVecTruncToBitCast(SDNode * N)20191 SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
20192   assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
20193 
20194   // Only for little endian
20195   if (!DAG.getDataLayout().isLittleEndian())
20196     return SDValue();
20197 
20198   SDLoc DL(N);
20199   EVT VT = N->getValueType(0);
20200   EVT OutScalarTy = VT.getScalarType();
20201   uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
20202 
20203   // Only for power of two types to be sure that bitcast works well
20204   if (!isPowerOf2_64(ScalarTypeBitsize))
20205     return SDValue();
20206 
20207   unsigned NumInScalars = N->getNumOperands();
20208 
20209   // Look through bitcasts
20210   auto PeekThroughBitcast = [](SDValue Op) {
20211     if (Op.getOpcode() == ISD::BITCAST)
20212       return Op.getOperand(0);
20213     return Op;
20214   };
20215 
20216   // The source value where all the parts are extracted.
20217   SDValue Src;
20218   for (unsigned i = 0; i != NumInScalars; ++i) {
20219     SDValue In = PeekThroughBitcast(N->getOperand(i));
20220     // Ignore undef inputs.
20221     if (In.isUndef()) continue;
20222 
20223     if (In.getOpcode() != ISD::TRUNCATE)
20224       return SDValue();
20225 
20226     In = PeekThroughBitcast(In.getOperand(0));
20227 
20228     if (In.getOpcode() != ISD::SRL) {
20229       // For now only build_vec without shuffling, handle shifts here in the
20230       // future.
20231       if (i != 0)
20232         return SDValue();
20233 
20234       Src = In;
20235     } else {
20236       // In is SRL
20237       SDValue part = PeekThroughBitcast(In.getOperand(0));
20238 
20239       if (!Src) {
20240         Src = part;
20241       } else if (Src != part) {
20242         // Vector parts do not stem from the same variable
20243         return SDValue();
20244       }
20245 
20246       SDValue ShiftAmtVal = In.getOperand(1);
20247       if (!isa<ConstantSDNode>(ShiftAmtVal))
20248         return SDValue();
20249 
20250       uint64_t ShiftAmt = In.getConstantOperandVal(1);
20251 
20252       // The extracted value is not extracted at the right position
20253       if (ShiftAmt != i * ScalarTypeBitsize)
20254         return SDValue();
20255     }
20256   }
20257 
20258   // Only cast if the size is the same
20259   if (Src.getValueType().getSizeInBits() != VT.getSizeInBits())
20260     return SDValue();
20261 
20262   return DAG.getBitcast(VT, Src);
20263 }
20264 
createBuildVecShuffle(const SDLoc & DL,SDNode * N,ArrayRef<int> VectorMask,SDValue VecIn1,SDValue VecIn2,unsigned LeftIdx,bool DidSplitVec)20265 SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
20266                                            ArrayRef<int> VectorMask,
20267                                            SDValue VecIn1, SDValue VecIn2,
20268                                            unsigned LeftIdx, bool DidSplitVec) {
20269   SDValue ZeroIdx = DAG.getVectorIdxConstant(0, DL);
20270 
20271   EVT VT = N->getValueType(0);
20272   EVT InVT1 = VecIn1.getValueType();
20273   EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
20274 
20275   unsigned NumElems = VT.getVectorNumElements();
20276   unsigned ShuffleNumElems = NumElems;
20277 
20278   // If we artificially split a vector in two already, then the offsets in the
20279   // operands will all be based off of VecIn1, even those in VecIn2.
20280   unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
20281 
20282   uint64_t VTSize = VT.getFixedSizeInBits();
20283   uint64_t InVT1Size = InVT1.getFixedSizeInBits();
20284   uint64_t InVT2Size = InVT2.getFixedSizeInBits();
20285 
20286   assert(InVT2Size <= InVT1Size &&
20287          "Inputs must be sorted to be in non-increasing vector size order.");
20288 
20289   // We can't generate a shuffle node with mismatched input and output types.
20290   // Try to make the types match the type of the output.
20291   if (InVT1 != VT || InVT2 != VT) {
20292     if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
20293       // If the output vector length is a multiple of both input lengths,
20294       // we can concatenate them and pad the rest with undefs.
20295       unsigned NumConcats = VTSize / InVT1Size;
20296       assert(NumConcats >= 2 && "Concat needs at least two inputs!");
20297       SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
20298       ConcatOps[0] = VecIn1;
20299       ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
20300       VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
20301       VecIn2 = SDValue();
20302     } else if (InVT1Size == VTSize * 2) {
20303       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
20304         return SDValue();
20305 
20306       if (!VecIn2.getNode()) {
20307         // If we only have one input vector, and it's twice the size of the
20308         // output, split it in two.
20309         VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
20310                              DAG.getVectorIdxConstant(NumElems, DL));
20311         VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx);
20312         // Since we now have shorter input vectors, adjust the offset of the
20313         // second vector's start.
20314         Vec2Offset = NumElems;
20315       } else {
20316         assert(InVT2Size <= InVT1Size &&
20317                "Second input is not going to be larger than the first one.");
20318 
20319         // VecIn1 is wider than the output, and we have another, possibly
20320         // smaller input. Pad the smaller input with undefs, shuffle at the
20321         // input vector width, and extract the output.
20322         // The shuffle type is different than VT, so check legality again.
20323         if (LegalOperations &&
20324             !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
20325           return SDValue();
20326 
20327         // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
20328         // lower it back into a BUILD_VECTOR. So if the inserted type is
20329         // illegal, don't even try.
20330         if (InVT1 != InVT2) {
20331           if (!TLI.isTypeLegal(InVT2))
20332             return SDValue();
20333           VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
20334                                DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
20335         }
20336         ShuffleNumElems = NumElems * 2;
20337       }
20338     } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
20339       SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
20340       ConcatOps[0] = VecIn2;
20341       VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
20342     } else {
20343       // TODO: Support cases where the length mismatch isn't exactly by a
20344       // factor of 2.
20345       // TODO: Move this check upwards, so that if we have bad type
20346       // mismatches, we don't create any DAG nodes.
20347       return SDValue();
20348     }
20349   }
20350 
20351   // Initialize mask to undef.
20352   SmallVector<int, 8> Mask(ShuffleNumElems, -1);
20353 
20354   // Only need to run up to the number of elements actually used, not the
20355   // total number of elements in the shuffle - if we are shuffling a wider
20356   // vector, the high lanes should be set to undef.
20357   for (unsigned i = 0; i != NumElems; ++i) {
20358     if (VectorMask[i] <= 0)
20359       continue;
20360 
20361     unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
20362     if (VectorMask[i] == (int)LeftIdx) {
20363       Mask[i] = ExtIndex;
20364     } else if (VectorMask[i] == (int)LeftIdx + 1) {
20365       Mask[i] = Vec2Offset + ExtIndex;
20366     }
20367   }
20368 
20369   // The type the input vectors may have changed above.
20370   InVT1 = VecIn1.getValueType();
20371 
20372   // If we already have a VecIn2, it should have the same type as VecIn1.
20373   // If we don't, get an undef/zero vector of the appropriate type.
20374   VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
20375   assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
20376 
20377   SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
20378   if (ShuffleNumElems > NumElems)
20379     Shuffle = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuffle, ZeroIdx);
20380 
20381   return Shuffle;
20382 }
20383 
reduceBuildVecToShuffleWithZero(SDNode * BV,SelectionDAG & DAG)20384 static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
20385   assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
20386 
20387   // First, determine where the build vector is not undef.
20388   // TODO: We could extend this to handle zero elements as well as undefs.
20389   int NumBVOps = BV->getNumOperands();
20390   int ZextElt = -1;
20391   for (int i = 0; i != NumBVOps; ++i) {
20392     SDValue Op = BV->getOperand(i);
20393     if (Op.isUndef())
20394       continue;
20395     if (ZextElt == -1)
20396       ZextElt = i;
20397     else
20398       return SDValue();
20399   }
20400   // Bail out if there's no non-undef element.
20401   if (ZextElt == -1)
20402     return SDValue();
20403 
20404   // The build vector contains some number of undef elements and exactly
20405   // one other element. That other element must be a zero-extended scalar
20406   // extracted from a vector at a constant index to turn this into a shuffle.
20407   // Also, require that the build vector does not implicitly truncate/extend
20408   // its elements.
20409   // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
20410   EVT VT = BV->getValueType(0);
20411   SDValue Zext = BV->getOperand(ZextElt);
20412   if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
20413       Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
20414       !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
20415       Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
20416     return SDValue();
20417 
20418   // The zero-extend must be a multiple of the source size, and we must be
20419   // building a vector of the same size as the source of the extract element.
20420   SDValue Extract = Zext.getOperand(0);
20421   unsigned DestSize = Zext.getValueSizeInBits();
20422   unsigned SrcSize = Extract.getValueSizeInBits();
20423   if (DestSize % SrcSize != 0 ||
20424       Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
20425     return SDValue();
20426 
20427   // Create a shuffle mask that will combine the extracted element with zeros
20428   // and undefs.
20429   int ZextRatio = DestSize / SrcSize;
20430   int NumMaskElts = NumBVOps * ZextRatio;
20431   SmallVector<int, 32> ShufMask(NumMaskElts, -1);
20432   for (int i = 0; i != NumMaskElts; ++i) {
20433     if (i / ZextRatio == ZextElt) {
20434       // The low bits of the (potentially translated) extracted element map to
20435       // the source vector. The high bits map to zero. We will use a zero vector
20436       // as the 2nd source operand of the shuffle, so use the 1st element of
20437       // that vector (mask value is number-of-elements) for the high bits.
20438       if (i % ZextRatio == 0)
20439         ShufMask[i] = Extract.getConstantOperandVal(1);
20440       else
20441         ShufMask[i] = NumMaskElts;
20442     }
20443 
20444     // Undef elements of the build vector remain undef because we initialize
20445     // the shuffle mask with -1.
20446   }
20447 
20448   // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
20449   // bitcast (shuffle V, ZeroVec, VectorMask)
20450   SDLoc DL(BV);
20451   EVT VecVT = Extract.getOperand(0).getValueType();
20452   SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
20453   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20454   SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
20455                                              ZeroVec, ShufMask, DAG);
20456   if (!Shuf)
20457     return SDValue();
20458   return DAG.getBitcast(VT, Shuf);
20459 }
20460 
20461 // FIXME: promote to STLExtras.
20462 template <typename R, typename T>
getFirstIndexOf(R && Range,const T & Val)20463 static auto getFirstIndexOf(R &&Range, const T &Val) {
20464   auto I = find(Range, Val);
20465   if (I == Range.end())
20466     return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
20467   return std::distance(Range.begin(), I);
20468 }
20469 
20470 // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
20471 // operations. If the types of the vectors we're extracting from allow it,
20472 // turn this into a vector_shuffle node.
reduceBuildVecToShuffle(SDNode * N)20473 SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
20474   SDLoc DL(N);
20475   EVT VT = N->getValueType(0);
20476 
20477   // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
20478   if (!isTypeLegal(VT))
20479     return SDValue();
20480 
20481   if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
20482     return V;
20483 
20484   // May only combine to shuffle after legalize if shuffle is legal.
20485   if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
20486     return SDValue();
20487 
20488   bool UsesZeroVector = false;
20489   unsigned NumElems = N->getNumOperands();
20490 
20491   // Record, for each element of the newly built vector, which input vector
20492   // that element comes from. -1 stands for undef, 0 for the zero vector,
20493   // and positive values for the input vectors.
20494   // VectorMask maps each element to its vector number, and VecIn maps vector
20495   // numbers to their initial SDValues.
20496 
20497   SmallVector<int, 8> VectorMask(NumElems, -1);
20498   SmallVector<SDValue, 8> VecIn;
20499   VecIn.push_back(SDValue());
20500 
20501   for (unsigned i = 0; i != NumElems; ++i) {
20502     SDValue Op = N->getOperand(i);
20503 
20504     if (Op.isUndef())
20505       continue;
20506 
20507     // See if we can use a blend with a zero vector.
20508     // TODO: Should we generalize this to a blend with an arbitrary constant
20509     // vector?
20510     if (isNullConstant(Op) || isNullFPConstant(Op)) {
20511       UsesZeroVector = true;
20512       VectorMask[i] = 0;
20513       continue;
20514     }
20515 
20516     // Not an undef or zero. If the input is something other than an
20517     // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
20518     if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
20519         !isa<ConstantSDNode>(Op.getOperand(1)))
20520       return SDValue();
20521     SDValue ExtractedFromVec = Op.getOperand(0);
20522 
20523     if (ExtractedFromVec.getValueType().isScalableVector())
20524       return SDValue();
20525 
20526     const APInt &ExtractIdx = Op.getConstantOperandAPInt(1);
20527     if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements()))
20528       return SDValue();
20529 
20530     // All inputs must have the same element type as the output.
20531     if (VT.getVectorElementType() !=
20532         ExtractedFromVec.getValueType().getVectorElementType())
20533       return SDValue();
20534 
20535     // Have we seen this input vector before?
20536     // The vectors are expected to be tiny (usually 1 or 2 elements), so using
20537     // a map back from SDValues to numbers isn't worth it.
20538     int Idx = getFirstIndexOf(VecIn, ExtractedFromVec);
20539     if (Idx == -1) { // A new source vector?
20540       Idx = VecIn.size();
20541       VecIn.push_back(ExtractedFromVec);
20542     }
20543 
20544     VectorMask[i] = Idx;
20545   }
20546 
20547   // If we didn't find at least one input vector, bail out.
20548   if (VecIn.size() < 2)
20549     return SDValue();
20550 
20551   // If all the Operands of BUILD_VECTOR extract from same
20552   // vector, then split the vector efficiently based on the maximum
20553   // vector access index and adjust the VectorMask and
20554   // VecIn accordingly.
20555   bool DidSplitVec = false;
20556   if (VecIn.size() == 2) {
20557     unsigned MaxIndex = 0;
20558     unsigned NearestPow2 = 0;
20559     SDValue Vec = VecIn.back();
20560     EVT InVT = Vec.getValueType();
20561     SmallVector<unsigned, 8> IndexVec(NumElems, 0);
20562 
20563     for (unsigned i = 0; i < NumElems; i++) {
20564       if (VectorMask[i] <= 0)
20565         continue;
20566       unsigned Index = N->getOperand(i).getConstantOperandVal(1);
20567       IndexVec[i] = Index;
20568       MaxIndex = std::max(MaxIndex, Index);
20569     }
20570 
20571     NearestPow2 = PowerOf2Ceil(MaxIndex);
20572     if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
20573         NumElems * 2 < NearestPow2) {
20574       unsigned SplitSize = NearestPow2 / 2;
20575       EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
20576                                      InVT.getVectorElementType(), SplitSize);
20577       if (TLI.isTypeLegal(SplitVT) &&
20578           SplitSize + SplitVT.getVectorNumElements() <=
20579               InVT.getVectorNumElements()) {
20580         SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
20581                                      DAG.getVectorIdxConstant(SplitSize, DL));
20582         SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
20583                                      DAG.getVectorIdxConstant(0, DL));
20584         VecIn.pop_back();
20585         VecIn.push_back(VecIn1);
20586         VecIn.push_back(VecIn2);
20587         DidSplitVec = true;
20588 
20589         for (unsigned i = 0; i < NumElems; i++) {
20590           if (VectorMask[i] <= 0)
20591             continue;
20592           VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
20593         }
20594       }
20595     }
20596   }
20597 
20598   // Sort input vectors by decreasing vector element count,
20599   // while preserving the relative order of equally-sized vectors.
20600   // Note that we keep the first "implicit zero vector as-is.
20601   SmallVector<SDValue, 8> SortedVecIn(VecIn);
20602   llvm::stable_sort(MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
20603                     [](const SDValue &a, const SDValue &b) {
20604                       return a.getValueType().getVectorNumElements() >
20605                              b.getValueType().getVectorNumElements();
20606                     });
20607 
20608   // We now also need to rebuild the VectorMask, because it referenced element
20609   // order in VecIn, and we just sorted them.
20610   for (int &SourceVectorIndex : VectorMask) {
20611     if (SourceVectorIndex <= 0)
20612       continue;
20613     unsigned Idx = getFirstIndexOf(SortedVecIn, VecIn[SourceVectorIndex]);
20614     assert(Idx > 0 && Idx < SortedVecIn.size() &&
20615            VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
20616     SourceVectorIndex = Idx;
20617   }
20618 
20619   VecIn = std::move(SortedVecIn);
20620 
20621   // TODO: Should this fire if some of the input vectors has illegal type (like
20622   // it does now), or should we let legalization run its course first?
20623 
20624   // Shuffle phase:
20625   // Take pairs of vectors, and shuffle them so that the result has elements
20626   // from these vectors in the correct places.
20627   // For example, given:
20628   // t10: i32 = extract_vector_elt t1, Constant:i64<0>
20629   // t11: i32 = extract_vector_elt t2, Constant:i64<0>
20630   // t12: i32 = extract_vector_elt t3, Constant:i64<0>
20631   // t13: i32 = extract_vector_elt t1, Constant:i64<1>
20632   // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
20633   // We will generate:
20634   // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
20635   // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
20636   SmallVector<SDValue, 4> Shuffles;
20637   for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
20638     unsigned LeftIdx = 2 * In + 1;
20639     SDValue VecLeft = VecIn[LeftIdx];
20640     SDValue VecRight =
20641         (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
20642 
20643     if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
20644                                                 VecRight, LeftIdx, DidSplitVec))
20645       Shuffles.push_back(Shuffle);
20646     else
20647       return SDValue();
20648   }
20649 
20650   // If we need the zero vector as an "ingredient" in the blend tree, add it
20651   // to the list of shuffles.
20652   if (UsesZeroVector)
20653     Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
20654                                       : DAG.getConstantFP(0.0, DL, VT));
20655 
20656   // If we only have one shuffle, we're done.
20657   if (Shuffles.size() == 1)
20658     return Shuffles[0];
20659 
20660   // Update the vector mask to point to the post-shuffle vectors.
20661   for (int &Vec : VectorMask)
20662     if (Vec == 0)
20663       Vec = Shuffles.size() - 1;
20664     else
20665       Vec = (Vec - 1) / 2;
20666 
20667   // More than one shuffle. Generate a binary tree of blends, e.g. if from
20668   // the previous step we got the set of shuffles t10, t11, t12, t13, we will
20669   // generate:
20670   // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
20671   // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
20672   // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
20673   // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
20674   // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
20675   // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
20676   // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
20677 
20678   // Make sure the initial size of the shuffle list is even.
20679   if (Shuffles.size() % 2)
20680     Shuffles.push_back(DAG.getUNDEF(VT));
20681 
20682   for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
20683     if (CurSize % 2) {
20684       Shuffles[CurSize] = DAG.getUNDEF(VT);
20685       CurSize++;
20686     }
20687     for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
20688       int Left = 2 * In;
20689       int Right = 2 * In + 1;
20690       SmallVector<int, 8> Mask(NumElems, -1);
20691       SDValue L = Shuffles[Left];
20692       ArrayRef<int> LMask;
20693       bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
20694                            L.use_empty() && L.getOperand(1).isUndef() &&
20695                            L.getOperand(0).getValueType() == L.getValueType();
20696       if (IsLeftShuffle) {
20697         LMask = cast<ShuffleVectorSDNode>(L.getNode())->getMask();
20698         L = L.getOperand(0);
20699       }
20700       SDValue R = Shuffles[Right];
20701       ArrayRef<int> RMask;
20702       bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
20703                             R.use_empty() && R.getOperand(1).isUndef() &&
20704                             R.getOperand(0).getValueType() == R.getValueType();
20705       if (IsRightShuffle) {
20706         RMask = cast<ShuffleVectorSDNode>(R.getNode())->getMask();
20707         R = R.getOperand(0);
20708       }
20709       for (unsigned I = 0; I != NumElems; ++I) {
20710         if (VectorMask[I] == Left) {
20711           Mask[I] = I;
20712           if (IsLeftShuffle)
20713             Mask[I] = LMask[I];
20714           VectorMask[I] = In;
20715         } else if (VectorMask[I] == Right) {
20716           Mask[I] = I + NumElems;
20717           if (IsRightShuffle)
20718             Mask[I] = RMask[I] + NumElems;
20719           VectorMask[I] = In;
20720         }
20721       }
20722 
20723       Shuffles[In] = DAG.getVectorShuffle(VT, DL, L, R, Mask);
20724     }
20725   }
20726   return Shuffles[0];
20727 }
20728 
20729 // Try to turn a build vector of zero extends of extract vector elts into a
20730 // a vector zero extend and possibly an extract subvector.
20731 // TODO: Support sign extend?
20732 // TODO: Allow undef elements?
convertBuildVecZextToZext(SDNode * N)20733 SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
20734   if (LegalOperations)
20735     return SDValue();
20736 
20737   EVT VT = N->getValueType(0);
20738 
20739   bool FoundZeroExtend = false;
20740   SDValue Op0 = N->getOperand(0);
20741   auto checkElem = [&](SDValue Op) -> int64_t {
20742     unsigned Opc = Op.getOpcode();
20743     FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
20744     if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
20745         Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20746         Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
20747       if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
20748         return C->getZExtValue();
20749     return -1;
20750   };
20751 
20752   // Make sure the first element matches
20753   // (zext (extract_vector_elt X, C))
20754   // Offset must be a constant multiple of the
20755   // known-minimum vector length of the result type.
20756   int64_t Offset = checkElem(Op0);
20757   if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
20758     return SDValue();
20759 
20760   unsigned NumElems = N->getNumOperands();
20761   SDValue In = Op0.getOperand(0).getOperand(0);
20762   EVT InSVT = In.getValueType().getScalarType();
20763   EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
20764 
20765   // Don't create an illegal input type after type legalization.
20766   if (LegalTypes && !TLI.isTypeLegal(InVT))
20767     return SDValue();
20768 
20769   // Ensure all the elements come from the same vector and are adjacent.
20770   for (unsigned i = 1; i != NumElems; ++i) {
20771     if ((Offset + i) != checkElem(N->getOperand(i)))
20772       return SDValue();
20773   }
20774 
20775   SDLoc DL(N);
20776   In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
20777                    Op0.getOperand(0).getOperand(1));
20778   return DAG.getNode(FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
20779                      VT, In);
20780 }
20781 
visitBUILD_VECTOR(SDNode * N)20782 SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
20783   EVT VT = N->getValueType(0);
20784 
20785   // A vector built entirely of undefs is undef.
20786   if (ISD::allOperandsUndef(N))
20787     return DAG.getUNDEF(VT);
20788 
20789   // If this is a splat of a bitcast from another vector, change to a
20790   // concat_vector.
20791   // For example:
20792   //   (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
20793   //     (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
20794   //
20795   // If X is a build_vector itself, the concat can become a larger build_vector.
20796   // TODO: Maybe this is useful for non-splat too?
20797   if (!LegalOperations) {
20798     if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) {
20799       Splat = peekThroughBitcasts(Splat);
20800       EVT SrcVT = Splat.getValueType();
20801       if (SrcVT.isVector()) {
20802         unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
20803         EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
20804                                      SrcVT.getVectorElementType(), NumElts);
20805         if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
20806           SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
20807           SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N),
20808                                        NewVT, Ops);
20809           return DAG.getBitcast(VT, Concat);
20810         }
20811       }
20812     }
20813   }
20814 
20815   // Check if we can express BUILD VECTOR via subvector extract.
20816   if (!LegalTypes && (N->getNumOperands() > 1)) {
20817     SDValue Op0 = N->getOperand(0);
20818     auto checkElem = [&](SDValue Op) -> uint64_t {
20819       if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
20820           (Op0.getOperand(0) == Op.getOperand(0)))
20821         if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
20822           return CNode->getZExtValue();
20823       return -1;
20824     };
20825 
20826     int Offset = checkElem(Op0);
20827     for (unsigned i = 0; i < N->getNumOperands(); ++i) {
20828       if (Offset + i != checkElem(N->getOperand(i))) {
20829         Offset = -1;
20830         break;
20831       }
20832     }
20833 
20834     if ((Offset == 0) &&
20835         (Op0.getOperand(0).getValueType() == N->getValueType(0)))
20836       return Op0.getOperand(0);
20837     if ((Offset != -1) &&
20838         ((Offset % N->getValueType(0).getVectorNumElements()) ==
20839          0)) // IDX must be multiple of output size.
20840       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
20841                          Op0.getOperand(0), Op0.getOperand(1));
20842   }
20843 
20844   if (SDValue V = convertBuildVecZextToZext(N))
20845     return V;
20846 
20847   if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
20848     return V;
20849 
20850   if (SDValue V = reduceBuildVecTruncToBitCast(N))
20851     return V;
20852 
20853   if (SDValue V = reduceBuildVecToShuffle(N))
20854     return V;
20855 
20856   // A splat of a single element is a SPLAT_VECTOR if supported on the target.
20857   // Do this late as some of the above may replace the splat.
20858   if (TLI.getOperationAction(ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
20859     if (SDValue V = cast<BuildVectorSDNode>(N)->getSplatValue()) {
20860       assert(!V.isUndef() && "Splat of undef should have been handled earlier");
20861       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V);
20862     }
20863 
20864   return SDValue();
20865 }
20866 
combineConcatVectorOfScalars(SDNode * N,SelectionDAG & DAG)20867 static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
20868   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20869   EVT OpVT = N->getOperand(0).getValueType();
20870 
20871   // If the operands are legal vectors, leave them alone.
20872   if (TLI.isTypeLegal(OpVT))
20873     return SDValue();
20874 
20875   SDLoc DL(N);
20876   EVT VT = N->getValueType(0);
20877   SmallVector<SDValue, 8> Ops;
20878 
20879   EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
20880   SDValue ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
20881 
20882   // Keep track of what we encounter.
20883   bool AnyInteger = false;
20884   bool AnyFP = false;
20885   for (const SDValue &Op : N->ops()) {
20886     if (ISD::BITCAST == Op.getOpcode() &&
20887         !Op.getOperand(0).getValueType().isVector())
20888       Ops.push_back(Op.getOperand(0));
20889     else if (ISD::UNDEF == Op.getOpcode())
20890       Ops.push_back(ScalarUndef);
20891     else
20892       return SDValue();
20893 
20894     // Note whether we encounter an integer or floating point scalar.
20895     // If it's neither, bail out, it could be something weird like x86mmx.
20896     EVT LastOpVT = Ops.back().getValueType();
20897     if (LastOpVT.isFloatingPoint())
20898       AnyFP = true;
20899     else if (LastOpVT.isInteger())
20900       AnyInteger = true;
20901     else
20902       return SDValue();
20903   }
20904 
20905   // If any of the operands is a floating point scalar bitcast to a vector,
20906   // use floating point types throughout, and bitcast everything.
20907   // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
20908   if (AnyFP) {
20909     SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits());
20910     ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
20911     if (AnyInteger) {
20912       for (SDValue &Op : Ops) {
20913         if (Op.getValueType() == SVT)
20914           continue;
20915         if (Op.isUndef())
20916           Op = ScalarUndef;
20917         else
20918           Op = DAG.getBitcast(SVT, Op);
20919       }
20920     }
20921   }
20922 
20923   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
20924                                VT.getSizeInBits() / SVT.getSizeInBits());
20925   return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
20926 }
20927 
20928 // Attempt to merge nested concat_vectors/undefs.
20929 // Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
20930 //  --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
combineConcatVectorOfConcatVectors(SDNode * N,SelectionDAG & DAG)20931 static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
20932                                                   SelectionDAG &DAG) {
20933   EVT VT = N->getValueType(0);
20934 
20935   // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
20936   EVT SubVT;
20937   SDValue FirstConcat;
20938   for (const SDValue &Op : N->ops()) {
20939     if (Op.isUndef())
20940       continue;
20941     if (Op.getOpcode() != ISD::CONCAT_VECTORS)
20942       return SDValue();
20943     if (!FirstConcat) {
20944       SubVT = Op.getOperand(0).getValueType();
20945       if (!DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
20946         return SDValue();
20947       FirstConcat = Op;
20948       continue;
20949     }
20950     if (SubVT != Op.getOperand(0).getValueType())
20951       return SDValue();
20952   }
20953   assert(FirstConcat && "Concat of all-undefs found");
20954 
20955   SmallVector<SDValue> ConcatOps;
20956   for (const SDValue &Op : N->ops()) {
20957     if (Op.isUndef()) {
20958       ConcatOps.append(FirstConcat->getNumOperands(), DAG.getUNDEF(SubVT));
20959       continue;
20960     }
20961     ConcatOps.append(Op->op_begin(), Op->op_end());
20962   }
20963   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, ConcatOps);
20964 }
20965 
20966 // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
20967 // operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
20968 // most two distinct vectors the same size as the result, attempt to turn this
20969 // into a legal shuffle.
combineConcatVectorOfExtracts(SDNode * N,SelectionDAG & DAG)20970 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
20971   EVT VT = N->getValueType(0);
20972   EVT OpVT = N->getOperand(0).getValueType();
20973 
20974   // We currently can't generate an appropriate shuffle for a scalable vector.
20975   if (VT.isScalableVector())
20976     return SDValue();
20977 
20978   int NumElts = VT.getVectorNumElements();
20979   int NumOpElts = OpVT.getVectorNumElements();
20980 
20981   SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
20982   SmallVector<int, 8> Mask;
20983 
20984   for (SDValue Op : N->ops()) {
20985     Op = peekThroughBitcasts(Op);
20986 
20987     // UNDEF nodes convert to UNDEF shuffle mask values.
20988     if (Op.isUndef()) {
20989       Mask.append((unsigned)NumOpElts, -1);
20990       continue;
20991     }
20992 
20993     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
20994       return SDValue();
20995 
20996     // What vector are we extracting the subvector from and at what index?
20997     SDValue ExtVec = Op.getOperand(0);
20998     int ExtIdx = Op.getConstantOperandVal(1);
20999 
21000     // We want the EVT of the original extraction to correctly scale the
21001     // extraction index.
21002     EVT ExtVT = ExtVec.getValueType();
21003     ExtVec = peekThroughBitcasts(ExtVec);
21004 
21005     // UNDEF nodes convert to UNDEF shuffle mask values.
21006     if (ExtVec.isUndef()) {
21007       Mask.append((unsigned)NumOpElts, -1);
21008       continue;
21009     }
21010 
21011     // Ensure that we are extracting a subvector from a vector the same
21012     // size as the result.
21013     if (ExtVT.getSizeInBits() != VT.getSizeInBits())
21014       return SDValue();
21015 
21016     // Scale the subvector index to account for any bitcast.
21017     int NumExtElts = ExtVT.getVectorNumElements();
21018     if (0 == (NumExtElts % NumElts))
21019       ExtIdx /= (NumExtElts / NumElts);
21020     else if (0 == (NumElts % NumExtElts))
21021       ExtIdx *= (NumElts / NumExtElts);
21022     else
21023       return SDValue();
21024 
21025     // At most we can reference 2 inputs in the final shuffle.
21026     if (SV0.isUndef() || SV0 == ExtVec) {
21027       SV0 = ExtVec;
21028       for (int i = 0; i != NumOpElts; ++i)
21029         Mask.push_back(i + ExtIdx);
21030     } else if (SV1.isUndef() || SV1 == ExtVec) {
21031       SV1 = ExtVec;
21032       for (int i = 0; i != NumOpElts; ++i)
21033         Mask.push_back(i + ExtIdx + NumElts);
21034     } else {
21035       return SDValue();
21036     }
21037   }
21038 
21039   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21040   return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
21041                                      DAG.getBitcast(VT, SV1), Mask, DAG);
21042 }
21043 
combineConcatVectorOfCasts(SDNode * N,SelectionDAG & DAG)21044 static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
21045   unsigned CastOpcode = N->getOperand(0).getOpcode();
21046   switch (CastOpcode) {
21047   case ISD::SINT_TO_FP:
21048   case ISD::UINT_TO_FP:
21049   case ISD::FP_TO_SINT:
21050   case ISD::FP_TO_UINT:
21051     // TODO: Allow more opcodes?
21052     //  case ISD::BITCAST:
21053     //  case ISD::TRUNCATE:
21054     //  case ISD::ZERO_EXTEND:
21055     //  case ISD::SIGN_EXTEND:
21056     //  case ISD::FP_EXTEND:
21057     break;
21058   default:
21059     return SDValue();
21060   }
21061 
21062   EVT SrcVT = N->getOperand(0).getOperand(0).getValueType();
21063   if (!SrcVT.isVector())
21064     return SDValue();
21065 
21066   // All operands of the concat must be the same kind of cast from the same
21067   // source type.
21068   SmallVector<SDValue, 4> SrcOps;
21069   for (SDValue Op : N->ops()) {
21070     if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
21071         Op.getOperand(0).getValueType() != SrcVT)
21072       return SDValue();
21073     SrcOps.push_back(Op.getOperand(0));
21074   }
21075 
21076   // The wider cast must be supported by the target. This is unusual because
21077   // the operation support type parameter depends on the opcode. In addition,
21078   // check the other type in the cast to make sure this is really legal.
21079   EVT VT = N->getValueType(0);
21080   EVT SrcEltVT = SrcVT.getVectorElementType();
21081   ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
21082   EVT ConcatSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcEltVT, NumElts);
21083   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21084   switch (CastOpcode) {
21085   case ISD::SINT_TO_FP:
21086   case ISD::UINT_TO_FP:
21087     if (!TLI.isOperationLegalOrCustom(CastOpcode, ConcatSrcVT) ||
21088         !TLI.isTypeLegal(VT))
21089       return SDValue();
21090     break;
21091   case ISD::FP_TO_SINT:
21092   case ISD::FP_TO_UINT:
21093     if (!TLI.isOperationLegalOrCustom(CastOpcode, VT) ||
21094         !TLI.isTypeLegal(ConcatSrcVT))
21095       return SDValue();
21096     break;
21097   default:
21098     llvm_unreachable("Unexpected cast opcode");
21099   }
21100 
21101   // concat (cast X), (cast Y)... -> cast (concat X, Y...)
21102   SDLoc DL(N);
21103   SDValue NewConcat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatSrcVT, SrcOps);
21104   return DAG.getNode(CastOpcode, DL, VT, NewConcat);
21105 }
21106 
visitCONCAT_VECTORS(SDNode * N)21107 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
21108   // If we only have one input vector, we don't need to do any concatenation.
21109   if (N->getNumOperands() == 1)
21110     return N->getOperand(0);
21111 
21112   // Check if all of the operands are undefs.
21113   EVT VT = N->getValueType(0);
21114   if (ISD::allOperandsUndef(N))
21115     return DAG.getUNDEF(VT);
21116 
21117   // Optimize concat_vectors where all but the first of the vectors are undef.
21118   if (all_of(drop_begin(N->ops()),
21119              [](const SDValue &Op) { return Op.isUndef(); })) {
21120     SDValue In = N->getOperand(0);
21121     assert(In.getValueType().isVector() && "Must concat vectors");
21122 
21123     // If the input is a concat_vectors, just make a larger concat by padding
21124     // with smaller undefs.
21125     if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse()) {
21126       unsigned NumOps = N->getNumOperands() * In.getNumOperands();
21127       SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
21128       Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType()));
21129       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
21130     }
21131 
21132     SDValue Scalar = peekThroughOneUseBitcasts(In);
21133 
21134     // concat_vectors(scalar_to_vector(scalar), undef) ->
21135     //     scalar_to_vector(scalar)
21136     if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
21137          Scalar.hasOneUse()) {
21138       EVT SVT = Scalar.getValueType().getVectorElementType();
21139       if (SVT == Scalar.getOperand(0).getValueType())
21140         Scalar = Scalar.getOperand(0);
21141     }
21142 
21143     // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
21144     if (!Scalar.getValueType().isVector()) {
21145       // If the bitcast type isn't legal, it might be a trunc of a legal type;
21146       // look through the trunc so we can still do the transform:
21147       //   concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
21148       if (Scalar->getOpcode() == ISD::TRUNCATE &&
21149           !TLI.isTypeLegal(Scalar.getValueType()) &&
21150           TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
21151         Scalar = Scalar->getOperand(0);
21152 
21153       EVT SclTy = Scalar.getValueType();
21154 
21155       if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
21156         return SDValue();
21157 
21158       // Bail out if the vector size is not a multiple of the scalar size.
21159       if (VT.getSizeInBits() % SclTy.getSizeInBits())
21160         return SDValue();
21161 
21162       unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
21163       if (VNTNumElms < 2)
21164         return SDValue();
21165 
21166       EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
21167       if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
21168         return SDValue();
21169 
21170       SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
21171       return DAG.getBitcast(VT, Res);
21172     }
21173   }
21174 
21175   // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
21176   // We have already tested above for an UNDEF only concatenation.
21177   // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
21178   // -> (BUILD_VECTOR A, B, ..., C, D, ...)
21179   auto IsBuildVectorOrUndef = [](const SDValue &Op) {
21180     return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
21181   };
21182   if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
21183     SmallVector<SDValue, 8> Opnds;
21184     EVT SVT = VT.getScalarType();
21185 
21186     EVT MinVT = SVT;
21187     if (!SVT.isFloatingPoint()) {
21188       // If BUILD_VECTOR are from built from integer, they may have different
21189       // operand types. Get the smallest type and truncate all operands to it.
21190       bool FoundMinVT = false;
21191       for (const SDValue &Op : N->ops())
21192         if (ISD::BUILD_VECTOR == Op.getOpcode()) {
21193           EVT OpSVT = Op.getOperand(0).getValueType();
21194           MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
21195           FoundMinVT = true;
21196         }
21197       assert(FoundMinVT && "Concat vector type mismatch");
21198     }
21199 
21200     for (const SDValue &Op : N->ops()) {
21201       EVT OpVT = Op.getValueType();
21202       unsigned NumElts = OpVT.getVectorNumElements();
21203 
21204       if (ISD::UNDEF == Op.getOpcode())
21205         Opnds.append(NumElts, DAG.getUNDEF(MinVT));
21206 
21207       if (ISD::BUILD_VECTOR == Op.getOpcode()) {
21208         if (SVT.isFloatingPoint()) {
21209           assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
21210           Opnds.append(Op->op_begin(), Op->op_begin() + NumElts);
21211         } else {
21212           for (unsigned i = 0; i != NumElts; ++i)
21213             Opnds.push_back(
21214                 DAG.getNode(ISD::TRUNCATE, SDLoc(N), MinVT, Op.getOperand(i)));
21215         }
21216       }
21217     }
21218 
21219     assert(VT.getVectorNumElements() == Opnds.size() &&
21220            "Concat vector type mismatch");
21221     return DAG.getBuildVector(VT, SDLoc(N), Opnds);
21222   }
21223 
21224   // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
21225   // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
21226   if (SDValue V = combineConcatVectorOfScalars(N, DAG))
21227     return V;
21228 
21229   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
21230     // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
21231     if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
21232       return V;
21233 
21234     // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
21235     if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
21236       return V;
21237   }
21238 
21239   if (SDValue V = combineConcatVectorOfCasts(N, DAG))
21240     return V;
21241 
21242   // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
21243   // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
21244   // operands and look for a CONCAT operations that place the incoming vectors
21245   // at the exact same location.
21246   //
21247   // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
21248   SDValue SingleSource = SDValue();
21249   unsigned PartNumElem =
21250       N->getOperand(0).getValueType().getVectorMinNumElements();
21251 
21252   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
21253     SDValue Op = N->getOperand(i);
21254 
21255     if (Op.isUndef())
21256       continue;
21257 
21258     // Check if this is the identity extract:
21259     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
21260       return SDValue();
21261 
21262     // Find the single incoming vector for the extract_subvector.
21263     if (SingleSource.getNode()) {
21264       if (Op.getOperand(0) != SingleSource)
21265         return SDValue();
21266     } else {
21267       SingleSource = Op.getOperand(0);
21268 
21269       // Check the source type is the same as the type of the result.
21270       // If not, this concat may extend the vector, so we can not
21271       // optimize it away.
21272       if (SingleSource.getValueType() != N->getValueType(0))
21273         return SDValue();
21274     }
21275 
21276     // Check that we are reading from the identity index.
21277     unsigned IdentityIndex = i * PartNumElem;
21278     if (Op.getConstantOperandAPInt(1) != IdentityIndex)
21279       return SDValue();
21280   }
21281 
21282   if (SingleSource.getNode())
21283     return SingleSource;
21284 
21285   return SDValue();
21286 }
21287 
21288 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
21289 // if the subvector can be sourced for free.
getSubVectorSrc(SDValue V,SDValue Index,EVT SubVT)21290 static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
21291   if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
21292       V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
21293     return V.getOperand(1);
21294   }
21295   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
21296   if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
21297       V.getOperand(0).getValueType() == SubVT &&
21298       (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
21299     uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
21300     return V.getOperand(SubIdx);
21301   }
21302   return SDValue();
21303 }
21304 
narrowInsertExtractVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)21305 static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
21306                                               SelectionDAG &DAG,
21307                                               bool LegalOperations) {
21308   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21309   SDValue BinOp = Extract->getOperand(0);
21310   unsigned BinOpcode = BinOp.getOpcode();
21311   if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1)
21312     return SDValue();
21313 
21314   EVT VecVT = BinOp.getValueType();
21315   SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
21316   if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
21317     return SDValue();
21318 
21319   SDValue Index = Extract->getOperand(1);
21320   EVT SubVT = Extract->getValueType(0);
21321   if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
21322     return SDValue();
21323 
21324   SDValue Sub0 = getSubVectorSrc(Bop0, Index, SubVT);
21325   SDValue Sub1 = getSubVectorSrc(Bop1, Index, SubVT);
21326 
21327   // TODO: We could handle the case where only 1 operand is being inserted by
21328   //       creating an extract of the other operand, but that requires checking
21329   //       number of uses and/or costs.
21330   if (!Sub0 || !Sub1)
21331     return SDValue();
21332 
21333   // We are inserting both operands of the wide binop only to extract back
21334   // to the narrow vector size. Eliminate all of the insert/extract:
21335   // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
21336   return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
21337                      BinOp->getFlags());
21338 }
21339 
21340 /// If we are extracting a subvector produced by a wide binary operator try
21341 /// to use a narrow binary operator and/or avoid concatenation and extraction.
narrowExtractedVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)21342 static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
21343                                           bool LegalOperations) {
21344   // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
21345   // some of these bailouts with other transforms.
21346 
21347   if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
21348     return V;
21349 
21350   // The extract index must be a constant, so we can map it to a concat operand.
21351   auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
21352   if (!ExtractIndexC)
21353     return SDValue();
21354 
21355   // We are looking for an optionally bitcasted wide vector binary operator
21356   // feeding an extract subvector.
21357   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21358   SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
21359   unsigned BOpcode = BinOp.getOpcode();
21360   if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1)
21361     return SDValue();
21362 
21363   // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
21364   // reduced to the unary fneg when it is visited, and we probably want to deal
21365   // with fneg in a target-specific way.
21366   if (BOpcode == ISD::FSUB) {
21367     auto *C = isConstOrConstSplatFP(BinOp.getOperand(0), /*AllowUndefs*/ true);
21368     if (C && C->getValueAPF().isNegZero())
21369       return SDValue();
21370   }
21371 
21372   // The binop must be a vector type, so we can extract some fraction of it.
21373   EVT WideBVT = BinOp.getValueType();
21374   // The optimisations below currently assume we are dealing with fixed length
21375   // vectors. It is possible to add support for scalable vectors, but at the
21376   // moment we've done no analysis to prove whether they are profitable or not.
21377   if (!WideBVT.isFixedLengthVector())
21378     return SDValue();
21379 
21380   EVT VT = Extract->getValueType(0);
21381   unsigned ExtractIndex = ExtractIndexC->getZExtValue();
21382   assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
21383          "Extract index is not a multiple of the vector length.");
21384 
21385   // Bail out if this is not a proper multiple width extraction.
21386   unsigned WideWidth = WideBVT.getSizeInBits();
21387   unsigned NarrowWidth = VT.getSizeInBits();
21388   if (WideWidth % NarrowWidth != 0)
21389     return SDValue();
21390 
21391   // Bail out if we are extracting a fraction of a single operation. This can
21392   // occur because we potentially looked through a bitcast of the binop.
21393   unsigned NarrowingRatio = WideWidth / NarrowWidth;
21394   unsigned WideNumElts = WideBVT.getVectorNumElements();
21395   if (WideNumElts % NarrowingRatio != 0)
21396     return SDValue();
21397 
21398   // Bail out if the target does not support a narrower version of the binop.
21399   EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
21400                                    WideNumElts / NarrowingRatio);
21401   if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT))
21402     return SDValue();
21403 
21404   // If extraction is cheap, we don't need to look at the binop operands
21405   // for concat ops. The narrow binop alone makes this transform profitable.
21406   // We can't just reuse the original extract index operand because we may have
21407   // bitcasted.
21408   unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
21409   unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
21410   if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
21411       BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
21412     // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
21413     SDLoc DL(Extract);
21414     SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
21415     SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
21416                             BinOp.getOperand(0), NewExtIndex);
21417     SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
21418                             BinOp.getOperand(1), NewExtIndex);
21419     SDValue NarrowBinOp =
21420         DAG.getNode(BOpcode, DL, NarrowBVT, X, Y, BinOp->getFlags());
21421     return DAG.getBitcast(VT, NarrowBinOp);
21422   }
21423 
21424   // Only handle the case where we are doubling and then halving. A larger ratio
21425   // may require more than two narrow binops to replace the wide binop.
21426   if (NarrowingRatio != 2)
21427     return SDValue();
21428 
21429   // TODO: The motivating case for this transform is an x86 AVX1 target. That
21430   // target has temptingly almost legal versions of bitwise logic ops in 256-bit
21431   // flavors, but no other 256-bit integer support. This could be extended to
21432   // handle any binop, but that may require fixing/adding other folds to avoid
21433   // codegen regressions.
21434   if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
21435     return SDValue();
21436 
21437   // We need at least one concatenation operation of a binop operand to make
21438   // this transform worthwhile. The concat must double the input vector sizes.
21439   auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
21440     if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
21441       return V.getOperand(ConcatOpNum);
21442     return SDValue();
21443   };
21444   SDValue SubVecL = GetSubVector(peekThroughBitcasts(BinOp.getOperand(0)));
21445   SDValue SubVecR = GetSubVector(peekThroughBitcasts(BinOp.getOperand(1)));
21446 
21447   if (SubVecL || SubVecR) {
21448     // If a binop operand was not the result of a concat, we must extract a
21449     // half-sized operand for our new narrow binop:
21450     // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
21451     // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
21452     // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
21453     SDLoc DL(Extract);
21454     SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
21455     SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
21456                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
21457                                       BinOp.getOperand(0), IndexC);
21458 
21459     SDValue Y = SubVecR ? DAG.getBitcast(NarrowBVT, SubVecR)
21460                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
21461                                       BinOp.getOperand(1), IndexC);
21462 
21463     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
21464     return DAG.getBitcast(VT, NarrowBinOp);
21465   }
21466 
21467   return SDValue();
21468 }
21469 
21470 /// If we are extracting a subvector from a wide vector load, convert to a
21471 /// narrow load to eliminate the extraction:
21472 /// (extract_subvector (load wide vector)) --> (load narrow vector)
narrowExtractedVectorLoad(SDNode * Extract,SelectionDAG & DAG)21473 static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
21474   // TODO: Add support for big-endian. The offset calculation must be adjusted.
21475   if (DAG.getDataLayout().isBigEndian())
21476     return SDValue();
21477 
21478   auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
21479   if (!Ld || Ld->getExtensionType() || !Ld->isSimple())
21480     return SDValue();
21481 
21482   // Allow targets to opt-out.
21483   EVT VT = Extract->getValueType(0);
21484 
21485   // We can only create byte sized loads.
21486   if (!VT.isByteSized())
21487     return SDValue();
21488 
21489   unsigned Index = Extract->getConstantOperandVal(1);
21490   unsigned NumElts = VT.getVectorMinNumElements();
21491 
21492   // The definition of EXTRACT_SUBVECTOR states that the index must be a
21493   // multiple of the minimum number of elements in the result type.
21494   assert(Index % NumElts == 0 && "The extract subvector index is not a "
21495                                  "multiple of the result's element count");
21496 
21497   // It's fine to use TypeSize here as we know the offset will not be negative.
21498   TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
21499 
21500   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21501   if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT))
21502     return SDValue();
21503 
21504   // The narrow load will be offset from the base address of the old load if
21505   // we are extracting from something besides index 0 (little-endian).
21506   SDLoc DL(Extract);
21507 
21508   // TODO: Use "BaseIndexOffset" to make this more effective.
21509   SDValue NewAddr = DAG.getMemBasePlusOffset(Ld->getBasePtr(), Offset, DL);
21510 
21511   uint64_t StoreSize = MemoryLocation::getSizeOrUnknown(VT.getStoreSize());
21512   MachineFunction &MF = DAG.getMachineFunction();
21513   MachineMemOperand *MMO;
21514   if (Offset.isScalable()) {
21515     MachinePointerInfo MPI =
21516         MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
21517     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), MPI, StoreSize);
21518   } else
21519     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset.getFixedSize(),
21520                                   StoreSize);
21521 
21522   SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
21523   DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
21524   return NewLd;
21525 }
21526 
21527 /// Given  EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
21528 /// try to produce  VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
21529 ///                                EXTRACT_SUBVECTOR(Op?, ?),
21530 ///                                Mask'))
21531 /// iff it is legal and profitable to do so. Notably, the trimmed mask
21532 /// (containing only the elements that are extracted)
21533 /// must reference at most two subvectors.
foldExtractSubvectorFromShuffleVector(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)21534 static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
21535                                                      SelectionDAG &DAG,
21536                                                      const TargetLowering &TLI,
21537                                                      bool LegalOperations) {
21538   assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
21539          "Must only be called on EXTRACT_SUBVECTOR's");
21540 
21541   SDValue N0 = N->getOperand(0);
21542 
21543   // Only deal with non-scalable vectors.
21544   EVT NarrowVT = N->getValueType(0);
21545   EVT WideVT = N0.getValueType();
21546   if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
21547     return SDValue();
21548 
21549   // The operand must be a shufflevector.
21550   auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(N0);
21551   if (!WideShuffleVector)
21552     return SDValue();
21553 
21554   // The old shuffleneeds to go away.
21555   if (!WideShuffleVector->hasOneUse())
21556     return SDValue();
21557 
21558   // And the narrow shufflevector that we'll form must be legal.
21559   if (LegalOperations &&
21560       !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, NarrowVT))
21561     return SDValue();
21562 
21563   uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(1);
21564   int NumEltsExtracted = NarrowVT.getVectorNumElements();
21565   assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 &&
21566          "Extract index is not a multiple of the output vector length.");
21567 
21568   int WideNumElts = WideVT.getVectorNumElements();
21569 
21570   SmallVector<int, 16> NewMask;
21571   NewMask.reserve(NumEltsExtracted);
21572   SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
21573       DemandedSubvectors;
21574 
21575   // Try to decode the wide mask into narrow mask from at most two subvectors.
21576   for (int M : WideShuffleVector->getMask().slice(FirstExtractedEltIdx,
21577                                                   NumEltsExtracted)) {
21578     assert((M >= -1) && (M < (2 * WideNumElts)) &&
21579            "Out-of-bounds shuffle mask?");
21580 
21581     if (M < 0) {
21582       // Does not depend on operands, does not require adjustment.
21583       NewMask.emplace_back(M);
21584       continue;
21585     }
21586 
21587     // From which operand of the shuffle does this shuffle mask element pick?
21588     int WideShufOpIdx = M / WideNumElts;
21589     // Which element of that operand is picked?
21590     int OpEltIdx = M % WideNumElts;
21591 
21592     assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
21593            "Shuffle mask vector decomposition failure.");
21594 
21595     // And which NumEltsExtracted-sized subvector of that operand is that?
21596     int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
21597     // And which element within that subvector of that operand is that?
21598     int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
21599 
21600     assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
21601            "Shuffle mask subvector decomposition failure.");
21602 
21603     assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
21604             WideShufOpIdx * WideNumElts) == M &&
21605            "Shuffle mask full decomposition failure.");
21606 
21607     SDValue Op = WideShuffleVector->getOperand(WideShufOpIdx);
21608 
21609     if (Op.isUndef()) {
21610       // Picking from an undef operand. Let's adjust mask instead.
21611       NewMask.emplace_back(-1);
21612       continue;
21613     }
21614 
21615     // Profitability check: only deal with extractions from the first subvector.
21616     if (OpSubvecIdx != 0)
21617       return SDValue();
21618 
21619     const std::pair<SDValue, int> DemandedSubvector =
21620         std::make_pair(Op, OpSubvecIdx);
21621 
21622     if (DemandedSubvectors.insert(DemandedSubvector)) {
21623       if (DemandedSubvectors.size() > 2)
21624         return SDValue(); // We can't handle more than two subvectors.
21625       // How many elements into the WideVT does this subvector start?
21626       int Index = NumEltsExtracted * OpSubvecIdx;
21627       // Bail out if the extraction isn't going to be cheap.
21628       if (!TLI.isExtractSubvectorCheap(NarrowVT, WideVT, Index))
21629         return SDValue();
21630     }
21631 
21632     // Ok, but from which operand of the new shuffle will this element pick?
21633     int NewOpIdx =
21634         getFirstIndexOf(DemandedSubvectors.getArrayRef(), DemandedSubvector);
21635     assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
21636 
21637     int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
21638     NewMask.emplace_back(AdjM);
21639   }
21640   assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
21641   assert(DemandedSubvectors.size() <= 2 &&
21642          "Should have ended up demanding at most two subvectors.");
21643 
21644   // Did we discover that the shuffle does not actually depend on operands?
21645   if (DemandedSubvectors.empty())
21646     return DAG.getUNDEF(NarrowVT);
21647 
21648   // We still perform the exact same EXTRACT_SUBVECTOR,  just on different
21649   // operand[s]/index[es], so there is no point in checking for it's legality.
21650 
21651   // Do not turn a legal shuffle into an illegal one.
21652   if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
21653       !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
21654     return SDValue();
21655 
21656   SDLoc DL(N);
21657 
21658   SmallVector<SDValue, 2> NewOps;
21659   for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
21660            &DemandedSubvector : DemandedSubvectors) {
21661     // How many elements into the WideVT does this subvector start?
21662     int Index = NumEltsExtracted * DemandedSubvector.second;
21663     SDValue IndexC = DAG.getVectorIdxConstant(Index, DL);
21664     NewOps.emplace_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowVT,
21665                                     DemandedSubvector.first, IndexC));
21666   }
21667   assert((NewOps.size() == 1 || NewOps.size() == 2) &&
21668          "Should end up with either one or two ops");
21669 
21670   // If we ended up with only one operand, pad with an undef.
21671   if (NewOps.size() == 1)
21672     NewOps.emplace_back(DAG.getUNDEF(NarrowVT));
21673 
21674   return DAG.getVectorShuffle(NarrowVT, DL, NewOps[0], NewOps[1], NewMask);
21675 }
21676 
visitEXTRACT_SUBVECTOR(SDNode * N)21677 SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
21678   EVT NVT = N->getValueType(0);
21679   SDValue V = N->getOperand(0);
21680   uint64_t ExtIdx = N->getConstantOperandVal(1);
21681 
21682   // Extract from UNDEF is UNDEF.
21683   if (V.isUndef())
21684     return DAG.getUNDEF(NVT);
21685 
21686   if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
21687     if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG))
21688       return NarrowLoad;
21689 
21690   // Combine an extract of an extract into a single extract_subvector.
21691   // ext (ext X, C), 0 --> ext X, C
21692   if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
21693     if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
21694                                     V.getConstantOperandVal(1)) &&
21695         TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
21696       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, V.getOperand(0),
21697                          V.getOperand(1));
21698     }
21699   }
21700 
21701   // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
21702   if (V.getOpcode() == ISD::SPLAT_VECTOR)
21703     if (DAG.isConstantValueOfAnyType(V.getOperand(0)) || V.hasOneUse())
21704       if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT))
21705         return DAG.getSplatVector(NVT, SDLoc(N), V.getOperand(0));
21706 
21707   // Try to move vector bitcast after extract_subv by scaling extraction index:
21708   // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
21709   if (V.getOpcode() == ISD::BITCAST &&
21710       V.getOperand(0).getValueType().isVector() &&
21711       (!LegalOperations || TLI.isOperationLegal(ISD::BITCAST, NVT))) {
21712     SDValue SrcOp = V.getOperand(0);
21713     EVT SrcVT = SrcOp.getValueType();
21714     unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
21715     unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
21716     if ((SrcNumElts % DestNumElts) == 0) {
21717       unsigned SrcDestRatio = SrcNumElts / DestNumElts;
21718       ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
21719       EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(),
21720                                       NewExtEC);
21721       if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
21722         SDLoc DL(N);
21723         SDValue NewIndex = DAG.getVectorIdxConstant(ExtIdx * SrcDestRatio, DL);
21724         SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
21725                                          V.getOperand(0), NewIndex);
21726         return DAG.getBitcast(NVT, NewExtract);
21727       }
21728     }
21729     if ((DestNumElts % SrcNumElts) == 0) {
21730       unsigned DestSrcRatio = DestNumElts / SrcNumElts;
21731       if (NVT.getVectorElementCount().isKnownMultipleOf(DestSrcRatio)) {
21732         ElementCount NewExtEC =
21733             NVT.getVectorElementCount().divideCoefficientBy(DestSrcRatio);
21734         EVT ScalarVT = SrcVT.getScalarType();
21735         if ((ExtIdx % DestSrcRatio) == 0) {
21736           SDLoc DL(N);
21737           unsigned IndexValScaled = ExtIdx / DestSrcRatio;
21738           EVT NewExtVT =
21739               EVT::getVectorVT(*DAG.getContext(), ScalarVT, NewExtEC);
21740           if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
21741             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
21742             SDValue NewExtract =
21743                 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
21744                             V.getOperand(0), NewIndex);
21745             return DAG.getBitcast(NVT, NewExtract);
21746           }
21747           if (NewExtEC.isScalar() &&
21748               TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, ScalarVT)) {
21749             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
21750             SDValue NewExtract =
21751                 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
21752                             V.getOperand(0), NewIndex);
21753             return DAG.getBitcast(NVT, NewExtract);
21754           }
21755         }
21756       }
21757     }
21758   }
21759 
21760   if (V.getOpcode() == ISD::CONCAT_VECTORS) {
21761     unsigned ExtNumElts = NVT.getVectorMinNumElements();
21762     EVT ConcatSrcVT = V.getOperand(0).getValueType();
21763     assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
21764            "Concat and extract subvector do not change element type");
21765     assert((ExtIdx % ExtNumElts) == 0 &&
21766            "Extract index is not a multiple of the input vector length.");
21767 
21768     unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
21769     unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
21770 
21771     // If the concatenated source types match this extract, it's a direct
21772     // simplification:
21773     // extract_subvec (concat V1, V2, ...), i --> Vi
21774     if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
21775       return V.getOperand(ConcatOpIdx);
21776 
21777     // If the concatenated source vectors are a multiple length of this extract,
21778     // then extract a fraction of one of those source vectors directly from a
21779     // concat operand. Example:
21780     //   v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
21781     //   v2i8 extract_subvec v8i8 Y, 6
21782     if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
21783         ConcatSrcNumElts % ExtNumElts == 0) {
21784       SDLoc DL(N);
21785       unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
21786       assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
21787              "Trying to extract from >1 concat operand?");
21788       assert(NewExtIdx % ExtNumElts == 0 &&
21789              "Extract index is not a multiple of the input vector length.");
21790       SDValue NewIndexC = DAG.getVectorIdxConstant(NewExtIdx, DL);
21791       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT,
21792                          V.getOperand(ConcatOpIdx), NewIndexC);
21793     }
21794   }
21795 
21796   if (SDValue V =
21797           foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations))
21798     return V;
21799 
21800   V = peekThroughBitcasts(V);
21801 
21802   // If the input is a build vector. Try to make a smaller build vector.
21803   if (V.getOpcode() == ISD::BUILD_VECTOR) {
21804     EVT InVT = V.getValueType();
21805     unsigned ExtractSize = NVT.getSizeInBits();
21806     unsigned EltSize = InVT.getScalarSizeInBits();
21807     // Only do this if we won't split any elements.
21808     if (ExtractSize % EltSize == 0) {
21809       unsigned NumElems = ExtractSize / EltSize;
21810       EVT EltVT = InVT.getVectorElementType();
21811       EVT ExtractVT =
21812           NumElems == 1 ? EltVT
21813                         : EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems);
21814       if ((Level < AfterLegalizeDAG ||
21815            (NumElems == 1 ||
21816             TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
21817           (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
21818         unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
21819 
21820         if (NumElems == 1) {
21821           SDValue Src = V->getOperand(IdxVal);
21822           if (EltVT != Src.getValueType())
21823             Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src);
21824           return DAG.getBitcast(NVT, Src);
21825         }
21826 
21827         // Extract the pieces from the original build_vector.
21828         SDValue BuildVec = DAG.getBuildVector(ExtractVT, SDLoc(N),
21829                                               V->ops().slice(IdxVal, NumElems));
21830         return DAG.getBitcast(NVT, BuildVec);
21831       }
21832     }
21833   }
21834 
21835   if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
21836     // Handle only simple case where vector being inserted and vector
21837     // being extracted are of same size.
21838     EVT SmallVT = V.getOperand(1).getValueType();
21839     if (!NVT.bitsEq(SmallVT))
21840       return SDValue();
21841 
21842     // Combine:
21843     //    (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
21844     // Into:
21845     //    indices are equal or bit offsets are equal => V1
21846     //    otherwise => (extract_subvec V1, ExtIdx)
21847     uint64_t InsIdx = V.getConstantOperandVal(2);
21848     if (InsIdx * SmallVT.getScalarSizeInBits() ==
21849         ExtIdx * NVT.getScalarSizeInBits()) {
21850       if (LegalOperations && !TLI.isOperationLegal(ISD::BITCAST, NVT))
21851         return SDValue();
21852 
21853       return DAG.getBitcast(NVT, V.getOperand(1));
21854     }
21855     return DAG.getNode(
21856         ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT,
21857         DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
21858         N->getOperand(1));
21859   }
21860 
21861   if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations))
21862     return NarrowBOp;
21863 
21864   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
21865     return SDValue(N, 0);
21866 
21867   return SDValue();
21868 }
21869 
21870 /// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
21871 /// followed by concatenation. Narrow vector ops may have better performance
21872 /// than wide ops, and this can unlock further narrowing of other vector ops.
21873 /// Targets can invert this transform later if it is not profitable.
foldShuffleOfConcatUndefs(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)21874 static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
21875                                          SelectionDAG &DAG) {
21876   SDValue N0 = Shuf->getOperand(0), N1 = Shuf->getOperand(1);
21877   if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
21878       N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
21879       !N0.getOperand(1).isUndef() || !N1.getOperand(1).isUndef())
21880     return SDValue();
21881 
21882   // Split the wide shuffle mask into halves. Any mask element that is accessing
21883   // operand 1 is offset down to account for narrowing of the vectors.
21884   ArrayRef<int> Mask = Shuf->getMask();
21885   EVT VT = Shuf->getValueType(0);
21886   unsigned NumElts = VT.getVectorNumElements();
21887   unsigned HalfNumElts = NumElts / 2;
21888   SmallVector<int, 16> Mask0(HalfNumElts, -1);
21889   SmallVector<int, 16> Mask1(HalfNumElts, -1);
21890   for (unsigned i = 0; i != NumElts; ++i) {
21891     if (Mask[i] == -1)
21892       continue;
21893     // If we reference the upper (undef) subvector then the element is undef.
21894     if ((Mask[i] % NumElts) >= HalfNumElts)
21895       continue;
21896     int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
21897     if (i < HalfNumElts)
21898       Mask0[i] = M;
21899     else
21900       Mask1[i - HalfNumElts] = M;
21901   }
21902 
21903   // Ask the target if this is a valid transform.
21904   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21905   EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
21906                                 HalfNumElts);
21907   if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
21908       !TLI.isShuffleMaskLegal(Mask1, HalfVT))
21909     return SDValue();
21910 
21911   // shuffle (concat X, undef), (concat Y, undef), Mask -->
21912   // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
21913   SDValue X = N0.getOperand(0), Y = N1.getOperand(0);
21914   SDLoc DL(Shuf);
21915   SDValue Shuf0 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask0);
21916   SDValue Shuf1 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask1);
21917   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Shuf0, Shuf1);
21918 }
21919 
21920 // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
21921 // or turn a shuffle of a single concat into simpler shuffle then concat.
partitionShuffleOfConcats(SDNode * N,SelectionDAG & DAG)21922 static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
21923   EVT VT = N->getValueType(0);
21924   unsigned NumElts = VT.getVectorNumElements();
21925 
21926   SDValue N0 = N->getOperand(0);
21927   SDValue N1 = N->getOperand(1);
21928   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
21929   ArrayRef<int> Mask = SVN->getMask();
21930 
21931   SmallVector<SDValue, 4> Ops;
21932   EVT ConcatVT = N0.getOperand(0).getValueType();
21933   unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
21934   unsigned NumConcats = NumElts / NumElemsPerConcat;
21935 
21936   auto IsUndefMaskElt = [](int i) { return i == -1; };
21937 
21938   // Special case: shuffle(concat(A,B)) can be more efficiently represented
21939   // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
21940   // half vector elements.
21941   if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
21942       llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
21943                    IsUndefMaskElt)) {
21944     N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
21945                               N0.getOperand(1),
21946                               Mask.slice(0, NumElemsPerConcat));
21947     N1 = DAG.getUNDEF(ConcatVT);
21948     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
21949   }
21950 
21951   // Look at every vector that's inserted. We're looking for exact
21952   // subvector-sized copies from a concatenated vector
21953   for (unsigned I = 0; I != NumConcats; ++I) {
21954     unsigned Begin = I * NumElemsPerConcat;
21955     ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
21956 
21957     // Make sure we're dealing with a copy.
21958     if (llvm::all_of(SubMask, IsUndefMaskElt)) {
21959       Ops.push_back(DAG.getUNDEF(ConcatVT));
21960       continue;
21961     }
21962 
21963     int OpIdx = -1;
21964     for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
21965       if (IsUndefMaskElt(SubMask[i]))
21966         continue;
21967       if ((SubMask[i] % (int)NumElemsPerConcat) != i)
21968         return SDValue();
21969       int EltOpIdx = SubMask[i] / NumElemsPerConcat;
21970       if (0 <= OpIdx && EltOpIdx != OpIdx)
21971         return SDValue();
21972       OpIdx = EltOpIdx;
21973     }
21974     assert(0 <= OpIdx && "Unknown concat_vectors op");
21975 
21976     if (OpIdx < (int)N0.getNumOperands())
21977       Ops.push_back(N0.getOperand(OpIdx));
21978     else
21979       Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
21980   }
21981 
21982   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
21983 }
21984 
21985 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
21986 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
21987 //
21988 // SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
21989 // a simplification in some sense, but it isn't appropriate in general: some
21990 // BUILD_VECTORs are substantially cheaper than others. The general case
21991 // of a BUILD_VECTOR requires inserting each element individually (or
21992 // performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
21993 // all constants is a single constant pool load.  A BUILD_VECTOR where each
21994 // element is identical is a splat.  A BUILD_VECTOR where most of the operands
21995 // are undef lowers to a small number of element insertions.
21996 //
21997 // To deal with this, we currently use a bunch of mostly arbitrary heuristics.
21998 // We don't fold shuffles where one side is a non-zero constant, and we don't
21999 // fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
22000 // non-constant operands. This seems to work out reasonably well in practice.
combineShuffleOfScalars(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI)22001 static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
22002                                        SelectionDAG &DAG,
22003                                        const TargetLowering &TLI) {
22004   EVT VT = SVN->getValueType(0);
22005   unsigned NumElts = VT.getVectorNumElements();
22006   SDValue N0 = SVN->getOperand(0);
22007   SDValue N1 = SVN->getOperand(1);
22008 
22009   if (!N0->hasOneUse())
22010     return SDValue();
22011 
22012   // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
22013   // discussed above.
22014   if (!N1.isUndef()) {
22015     if (!N1->hasOneUse())
22016       return SDValue();
22017 
22018     bool N0AnyConst = isAnyConstantBuildVector(N0);
22019     bool N1AnyConst = isAnyConstantBuildVector(N1);
22020     if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
22021       return SDValue();
22022     if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
22023       return SDValue();
22024   }
22025 
22026   // If both inputs are splats of the same value then we can safely merge this
22027   // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
22028   bool IsSplat = false;
22029   auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
22030   auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
22031   if (BV0 && BV1)
22032     if (SDValue Splat0 = BV0->getSplatValue())
22033       IsSplat = (Splat0 == BV1->getSplatValue());
22034 
22035   SmallVector<SDValue, 8> Ops;
22036   SmallSet<SDValue, 16> DuplicateOps;
22037   for (int M : SVN->getMask()) {
22038     SDValue Op = DAG.getUNDEF(VT.getScalarType());
22039     if (M >= 0) {
22040       int Idx = M < (int)NumElts ? M : M - NumElts;
22041       SDValue &S = (M < (int)NumElts ? N0 : N1);
22042       if (S.getOpcode() == ISD::BUILD_VECTOR) {
22043         Op = S.getOperand(Idx);
22044       } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
22045         SDValue Op0 = S.getOperand(0);
22046         Op = Idx == 0 ? Op0 : DAG.getUNDEF(Op0.getValueType());
22047       } else {
22048         // Operand can't be combined - bail out.
22049         return SDValue();
22050       }
22051     }
22052 
22053     // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
22054     // generating a splat; semantically, this is fine, but it's likely to
22055     // generate low-quality code if the target can't reconstruct an appropriate
22056     // shuffle.
22057     if (!Op.isUndef() && !isIntOrFPConstant(Op))
22058       if (!IsSplat && !DuplicateOps.insert(Op).second)
22059         return SDValue();
22060 
22061     Ops.push_back(Op);
22062   }
22063 
22064   // BUILD_VECTOR requires all inputs to be of the same type, find the
22065   // maximum type and extend them all.
22066   EVT SVT = VT.getScalarType();
22067   if (SVT.isInteger())
22068     for (SDValue &Op : Ops)
22069       SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
22070   if (SVT != VT.getScalarType())
22071     for (SDValue &Op : Ops)
22072       Op = Op.isUndef() ? DAG.getUNDEF(SVT)
22073                         : (TLI.isZExtFree(Op.getValueType(), SVT)
22074                                ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
22075                                : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT));
22076   return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
22077 }
22078 
22079 // Match shuffles that can be converted to any_vector_extend_in_reg.
22080 // This is often generated during legalization.
22081 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
22082 // TODO Add support for ZERO_EXTEND_VECTOR_INREG when we have a test case.
combineShuffleToVectorExtend(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)22083 static SDValue combineShuffleToVectorExtend(ShuffleVectorSDNode *SVN,
22084                                             SelectionDAG &DAG,
22085                                             const TargetLowering &TLI,
22086                                             bool LegalOperations) {
22087   EVT VT = SVN->getValueType(0);
22088   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
22089 
22090   // TODO Add support for big-endian when we have a test case.
22091   if (!VT.isInteger() || IsBigEndian)
22092     return SDValue();
22093 
22094   unsigned NumElts = VT.getVectorNumElements();
22095   unsigned EltSizeInBits = VT.getScalarSizeInBits();
22096   ArrayRef<int> Mask = SVN->getMask();
22097   SDValue N0 = SVN->getOperand(0);
22098 
22099   // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
22100   auto isAnyExtend = [&Mask, &NumElts](unsigned Scale) {
22101     for (unsigned i = 0; i != NumElts; ++i) {
22102       if (Mask[i] < 0)
22103         continue;
22104       if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
22105         continue;
22106       return false;
22107     }
22108     return true;
22109   };
22110 
22111   // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
22112   // power-of-2 extensions as they are the most likely.
22113   for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
22114     // Check for non power of 2 vector sizes
22115     if (NumElts % Scale != 0)
22116       continue;
22117     if (!isAnyExtend(Scale))
22118       continue;
22119 
22120     EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
22121     EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
22122     // Never create an illegal type. Only create unsupported operations if we
22123     // are pre-legalization.
22124     if (TLI.isTypeLegal(OutVT))
22125       if (!LegalOperations ||
22126           TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND_VECTOR_INREG, OutVT))
22127         return DAG.getBitcast(VT,
22128                               DAG.getNode(ISD::ANY_EXTEND_VECTOR_INREG,
22129                                           SDLoc(SVN), OutVT, N0));
22130   }
22131 
22132   return SDValue();
22133 }
22134 
22135 // Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
22136 // each source element of a large type into the lowest elements of a smaller
22137 // destination type. This is often generated during legalization.
22138 // If the source node itself was a '*_extend_vector_inreg' node then we should
22139 // then be able to remove it.
combineTruncationShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG)22140 static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
22141                                         SelectionDAG &DAG) {
22142   EVT VT = SVN->getValueType(0);
22143   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
22144 
22145   // TODO Add support for big-endian when we have a test case.
22146   if (!VT.isInteger() || IsBigEndian)
22147     return SDValue();
22148 
22149   SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
22150 
22151   unsigned Opcode = N0.getOpcode();
22152   if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG &&
22153       Opcode != ISD::SIGN_EXTEND_VECTOR_INREG &&
22154       Opcode != ISD::ZERO_EXTEND_VECTOR_INREG)
22155     return SDValue();
22156 
22157   SDValue N00 = N0.getOperand(0);
22158   ArrayRef<int> Mask = SVN->getMask();
22159   unsigned NumElts = VT.getVectorNumElements();
22160   unsigned EltSizeInBits = VT.getScalarSizeInBits();
22161   unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
22162   unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
22163 
22164   if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
22165     return SDValue();
22166   unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
22167 
22168   // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
22169   // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
22170   // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
22171   auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
22172     for (unsigned i = 0; i != NumElts; ++i) {
22173       if (Mask[i] < 0)
22174         continue;
22175       if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
22176         continue;
22177       return false;
22178     }
22179     return true;
22180   };
22181 
22182   // At the moment we just handle the case where we've truncated back to the
22183   // same size as before the extension.
22184   // TODO: handle more extension/truncation cases as cases arise.
22185   if (EltSizeInBits != ExtSrcSizeInBits)
22186     return SDValue();
22187 
22188   // We can remove *extend_vector_inreg only if the truncation happens at
22189   // the same scale as the extension.
22190   if (isTruncate(ExtScale))
22191     return DAG.getBitcast(VT, N00);
22192 
22193   return SDValue();
22194 }
22195 
22196 // Combine shuffles of splat-shuffles of the form:
22197 // shuffle (shuffle V, undef, splat-mask), undef, M
22198 // If splat-mask contains undef elements, we need to be careful about
22199 // introducing undef's in the folded mask which are not the result of composing
22200 // the masks of the shuffles.
combineShuffleOfSplatVal(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)22201 static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
22202                                         SelectionDAG &DAG) {
22203   if (!Shuf->getOperand(1).isUndef())
22204     return SDValue();
22205 
22206   // If the inner operand is a known splat with no undefs, just return that directly.
22207   // TODO: Create DemandedElts mask from Shuf's mask.
22208   // TODO: Allow undef elements and merge with the shuffle code below.
22209   if (DAG.isSplatValue(Shuf->getOperand(0), /*AllowUndefs*/ false))
22210     return Shuf->getOperand(0);
22211 
22212   auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
22213   if (!Splat || !Splat->isSplat())
22214     return SDValue();
22215 
22216   ArrayRef<int> ShufMask = Shuf->getMask();
22217   ArrayRef<int> SplatMask = Splat->getMask();
22218   assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
22219 
22220   // Prefer simplifying to the splat-shuffle, if possible. This is legal if
22221   // every undef mask element in the splat-shuffle has a corresponding undef
22222   // element in the user-shuffle's mask or if the composition of mask elements
22223   // would result in undef.
22224   // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
22225   // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
22226   //   In this case it is not legal to simplify to the splat-shuffle because we
22227   //   may be exposing the users of the shuffle an undef element at index 1
22228   //   which was not there before the combine.
22229   // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
22230   //   In this case the composition of masks yields SplatMask, so it's ok to
22231   //   simplify to the splat-shuffle.
22232   // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
22233   //   In this case the composed mask includes all undef elements of SplatMask
22234   //   and in addition sets element zero to undef. It is safe to simplify to
22235   //   the splat-shuffle.
22236   auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
22237                                        ArrayRef<int> SplatMask) {
22238     for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
22239       if (UserMask[i] != -1 && SplatMask[i] == -1 &&
22240           SplatMask[UserMask[i]] != -1)
22241         return false;
22242     return true;
22243   };
22244   if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
22245     return Shuf->getOperand(0);
22246 
22247   // Create a new shuffle with a mask that is composed of the two shuffles'
22248   // masks.
22249   SmallVector<int, 32> NewMask;
22250   for (int Idx : ShufMask)
22251     NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
22252 
22253   return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
22254                               Splat->getOperand(0), Splat->getOperand(1),
22255                               NewMask);
22256 }
22257 
22258 // Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
22259 // the mask can be treated as a larger type.
combineShuffleOfBitcast(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)22260 static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
22261                                        SelectionDAG &DAG,
22262                                        const TargetLowering &TLI,
22263                                        bool LegalOperations) {
22264   SDValue Op0 = SVN->getOperand(0);
22265   SDValue Op1 = SVN->getOperand(1);
22266   EVT VT = SVN->getValueType(0);
22267   if (Op0.getOpcode() != ISD::BITCAST)
22268     return SDValue();
22269   EVT InVT = Op0.getOperand(0).getValueType();
22270   if (!InVT.isVector() ||
22271       (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
22272                           Op1.getOperand(0).getValueType() != InVT)))
22273     return SDValue();
22274   if (isAnyConstantBuildVector(Op0.getOperand(0)) &&
22275       (Op1.isUndef() || isAnyConstantBuildVector(Op1.getOperand(0))))
22276     return SDValue();
22277 
22278   int VTLanes = VT.getVectorNumElements();
22279   int InLanes = InVT.getVectorNumElements();
22280   if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
22281       (LegalOperations &&
22282        !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, InVT)))
22283     return SDValue();
22284   int Factor = VTLanes / InLanes;
22285 
22286   // Check that each group of lanes in the mask are either undef or make a valid
22287   // mask for the wider lane type.
22288   ArrayRef<int> Mask = SVN->getMask();
22289   SmallVector<int> NewMask;
22290   if (!widenShuffleMaskElts(Factor, Mask, NewMask))
22291     return SDValue();
22292 
22293   if (!TLI.isShuffleMaskLegal(NewMask, InVT))
22294     return SDValue();
22295 
22296   // Create the new shuffle with the new mask and bitcast it back to the
22297   // original type.
22298   SDLoc DL(SVN);
22299   Op0 = Op0.getOperand(0);
22300   Op1 = Op1.isUndef() ? DAG.getUNDEF(InVT) : Op1.getOperand(0);
22301   SDValue NewShuf = DAG.getVectorShuffle(InVT, DL, Op0, Op1, NewMask);
22302   return DAG.getBitcast(VT, NewShuf);
22303 }
22304 
22305 /// Combine shuffle of shuffle of the form:
22306 /// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
formSplatFromShuffles(ShuffleVectorSDNode * OuterShuf,SelectionDAG & DAG)22307 static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
22308                                      SelectionDAG &DAG) {
22309   if (!OuterShuf->getOperand(1).isUndef())
22310     return SDValue();
22311   auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
22312   if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
22313     return SDValue();
22314 
22315   ArrayRef<int> OuterMask = OuterShuf->getMask();
22316   ArrayRef<int> InnerMask = InnerShuf->getMask();
22317   unsigned NumElts = OuterMask.size();
22318   assert(NumElts == InnerMask.size() && "Mask length mismatch");
22319   SmallVector<int, 32> CombinedMask(NumElts, -1);
22320   int SplatIndex = -1;
22321   for (unsigned i = 0; i != NumElts; ++i) {
22322     // Undef lanes remain undef.
22323     int OuterMaskElt = OuterMask[i];
22324     if (OuterMaskElt == -1)
22325       continue;
22326 
22327     // Peek through the shuffle masks to get the underlying source element.
22328     int InnerMaskElt = InnerMask[OuterMaskElt];
22329     if (InnerMaskElt == -1)
22330       continue;
22331 
22332     // Initialize the splatted element.
22333     if (SplatIndex == -1)
22334       SplatIndex = InnerMaskElt;
22335 
22336     // Non-matching index - this is not a splat.
22337     if (SplatIndex != InnerMaskElt)
22338       return SDValue();
22339 
22340     CombinedMask[i] = InnerMaskElt;
22341   }
22342   assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
22343           getSplatIndex(CombinedMask) != -1) &&
22344          "Expected a splat mask");
22345 
22346   // TODO: The transform may be a win even if the mask is not legal.
22347   EVT VT = OuterShuf->getValueType(0);
22348   assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
22349   if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
22350     return SDValue();
22351 
22352   return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
22353                               InnerShuf->getOperand(1), CombinedMask);
22354 }
22355 
22356 /// If the shuffle mask is taking exactly one element from the first vector
22357 /// operand and passing through all other elements from the second vector
22358 /// operand, return the index of the mask element that is choosing an element
22359 /// from the first operand. Otherwise, return -1.
getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask)22360 static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
22361   int MaskSize = Mask.size();
22362   int EltFromOp0 = -1;
22363   // TODO: This does not match if there are undef elements in the shuffle mask.
22364   // Should we ignore undefs in the shuffle mask instead? The trade-off is
22365   // removing an instruction (a shuffle), but losing the knowledge that some
22366   // vector lanes are not needed.
22367   for (int i = 0; i != MaskSize; ++i) {
22368     if (Mask[i] >= 0 && Mask[i] < MaskSize) {
22369       // We're looking for a shuffle of exactly one element from operand 0.
22370       if (EltFromOp0 != -1)
22371         return -1;
22372       EltFromOp0 = i;
22373     } else if (Mask[i] != i + MaskSize) {
22374       // Nothing from operand 1 can change lanes.
22375       return -1;
22376     }
22377   }
22378   return EltFromOp0;
22379 }
22380 
22381 /// If a shuffle inserts exactly one element from a source vector operand into
22382 /// another vector operand and we can access the specified element as a scalar,
22383 /// then we can eliminate the shuffle.
replaceShuffleOfInsert(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)22384 static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
22385                                       SelectionDAG &DAG) {
22386   // First, check if we are taking one element of a vector and shuffling that
22387   // element into another vector.
22388   ArrayRef<int> Mask = Shuf->getMask();
22389   SmallVector<int, 16> CommutedMask(Mask.begin(), Mask.end());
22390   SDValue Op0 = Shuf->getOperand(0);
22391   SDValue Op1 = Shuf->getOperand(1);
22392   int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
22393   if (ShufOp0Index == -1) {
22394     // Commute mask and check again.
22395     ShuffleVectorSDNode::commuteMask(CommutedMask);
22396     ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
22397     if (ShufOp0Index == -1)
22398       return SDValue();
22399     // Commute operands to match the commuted shuffle mask.
22400     std::swap(Op0, Op1);
22401     Mask = CommutedMask;
22402   }
22403 
22404   // The shuffle inserts exactly one element from operand 0 into operand 1.
22405   // Now see if we can access that element as a scalar via a real insert element
22406   // instruction.
22407   // TODO: We can try harder to locate the element as a scalar. Examples: it
22408   // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
22409   assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
22410          "Shuffle mask value must be from operand 0");
22411   if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
22412     return SDValue();
22413 
22414   auto *InsIndexC = dyn_cast<ConstantSDNode>(Op0.getOperand(2));
22415   if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
22416     return SDValue();
22417 
22418   // There's an existing insertelement with constant insertion index, so we
22419   // don't need to check the legality/profitability of a replacement operation
22420   // that differs at most in the constant value. The target should be able to
22421   // lower any of those in a similar way. If not, legalization will expand this
22422   // to a scalar-to-vector plus shuffle.
22423   //
22424   // Note that the shuffle may move the scalar from the position that the insert
22425   // element used. Therefore, our new insert element occurs at the shuffle's
22426   // mask index value, not the insert's index value.
22427   // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
22428   SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf));
22429   return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
22430                      Op1, Op0.getOperand(1), NewInsIndex);
22431 }
22432 
22433 /// If we have a unary shuffle of a shuffle, see if it can be folded away
22434 /// completely. This has the potential to lose undef knowledge because the first
22435 /// shuffle may not have an undef mask element where the second one does. So
22436 /// only call this after doing simplifications based on demanded elements.
simplifyShuffleOfShuffle(ShuffleVectorSDNode * Shuf)22437 static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
22438   // shuf (shuf0 X, Y, Mask0), undef, Mask
22439   auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
22440   if (!Shuf0 || !Shuf->getOperand(1).isUndef())
22441     return SDValue();
22442 
22443   ArrayRef<int> Mask = Shuf->getMask();
22444   ArrayRef<int> Mask0 = Shuf0->getMask();
22445   for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
22446     // Ignore undef elements.
22447     if (Mask[i] == -1)
22448       continue;
22449     assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
22450 
22451     // Is the element of the shuffle operand chosen by this shuffle the same as
22452     // the element chosen by the shuffle operand itself?
22453     if (Mask0[Mask[i]] != Mask0[i])
22454       return SDValue();
22455   }
22456   // Every element of this shuffle is identical to the result of the previous
22457   // shuffle, so we can replace this value.
22458   return Shuf->getOperand(0);
22459 }
22460 
visitVECTOR_SHUFFLE(SDNode * N)22461 SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
22462   EVT VT = N->getValueType(0);
22463   unsigned NumElts = VT.getVectorNumElements();
22464 
22465   SDValue N0 = N->getOperand(0);
22466   SDValue N1 = N->getOperand(1);
22467 
22468   assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
22469 
22470   // Canonicalize shuffle undef, undef -> undef
22471   if (N0.isUndef() && N1.isUndef())
22472     return DAG.getUNDEF(VT);
22473 
22474   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
22475 
22476   // Canonicalize shuffle v, v -> v, undef
22477   if (N0 == N1)
22478     return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT),
22479                                 createUnaryMask(SVN->getMask(), NumElts));
22480 
22481   // Canonicalize shuffle undef, v -> v, undef.  Commute the shuffle mask.
22482   if (N0.isUndef())
22483     return DAG.getCommutedVectorShuffle(*SVN);
22484 
22485   // Remove references to rhs if it is undef
22486   if (N1.isUndef()) {
22487     bool Changed = false;
22488     SmallVector<int, 8> NewMask;
22489     for (unsigned i = 0; i != NumElts; ++i) {
22490       int Idx = SVN->getMaskElt(i);
22491       if (Idx >= (int)NumElts) {
22492         Idx = -1;
22493         Changed = true;
22494       }
22495       NewMask.push_back(Idx);
22496     }
22497     if (Changed)
22498       return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
22499   }
22500 
22501   if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG))
22502     return InsElt;
22503 
22504   // A shuffle of a single vector that is a splatted value can always be folded.
22505   if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
22506     return V;
22507 
22508   if (SDValue V = formSplatFromShuffles(SVN, DAG))
22509     return V;
22510 
22511   // If it is a splat, check if the argument vector is another splat or a
22512   // build_vector.
22513   if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
22514     int SplatIndex = SVN->getSplatIndex();
22515     if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) &&
22516         TLI.isBinOp(N0.getOpcode()) && N0->getNumValues() == 1) {
22517       // splat (vector_bo L, R), Index -->
22518       // splat (scalar_bo (extelt L, Index), (extelt R, Index))
22519       SDValue L = N0.getOperand(0), R = N0.getOperand(1);
22520       SDLoc DL(N);
22521       EVT EltVT = VT.getScalarType();
22522       SDValue Index = DAG.getVectorIdxConstant(SplatIndex, DL);
22523       SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index);
22524       SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index);
22525       SDValue NewBO =
22526           DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR, N0->getFlags());
22527       SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO);
22528       SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
22529       return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask);
22530     }
22531 
22532     // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
22533     // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
22534     if ((!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) &&
22535         N0.hasOneUse()) {
22536       if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
22537         return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(0));
22538 
22539       if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
22540         if (auto *Idx = dyn_cast<ConstantSDNode>(N0.getOperand(2)))
22541           if (Idx->getAPIntValue() == SplatIndex)
22542             return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(1));
22543     }
22544 
22545     // If this is a bit convert that changes the element type of the vector but
22546     // not the number of vector elements, look through it.  Be careful not to
22547     // look though conversions that change things like v4f32 to v2f64.
22548     SDNode *V = N0.getNode();
22549     if (V->getOpcode() == ISD::BITCAST) {
22550       SDValue ConvInput = V->getOperand(0);
22551       if (ConvInput.getValueType().isVector() &&
22552           ConvInput.getValueType().getVectorNumElements() == NumElts)
22553         V = ConvInput.getNode();
22554     }
22555 
22556     if (V->getOpcode() == ISD::BUILD_VECTOR) {
22557       assert(V->getNumOperands() == NumElts &&
22558              "BUILD_VECTOR has wrong number of operands");
22559       SDValue Base;
22560       bool AllSame = true;
22561       for (unsigned i = 0; i != NumElts; ++i) {
22562         if (!V->getOperand(i).isUndef()) {
22563           Base = V->getOperand(i);
22564           break;
22565         }
22566       }
22567       // Splat of <u, u, u, u>, return <u, u, u, u>
22568       if (!Base.getNode())
22569         return N0;
22570       for (unsigned i = 0; i != NumElts; ++i) {
22571         if (V->getOperand(i) != Base) {
22572           AllSame = false;
22573           break;
22574         }
22575       }
22576       // Splat of <x, x, x, x>, return <x, x, x, x>
22577       if (AllSame)
22578         return N0;
22579 
22580       // Canonicalize any other splat as a build_vector.
22581       SDValue Splatted = V->getOperand(SplatIndex);
22582       SmallVector<SDValue, 8> Ops(NumElts, Splatted);
22583       SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
22584 
22585       // We may have jumped through bitcasts, so the type of the
22586       // BUILD_VECTOR may not match the type of the shuffle.
22587       if (V->getValueType(0) != VT)
22588         NewBV = DAG.getBitcast(VT, NewBV);
22589       return NewBV;
22590     }
22591   }
22592 
22593   // Simplify source operands based on shuffle mask.
22594   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
22595     return SDValue(N, 0);
22596 
22597   // This is intentionally placed after demanded elements simplification because
22598   // it could eliminate knowledge of undef elements created by this shuffle.
22599   if (SDValue ShufOp = simplifyShuffleOfShuffle(SVN))
22600     return ShufOp;
22601 
22602   // Match shuffles that can be converted to any_vector_extend_in_reg.
22603   if (SDValue V = combineShuffleToVectorExtend(SVN, DAG, TLI, LegalOperations))
22604     return V;
22605 
22606   // Combine "truncate_vector_in_reg" style shuffles.
22607   if (SDValue V = combineTruncationShuffle(SVN, DAG))
22608     return V;
22609 
22610   if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
22611       Level < AfterLegalizeVectorOps &&
22612       (N1.isUndef() ||
22613       (N1.getOpcode() == ISD::CONCAT_VECTORS &&
22614        N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
22615     if (SDValue V = partitionShuffleOfConcats(N, DAG))
22616       return V;
22617   }
22618 
22619   // A shuffle of a concat of the same narrow vector can be reduced to use
22620   // only low-half elements of a concat with undef:
22621   // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
22622   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
22623       N0.getNumOperands() == 2 &&
22624       N0.getOperand(0) == N0.getOperand(1)) {
22625     int HalfNumElts = (int)NumElts / 2;
22626     SmallVector<int, 8> NewMask;
22627     for (unsigned i = 0; i != NumElts; ++i) {
22628       int Idx = SVN->getMaskElt(i);
22629       if (Idx >= HalfNumElts) {
22630         assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
22631         Idx -= HalfNumElts;
22632       }
22633       NewMask.push_back(Idx);
22634     }
22635     if (TLI.isShuffleMaskLegal(NewMask, VT)) {
22636       SDValue UndefVec = DAG.getUNDEF(N0.getOperand(0).getValueType());
22637       SDValue NewCat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
22638                                    N0.getOperand(0), UndefVec);
22639       return DAG.getVectorShuffle(VT, SDLoc(N), NewCat, N1, NewMask);
22640     }
22641   }
22642 
22643   // See if we can replace a shuffle with an insert_subvector.
22644   // e.g. v2i32 into v8i32:
22645   // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
22646   // --> insert_subvector(lhs,rhs1,4).
22647   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
22648       TLI.isOperationLegalOrCustom(ISD::INSERT_SUBVECTOR, VT)) {
22649     auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
22650       // Ensure RHS subvectors are legal.
22651       assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
22652       EVT SubVT = RHS.getOperand(0).getValueType();
22653       int NumSubVecs = RHS.getNumOperands();
22654       int NumSubElts = SubVT.getVectorNumElements();
22655       assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
22656       if (!TLI.isTypeLegal(SubVT))
22657         return SDValue();
22658 
22659       // Don't bother if we have an unary shuffle (matches undef + LHS elts).
22660       if (all_of(Mask, [NumElts](int M) { return M < (int)NumElts; }))
22661         return SDValue();
22662 
22663       // Search [NumSubElts] spans for RHS sequence.
22664       // TODO: Can we avoid nested loops to increase performance?
22665       SmallVector<int> InsertionMask(NumElts);
22666       for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
22667         for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
22668           // Reset mask to identity.
22669           std::iota(InsertionMask.begin(), InsertionMask.end(), 0);
22670 
22671           // Add subvector insertion.
22672           std::iota(InsertionMask.begin() + SubIdx,
22673                     InsertionMask.begin() + SubIdx + NumSubElts,
22674                     NumElts + (SubVec * NumSubElts));
22675 
22676           // See if the shuffle mask matches the reference insertion mask.
22677           bool MatchingShuffle = true;
22678           for (int i = 0; i != (int)NumElts; ++i) {
22679             int ExpectIdx = InsertionMask[i];
22680             int ActualIdx = Mask[i];
22681             if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
22682               MatchingShuffle = false;
22683               break;
22684             }
22685           }
22686 
22687           if (MatchingShuffle)
22688             return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, LHS,
22689                                RHS.getOperand(SubVec),
22690                                DAG.getVectorIdxConstant(SubIdx, SDLoc(N)));
22691         }
22692       }
22693       return SDValue();
22694     };
22695     ArrayRef<int> Mask = SVN->getMask();
22696     if (N1.getOpcode() == ISD::CONCAT_VECTORS)
22697       if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
22698         return InsertN1;
22699     if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
22700       SmallVector<int> CommuteMask(Mask.begin(), Mask.end());
22701       ShuffleVectorSDNode::commuteMask(CommuteMask);
22702       if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
22703         return InsertN0;
22704     }
22705   }
22706 
22707   // If we're not performing a select/blend shuffle, see if we can convert the
22708   // shuffle into a AND node, with all the out-of-lane elements are known zero.
22709   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
22710     bool IsInLaneMask = true;
22711     ArrayRef<int> Mask = SVN->getMask();
22712     SmallVector<int, 16> ClearMask(NumElts, -1);
22713     APInt DemandedLHS = APInt::getNullValue(NumElts);
22714     APInt DemandedRHS = APInt::getNullValue(NumElts);
22715     for (int I = 0; I != (int)NumElts; ++I) {
22716       int M = Mask[I];
22717       if (M < 0)
22718         continue;
22719       ClearMask[I] = M == I ? I : (I + NumElts);
22720       IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
22721       if (M != I) {
22722         APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
22723         Demanded.setBit(M % NumElts);
22724       }
22725     }
22726     // TODO: Should we try to mask with N1 as well?
22727     if (!IsInLaneMask &&
22728         (!DemandedLHS.isNullValue() || !DemandedRHS.isNullValue()) &&
22729         (DemandedLHS.isNullValue() ||
22730          DAG.MaskedVectorIsZero(N0, DemandedLHS)) &&
22731         (DemandedRHS.isNullValue() ||
22732          DAG.MaskedVectorIsZero(N1, DemandedRHS))) {
22733       SDLoc DL(N);
22734       EVT IntVT = VT.changeVectorElementTypeToInteger();
22735       EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
22736       // Transform the type to a legal type so that the buildvector constant
22737       // elements are not illegal. Make sure that the result is larger than the
22738       // original type, incase the value is split into two (eg i64->i32).
22739       if (!TLI.isTypeLegal(IntSVT) && LegalTypes)
22740         IntSVT = TLI.getTypeToTransformTo(*DAG.getContext(), IntSVT);
22741       if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
22742         SDValue ZeroElt = DAG.getConstant(0, DL, IntSVT);
22743         SDValue AllOnesElt = DAG.getAllOnesConstant(DL, IntSVT);
22744         SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(IntSVT));
22745         for (int I = 0; I != (int)NumElts; ++I)
22746           if (0 <= Mask[I])
22747             AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
22748 
22749         // See if a clear mask is legal instead of going via
22750         // XformToShuffleWithZero which loses UNDEF mask elements.
22751         if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
22752           return DAG.getBitcast(
22753               VT, DAG.getVectorShuffle(IntVT, DL, DAG.getBitcast(IntVT, N0),
22754                                       DAG.getConstant(0, DL, IntVT), ClearMask));
22755 
22756         if (TLI.isOperationLegalOrCustom(ISD::AND, IntVT))
22757           return DAG.getBitcast(
22758               VT, DAG.getNode(ISD::AND, DL, IntVT, DAG.getBitcast(IntVT, N0),
22759                               DAG.getBuildVector(IntVT, DL, AndMask)));
22760       }
22761     }
22762   }
22763 
22764   // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
22765   // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
22766   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
22767     if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
22768       return Res;
22769 
22770   // If this shuffle only has a single input that is a bitcasted shuffle,
22771   // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
22772   // back to their original types.
22773   if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
22774       N1.isUndef() && Level < AfterLegalizeVectorOps &&
22775       TLI.isTypeLegal(VT)) {
22776 
22777     SDValue BC0 = peekThroughOneUseBitcasts(N0);
22778     if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
22779       EVT SVT = VT.getScalarType();
22780       EVT InnerVT = BC0->getValueType(0);
22781       EVT InnerSVT = InnerVT.getScalarType();
22782 
22783       // Determine which shuffle works with the smaller scalar type.
22784       EVT ScaleVT = SVT.bitsLT(InnerSVT) ? VT : InnerVT;
22785       EVT ScaleSVT = ScaleVT.getScalarType();
22786 
22787       if (TLI.isTypeLegal(ScaleVT) &&
22788           0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
22789           0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
22790         int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
22791         int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
22792 
22793         // Scale the shuffle masks to the smaller scalar type.
22794         ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0);
22795         SmallVector<int, 8> InnerMask;
22796         SmallVector<int, 8> OuterMask;
22797         narrowShuffleMaskElts(InnerScale, InnerSVN->getMask(), InnerMask);
22798         narrowShuffleMaskElts(OuterScale, SVN->getMask(), OuterMask);
22799 
22800         // Merge the shuffle masks.
22801         SmallVector<int, 8> NewMask;
22802         for (int M : OuterMask)
22803           NewMask.push_back(M < 0 ? -1 : InnerMask[M]);
22804 
22805         // Test for shuffle mask legality over both commutations.
22806         SDValue SV0 = BC0->getOperand(0);
22807         SDValue SV1 = BC0->getOperand(1);
22808         bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
22809         if (!LegalMask) {
22810           std::swap(SV0, SV1);
22811           ShuffleVectorSDNode::commuteMask(NewMask);
22812           LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
22813         }
22814 
22815         if (LegalMask) {
22816           SV0 = DAG.getBitcast(ScaleVT, SV0);
22817           SV1 = DAG.getBitcast(ScaleVT, SV1);
22818           return DAG.getBitcast(
22819               VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
22820         }
22821       }
22822     }
22823   }
22824 
22825   // Match shuffles of bitcasts, so long as the mask can be treated as the
22826   // larger type.
22827   if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
22828     return V;
22829 
22830   // Compute the combined shuffle mask for a shuffle with SV0 as the first
22831   // operand, and SV1 as the second operand.
22832   // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
22833   //      Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
22834   auto MergeInnerShuffle =
22835       [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
22836                      ShuffleVectorSDNode *OtherSVN, SDValue N1,
22837                      const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
22838                      SmallVectorImpl<int> &Mask) -> bool {
22839     // Don't try to fold splats; they're likely to simplify somehow, or they
22840     // might be free.
22841     if (OtherSVN->isSplat())
22842       return false;
22843 
22844     SV0 = SV1 = SDValue();
22845     Mask.clear();
22846 
22847     for (unsigned i = 0; i != NumElts; ++i) {
22848       int Idx = SVN->getMaskElt(i);
22849       if (Idx < 0) {
22850         // Propagate Undef.
22851         Mask.push_back(Idx);
22852         continue;
22853       }
22854 
22855       if (Commute)
22856         Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
22857 
22858       SDValue CurrentVec;
22859       if (Idx < (int)NumElts) {
22860         // This shuffle index refers to the inner shuffle N0. Lookup the inner
22861         // shuffle mask to identify which vector is actually referenced.
22862         Idx = OtherSVN->getMaskElt(Idx);
22863         if (Idx < 0) {
22864           // Propagate Undef.
22865           Mask.push_back(Idx);
22866           continue;
22867         }
22868         CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(0)
22869                                           : OtherSVN->getOperand(1);
22870       } else {
22871         // This shuffle index references an element within N1.
22872         CurrentVec = N1;
22873       }
22874 
22875       // Simple case where 'CurrentVec' is UNDEF.
22876       if (CurrentVec.isUndef()) {
22877         Mask.push_back(-1);
22878         continue;
22879       }
22880 
22881       // Canonicalize the shuffle index. We don't know yet if CurrentVec
22882       // will be the first or second operand of the combined shuffle.
22883       Idx = Idx % NumElts;
22884       if (!SV0.getNode() || SV0 == CurrentVec) {
22885         // Ok. CurrentVec is the left hand side.
22886         // Update the mask accordingly.
22887         SV0 = CurrentVec;
22888         Mask.push_back(Idx);
22889         continue;
22890       }
22891       if (!SV1.getNode() || SV1 == CurrentVec) {
22892         // Ok. CurrentVec is the right hand side.
22893         // Update the mask accordingly.
22894         SV1 = CurrentVec;
22895         Mask.push_back(Idx + NumElts);
22896         continue;
22897       }
22898 
22899       // Last chance - see if the vector is another shuffle and if it
22900       // uses one of the existing candidate shuffle ops.
22901       if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(CurrentVec)) {
22902         int InnerIdx = CurrentSVN->getMaskElt(Idx);
22903         if (InnerIdx < 0) {
22904           Mask.push_back(-1);
22905           continue;
22906         }
22907         SDValue InnerVec = (InnerIdx < (int)NumElts)
22908                                ? CurrentSVN->getOperand(0)
22909                                : CurrentSVN->getOperand(1);
22910         if (InnerVec.isUndef()) {
22911           Mask.push_back(-1);
22912           continue;
22913         }
22914         InnerIdx %= NumElts;
22915         if (InnerVec == SV0) {
22916           Mask.push_back(InnerIdx);
22917           continue;
22918         }
22919         if (InnerVec == SV1) {
22920           Mask.push_back(InnerIdx + NumElts);
22921           continue;
22922         }
22923       }
22924 
22925       // Bail out if we cannot convert the shuffle pair into a single shuffle.
22926       return false;
22927     }
22928 
22929     if (llvm::all_of(Mask, [](int M) { return M < 0; }))
22930       return true;
22931 
22932     // Avoid introducing shuffles with illegal mask.
22933     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
22934     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
22935     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
22936     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
22937     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
22938     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
22939     if (TLI.isShuffleMaskLegal(Mask, VT))
22940       return true;
22941 
22942     std::swap(SV0, SV1);
22943     ShuffleVectorSDNode::commuteMask(Mask);
22944     return TLI.isShuffleMaskLegal(Mask, VT);
22945   };
22946 
22947   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
22948     // Canonicalize shuffles according to rules:
22949     //  shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
22950     //  shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
22951     //  shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
22952     if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
22953         N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
22954       // The incoming shuffle must be of the same type as the result of the
22955       // current shuffle.
22956       assert(N1->getOperand(0).getValueType() == VT &&
22957              "Shuffle types don't match");
22958 
22959       SDValue SV0 = N1->getOperand(0);
22960       SDValue SV1 = N1->getOperand(1);
22961       bool HasSameOp0 = N0 == SV0;
22962       bool IsSV1Undef = SV1.isUndef();
22963       if (HasSameOp0 || IsSV1Undef || N0 == SV1)
22964         // Commute the operands of this shuffle so merging below will trigger.
22965         return DAG.getCommutedVectorShuffle(*SVN);
22966     }
22967 
22968     // Canonicalize splat shuffles to the RHS to improve merging below.
22969     //  shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
22970     if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
22971         N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
22972         cast<ShuffleVectorSDNode>(N0)->isSplat() &&
22973         !cast<ShuffleVectorSDNode>(N1)->isSplat()) {
22974       return DAG.getCommutedVectorShuffle(*SVN);
22975     }
22976 
22977     // Try to fold according to rules:
22978     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
22979     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
22980     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
22981     // Don't try to fold shuffles with illegal type.
22982     // Only fold if this shuffle is the only user of the other shuffle.
22983     // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
22984     for (int i = 0; i != 2; ++i) {
22985       if (N->getOperand(i).getOpcode() == ISD::VECTOR_SHUFFLE &&
22986           N->isOnlyUserOf(N->getOperand(i).getNode())) {
22987         // The incoming shuffle must be of the same type as the result of the
22988         // current shuffle.
22989         auto *OtherSV = cast<ShuffleVectorSDNode>(N->getOperand(i));
22990         assert(OtherSV->getOperand(0).getValueType() == VT &&
22991                "Shuffle types don't match");
22992 
22993         SDValue SV0, SV1;
22994         SmallVector<int, 4> Mask;
22995         if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(1 - i), TLI,
22996                               SV0, SV1, Mask)) {
22997           // Check if all indices in Mask are Undef. In case, propagate Undef.
22998           if (llvm::all_of(Mask, [](int M) { return M < 0; }))
22999             return DAG.getUNDEF(VT);
23000 
23001           return DAG.getVectorShuffle(VT, SDLoc(N),
23002                                       SV0 ? SV0 : DAG.getUNDEF(VT),
23003                                       SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
23004         }
23005       }
23006     }
23007 
23008     // Merge shuffles through binops if we are able to merge it with at least
23009     // one other shuffles.
23010     // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
23011     // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
23012     unsigned SrcOpcode = N0.getOpcode();
23013     if (TLI.isBinOp(SrcOpcode) && N->isOnlyUserOf(N0.getNode()) &&
23014         (N1.isUndef() ||
23015          (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N1.getNode())))) {
23016       // Get binop source ops, or just pass on the undef.
23017       SDValue Op00 = N0.getOperand(0);
23018       SDValue Op01 = N0.getOperand(1);
23019       SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(0);
23020       SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(1);
23021       // TODO: We might be able to relax the VT check but we don't currently
23022       // have any isBinOp() that has different result/ops VTs so play safe until
23023       // we have test coverage.
23024       if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
23025           Op01.getValueType() == VT && Op11.getValueType() == VT &&
23026           (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
23027            Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
23028            Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
23029            Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
23030         auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
23031                                         SmallVectorImpl<int> &Mask, bool LeftOp,
23032                                         bool Commute) {
23033           SDValue InnerN = Commute ? N1 : N0;
23034           SDValue Op0 = LeftOp ? Op00 : Op01;
23035           SDValue Op1 = LeftOp ? Op10 : Op11;
23036           if (Commute)
23037             std::swap(Op0, Op1);
23038           // Only accept the merged shuffle if we don't introduce undef elements,
23039           // or the inner shuffle already contained undef elements.
23040           auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Op0);
23041           return SVN0 && InnerN->isOnlyUserOf(SVN0) &&
23042                  MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
23043                                    Mask) &&
23044                  (llvm::any_of(SVN0->getMask(), [](int M) { return M < 0; }) ||
23045                   llvm::none_of(Mask, [](int M) { return M < 0; }));
23046         };
23047 
23048         // Ensure we don't increase the number of shuffles - we must merge a
23049         // shuffle from at least one of the LHS and RHS ops.
23050         bool MergedLeft = false;
23051         SDValue LeftSV0, LeftSV1;
23052         SmallVector<int, 4> LeftMask;
23053         if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
23054             CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
23055           MergedLeft = true;
23056         } else {
23057           LeftMask.assign(SVN->getMask().begin(), SVN->getMask().end());
23058           LeftSV0 = Op00, LeftSV1 = Op10;
23059         }
23060 
23061         bool MergedRight = false;
23062         SDValue RightSV0, RightSV1;
23063         SmallVector<int, 4> RightMask;
23064         if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
23065             CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
23066           MergedRight = true;
23067         } else {
23068           RightMask.assign(SVN->getMask().begin(), SVN->getMask().end());
23069           RightSV0 = Op01, RightSV1 = Op11;
23070         }
23071 
23072         if (MergedLeft || MergedRight) {
23073           SDLoc DL(N);
23074           SDValue LHS = DAG.getVectorShuffle(
23075               VT, DL, LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
23076               LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), LeftMask);
23077           SDValue RHS = DAG.getVectorShuffle(
23078               VT, DL, RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
23079               RightSV1 ? RightSV1 : DAG.getUNDEF(VT), RightMask);
23080           return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
23081         }
23082       }
23083     }
23084   }
23085 
23086   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
23087     return V;
23088 
23089   return SDValue();
23090 }
23091 
visitSCALAR_TO_VECTOR(SDNode * N)23092 SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
23093   SDValue InVal = N->getOperand(0);
23094   EVT VT = N->getValueType(0);
23095 
23096   // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
23097   // with a VECTOR_SHUFFLE and possible truncate.
23098   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23099       VT.isFixedLengthVector() &&
23100       InVal->getOperand(0).getValueType().isFixedLengthVector()) {
23101     SDValue InVec = InVal->getOperand(0);
23102     SDValue EltNo = InVal->getOperand(1);
23103     auto InVecT = InVec.getValueType();
23104     if (ConstantSDNode *C0 = dyn_cast<ConstantSDNode>(EltNo)) {
23105       SmallVector<int, 8> NewMask(InVecT.getVectorNumElements(), -1);
23106       int Elt = C0->getZExtValue();
23107       NewMask[0] = Elt;
23108       // If we have an implict truncate do truncate here as long as it's legal.
23109       // if it's not legal, this should
23110       if (VT.getScalarType() != InVal.getValueType() &&
23111           InVal.getValueType().isScalarInteger() &&
23112           isTypeLegal(VT.getScalarType())) {
23113         SDValue Val =
23114             DAG.getNode(ISD::TRUNCATE, SDLoc(InVal), VT.getScalarType(), InVal);
23115         return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
23116       }
23117       if (VT.getScalarType() == InVecT.getScalarType() &&
23118           VT.getVectorNumElements() <= InVecT.getVectorNumElements()) {
23119         SDValue LegalShuffle =
23120           TLI.buildLegalVectorShuffle(InVecT, SDLoc(N), InVec,
23121                                       DAG.getUNDEF(InVecT), NewMask, DAG);
23122         if (LegalShuffle) {
23123           // If the initial vector is the correct size this shuffle is a
23124           // valid result.
23125           if (VT == InVecT)
23126             return LegalShuffle;
23127           // If not we must truncate the vector.
23128           if (VT.getVectorNumElements() != InVecT.getVectorNumElements()) {
23129             SDValue ZeroIdx = DAG.getVectorIdxConstant(0, SDLoc(N));
23130             EVT SubVT = EVT::getVectorVT(*DAG.getContext(),
23131                                          InVecT.getVectorElementType(),
23132                                          VT.getVectorNumElements());
23133             return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT,
23134                                LegalShuffle, ZeroIdx);
23135           }
23136         }
23137       }
23138     }
23139   }
23140 
23141   return SDValue();
23142 }
23143 
visitINSERT_SUBVECTOR(SDNode * N)23144 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
23145   EVT VT = N->getValueType(0);
23146   SDValue N0 = N->getOperand(0);
23147   SDValue N1 = N->getOperand(1);
23148   SDValue N2 = N->getOperand(2);
23149   uint64_t InsIdx = N->getConstantOperandVal(2);
23150 
23151   // If inserting an UNDEF, just return the original vector.
23152   if (N1.isUndef())
23153     return N0;
23154 
23155   // If this is an insert of an extracted vector into an undef vector, we can
23156   // just use the input to the extract.
23157   if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
23158       N1.getOperand(1) == N2 && N1.getOperand(0).getValueType() == VT)
23159     return N1.getOperand(0);
23160 
23161   // Simplify scalar inserts into an undef vector:
23162   // insert_subvector undef, (splat X), N2 -> splat X
23163   if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
23164     return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, N1.getOperand(0));
23165 
23166   // If we are inserting a bitcast value into an undef, with the same
23167   // number of elements, just use the bitcast input of the extract.
23168   // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
23169   //        BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
23170   if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
23171       N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
23172       N1.getOperand(0).getOperand(1) == N2 &&
23173       N1.getOperand(0).getOperand(0).getValueType().getVectorElementCount() ==
23174           VT.getVectorElementCount() &&
23175       N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
23176           VT.getSizeInBits()) {
23177     return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
23178   }
23179 
23180   // If both N1 and N2 are bitcast values on which insert_subvector
23181   // would makes sense, pull the bitcast through.
23182   // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
23183   //        BITCAST (INSERT_SUBVECTOR N0 N1 N2)
23184   if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
23185     SDValue CN0 = N0.getOperand(0);
23186     SDValue CN1 = N1.getOperand(0);
23187     EVT CN0VT = CN0.getValueType();
23188     EVT CN1VT = CN1.getValueType();
23189     if (CN0VT.isVector() && CN1VT.isVector() &&
23190         CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
23191         CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
23192       SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
23193                                       CN0.getValueType(), CN0, CN1, N2);
23194       return DAG.getBitcast(VT, NewINSERT);
23195     }
23196   }
23197 
23198   // Combine INSERT_SUBVECTORs where we are inserting to the same index.
23199   // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
23200   // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
23201   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
23202       N0.getOperand(1).getValueType() == N1.getValueType() &&
23203       N0.getOperand(2) == N2)
23204     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
23205                        N1, N2);
23206 
23207   // Eliminate an intermediate insert into an undef vector:
23208   // insert_subvector undef, (insert_subvector undef, X, 0), N2 -->
23209   // insert_subvector undef, X, N2
23210   if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
23211       N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)))
23212     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
23213                        N1.getOperand(1), N2);
23214 
23215   // Push subvector bitcasts to the output, adjusting the index as we go.
23216   // insert_subvector(bitcast(v), bitcast(s), c1)
23217   // -> bitcast(insert_subvector(v, s, c2))
23218   if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
23219       N1.getOpcode() == ISD::BITCAST) {
23220     SDValue N0Src = peekThroughBitcasts(N0);
23221     SDValue N1Src = peekThroughBitcasts(N1);
23222     EVT N0SrcSVT = N0Src.getValueType().getScalarType();
23223     EVT N1SrcSVT = N1Src.getValueType().getScalarType();
23224     if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
23225         N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
23226       EVT NewVT;
23227       SDLoc DL(N);
23228       SDValue NewIdx;
23229       LLVMContext &Ctx = *DAG.getContext();
23230       ElementCount NumElts = VT.getVectorElementCount();
23231       unsigned EltSizeInBits = VT.getScalarSizeInBits();
23232       if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
23233         unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
23234         NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale);
23235         NewIdx = DAG.getVectorIdxConstant(InsIdx * Scale, DL);
23236       } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
23237         unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
23238         if (NumElts.isKnownMultipleOf(Scale) && (InsIdx % Scale) == 0) {
23239           NewVT = EVT::getVectorVT(Ctx, N1SrcSVT,
23240                                    NumElts.divideCoefficientBy(Scale));
23241           NewIdx = DAG.getVectorIdxConstant(InsIdx / Scale, DL);
23242         }
23243       }
23244       if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) {
23245         SDValue Res = DAG.getBitcast(NewVT, N0Src);
23246         Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, Res, N1Src, NewIdx);
23247         return DAG.getBitcast(VT, Res);
23248       }
23249     }
23250   }
23251 
23252   // Canonicalize insert_subvector dag nodes.
23253   // Example:
23254   // (insert_subvector (insert_subvector A, Idx0), Idx1)
23255   // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
23256   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
23257       N1.getValueType() == N0.getOperand(1).getValueType()) {
23258     unsigned OtherIdx = N0.getConstantOperandVal(2);
23259     if (InsIdx < OtherIdx) {
23260       // Swap nodes.
23261       SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
23262                                   N0.getOperand(0), N1, N2);
23263       AddToWorklist(NewOp.getNode());
23264       return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
23265                          VT, NewOp, N0.getOperand(1), N0.getOperand(2));
23266     }
23267   }
23268 
23269   // If the input vector is a concatenation, and the insert replaces
23270   // one of the pieces, we can optimize into a single concat_vectors.
23271   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
23272       N0.getOperand(0).getValueType() == N1.getValueType() &&
23273       N0.getOperand(0).getValueType().isScalableVector() ==
23274           N1.getValueType().isScalableVector()) {
23275     unsigned Factor = N1.getValueType().getVectorMinNumElements();
23276     SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
23277     Ops[InsIdx / Factor] = N1;
23278     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
23279   }
23280 
23281   // Simplify source operands based on insertion.
23282   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
23283     return SDValue(N, 0);
23284 
23285   return SDValue();
23286 }
23287 
visitFP_TO_FP16(SDNode * N)23288 SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
23289   SDValue N0 = N->getOperand(0);
23290 
23291   // fold (fp_to_fp16 (fp16_to_fp op)) -> op
23292   if (N0->getOpcode() == ISD::FP16_TO_FP)
23293     return N0->getOperand(0);
23294 
23295   return SDValue();
23296 }
23297 
visitFP16_TO_FP(SDNode * N)23298 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
23299   SDValue N0 = N->getOperand(0);
23300 
23301   // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op)
23302   if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
23303     ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
23304     if (AndConst && AndConst->getAPIntValue() == 0xffff) {
23305       return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0),
23306                          N0.getOperand(0));
23307     }
23308   }
23309 
23310   return SDValue();
23311 }
23312 
visitFP_TO_BF16(SDNode * N)23313 SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
23314   SDValue N0 = N->getOperand(0);
23315 
23316   // fold (fp_to_bf16 (bf16_to_fp op)) -> op
23317   if (N0->getOpcode() == ISD::BF16_TO_FP)
23318     return N0->getOperand(0);
23319 
23320   return SDValue();
23321 }
23322 
visitVECREDUCE(SDNode * N)23323 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
23324   SDValue N0 = N->getOperand(0);
23325   EVT VT = N0.getValueType();
23326   unsigned Opcode = N->getOpcode();
23327 
23328   // VECREDUCE over 1-element vector is just an extract.
23329   if (VT.getVectorElementCount().isScalar()) {
23330     SDLoc dl(N);
23331     SDValue Res =
23332         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
23333                     DAG.getVectorIdxConstant(0, dl));
23334     if (Res.getValueType() != N->getValueType(0))
23335       Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
23336     return Res;
23337   }
23338 
23339   // On an boolean vector an and/or reduction is the same as a umin/umax
23340   // reduction. Convert them if the latter is legal while the former isn't.
23341   if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
23342     unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
23343         ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
23344     if (!TLI.isOperationLegalOrCustom(Opcode, VT) &&
23345         TLI.isOperationLegalOrCustom(NewOpcode, VT) &&
23346         DAG.ComputeNumSignBits(N0) == VT.getScalarSizeInBits())
23347       return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0);
23348   }
23349 
23350   // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
23351   // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
23352   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
23353       TLI.isTypeLegal(N0.getOperand(1).getValueType())) {
23354     SDValue Vec = N0.getOperand(0);
23355     SDValue Subvec = N0.getOperand(1);
23356     if ((Opcode == ISD::VECREDUCE_OR &&
23357          (N0.getOperand(0).isUndef() || isNullOrNullSplat(Vec))) ||
23358         (Opcode == ISD::VECREDUCE_AND &&
23359          (N0.getOperand(0).isUndef() || isAllOnesOrAllOnesSplat(Vec))))
23360       return DAG.getNode(Opcode, SDLoc(N), N->getValueType(0), Subvec);
23361   }
23362 
23363   return SDValue();
23364 }
23365 
visitVPOp(SDNode * N)23366 SDValue DAGCombiner::visitVPOp(SDNode *N) {
23367   // VP operations in which all vector elements are disabled - either by
23368   // determining that the mask is all false or that the EVL is 0 - can be
23369   // eliminated.
23370   bool AreAllEltsDisabled = false;
23371   if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(N->getOpcode()))
23372     AreAllEltsDisabled |= isNullConstant(N->getOperand(*EVLIdx));
23373   if (auto MaskIdx = ISD::getVPMaskIdx(N->getOpcode()))
23374     AreAllEltsDisabled |=
23375         ISD::isConstantSplatVectorAllZeros(N->getOperand(*MaskIdx).getNode());
23376 
23377   // This is the only generic VP combine we support for now.
23378   if (!AreAllEltsDisabled)
23379     return SDValue();
23380 
23381   // Binary operations can be replaced by UNDEF.
23382   if (ISD::isVPBinaryOp(N->getOpcode()))
23383     return DAG.getUNDEF(N->getValueType(0));
23384 
23385   // VP Memory operations can be replaced by either the chain (stores) or the
23386   // chain + undef (loads).
23387   if (const auto *MemSD = dyn_cast<MemSDNode>(N)) {
23388     if (MemSD->writeMem())
23389       return MemSD->getChain();
23390     return CombineTo(N, DAG.getUNDEF(N->getValueType(0)), MemSD->getChain());
23391   }
23392 
23393   // Reduction operations return the start operand when no elements are active.
23394   if (ISD::isVPReduction(N->getOpcode()))
23395     return N->getOperand(0);
23396 
23397   return SDValue();
23398 }
23399 
23400 /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
23401 /// with the destination vector and a zero vector.
23402 /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
23403 ///      vector_shuffle V, Zero, <0, 4, 2, 4>
XformToShuffleWithZero(SDNode * N)23404 SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
23405   assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
23406 
23407   EVT VT = N->getValueType(0);
23408   SDValue LHS = N->getOperand(0);
23409   SDValue RHS = peekThroughBitcasts(N->getOperand(1));
23410   SDLoc DL(N);
23411 
23412   // Make sure we're not running after operation legalization where it
23413   // may have custom lowered the vector shuffles.
23414   if (LegalOperations)
23415     return SDValue();
23416 
23417   if (RHS.getOpcode() != ISD::BUILD_VECTOR)
23418     return SDValue();
23419 
23420   EVT RVT = RHS.getValueType();
23421   unsigned NumElts = RHS.getNumOperands();
23422 
23423   // Attempt to create a valid clear mask, splitting the mask into
23424   // sub elements and checking to see if each is
23425   // all zeros or all ones - suitable for shuffle masking.
23426   auto BuildClearMask = [&](int Split) {
23427     int NumSubElts = NumElts * Split;
23428     int NumSubBits = RVT.getScalarSizeInBits() / Split;
23429 
23430     SmallVector<int, 8> Indices;
23431     for (int i = 0; i != NumSubElts; ++i) {
23432       int EltIdx = i / Split;
23433       int SubIdx = i % Split;
23434       SDValue Elt = RHS.getOperand(EltIdx);
23435       // X & undef --> 0 (not undef). So this lane must be converted to choose
23436       // from the zero constant vector (same as if the element had all 0-bits).
23437       if (Elt.isUndef()) {
23438         Indices.push_back(i + NumSubElts);
23439         continue;
23440       }
23441 
23442       APInt Bits;
23443       if (isa<ConstantSDNode>(Elt))
23444         Bits = cast<ConstantSDNode>(Elt)->getAPIntValue();
23445       else if (isa<ConstantFPSDNode>(Elt))
23446         Bits = cast<ConstantFPSDNode>(Elt)->getValueAPF().bitcastToAPInt();
23447       else
23448         return SDValue();
23449 
23450       // Extract the sub element from the constant bit mask.
23451       if (DAG.getDataLayout().isBigEndian())
23452         Bits = Bits.extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits);
23453       else
23454         Bits = Bits.extractBits(NumSubBits, SubIdx * NumSubBits);
23455 
23456       if (Bits.isAllOnes())
23457         Indices.push_back(i);
23458       else if (Bits == 0)
23459         Indices.push_back(i + NumSubElts);
23460       else
23461         return SDValue();
23462     }
23463 
23464     // Let's see if the target supports this vector_shuffle.
23465     EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits);
23466     EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts);
23467     if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
23468       return SDValue();
23469 
23470     SDValue Zero = DAG.getConstant(0, DL, ClearVT);
23471     return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
23472                                                    DAG.getBitcast(ClearVT, LHS),
23473                                                    Zero, Indices));
23474   };
23475 
23476   // Determine maximum split level (byte level masking).
23477   int MaxSplit = 1;
23478   if (RVT.getScalarSizeInBits() % 8 == 0)
23479     MaxSplit = RVT.getScalarSizeInBits() / 8;
23480 
23481   for (int Split = 1; Split <= MaxSplit; ++Split)
23482     if (RVT.getScalarSizeInBits() % Split == 0)
23483       if (SDValue S = BuildClearMask(Split))
23484         return S;
23485 
23486   return SDValue();
23487 }
23488 
23489 /// If a vector binop is performed on splat values, it may be profitable to
23490 /// extract, scalarize, and insert/splat.
scalarizeBinOpOfSplats(SDNode * N,SelectionDAG & DAG,const SDLoc & DL)23491 static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
23492                                       const SDLoc &DL) {
23493   SDValue N0 = N->getOperand(0);
23494   SDValue N1 = N->getOperand(1);
23495   unsigned Opcode = N->getOpcode();
23496   EVT VT = N->getValueType(0);
23497   EVT EltVT = VT.getVectorElementType();
23498   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23499 
23500   // TODO: Remove/replace the extract cost check? If the elements are available
23501   //       as scalars, then there may be no extract cost. Should we ask if
23502   //       inserting a scalar back into a vector is cheap instead?
23503   int Index0, Index1;
23504   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
23505   SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
23506   // Extract element from splat_vector should be free.
23507   // TODO: use DAG.isSplatValue instead?
23508   bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
23509                            N1.getOpcode() == ISD::SPLAT_VECTOR;
23510   if (!Src0 || !Src1 || Index0 != Index1 ||
23511       Src0.getValueType().getVectorElementType() != EltVT ||
23512       Src1.getValueType().getVectorElementType() != EltVT ||
23513       !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index0)) ||
23514       !TLI.isOperationLegalOrCustom(Opcode, EltVT))
23515     return SDValue();
23516 
23517   SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
23518   SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src0, IndexC);
23519   SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src1, IndexC);
23520   SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
23521 
23522   // If all lanes but 1 are undefined, no need to splat the scalar result.
23523   // TODO: Keep track of undefs and use that info in the general case.
23524   if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
23525       count_if(N0->ops(), [](SDValue V) { return !V.isUndef(); }) == 1 &&
23526       count_if(N1->ops(), [](SDValue V) { return !V.isUndef(); }) == 1) {
23527     // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
23528     // build_vec ..undef, (bo X, Y), undef...
23529     SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(EltVT));
23530     Ops[Index0] = ScalarBO;
23531     return DAG.getBuildVector(VT, DL, Ops);
23532   }
23533 
23534   // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
23535   if (VT.isScalableVector())
23536     return DAG.getSplatVector(VT, DL, ScalarBO);
23537   SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
23538   return DAG.getBuildVector(VT, DL, Ops);
23539 }
23540 
23541 /// Visit a binary vector operation, like ADD.
SimplifyVBinOp(SDNode * N,const SDLoc & DL)23542 SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
23543   EVT VT = N->getValueType(0);
23544   assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
23545 
23546   SDValue LHS = N->getOperand(0);
23547   SDValue RHS = N->getOperand(1);
23548   unsigned Opcode = N->getOpcode();
23549   SDNodeFlags Flags = N->getFlags();
23550 
23551   // Move unary shuffles with identical masks after a vector binop:
23552   // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
23553   //   --> shuffle (VBinOp A, B), Undef, Mask
23554   // This does not require type legality checks because we are creating the
23555   // same types of operations that are in the original sequence. We do have to
23556   // restrict ops like integer div that have immediate UB (eg, div-by-zero)
23557   // though. This code is adapted from the identical transform in instcombine.
23558   if (Opcode != ISD::UDIV && Opcode != ISD::SDIV &&
23559       Opcode != ISD::UREM && Opcode != ISD::SREM &&
23560       Opcode != ISD::UDIVREM && Opcode != ISD::SDIVREM) {
23561     auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
23562     auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
23563     if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
23564         LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
23565         (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
23566       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0),
23567                                      RHS.getOperand(0), Flags);
23568       SDValue UndefV = LHS.getOperand(1);
23569       return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
23570     }
23571 
23572     // Try to sink a splat shuffle after a binop with a uniform constant.
23573     // This is limited to cases where neither the shuffle nor the constant have
23574     // undefined elements because that could be poison-unsafe or inhibit
23575     // demanded elements analysis. It is further limited to not change a splat
23576     // of an inserted scalar because that may be optimized better by
23577     // load-folding or other target-specific behaviors.
23578     if (isConstOrConstSplat(RHS) && Shuf0 && is_splat(Shuf0->getMask()) &&
23579         Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
23580         Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
23581       // binop (splat X), (splat C) --> splat (binop X, C)
23582       SDValue X = Shuf0->getOperand(0);
23583       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, X, RHS, Flags);
23584       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
23585                                   Shuf0->getMask());
23586     }
23587     if (isConstOrConstSplat(LHS) && Shuf1 && is_splat(Shuf1->getMask()) &&
23588         Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
23589         Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
23590       // binop (splat C), (splat X) --> splat (binop C, X)
23591       SDValue X = Shuf1->getOperand(0);
23592       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS, X, Flags);
23593       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
23594                                   Shuf1->getMask());
23595     }
23596   }
23597 
23598   // The following pattern is likely to emerge with vector reduction ops. Moving
23599   // the binary operation ahead of insertion may allow using a narrower vector
23600   // instruction that has better performance than the wide version of the op:
23601   // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
23602   if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(0).isUndef() &&
23603       RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(0).isUndef() &&
23604       LHS.getOperand(2) == RHS.getOperand(2) &&
23605       (LHS.hasOneUse() || RHS.hasOneUse())) {
23606     SDValue X = LHS.getOperand(1);
23607     SDValue Y = RHS.getOperand(1);
23608     SDValue Z = LHS.getOperand(2);
23609     EVT NarrowVT = X.getValueType();
23610     if (NarrowVT == Y.getValueType() &&
23611         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT,
23612                                               LegalOperations)) {
23613       // (binop undef, undef) may not return undef, so compute that result.
23614       SDValue VecC =
23615           DAG.getNode(Opcode, DL, VT, DAG.getUNDEF(VT), DAG.getUNDEF(VT));
23616       SDValue NarrowBO = DAG.getNode(Opcode, DL, NarrowVT, X, Y);
23617       return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, VecC, NarrowBO, Z);
23618     }
23619   }
23620 
23621   // Make sure all but the first op are undef or constant.
23622   auto ConcatWithConstantOrUndef = [](SDValue Concat) {
23623     return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
23624            all_of(drop_begin(Concat->ops()), [](const SDValue &Op) {
23625              return Op.isUndef() ||
23626                     ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
23627            });
23628   };
23629 
23630   // The following pattern is likely to emerge with vector reduction ops. Moving
23631   // the binary operation ahead of the concat may allow using a narrower vector
23632   // instruction that has better performance than the wide version of the op:
23633   // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
23634   //   concat (VBinOp X, Y), VecC
23635   if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
23636       (LHS.hasOneUse() || RHS.hasOneUse())) {
23637     EVT NarrowVT = LHS.getOperand(0).getValueType();
23638     if (NarrowVT == RHS.getOperand(0).getValueType() &&
23639         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
23640       unsigned NumOperands = LHS.getNumOperands();
23641       SmallVector<SDValue, 4> ConcatOps;
23642       for (unsigned i = 0; i != NumOperands; ++i) {
23643         // This constant fold for operands 1 and up.
23644         ConcatOps.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i),
23645                                         RHS.getOperand(i)));
23646       }
23647 
23648       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
23649     }
23650   }
23651 
23652   if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL))
23653     return V;
23654 
23655   return SDValue();
23656 }
23657 
SimplifySelect(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2)23658 SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
23659                                     SDValue N2) {
23660   assert(N0.getOpcode() ==ISD::SETCC && "First argument must be a SetCC node!");
23661 
23662   SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
23663                                  cast<CondCodeSDNode>(N0.getOperand(2))->get());
23664 
23665   // If we got a simplified select_cc node back from SimplifySelectCC, then
23666   // break it down into a new SETCC node, and a new SELECT node, and then return
23667   // the SELECT node, since we were called with a SELECT node.
23668   if (SCC.getNode()) {
23669     // Check to see if we got a select_cc back (to turn into setcc/select).
23670     // Otherwise, just return whatever node we got back, like fabs.
23671     if (SCC.getOpcode() == ISD::SELECT_CC) {
23672       const SDNodeFlags Flags = N0->getFlags();
23673       SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0),
23674                                   N0.getValueType(),
23675                                   SCC.getOperand(0), SCC.getOperand(1),
23676                                   SCC.getOperand(4), Flags);
23677       AddToWorklist(SETCC.getNode());
23678       SDValue SelectNode = DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC,
23679                                          SCC.getOperand(2), SCC.getOperand(3));
23680       SelectNode->setFlags(Flags);
23681       return SelectNode;
23682     }
23683 
23684     return SCC;
23685   }
23686   return SDValue();
23687 }
23688 
23689 /// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
23690 /// being selected between, see if we can simplify the select.  Callers of this
23691 /// should assume that TheSelect is deleted if this returns true.  As such, they
23692 /// should return the appropriate thing (e.g. the node) back to the top-level of
23693 /// the DAG combiner loop to avoid it being looked at.
SimplifySelectOps(SDNode * TheSelect,SDValue LHS,SDValue RHS)23694 bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
23695                                     SDValue RHS) {
23696   // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
23697   // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
23698   if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
23699     if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
23700       // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
23701       SDValue Sqrt = RHS;
23702       ISD::CondCode CC;
23703       SDValue CmpLHS;
23704       const ConstantFPSDNode *Zero = nullptr;
23705 
23706       if (TheSelect->getOpcode() == ISD::SELECT_CC) {
23707         CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
23708         CmpLHS = TheSelect->getOperand(0);
23709         Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
23710       } else {
23711         // SELECT or VSELECT
23712         SDValue Cmp = TheSelect->getOperand(0);
23713         if (Cmp.getOpcode() == ISD::SETCC) {
23714           CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
23715           CmpLHS = Cmp.getOperand(0);
23716           Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
23717         }
23718       }
23719       if (Zero && Zero->isZero() &&
23720           Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
23721           CC == ISD::SETULT || CC == ISD::SETLT)) {
23722         // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
23723         CombineTo(TheSelect, Sqrt);
23724         return true;
23725       }
23726     }
23727   }
23728   // Cannot simplify select with vector condition
23729   if (TheSelect->getOperand(0).getValueType().isVector()) return false;
23730 
23731   // If this is a select from two identical things, try to pull the operation
23732   // through the select.
23733   if (LHS.getOpcode() != RHS.getOpcode() ||
23734       !LHS.hasOneUse() || !RHS.hasOneUse())
23735     return false;
23736 
23737   // If this is a load and the token chain is identical, replace the select
23738   // of two loads with a load through a select of the address to load from.
23739   // This triggers in things like "select bool X, 10.0, 123.0" after the FP
23740   // constants have been dropped into the constant pool.
23741   if (LHS.getOpcode() == ISD::LOAD) {
23742     LoadSDNode *LLD = cast<LoadSDNode>(LHS);
23743     LoadSDNode *RLD = cast<LoadSDNode>(RHS);
23744 
23745     // Token chains must be identical.
23746     if (LHS.getOperand(0) != RHS.getOperand(0) ||
23747         // Do not let this transformation reduce the number of volatile loads.
23748         // Be conservative for atomics for the moment
23749         // TODO: This does appear to be legal for unordered atomics (see D66309)
23750         !LLD->isSimple() || !RLD->isSimple() ||
23751         // FIXME: If either is a pre/post inc/dec load,
23752         // we'd need to split out the address adjustment.
23753         LLD->isIndexed() || RLD->isIndexed() ||
23754         // If this is an EXTLOAD, the VT's must match.
23755         LLD->getMemoryVT() != RLD->getMemoryVT() ||
23756         // If this is an EXTLOAD, the kind of extension must match.
23757         (LLD->getExtensionType() != RLD->getExtensionType() &&
23758          // The only exception is if one of the extensions is anyext.
23759          LLD->getExtensionType() != ISD::EXTLOAD &&
23760          RLD->getExtensionType() != ISD::EXTLOAD) ||
23761         // FIXME: this discards src value information.  This is
23762         // over-conservative. It would be beneficial to be able to remember
23763         // both potential memory locations.  Since we are discarding
23764         // src value info, don't do the transformation if the memory
23765         // locations are not in the default address space.
23766         LLD->getPointerInfo().getAddrSpace() != 0 ||
23767         RLD->getPointerInfo().getAddrSpace() != 0 ||
23768         // We can't produce a CMOV of a TargetFrameIndex since we won't
23769         // generate the address generation required.
23770         LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
23771         RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
23772         !TLI.isOperationLegalOrCustom(TheSelect->getOpcode(),
23773                                       LLD->getBasePtr().getValueType()))
23774       return false;
23775 
23776     // The loads must not depend on one another.
23777     if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
23778       return false;
23779 
23780     // Check that the select condition doesn't reach either load.  If so,
23781     // folding this will induce a cycle into the DAG.  If not, this is safe to
23782     // xform, so create a select of the addresses.
23783 
23784     SmallPtrSet<const SDNode *, 32> Visited;
23785     SmallVector<const SDNode *, 16> Worklist;
23786 
23787     // Always fail if LLD and RLD are not independent. TheSelect is a
23788     // predecessor to all Nodes in question so we need not search past it.
23789 
23790     Visited.insert(TheSelect);
23791     Worklist.push_back(LLD);
23792     Worklist.push_back(RLD);
23793 
23794     if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
23795         SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
23796       return false;
23797 
23798     SDValue Addr;
23799     if (TheSelect->getOpcode() == ISD::SELECT) {
23800       // We cannot do this optimization if any pair of {RLD, LLD} is a
23801       // predecessor to {RLD, LLD, CondNode}. As we've already compared the
23802       // Loads, we only need to check if CondNode is a successor to one of the
23803       // loads. We can further avoid this if there's no use of their chain
23804       // value.
23805       SDNode *CondNode = TheSelect->getOperand(0).getNode();
23806       Worklist.push_back(CondNode);
23807 
23808       if ((LLD->hasAnyUseOfValue(1) &&
23809            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
23810           (RLD->hasAnyUseOfValue(1) &&
23811            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
23812         return false;
23813 
23814       Addr = DAG.getSelect(SDLoc(TheSelect),
23815                            LLD->getBasePtr().getValueType(),
23816                            TheSelect->getOperand(0), LLD->getBasePtr(),
23817                            RLD->getBasePtr());
23818     } else {  // Otherwise SELECT_CC
23819       // We cannot do this optimization if any pair of {RLD, LLD} is a
23820       // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
23821       // the Loads, we only need to check if CondLHS/CondRHS is a successor to
23822       // one of the loads. We can further avoid this if there's no use of their
23823       // chain value.
23824 
23825       SDNode *CondLHS = TheSelect->getOperand(0).getNode();
23826       SDNode *CondRHS = TheSelect->getOperand(1).getNode();
23827       Worklist.push_back(CondLHS);
23828       Worklist.push_back(CondRHS);
23829 
23830       if ((LLD->hasAnyUseOfValue(1) &&
23831            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
23832           (RLD->hasAnyUseOfValue(1) &&
23833            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
23834         return false;
23835 
23836       Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
23837                          LLD->getBasePtr().getValueType(),
23838                          TheSelect->getOperand(0),
23839                          TheSelect->getOperand(1),
23840                          LLD->getBasePtr(), RLD->getBasePtr(),
23841                          TheSelect->getOperand(4));
23842     }
23843 
23844     SDValue Load;
23845     // It is safe to replace the two loads if they have different alignments,
23846     // but the new load must be the minimum (most restrictive) alignment of the
23847     // inputs.
23848     Align Alignment = std::min(LLD->getAlign(), RLD->getAlign());
23849     MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
23850     if (!RLD->isInvariant())
23851       MMOFlags &= ~MachineMemOperand::MOInvariant;
23852     if (!RLD->isDereferenceable())
23853       MMOFlags &= ~MachineMemOperand::MODereferenceable;
23854     if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
23855       // FIXME: Discards pointer and AA info.
23856       Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
23857                          LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
23858                          MMOFlags);
23859     } else {
23860       // FIXME: Discards pointer and AA info.
23861       Load = DAG.getExtLoad(
23862           LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
23863                                                   : LLD->getExtensionType(),
23864           SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
23865           MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
23866     }
23867 
23868     // Users of the select now use the result of the load.
23869     CombineTo(TheSelect, Load);
23870 
23871     // Users of the old loads now use the new load's chain.  We know the
23872     // old-load value is dead now.
23873     CombineTo(LHS.getNode(), Load.getValue(0), Load.getValue(1));
23874     CombineTo(RHS.getNode(), Load.getValue(0), Load.getValue(1));
23875     return true;
23876   }
23877 
23878   return false;
23879 }
23880 
23881 /// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
23882 /// bitwise 'and'.
foldSelectCCToShiftAnd(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)23883 SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
23884                                             SDValue N1, SDValue N2, SDValue N3,
23885                                             ISD::CondCode CC) {
23886   // If this is a select where the false operand is zero and the compare is a
23887   // check of the sign bit, see if we can perform the "gzip trick":
23888   // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
23889   // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
23890   EVT XType = N0.getValueType();
23891   EVT AType = N2.getValueType();
23892   if (!isNullConstant(N3) || !XType.bitsGE(AType))
23893     return SDValue();
23894 
23895   // If the comparison is testing for a positive value, we have to invert
23896   // the sign bit mask, so only do that transform if the target has a bitwise
23897   // 'and not' instruction (the invert is free).
23898   if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
23899     // (X > -1) ? A : 0
23900     // (X >  0) ? X : 0 <-- This is canonical signed max.
23901     if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
23902       return SDValue();
23903   } else if (CC == ISD::SETLT) {
23904     // (X <  0) ? A : 0
23905     // (X <  1) ? X : 0 <-- This is un-canonicalized signed min.
23906     if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
23907       return SDValue();
23908   } else {
23909     return SDValue();
23910   }
23911 
23912   // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
23913   // constant.
23914   EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
23915   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
23916   if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
23917     unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
23918     if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
23919       SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
23920       SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
23921       AddToWorklist(Shift.getNode());
23922 
23923       if (XType.bitsGT(AType)) {
23924         Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
23925         AddToWorklist(Shift.getNode());
23926       }
23927 
23928       if (CC == ISD::SETGT)
23929         Shift = DAG.getNOT(DL, Shift, AType);
23930 
23931       return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
23932     }
23933   }
23934 
23935   unsigned ShCt = XType.getSizeInBits() - 1;
23936   if (TLI.shouldAvoidTransformToShift(XType, ShCt))
23937     return SDValue();
23938 
23939   SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
23940   SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
23941   AddToWorklist(Shift.getNode());
23942 
23943   if (XType.bitsGT(AType)) {
23944     Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
23945     AddToWorklist(Shift.getNode());
23946   }
23947 
23948   if (CC == ISD::SETGT)
23949     Shift = DAG.getNOT(DL, Shift, AType);
23950 
23951   return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
23952 }
23953 
23954 // Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
foldSelectOfBinops(SDNode * N)23955 SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
23956   SDValue N0 = N->getOperand(0);
23957   SDValue N1 = N->getOperand(1);
23958   SDValue N2 = N->getOperand(2);
23959   EVT VT = N->getValueType(0);
23960   SDLoc DL(N);
23961 
23962   unsigned BinOpc = N1.getOpcode();
23963   if (!TLI.isBinOp(BinOpc) || (N2.getOpcode() != BinOpc))
23964     return SDValue();
23965 
23966   // The use checks are intentionally on SDNode because we may be dealing
23967   // with opcodes that produce more than one SDValue.
23968   // TODO: Do we really need to check N0 (the condition operand of the select)?
23969   //       But removing that clause could cause an infinite loop...
23970   if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
23971     return SDValue();
23972 
23973   // Binops may include opcodes that return multiple values, so all values
23974   // must be created/propagated from the newly created binops below.
23975   SDVTList OpVTs = N1->getVTList();
23976 
23977   // Fold select(cond, binop(x, y), binop(z, y))
23978   //  --> binop(select(cond, x, z), y)
23979   if (N1.getOperand(1) == N2.getOperand(1)) {
23980     SDValue NewSel =
23981         DAG.getSelect(DL, VT, N0, N1.getOperand(0), N2.getOperand(0));
23982     SDValue NewBinOp = DAG.getNode(BinOpc, DL, OpVTs, NewSel, N1.getOperand(1));
23983     NewBinOp->setFlags(N1->getFlags());
23984     NewBinOp->intersectFlagsWith(N2->getFlags());
23985     return NewBinOp;
23986   }
23987 
23988   // Fold select(cond, binop(x, y), binop(x, z))
23989   //  --> binop(x, select(cond, y, z))
23990   // Second op VT might be different (e.g. shift amount type)
23991   if (N1.getOperand(0) == N2.getOperand(0) &&
23992       VT == N1.getOperand(1).getValueType() &&
23993       VT == N2.getOperand(1).getValueType()) {
23994     SDValue NewSel =
23995         DAG.getSelect(DL, VT, N0, N1.getOperand(1), N2.getOperand(1));
23996     SDValue NewBinOp = DAG.getNode(BinOpc, DL, OpVTs, N1.getOperand(0), NewSel);
23997     NewBinOp->setFlags(N1->getFlags());
23998     NewBinOp->intersectFlagsWith(N2->getFlags());
23999     return NewBinOp;
24000   }
24001 
24002   // TODO: Handle isCommutativeBinOp patterns as well?
24003   return SDValue();
24004 }
24005 
24006 // Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
foldSignChangeInBitcast(SDNode * N)24007 SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
24008   SDValue N0 = N->getOperand(0);
24009   EVT VT = N->getValueType(0);
24010   bool IsFabs = N->getOpcode() == ISD::FABS;
24011   bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
24012 
24013   if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
24014     return SDValue();
24015 
24016   SDValue Int = N0.getOperand(0);
24017   EVT IntVT = Int.getValueType();
24018 
24019   // The operand to cast should be integer.
24020   if (!IntVT.isInteger() || IntVT.isVector())
24021     return SDValue();
24022 
24023   // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
24024   // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
24025   APInt SignMask;
24026   if (N0.getValueType().isVector()) {
24027     // For vector, create a sign mask (0x80...) or its inverse (for fabs,
24028     // 0x7f...) per element and splat it.
24029     SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
24030     if (IsFabs)
24031       SignMask = ~SignMask;
24032     SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
24033   } else {
24034     // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
24035     SignMask = APInt::getSignMask(IntVT.getSizeInBits());
24036     if (IsFabs)
24037       SignMask = ~SignMask;
24038   }
24039   SDLoc DL(N0);
24040   Int = DAG.getNode(IsFabs ? ISD::AND : ISD::XOR, DL, IntVT, Int,
24041                     DAG.getConstant(SignMask, DL, IntVT));
24042   AddToWorklist(Int.getNode());
24043   return DAG.getBitcast(VT, Int);
24044 }
24045 
24046 /// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
24047 /// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
24048 /// in it. This may be a win when the constant is not otherwise available
24049 /// because it replaces two constant pool loads with one.
convertSelectOfFPConstantsToLoadOffset(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)24050 SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
24051     const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
24052     ISD::CondCode CC) {
24053   if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType()))
24054     return SDValue();
24055 
24056   // If we are before legalize types, we want the other legalization to happen
24057   // first (for example, to avoid messing with soft float).
24058   auto *TV = dyn_cast<ConstantFPSDNode>(N2);
24059   auto *FV = dyn_cast<ConstantFPSDNode>(N3);
24060   EVT VT = N2.getValueType();
24061   if (!TV || !FV || !TLI.isTypeLegal(VT))
24062     return SDValue();
24063 
24064   // If a constant can be materialized without loads, this does not make sense.
24065   if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
24066       TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0), ForCodeSize) ||
24067       TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0), ForCodeSize))
24068     return SDValue();
24069 
24070   // If both constants have multiple uses, then we won't need to do an extra
24071   // load. The values are likely around in registers for other users.
24072   if (!TV->hasOneUse() && !FV->hasOneUse())
24073     return SDValue();
24074 
24075   Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
24076                        const_cast<ConstantFP*>(TV->getConstantFPValue()) };
24077   Type *FPTy = Elts[0]->getType();
24078   const DataLayout &TD = DAG.getDataLayout();
24079 
24080   // Create a ConstantArray of the two constants.
24081   Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
24082   SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
24083                                       TD.getPrefTypeAlign(FPTy));
24084   Align Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlign();
24085 
24086   // Get offsets to the 0 and 1 elements of the array, so we can select between
24087   // them.
24088   SDValue Zero = DAG.getIntPtrConstant(0, DL);
24089   unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
24090   SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
24091   SDValue Cond =
24092       DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
24093   AddToWorklist(Cond.getNode());
24094   SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
24095   AddToWorklist(CstOffset.getNode());
24096   CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
24097   AddToWorklist(CPIdx.getNode());
24098   return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
24099                      MachinePointerInfo::getConstantPool(
24100                          DAG.getMachineFunction()), Alignment);
24101 }
24102 
24103 /// Simplify an expression of the form (N0 cond N1) ? N2 : N3
24104 /// where 'cond' is the comparison specified by CC.
SimplifySelectCC(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,bool NotExtCompare)24105 SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
24106                                       SDValue N2, SDValue N3, ISD::CondCode CC,
24107                                       bool NotExtCompare) {
24108   // (x ? y : y) -> y.
24109   if (N2 == N3) return N2;
24110 
24111   EVT CmpOpVT = N0.getValueType();
24112   EVT CmpResVT = getSetCCResultType(CmpOpVT);
24113   EVT VT = N2.getValueType();
24114   auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
24115   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
24116   auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
24117 
24118   // Determine if the condition we're dealing with is constant.
24119   if (SDValue SCC = DAG.FoldSetCC(CmpResVT, N0, N1, CC, DL)) {
24120     AddToWorklist(SCC.getNode());
24121     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC)) {
24122       // fold select_cc true, x, y -> x
24123       // fold select_cc false, x, y -> y
24124       return !(SCCC->isZero()) ? N2 : N3;
24125     }
24126   }
24127 
24128   if (SDValue V =
24129           convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
24130     return V;
24131 
24132   if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
24133     return V;
24134 
24135   // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (shr (shl x)) A)
24136   // where y is has a single bit set.
24137   // A plaintext description would be, we can turn the SELECT_CC into an AND
24138   // when the condition can be materialized as an all-ones register.  Any
24139   // single bit-test can be materialized as an all-ones register with
24140   // shift-left and shift-right-arith.
24141   if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
24142       N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
24143     SDValue AndLHS = N0->getOperand(0);
24144     auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
24145     if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) {
24146       // Shift the tested bit over the sign bit.
24147       const APInt &AndMask = ConstAndRHS->getAPIntValue();
24148       unsigned ShCt = AndMask.getBitWidth() - 1;
24149       if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
24150         SDValue ShlAmt =
24151           DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS),
24152                           getShiftAmountTy(AndLHS.getValueType()));
24153         SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
24154 
24155         // Now arithmetic right shift it all the way over, so the result is
24156         // either all-ones, or zero.
24157         SDValue ShrAmt =
24158           DAG.getConstant(ShCt, SDLoc(Shl),
24159                           getShiftAmountTy(Shl.getValueType()));
24160         SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
24161 
24162         return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
24163       }
24164     }
24165   }
24166 
24167   // fold select C, 16, 0 -> shl C, 4
24168   bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
24169   bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
24170 
24171   if ((Fold || Swap) &&
24172       TLI.getBooleanContents(CmpOpVT) ==
24173           TargetLowering::ZeroOrOneBooleanContent &&
24174       (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) {
24175 
24176     if (Swap) {
24177       CC = ISD::getSetCCInverse(CC, CmpOpVT);
24178       std::swap(N2C, N3C);
24179     }
24180 
24181     // If the caller doesn't want us to simplify this into a zext of a compare,
24182     // don't do it.
24183     if (NotExtCompare && N2C->isOne())
24184       return SDValue();
24185 
24186     SDValue Temp, SCC;
24187     // zext (setcc n0, n1)
24188     if (LegalTypes) {
24189       SCC = DAG.getSetCC(DL, CmpResVT, N0, N1, CC);
24190       if (VT.bitsLT(SCC.getValueType()))
24191         Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), VT);
24192       else
24193         Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
24194     } else {
24195       SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
24196       Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
24197     }
24198 
24199     AddToWorklist(SCC.getNode());
24200     AddToWorklist(Temp.getNode());
24201 
24202     if (N2C->isOne())
24203       return Temp;
24204 
24205     unsigned ShCt = N2C->getAPIntValue().logBase2();
24206     if (TLI.shouldAvoidTransformToShift(VT, ShCt))
24207       return SDValue();
24208 
24209     // shl setcc result by log2 n2c
24210     return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
24211                        DAG.getConstant(ShCt, SDLoc(Temp),
24212                                        getShiftAmountTy(Temp.getValueType())));
24213   }
24214 
24215   // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
24216   // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
24217   // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
24218   // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
24219   // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
24220   // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
24221   // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
24222   // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
24223   if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
24224     SDValue ValueOnZero = N2;
24225     SDValue Count = N3;
24226     // If the condition is NE instead of E, swap the operands.
24227     if (CC == ISD::SETNE)
24228       std::swap(ValueOnZero, Count);
24229     // Check if the value on zero is a constant equal to the bits in the type.
24230     if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
24231       if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
24232         // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
24233         // legal, combine to just cttz.
24234         if ((Count.getOpcode() == ISD::CTTZ ||
24235              Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
24236             N0 == Count.getOperand(0) &&
24237             (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
24238           return DAG.getNode(ISD::CTTZ, DL, VT, N0);
24239         // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
24240         // legal, combine to just ctlz.
24241         if ((Count.getOpcode() == ISD::CTLZ ||
24242              Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
24243             N0 == Count.getOperand(0) &&
24244             (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
24245           return DAG.getNode(ISD::CTLZ, DL, VT, N0);
24246       }
24247     }
24248   }
24249 
24250   // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
24251   // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
24252   if (!NotExtCompare && N1C && N2C && N3C &&
24253       N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
24254       ((N1C->isAllOnes() && CC == ISD::SETGT) ||
24255        (N1C->isZero() && CC == ISD::SETLT)) &&
24256       !TLI.shouldAvoidTransformToShift(VT, CmpOpVT.getScalarSizeInBits() - 1)) {
24257     SDValue ASR = DAG.getNode(
24258         ISD::SRA, DL, CmpOpVT, N0,
24259         DAG.getConstant(CmpOpVT.getScalarSizeInBits() - 1, DL, CmpOpVT));
24260     return DAG.getNode(ISD::XOR, DL, VT, DAG.getSExtOrTrunc(ASR, DL, VT),
24261                        DAG.getSExtOrTrunc(CC == ISD::SETLT ? N3 : N2, DL, VT));
24262   }
24263 
24264   if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
24265     return S;
24266   if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
24267     return S;
24268 
24269   return SDValue();
24270 }
24271 
24272 /// This is a stub for TargetLowering::SimplifySetCC.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,bool foldBooleans)24273 SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
24274                                    ISD::CondCode Cond, const SDLoc &DL,
24275                                    bool foldBooleans) {
24276   TargetLowering::DAGCombinerInfo
24277     DagCombineInfo(DAG, Level, false, this);
24278   return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
24279 }
24280 
24281 /// Given an ISD::SDIV node expressing a divide by constant, return
24282 /// a DAG expression to select that will generate the same value by multiplying
24283 /// by a magic number.
24284 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N)24285 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
24286   // when optimising for minimum size, we don't want to expand a div to a mul
24287   // and a shift.
24288   if (DAG.getMachineFunction().getFunction().hasMinSize())
24289     return SDValue();
24290 
24291   SmallVector<SDNode *, 8> Built;
24292   if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) {
24293     for (SDNode *N : Built)
24294       AddToWorklist(N);
24295     return S;
24296   }
24297 
24298   return SDValue();
24299 }
24300 
24301 /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
24302 /// DAG expression that will generate the same value by right shifting.
BuildSDIVPow2(SDNode * N)24303 SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
24304   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
24305   if (!C)
24306     return SDValue();
24307 
24308   // Avoid division by zero.
24309   if (C->isZero())
24310     return SDValue();
24311 
24312   SmallVector<SDNode *, 8> Built;
24313   if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
24314     for (SDNode *N : Built)
24315       AddToWorklist(N);
24316     return S;
24317   }
24318 
24319   return SDValue();
24320 }
24321 
24322 /// Given an ISD::UDIV node expressing a divide by constant, return a DAG
24323 /// expression that will generate the same value by multiplying by a magic
24324 /// number.
24325 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N)24326 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
24327   // when optimising for minimum size, we don't want to expand a div to a mul
24328   // and a shift.
24329   if (DAG.getMachineFunction().getFunction().hasMinSize())
24330     return SDValue();
24331 
24332   SmallVector<SDNode *, 8> Built;
24333   if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) {
24334     for (SDNode *N : Built)
24335       AddToWorklist(N);
24336     return S;
24337   }
24338 
24339   return SDValue();
24340 }
24341 
24342 /// Given an ISD::SREM node expressing a remainder by constant power of 2,
24343 /// return a DAG expression that will generate the same value.
BuildSREMPow2(SDNode * N)24344 SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
24345   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
24346   if (!C)
24347     return SDValue();
24348 
24349   // Avoid division by zero.
24350   if (C->isZero())
24351     return SDValue();
24352 
24353   SmallVector<SDNode *, 8> Built;
24354   if (SDValue S = TLI.BuildSREMPow2(N, C->getAPIntValue(), DAG, Built)) {
24355     for (SDNode *N : Built)
24356       AddToWorklist(N);
24357     return S;
24358   }
24359 
24360   return SDValue();
24361 }
24362 
24363 /// Determines the LogBase2 value for a non-null input value using the
24364 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
BuildLogBase2(SDValue V,const SDLoc & DL)24365 SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) {
24366   EVT VT = V.getValueType();
24367   SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
24368   SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
24369   SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
24370   return LogBase2;
24371 }
24372 
24373 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
24374 /// For the reciprocal, we need to find the zero of the function:
24375 ///   F(X) = 1/X - A [which has a zero at X = 1/A]
24376 ///     =>
24377 ///   X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
24378 ///     does not require additional intermediate precision]
24379 /// For the last iteration, put numerator N into it to gain more precision:
24380 ///   Result = N X_i + X_i (N - N A X_i)
BuildDivEstimate(SDValue N,SDValue Op,SDNodeFlags Flags)24381 SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
24382                                       SDNodeFlags Flags) {
24383   if (LegalDAG)
24384     return SDValue();
24385 
24386   // TODO: Handle extended types?
24387   EVT VT = Op.getValueType();
24388   if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
24389       VT.getScalarType() != MVT::f64)
24390     return SDValue();
24391 
24392   // If estimates are explicitly disabled for this function, we're done.
24393   MachineFunction &MF = DAG.getMachineFunction();
24394   int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
24395   if (Enabled == TLI.ReciprocalEstimate::Disabled)
24396     return SDValue();
24397 
24398   // Estimates may be explicitly enabled for this type with a custom number of
24399   // refinement steps.
24400   int Iterations = TLI.getDivRefinementSteps(VT, MF);
24401   if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
24402     AddToWorklist(Est.getNode());
24403 
24404     SDLoc DL(Op);
24405     if (Iterations) {
24406       SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
24407 
24408       // Newton iterations: Est = Est + Est (N - Arg * Est)
24409       // If this is the last iteration, also multiply by the numerator.
24410       for (int i = 0; i < Iterations; ++i) {
24411         SDValue MulEst = Est;
24412 
24413         if (i == Iterations - 1) {
24414           MulEst = DAG.getNode(ISD::FMUL, DL, VT, N, Est, Flags);
24415           AddToWorklist(MulEst.getNode());
24416         }
24417 
24418         SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, MulEst, Flags);
24419         AddToWorklist(NewEst.getNode());
24420 
24421         NewEst = DAG.getNode(ISD::FSUB, DL, VT,
24422                              (i == Iterations - 1 ? N : FPOne), NewEst, Flags);
24423         AddToWorklist(NewEst.getNode());
24424 
24425         NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
24426         AddToWorklist(NewEst.getNode());
24427 
24428         Est = DAG.getNode(ISD::FADD, DL, VT, MulEst, NewEst, Flags);
24429         AddToWorklist(Est.getNode());
24430       }
24431     } else {
24432       // If no iterations are available, multiply with N.
24433       Est = DAG.getNode(ISD::FMUL, DL, VT, Est, N, Flags);
24434       AddToWorklist(Est.getNode());
24435     }
24436 
24437     return Est;
24438   }
24439 
24440   return SDValue();
24441 }
24442 
24443 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
24444 /// For the reciprocal sqrt, we need to find the zero of the function:
24445 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
24446 ///     =>
24447 ///   X_{i+1} = X_i (1.5 - A X_i^2 / 2)
24448 /// As a result, we precompute A/2 prior to the iteration loop.
buildSqrtNROneConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)24449 SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
24450                                          unsigned Iterations,
24451                                          SDNodeFlags Flags, bool Reciprocal) {
24452   EVT VT = Arg.getValueType();
24453   SDLoc DL(Arg);
24454   SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
24455 
24456   // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
24457   // this entire sequence requires only one FP constant.
24458   SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags);
24459   HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags);
24460 
24461   // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
24462   for (unsigned i = 0; i < Iterations; ++i) {
24463     SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
24464     NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags);
24465     NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags);
24466     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
24467   }
24468 
24469   // If non-reciprocal square root is requested, multiply the result by Arg.
24470   if (!Reciprocal)
24471     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
24472 
24473   return Est;
24474 }
24475 
24476 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
24477 /// For the reciprocal sqrt, we need to find the zero of the function:
24478 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
24479 ///     =>
24480 ///   X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
buildSqrtNRTwoConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)24481 SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
24482                                          unsigned Iterations,
24483                                          SDNodeFlags Flags, bool Reciprocal) {
24484   EVT VT = Arg.getValueType();
24485   SDLoc DL(Arg);
24486   SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
24487   SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
24488 
24489   // This routine must enter the loop below to work correctly
24490   // when (Reciprocal == false).
24491   assert(Iterations > 0);
24492 
24493   // Newton iterations for reciprocal square root:
24494   // E = (E * -0.5) * ((A * E) * E + -3.0)
24495   for (unsigned i = 0; i < Iterations; ++i) {
24496     SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
24497     SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
24498     SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
24499 
24500     // When calculating a square root at the last iteration build:
24501     // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
24502     // (notice a common subexpression)
24503     SDValue LHS;
24504     if (Reciprocal || (i + 1) < Iterations) {
24505       // RSQRT: LHS = (E * -0.5)
24506       LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
24507     } else {
24508       // SQRT: LHS = (A * E) * -0.5
24509       LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
24510     }
24511 
24512     Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
24513   }
24514 
24515   return Est;
24516 }
24517 
24518 /// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
24519 /// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
24520 /// Op can be zero.
buildSqrtEstimateImpl(SDValue Op,SDNodeFlags Flags,bool Reciprocal)24521 SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
24522                                            bool Reciprocal) {
24523   if (LegalDAG)
24524     return SDValue();
24525 
24526   // TODO: Handle extended types?
24527   EVT VT = Op.getValueType();
24528   if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
24529       VT.getScalarType() != MVT::f64)
24530     return SDValue();
24531 
24532   // If estimates are explicitly disabled for this function, we're done.
24533   MachineFunction &MF = DAG.getMachineFunction();
24534   int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
24535   if (Enabled == TLI.ReciprocalEstimate::Disabled)
24536     return SDValue();
24537 
24538   // Estimates may be explicitly enabled for this type with a custom number of
24539   // refinement steps.
24540   int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
24541 
24542   bool UseOneConstNR = false;
24543   if (SDValue Est =
24544       TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
24545                           Reciprocal)) {
24546     AddToWorklist(Est.getNode());
24547 
24548     if (Iterations)
24549       Est = UseOneConstNR
24550             ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
24551             : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
24552     if (!Reciprocal) {
24553       SDLoc DL(Op);
24554       // Try the target specific test first.
24555       SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
24556 
24557       // The estimate is now completely wrong if the input was exactly 0.0 or
24558       // possibly a denormal. Force the answer to 0.0 or value provided by
24559       // target for those cases.
24560       Est = DAG.getNode(
24561           Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
24562           Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
24563     }
24564     return Est;
24565   }
24566 
24567   return SDValue();
24568 }
24569 
buildRsqrtEstimate(SDValue Op,SDNodeFlags Flags)24570 SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
24571   return buildSqrtEstimateImpl(Op, Flags, true);
24572 }
24573 
buildSqrtEstimate(SDValue Op,SDNodeFlags Flags)24574 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
24575   return buildSqrtEstimateImpl(Op, Flags, false);
24576 }
24577 
24578 /// Return true if there is any possibility that the two addresses overlap.
mayAlias(SDNode * Op0,SDNode * Op1) const24579 bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
24580 
24581   struct MemUseCharacteristics {
24582     bool IsVolatile;
24583     bool IsAtomic;
24584     SDValue BasePtr;
24585     int64_t Offset;
24586     Optional<int64_t> NumBytes;
24587     MachineMemOperand *MMO;
24588   };
24589 
24590   auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
24591     if (const auto *LSN = dyn_cast<LSBaseSDNode>(N)) {
24592       int64_t Offset = 0;
24593       if (auto *C = dyn_cast<ConstantSDNode>(LSN->getOffset()))
24594         Offset = (LSN->getAddressingMode() == ISD::PRE_INC)
24595                      ? C->getSExtValue()
24596                      : (LSN->getAddressingMode() == ISD::PRE_DEC)
24597                            ? -1 * C->getSExtValue()
24598                            : 0;
24599       uint64_t Size =
24600           MemoryLocation::getSizeOrUnknown(LSN->getMemoryVT().getStoreSize());
24601       return {LSN->isVolatile(), LSN->isAtomic(), LSN->getBasePtr(),
24602               Offset /*base offset*/,
24603               Optional<int64_t>(Size),
24604               LSN->getMemOperand()};
24605     }
24606     if (const auto *LN = cast<LifetimeSDNode>(N))
24607       return {false /*isVolatile*/, /*isAtomic*/ false, LN->getOperand(1),
24608               (LN->hasOffset()) ? LN->getOffset() : 0,
24609               (LN->hasOffset()) ? Optional<int64_t>(LN->getSize())
24610                                 : Optional<int64_t>(),
24611               (MachineMemOperand *)nullptr};
24612     // Default.
24613     return {false /*isvolatile*/, /*isAtomic*/ false, SDValue(),
24614             (int64_t)0 /*offset*/,
24615             Optional<int64_t>() /*size*/, (MachineMemOperand *)nullptr};
24616   };
24617 
24618   MemUseCharacteristics MUC0 = getCharacteristics(Op0),
24619                         MUC1 = getCharacteristics(Op1);
24620 
24621   // If they are to the same address, then they must be aliases.
24622   if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
24623       MUC0.Offset == MUC1.Offset)
24624     return true;
24625 
24626   // If they are both volatile then they cannot be reordered.
24627   if (MUC0.IsVolatile && MUC1.IsVolatile)
24628     return true;
24629 
24630   // Be conservative about atomics for the moment
24631   // TODO: This is way overconservative for unordered atomics (see D66309)
24632   if (MUC0.IsAtomic && MUC1.IsAtomic)
24633     return true;
24634 
24635   if (MUC0.MMO && MUC1.MMO) {
24636     if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
24637         (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
24638       return false;
24639   }
24640 
24641   // Try to prove that there is aliasing, or that there is no aliasing. Either
24642   // way, we can return now. If nothing can be proved, proceed with more tests.
24643   bool IsAlias;
24644   if (BaseIndexOffset::computeAliasing(Op0, MUC0.NumBytes, Op1, MUC1.NumBytes,
24645                                        DAG, IsAlias))
24646     return IsAlias;
24647 
24648   // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
24649   // either are not known.
24650   if (!MUC0.MMO || !MUC1.MMO)
24651     return true;
24652 
24653   // If one operation reads from invariant memory, and the other may store, they
24654   // cannot alias. These should really be checking the equivalent of mayWrite,
24655   // but it only matters for memory nodes other than load /store.
24656   if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
24657       (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
24658     return false;
24659 
24660   // If we know required SrcValue1 and SrcValue2 have relatively large
24661   // alignment compared to the size and offset of the access, we may be able
24662   // to prove they do not alias. This check is conservative for now to catch
24663   // cases created by splitting vector types, it only works when the offsets are
24664   // multiples of the size of the data.
24665   int64_t SrcValOffset0 = MUC0.MMO->getOffset();
24666   int64_t SrcValOffset1 = MUC1.MMO->getOffset();
24667   Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
24668   Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
24669   auto &Size0 = MUC0.NumBytes;
24670   auto &Size1 = MUC1.NumBytes;
24671   if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
24672       Size0.has_value() && Size1.has_value() && *Size0 == *Size1 &&
24673       OrigAlignment0 > *Size0 && SrcValOffset0 % *Size0 == 0 &&
24674       SrcValOffset1 % *Size1 == 0) {
24675     int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
24676     int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
24677 
24678     // There is no overlap between these relatively aligned accesses of
24679     // similar size. Return no alias.
24680     if ((OffAlign0 + *Size0) <= OffAlign1 || (OffAlign1 + *Size1) <= OffAlign0)
24681       return false;
24682   }
24683 
24684   bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
24685                    ? CombinerGlobalAA
24686                    : DAG.getSubtarget().useAA();
24687 #ifndef NDEBUG
24688   if (CombinerAAOnlyFunc.getNumOccurrences() &&
24689       CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
24690     UseAA = false;
24691 #endif
24692 
24693   if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() && Size0 &&
24694       Size1) {
24695     // Use alias analysis information.
24696     int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
24697     int64_t Overlap0 = *Size0 + SrcValOffset0 - MinOffset;
24698     int64_t Overlap1 = *Size1 + SrcValOffset1 - MinOffset;
24699     if (AA->isNoAlias(
24700             MemoryLocation(MUC0.MMO->getValue(), Overlap0,
24701                            UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
24702             MemoryLocation(MUC1.MMO->getValue(), Overlap1,
24703                            UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
24704       return false;
24705   }
24706 
24707   // Otherwise we have to assume they alias.
24708   return true;
24709 }
24710 
24711 /// Walk up chain skipping non-aliasing memory nodes,
24712 /// looking for aliasing nodes and adding them to the Aliases vector.
GatherAllAliases(SDNode * N,SDValue OriginalChain,SmallVectorImpl<SDValue> & Aliases)24713 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
24714                                    SmallVectorImpl<SDValue> &Aliases) {
24715   SmallVector<SDValue, 8> Chains;     // List of chains to visit.
24716   SmallPtrSet<SDNode *, 16> Visited;  // Visited node set.
24717 
24718   // Get alias information for node.
24719   // TODO: relax aliasing for unordered atomics (see D66309)
24720   const bool IsLoad = isa<LoadSDNode>(N) && cast<LoadSDNode>(N)->isSimple();
24721 
24722   // Starting off.
24723   Chains.push_back(OriginalChain);
24724   unsigned Depth = 0;
24725 
24726   // Attempt to improve chain by a single step
24727   auto ImproveChain = [&](SDValue &C) -> bool {
24728     switch (C.getOpcode()) {
24729     case ISD::EntryToken:
24730       // No need to mark EntryToken.
24731       C = SDValue();
24732       return true;
24733     case ISD::LOAD:
24734     case ISD::STORE: {
24735       // Get alias information for C.
24736       // TODO: Relax aliasing for unordered atomics (see D66309)
24737       bool IsOpLoad = isa<LoadSDNode>(C.getNode()) &&
24738                       cast<LSBaseSDNode>(C.getNode())->isSimple();
24739       if ((IsLoad && IsOpLoad) || !mayAlias(N, C.getNode())) {
24740         // Look further up the chain.
24741         C = C.getOperand(0);
24742         return true;
24743       }
24744       // Alias, so stop here.
24745       return false;
24746     }
24747 
24748     case ISD::CopyFromReg:
24749       // Always forward past past CopyFromReg.
24750       C = C.getOperand(0);
24751       return true;
24752 
24753     case ISD::LIFETIME_START:
24754     case ISD::LIFETIME_END: {
24755       // We can forward past any lifetime start/end that can be proven not to
24756       // alias the memory access.
24757       if (!mayAlias(N, C.getNode())) {
24758         // Look further up the chain.
24759         C = C.getOperand(0);
24760         return true;
24761       }
24762       return false;
24763     }
24764     default:
24765       return false;
24766     }
24767   };
24768 
24769   // Look at each chain and determine if it is an alias.  If so, add it to the
24770   // aliases list.  If not, then continue up the chain looking for the next
24771   // candidate.
24772   while (!Chains.empty()) {
24773     SDValue Chain = Chains.pop_back_val();
24774 
24775     // Don't bother if we've seen Chain before.
24776     if (!Visited.insert(Chain.getNode()).second)
24777       continue;
24778 
24779     // For TokenFactor nodes, look at each operand and only continue up the
24780     // chain until we reach the depth limit.
24781     //
24782     // FIXME: The depth check could be made to return the last non-aliasing
24783     // chain we found before we hit a tokenfactor rather than the original
24784     // chain.
24785     if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
24786       Aliases.clear();
24787       Aliases.push_back(OriginalChain);
24788       return;
24789     }
24790 
24791     if (Chain.getOpcode() == ISD::TokenFactor) {
24792       // We have to check each of the operands of the token factor for "small"
24793       // token factors, so we queue them up.  Adding the operands to the queue
24794       // (stack) in reverse order maintains the original order and increases the
24795       // likelihood that getNode will find a matching token factor (CSE.)
24796       if (Chain.getNumOperands() > 16) {
24797         Aliases.push_back(Chain);
24798         continue;
24799       }
24800       for (unsigned n = Chain.getNumOperands(); n;)
24801         Chains.push_back(Chain.getOperand(--n));
24802       ++Depth;
24803       continue;
24804     }
24805     // Everything else
24806     if (ImproveChain(Chain)) {
24807       // Updated Chain Found, Consider new chain if one exists.
24808       if (Chain.getNode())
24809         Chains.push_back(Chain);
24810       ++Depth;
24811       continue;
24812     }
24813     // No Improved Chain Possible, treat as Alias.
24814     Aliases.push_back(Chain);
24815   }
24816 }
24817 
24818 /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
24819 /// (aliasing node.)
FindBetterChain(SDNode * N,SDValue OldChain)24820 SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
24821   if (OptLevel == CodeGenOpt::None)
24822     return OldChain;
24823 
24824   // Ops for replacing token factor.
24825   SmallVector<SDValue, 8> Aliases;
24826 
24827   // Accumulate all the aliases to this node.
24828   GatherAllAliases(N, OldChain, Aliases);
24829 
24830   // If no operands then chain to entry token.
24831   if (Aliases.size() == 0)
24832     return DAG.getEntryNode();
24833 
24834   // If a single operand then chain to it.  We don't need to revisit it.
24835   if (Aliases.size() == 1)
24836     return Aliases[0];
24837 
24838   // Construct a custom tailored token factor.
24839   return DAG.getTokenFactor(SDLoc(N), Aliases);
24840 }
24841 
24842 namespace {
24843 // TODO: Replace with with std::monostate when we move to C++17.
24844 struct UnitT { } Unit;
operator ==(const UnitT &,const UnitT &)24845 bool operator==(const UnitT &, const UnitT &) { return true; }
operator !=(const UnitT &,const UnitT &)24846 bool operator!=(const UnitT &, const UnitT &) { return false; }
24847 } // namespace
24848 
24849 // This function tries to collect a bunch of potentially interesting
24850 // nodes to improve the chains of, all at once. This might seem
24851 // redundant, as this function gets called when visiting every store
24852 // node, so why not let the work be done on each store as it's visited?
24853 //
24854 // I believe this is mainly important because mergeConsecutiveStores
24855 // is unable to deal with merging stores of different sizes, so unless
24856 // we improve the chains of all the potential candidates up-front
24857 // before running mergeConsecutiveStores, it might only see some of
24858 // the nodes that will eventually be candidates, and then not be able
24859 // to go from a partially-merged state to the desired final
24860 // fully-merged state.
24861 
parallelizeChainedStores(StoreSDNode * St)24862 bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
24863   SmallVector<StoreSDNode *, 8> ChainedStores;
24864   StoreSDNode *STChain = St;
24865   // Intervals records which offsets from BaseIndex have been covered. In
24866   // the common case, every store writes to the immediately previous address
24867   // space and thus merged with the previous interval at insertion time.
24868 
24869   using IMap =
24870       llvm::IntervalMap<int64_t, UnitT, 8, IntervalMapHalfOpenInfo<int64_t>>;
24871   IMap::Allocator A;
24872   IMap Intervals(A);
24873 
24874   // This holds the base pointer, index, and the offset in bytes from the base
24875   // pointer.
24876   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
24877 
24878   // We must have a base and an offset.
24879   if (!BasePtr.getBase().getNode())
24880     return false;
24881 
24882   // Do not handle stores to undef base pointers.
24883   if (BasePtr.getBase().isUndef())
24884     return false;
24885 
24886   // Do not handle stores to opaque types
24887   if (St->getMemoryVT().isZeroSized())
24888     return false;
24889 
24890   // BaseIndexOffset assumes that offsets are fixed-size, which
24891   // is not valid for scalable vectors where the offsets are
24892   // scaled by `vscale`, so bail out early.
24893   if (St->getMemoryVT().isScalableVector())
24894     return false;
24895 
24896   // Add ST's interval.
24897   Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit);
24898 
24899   while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
24900     if (Chain->getMemoryVT().isScalableVector())
24901       return false;
24902 
24903     // If the chain has more than one use, then we can't reorder the mem ops.
24904     if (!SDValue(Chain, 0)->hasOneUse())
24905       break;
24906     // TODO: Relax for unordered atomics (see D66309)
24907     if (!Chain->isSimple() || Chain->isIndexed())
24908       break;
24909 
24910     // Find the base pointer and offset for this memory node.
24911     const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
24912     // Check that the base pointer is the same as the original one.
24913     int64_t Offset;
24914     if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
24915       break;
24916     int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
24917     // Make sure we don't overlap with other intervals by checking the ones to
24918     // the left or right before inserting.
24919     auto I = Intervals.find(Offset);
24920     // If there's a next interval, we should end before it.
24921     if (I != Intervals.end() && I.start() < (Offset + Length))
24922       break;
24923     // If there's a previous interval, we should start after it.
24924     if (I != Intervals.begin() && (--I).stop() <= Offset)
24925       break;
24926     Intervals.insert(Offset, Offset + Length, Unit);
24927 
24928     ChainedStores.push_back(Chain);
24929     STChain = Chain;
24930   }
24931 
24932   // If we didn't find a chained store, exit.
24933   if (ChainedStores.size() == 0)
24934     return false;
24935 
24936   // Improve all chained stores (St and ChainedStores members) starting from
24937   // where the store chain ended and return single TokenFactor.
24938   SDValue NewChain = STChain->getChain();
24939   SmallVector<SDValue, 8> TFOps;
24940   for (unsigned I = ChainedStores.size(); I;) {
24941     StoreSDNode *S = ChainedStores[--I];
24942     SDValue BetterChain = FindBetterChain(S, NewChain);
24943     S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
24944         S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
24945     TFOps.push_back(SDValue(S, 0));
24946     ChainedStores[I] = S;
24947   }
24948 
24949   // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
24950   SDValue BetterChain = FindBetterChain(St, NewChain);
24951   SDValue NewST;
24952   if (St->isTruncatingStore())
24953     NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
24954                               St->getBasePtr(), St->getMemoryVT(),
24955                               St->getMemOperand());
24956   else
24957     NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
24958                          St->getBasePtr(), St->getMemOperand());
24959 
24960   TFOps.push_back(NewST);
24961 
24962   // If we improved every element of TFOps, then we've lost the dependence on
24963   // NewChain to successors of St and we need to add it back to TFOps. Do so at
24964   // the beginning to keep relative order consistent with FindBetterChains.
24965   auto hasImprovedChain = [&](SDValue ST) -> bool {
24966     return ST->getOperand(0) != NewChain;
24967   };
24968   bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
24969   if (AddNewChain)
24970     TFOps.insert(TFOps.begin(), NewChain);
24971 
24972   SDValue TF = DAG.getTokenFactor(SDLoc(STChain), TFOps);
24973   CombineTo(St, TF);
24974 
24975   // Add TF and its operands to the worklist.
24976   AddToWorklist(TF.getNode());
24977   for (const SDValue &Op : TF->ops())
24978     AddToWorklist(Op.getNode());
24979   AddToWorklist(STChain);
24980   return true;
24981 }
24982 
findBetterNeighborChains(StoreSDNode * St)24983 bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
24984   if (OptLevel == CodeGenOpt::None)
24985     return false;
24986 
24987   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
24988 
24989   // We must have a base and an offset.
24990   if (!BasePtr.getBase().getNode())
24991     return false;
24992 
24993   // Do not handle stores to undef base pointers.
24994   if (BasePtr.getBase().isUndef())
24995     return false;
24996 
24997   // Directly improve a chain of disjoint stores starting at St.
24998   if (parallelizeChainedStores(St))
24999     return true;
25000 
25001   // Improve St's Chain..
25002   SDValue BetterChain = FindBetterChain(St, St->getChain());
25003   if (St->getChain() != BetterChain) {
25004     replaceStoreChain(St, BetterChain);
25005     return true;
25006   }
25007   return false;
25008 }
25009 
25010 /// This is the entry point for the file.
Combine(CombineLevel Level,AliasAnalysis * AA,CodeGenOpt::Level OptLevel)25011 void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
25012                            CodeGenOpt::Level OptLevel) {
25013   /// This is the main entry point to this class.
25014   DAGCombiner(*this, AA, OptLevel).Run(Level);
25015 }
25016