1 //===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass combines dag nodes to form fewer, simpler DAG nodes. It can be run
10 // both before and after the DAG is legalized.
11 //
12 // This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13 // primarily intended to handle simplification opportunities that are implicit
14 // in the LLVM IR and exposed by the various codegen lowering phases.
15 //
16 //===----------------------------------------------------------------------===//
17
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/IntervalMap.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallBitVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/SmallSet.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/Statistic.h"
32 #include "llvm/Analysis/AliasAnalysis.h"
33 #include "llvm/Analysis/MemoryLocation.h"
34 #include "llvm/Analysis/TargetLibraryInfo.h"
35 #include "llvm/Analysis/VectorUtils.h"
36 #include "llvm/CodeGen/DAGCombine.h"
37 #include "llvm/CodeGen/ISDOpcodes.h"
38 #include "llvm/CodeGen/MachineFunction.h"
39 #include "llvm/CodeGen/MachineMemOperand.h"
40 #include "llvm/CodeGen/RuntimeLibcalls.h"
41 #include "llvm/CodeGen/SelectionDAG.h"
42 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
43 #include "llvm/CodeGen/SelectionDAGNodes.h"
44 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
45 #include "llvm/CodeGen/TargetLowering.h"
46 #include "llvm/CodeGen/TargetRegisterInfo.h"
47 #include "llvm/CodeGen/TargetSubtargetInfo.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/IR/Attributes.h"
50 #include "llvm/IR/Constant.h"
51 #include "llvm/IR/DataLayout.h"
52 #include "llvm/IR/DerivedTypes.h"
53 #include "llvm/IR/Function.h"
54 #include "llvm/IR/Metadata.h"
55 #include "llvm/Support/Casting.h"
56 #include "llvm/Support/CodeGen.h"
57 #include "llvm/Support/CommandLine.h"
58 #include "llvm/Support/Compiler.h"
59 #include "llvm/Support/Debug.h"
60 #include "llvm/Support/ErrorHandling.h"
61 #include "llvm/Support/KnownBits.h"
62 #include "llvm/Support/MachineValueType.h"
63 #include "llvm/Support/MathExtras.h"
64 #include "llvm/Support/raw_ostream.h"
65 #include "llvm/Target/TargetMachine.h"
66 #include "llvm/Target/TargetOptions.h"
67 #include <algorithm>
68 #include <cassert>
69 #include <cstdint>
70 #include <functional>
71 #include <iterator>
72 #include <string>
73 #include <tuple>
74 #include <utility>
75
76 using namespace llvm;
77
78 #define DEBUG_TYPE "dagcombine"
79
80 STATISTIC(NodesCombined , "Number of dag nodes combined");
81 STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
82 STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
83 STATISTIC(OpsNarrowed , "Number of load/op/store narrowed");
84 STATISTIC(LdStFP2Int , "Number of fp load/store pairs transformed to int");
85 STATISTIC(SlicedLoads, "Number of load sliced");
86 STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
87
88 static cl::opt<bool>
89 CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
90 cl::desc("Enable DAG combiner's use of IR alias analysis"));
91
92 static cl::opt<bool>
93 UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
94 cl::desc("Enable DAG combiner's use of TBAA"));
95
96 #ifndef NDEBUG
97 static cl::opt<std::string>
98 CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
99 cl::desc("Only use DAG-combiner alias analysis in this"
100 " function"));
101 #endif
102
103 /// Hidden option to stress test load slicing, i.e., when this option
104 /// is enabled, load slicing bypasses most of its profitability guards.
105 static cl::opt<bool>
106 StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
107 cl::desc("Bypass the profitability model of load slicing"),
108 cl::init(false));
109
110 static cl::opt<bool>
111 MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
112 cl::desc("DAG combiner may split indexing from loads"));
113
114 static cl::opt<bool>
115 EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
116 cl::desc("DAG combiner enable merging multiple stores "
117 "into a wider store"));
118
119 static cl::opt<unsigned> TokenFactorInlineLimit(
120 "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
121 cl::desc("Limit the number of operands to inline for Token Factors"));
122
123 static cl::opt<unsigned> StoreMergeDependenceLimit(
124 "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
125 cl::desc("Limit the number of times for the same StoreNode and RootNode "
126 "to bail out in store merging dependence check"));
127
128 static cl::opt<bool> EnableReduceLoadOpStoreWidth(
129 "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(true),
130 cl::desc("DAG combiner enable reducing the width of load/op/store "
131 "sequence"));
132
133 static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
134 "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(true),
135 cl::desc("DAG combiner enable load/<replace bytes>/store with "
136 "a narrower store"));
137
138 namespace {
139
140 class DAGCombiner {
141 SelectionDAG &DAG;
142 const TargetLowering &TLI;
143 const SelectionDAGTargetInfo *STI;
144 CombineLevel Level = BeforeLegalizeTypes;
145 CodeGenOpt::Level OptLevel;
146 bool LegalDAG = false;
147 bool LegalOperations = false;
148 bool LegalTypes = false;
149 bool ForCodeSize;
150 bool DisableGenericCombines;
151
152 /// Worklist of all of the nodes that need to be simplified.
153 ///
154 /// This must behave as a stack -- new nodes to process are pushed onto the
155 /// back and when processing we pop off of the back.
156 ///
157 /// The worklist will not contain duplicates but may contain null entries
158 /// due to nodes being deleted from the underlying DAG.
159 SmallVector<SDNode *, 64> Worklist;
160
161 /// Mapping from an SDNode to its position on the worklist.
162 ///
163 /// This is used to find and remove nodes from the worklist (by nulling
164 /// them) when they are deleted from the underlying DAG. It relies on
165 /// stable indices of nodes within the worklist.
166 DenseMap<SDNode *, unsigned> WorklistMap;
167 /// This records all nodes attempted to add to the worklist since we
168 /// considered a new worklist entry. As we keep do not add duplicate nodes
169 /// in the worklist, this is different from the tail of the worklist.
170 SmallSetVector<SDNode *, 32> PruningList;
171
172 /// Set of nodes which have been combined (at least once).
173 ///
174 /// This is used to allow us to reliably add any operands of a DAG node
175 /// which have not yet been combined to the worklist.
176 SmallPtrSet<SDNode *, 32> CombinedNodes;
177
178 /// Map from candidate StoreNode to the pair of RootNode and count.
179 /// The count is used to track how many times we have seen the StoreNode
180 /// with the same RootNode bail out in dependence check. If we have seen
181 /// the bail out for the same pair many times over a limit, we won't
182 /// consider the StoreNode with the same RootNode as store merging
183 /// candidate again.
184 DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
185
186 // AA - Used for DAG load/store alias analysis.
187 AliasAnalysis *AA;
188
189 /// When an instruction is simplified, add all users of the instruction to
190 /// the work lists because they might get more simplified now.
AddUsersToWorklist(SDNode * N)191 void AddUsersToWorklist(SDNode *N) {
192 for (SDNode *Node : N->uses())
193 AddToWorklist(Node);
194 }
195
196 /// Convenient shorthand to add a node and all of its user to the worklist.
AddToWorklistWithUsers(SDNode * N)197 void AddToWorklistWithUsers(SDNode *N) {
198 AddUsersToWorklist(N);
199 AddToWorklist(N);
200 }
201
202 // Prune potentially dangling nodes. This is called after
203 // any visit to a node, but should also be called during a visit after any
204 // failed combine which may have created a DAG node.
clearAddedDanglingWorklistEntries()205 void clearAddedDanglingWorklistEntries() {
206 // Check any nodes added to the worklist to see if they are prunable.
207 while (!PruningList.empty()) {
208 auto *N = PruningList.pop_back_val();
209 if (N->use_empty())
210 recursivelyDeleteUnusedNodes(N);
211 }
212 }
213
getNextWorklistEntry()214 SDNode *getNextWorklistEntry() {
215 // Before we do any work, remove nodes that are not in use.
216 clearAddedDanglingWorklistEntries();
217 SDNode *N = nullptr;
218 // The Worklist holds the SDNodes in order, but it may contain null
219 // entries.
220 while (!N && !Worklist.empty()) {
221 N = Worklist.pop_back_val();
222 }
223
224 if (N) {
225 bool GoodWorklistEntry = WorklistMap.erase(N);
226 (void)GoodWorklistEntry;
227 assert(GoodWorklistEntry &&
228 "Found a worklist entry without a corresponding map entry!");
229 }
230 return N;
231 }
232
233 /// Call the node-specific routine that folds each particular type of node.
234 SDValue visit(SDNode *N);
235
236 public:
DAGCombiner(SelectionDAG & D,AliasAnalysis * AA,CodeGenOpt::Level OL)237 DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOpt::Level OL)
238 : DAG(D), TLI(D.getTargetLoweringInfo()),
239 STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
240 ForCodeSize = DAG.shouldOptForSize();
241 DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
242
243 MaximumLegalStoreInBits = 0;
244 // We use the minimum store size here, since that's all we can guarantee
245 // for the scalable vector types.
246 for (MVT VT : MVT::all_valuetypes())
247 if (EVT(VT).isSimple() && VT != MVT::Other &&
248 TLI.isTypeLegal(EVT(VT)) &&
249 VT.getSizeInBits().getKnownMinSize() >= MaximumLegalStoreInBits)
250 MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinSize();
251 }
252
ConsiderForPruning(SDNode * N)253 void ConsiderForPruning(SDNode *N) {
254 // Mark this for potential pruning.
255 PruningList.insert(N);
256 }
257
258 /// Add to the worklist making sure its instance is at the back (next to be
259 /// processed.)
AddToWorklist(SDNode * N)260 void AddToWorklist(SDNode *N) {
261 assert(N->getOpcode() != ISD::DELETED_NODE &&
262 "Deleted Node added to Worklist");
263
264 // Skip handle nodes as they can't usefully be combined and confuse the
265 // zero-use deletion strategy.
266 if (N->getOpcode() == ISD::HANDLENODE)
267 return;
268
269 ConsiderForPruning(N);
270
271 if (WorklistMap.insert(std::make_pair(N, Worklist.size())).second)
272 Worklist.push_back(N);
273 }
274
275 /// Remove all instances of N from the worklist.
removeFromWorklist(SDNode * N)276 void removeFromWorklist(SDNode *N) {
277 CombinedNodes.erase(N);
278 PruningList.remove(N);
279 StoreRootCountMap.erase(N);
280
281 auto It = WorklistMap.find(N);
282 if (It == WorklistMap.end())
283 return; // Not in the worklist.
284
285 // Null out the entry rather than erasing it to avoid a linear operation.
286 Worklist[It->second] = nullptr;
287 WorklistMap.erase(It);
288 }
289
290 void deleteAndRecombine(SDNode *N);
291 bool recursivelyDeleteUnusedNodes(SDNode *N);
292
293 /// Replaces all uses of the results of one DAG node with new values.
294 SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
295 bool AddTo = true);
296
297 /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res,bool AddTo=true)298 SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
299 return CombineTo(N, &Res, 1, AddTo);
300 }
301
302 /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo=true)303 SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
304 bool AddTo = true) {
305 SDValue To[] = { Res0, Res1 };
306 return CombineTo(N, To, 2, AddTo);
307 }
308
309 void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
310
311 private:
312 unsigned MaximumLegalStoreInBits;
313
314 /// Check the specified integer node value to see if it can be simplified or
315 /// if things it uses can be simplified by bit propagation.
316 /// If so, return true.
SimplifyDemandedBits(SDValue Op)317 bool SimplifyDemandedBits(SDValue Op) {
318 unsigned BitWidth = Op.getScalarValueSizeInBits();
319 APInt DemandedBits = APInt::getAllOnes(BitWidth);
320 return SimplifyDemandedBits(Op, DemandedBits);
321 }
322
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits)323 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
324 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
325 KnownBits Known;
326 if (!TLI.SimplifyDemandedBits(Op, DemandedBits, Known, TLO, 0, false))
327 return false;
328
329 // Revisit the node.
330 AddToWorklist(Op.getNode());
331
332 CommitTargetLoweringOpt(TLO);
333 return true;
334 }
335
336 /// Check the specified vector node value to see if it can be simplified or
337 /// if things it uses can be simplified as it only uses some of the
338 /// elements. If so, return true.
SimplifyDemandedVectorElts(SDValue Op)339 bool SimplifyDemandedVectorElts(SDValue Op) {
340 // TODO: For now just pretend it cannot be simplified.
341 if (Op.getValueType().isScalableVector())
342 return false;
343
344 unsigned NumElts = Op.getValueType().getVectorNumElements();
345 APInt DemandedElts = APInt::getAllOnes(NumElts);
346 return SimplifyDemandedVectorElts(Op, DemandedElts);
347 }
348
349 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
350 const APInt &DemandedElts,
351 bool AssumeSingleUse = false);
352 bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
353 bool AssumeSingleUse = false);
354
355 bool CombineToPreIndexedLoadStore(SDNode *N);
356 bool CombineToPostIndexedLoadStore(SDNode *N);
357 SDValue SplitIndexingFromLoad(LoadSDNode *LD);
358 bool SliceUpLoad(SDNode *N);
359
360 // Scalars have size 0 to distinguish from singleton vectors.
361 SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
362 bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
363 bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
364
365 /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
366 /// load.
367 ///
368 /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
369 /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
370 /// \param EltNo index of the vector element to load.
371 /// \param OriginalLoad load that EVE came from to be replaced.
372 /// \returns EVE on success SDValue() on failure.
373 SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
374 SDValue EltNo,
375 LoadSDNode *OriginalLoad);
376 void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
377 SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
378 SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
379 SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
380 SDValue PromoteIntBinOp(SDValue Op);
381 SDValue PromoteIntShiftOp(SDValue Op);
382 SDValue PromoteExtend(SDValue Op);
383 bool PromoteLoad(SDValue Op);
384
385 /// Call the node-specific routine that knows how to fold each
386 /// particular type of node. If that doesn't do anything, try the
387 /// target-specific DAG combines.
388 SDValue combine(SDNode *N);
389
390 // Visitation implementation - Implement dag node combining for different
391 // node types. The semantics are as follows:
392 // Return Value:
393 // SDValue.getNode() == 0 - No change was made
394 // SDValue.getNode() == N - N was replaced, is dead and has been handled.
395 // otherwise - N should be replaced by the returned Operand.
396 //
397 SDValue visitTokenFactor(SDNode *N);
398 SDValue visitMERGE_VALUES(SDNode *N);
399 SDValue visitADD(SDNode *N);
400 SDValue visitADDLike(SDNode *N);
401 SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
402 SDValue visitSUB(SDNode *N);
403 SDValue visitADDSAT(SDNode *N);
404 SDValue visitSUBSAT(SDNode *N);
405 SDValue visitADDC(SDNode *N);
406 SDValue visitADDO(SDNode *N);
407 SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
408 SDValue visitSUBC(SDNode *N);
409 SDValue visitSUBO(SDNode *N);
410 SDValue visitADDE(SDNode *N);
411 SDValue visitADDCARRY(SDNode *N);
412 SDValue visitSADDO_CARRY(SDNode *N);
413 SDValue visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, SDNode *N);
414 SDValue visitSUBE(SDNode *N);
415 SDValue visitSUBCARRY(SDNode *N);
416 SDValue visitSSUBO_CARRY(SDNode *N);
417 SDValue visitMUL(SDNode *N);
418 SDValue visitMULFIX(SDNode *N);
419 SDValue useDivRem(SDNode *N);
420 SDValue visitSDIV(SDNode *N);
421 SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
422 SDValue visitUDIV(SDNode *N);
423 SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
424 SDValue visitREM(SDNode *N);
425 SDValue visitMULHU(SDNode *N);
426 SDValue visitMULHS(SDNode *N);
427 SDValue visitAVG(SDNode *N);
428 SDValue visitSMUL_LOHI(SDNode *N);
429 SDValue visitUMUL_LOHI(SDNode *N);
430 SDValue visitMULO(SDNode *N);
431 SDValue visitIMINMAX(SDNode *N);
432 SDValue visitAND(SDNode *N);
433 SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
434 SDValue visitOR(SDNode *N);
435 SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N);
436 SDValue visitXOR(SDNode *N);
437 SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
438 SDValue visitSHL(SDNode *N);
439 SDValue visitSRA(SDNode *N);
440 SDValue visitSRL(SDNode *N);
441 SDValue visitFunnelShift(SDNode *N);
442 SDValue visitSHLSAT(SDNode *N);
443 SDValue visitRotate(SDNode *N);
444 SDValue visitABS(SDNode *N);
445 SDValue visitBSWAP(SDNode *N);
446 SDValue visitBITREVERSE(SDNode *N);
447 SDValue visitCTLZ(SDNode *N);
448 SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
449 SDValue visitCTTZ(SDNode *N);
450 SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
451 SDValue visitCTPOP(SDNode *N);
452 SDValue visitSELECT(SDNode *N);
453 SDValue visitVSELECT(SDNode *N);
454 SDValue visitSELECT_CC(SDNode *N);
455 SDValue visitSETCC(SDNode *N);
456 SDValue visitSETCCCARRY(SDNode *N);
457 SDValue visitSIGN_EXTEND(SDNode *N);
458 SDValue visitZERO_EXTEND(SDNode *N);
459 SDValue visitANY_EXTEND(SDNode *N);
460 SDValue visitAssertExt(SDNode *N);
461 SDValue visitAssertAlign(SDNode *N);
462 SDValue visitSIGN_EXTEND_INREG(SDNode *N);
463 SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
464 SDValue visitTRUNCATE(SDNode *N);
465 SDValue visitBITCAST(SDNode *N);
466 SDValue visitFREEZE(SDNode *N);
467 SDValue visitBUILD_PAIR(SDNode *N);
468 SDValue visitFADD(SDNode *N);
469 SDValue visitSTRICT_FADD(SDNode *N);
470 SDValue visitFSUB(SDNode *N);
471 SDValue visitFMUL(SDNode *N);
472 SDValue visitFMA(SDNode *N);
473 SDValue visitFDIV(SDNode *N);
474 SDValue visitFREM(SDNode *N);
475 SDValue visitFSQRT(SDNode *N);
476 SDValue visitFCOPYSIGN(SDNode *N);
477 SDValue visitFPOW(SDNode *N);
478 SDValue visitSINT_TO_FP(SDNode *N);
479 SDValue visitUINT_TO_FP(SDNode *N);
480 SDValue visitFP_TO_SINT(SDNode *N);
481 SDValue visitFP_TO_UINT(SDNode *N);
482 SDValue visitFP_ROUND(SDNode *N);
483 SDValue visitFP_EXTEND(SDNode *N);
484 SDValue visitFNEG(SDNode *N);
485 SDValue visitFABS(SDNode *N);
486 SDValue visitFCEIL(SDNode *N);
487 SDValue visitFTRUNC(SDNode *N);
488 SDValue visitFFLOOR(SDNode *N);
489 SDValue visitFMinMax(SDNode *N);
490 SDValue visitBRCOND(SDNode *N);
491 SDValue visitBR_CC(SDNode *N);
492 SDValue visitLOAD(SDNode *N);
493
494 SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
495 SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
496
497 SDValue visitSTORE(SDNode *N);
498 SDValue visitLIFETIME_END(SDNode *N);
499 SDValue visitINSERT_VECTOR_ELT(SDNode *N);
500 SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
501 SDValue visitBUILD_VECTOR(SDNode *N);
502 SDValue visitCONCAT_VECTORS(SDNode *N);
503 SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
504 SDValue visitVECTOR_SHUFFLE(SDNode *N);
505 SDValue visitSCALAR_TO_VECTOR(SDNode *N);
506 SDValue visitINSERT_SUBVECTOR(SDNode *N);
507 SDValue visitMLOAD(SDNode *N);
508 SDValue visitMSTORE(SDNode *N);
509 SDValue visitMGATHER(SDNode *N);
510 SDValue visitMSCATTER(SDNode *N);
511 SDValue visitFP_TO_FP16(SDNode *N);
512 SDValue visitFP16_TO_FP(SDNode *N);
513 SDValue visitFP_TO_BF16(SDNode *N);
514 SDValue visitVECREDUCE(SDNode *N);
515 SDValue visitVPOp(SDNode *N);
516
517 SDValue visitFADDForFMACombine(SDNode *N);
518 SDValue visitFSUBForFMACombine(SDNode *N);
519 SDValue visitFMULForFMADistributiveCombine(SDNode *N);
520
521 SDValue XformToShuffleWithZero(SDNode *N);
522 bool reassociationCanBreakAddressingModePattern(unsigned Opc,
523 const SDLoc &DL,
524 SDNode *N,
525 SDValue N0,
526 SDValue N1);
527 SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
528 SDValue N1);
529 SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
530 SDValue N1, SDNodeFlags Flags);
531
532 SDValue visitShiftByConstant(SDNode *N);
533
534 SDValue foldSelectOfConstants(SDNode *N);
535 SDValue foldVSelectOfConstants(SDNode *N);
536 SDValue foldBinOpIntoSelect(SDNode *BO);
537 bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
538 SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
539 SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
540 SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
541 SDValue N2, SDValue N3, ISD::CondCode CC,
542 bool NotExtCompare = false);
543 SDValue convertSelectOfFPConstantsToLoadOffset(
544 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
545 ISD::CondCode CC);
546 SDValue foldSignChangeInBitcast(SDNode *N);
547 SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
548 SDValue N2, SDValue N3, ISD::CondCode CC);
549 SDValue foldSelectOfBinops(SDNode *N);
550 SDValue foldSextSetcc(SDNode *N);
551 SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
552 const SDLoc &DL);
553 SDValue foldSubToUSubSat(EVT DstVT, SDNode *N);
554 SDValue unfoldMaskedMerge(SDNode *N);
555 SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
556 SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
557 const SDLoc &DL, bool foldBooleans);
558 SDValue rebuildSetCC(SDValue N);
559
560 bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
561 SDValue &CC, bool MatchStrict = false) const;
562 bool isOneUseSetCC(SDValue N) const;
563
564 SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
565 unsigned HiOp);
566 SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
567 SDValue CombineExtLoad(SDNode *N);
568 SDValue CombineZExtLogicopShiftLoad(SDNode *N);
569 SDValue combineRepeatedFPDivisors(SDNode *N);
570 SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
571 SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
572 SDValue BuildSDIV(SDNode *N);
573 SDValue BuildSDIVPow2(SDNode *N);
574 SDValue BuildUDIV(SDNode *N);
575 SDValue BuildSREMPow2(SDNode *N);
576 SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
577 SDValue BuildLogBase2(SDValue V, const SDLoc &DL);
578 SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
579 SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
580 SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
581 SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
582 SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
583 SDNodeFlags Flags, bool Reciprocal);
584 SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
585 SDNodeFlags Flags, bool Reciprocal);
586 SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
587 bool DemandHighBits = true);
588 SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
589 SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
590 SDValue InnerPos, SDValue InnerNeg, bool HasPos,
591 unsigned PosOpcode, unsigned NegOpcode,
592 const SDLoc &DL);
593 SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
594 SDValue InnerPos, SDValue InnerNeg, bool HasPos,
595 unsigned PosOpcode, unsigned NegOpcode,
596 const SDLoc &DL);
597 SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
598 SDValue MatchLoadCombine(SDNode *N);
599 SDValue mergeTruncStores(StoreSDNode *N);
600 SDValue reduceLoadWidth(SDNode *N);
601 SDValue ReduceLoadOpStoreWidth(SDNode *N);
602 SDValue splitMergedValStore(StoreSDNode *ST);
603 SDValue TransformFPLoadStorePair(SDNode *N);
604 SDValue convertBuildVecZextToZext(SDNode *N);
605 SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
606 SDValue reduceBuildVecTruncToBitCast(SDNode *N);
607 SDValue reduceBuildVecToShuffle(SDNode *N);
608 SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
609 ArrayRef<int> VectorMask, SDValue VecIn1,
610 SDValue VecIn2, unsigned LeftIdx,
611 bool DidSplitVec);
612 SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
613
614 /// Walk up chain skipping non-aliasing memory nodes,
615 /// looking for aliasing nodes and adding them to the Aliases vector.
616 void GatherAllAliases(SDNode *N, SDValue OriginalChain,
617 SmallVectorImpl<SDValue> &Aliases);
618
619 /// Return true if there is any possibility that the two addresses overlap.
620 bool mayAlias(SDNode *Op0, SDNode *Op1) const;
621
622 /// Walk up chain skipping non-aliasing memory nodes, looking for a better
623 /// chain (aliasing node.)
624 SDValue FindBetterChain(SDNode *N, SDValue Chain);
625
626 /// Try to replace a store and any possibly adjacent stores on
627 /// consecutive chains with better chains. Return true only if St is
628 /// replaced.
629 ///
630 /// Notice that other chains may still be replaced even if the function
631 /// returns false.
632 bool findBetterNeighborChains(StoreSDNode *St);
633
634 // Helper for findBetterNeighborChains. Walk up store chain add additional
635 // chained stores that do not overlap and can be parallelized.
636 bool parallelizeChainedStores(StoreSDNode *St);
637
638 /// Holds a pointer to an LSBaseSDNode as well as information on where it
639 /// is located in a sequence of memory operations connected by a chain.
640 struct MemOpLink {
641 // Ptr to the mem node.
642 LSBaseSDNode *MemNode;
643
644 // Offset from the base ptr.
645 int64_t OffsetFromBase;
646
MemOpLink__anon54f00e400111::DAGCombiner::MemOpLink647 MemOpLink(LSBaseSDNode *N, int64_t Offset)
648 : MemNode(N), OffsetFromBase(Offset) {}
649 };
650
651 // Classify the origin of a stored value.
652 enum class StoreSource { Unknown, Constant, Extract, Load };
getStoreSource(SDValue StoreVal)653 StoreSource getStoreSource(SDValue StoreVal) {
654 switch (StoreVal.getOpcode()) {
655 case ISD::Constant:
656 case ISD::ConstantFP:
657 return StoreSource::Constant;
658 case ISD::EXTRACT_VECTOR_ELT:
659 case ISD::EXTRACT_SUBVECTOR:
660 return StoreSource::Extract;
661 case ISD::LOAD:
662 return StoreSource::Load;
663 default:
664 return StoreSource::Unknown;
665 }
666 }
667
668 /// This is a helper function for visitMUL to check the profitability
669 /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
670 /// MulNode is the original multiply, AddNode is (add x, c1),
671 /// and ConstNode is c2.
672 bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
673 SDValue ConstNode);
674
675 /// This is a helper function for visitAND and visitZERO_EXTEND. Returns
676 /// true if the (and (load x) c) pattern matches an extload. ExtVT returns
677 /// the type of the loaded value to be extended.
678 bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
679 EVT LoadResultTy, EVT &ExtVT);
680
681 /// Helper function to calculate whether the given Load/Store can have its
682 /// width reduced to ExtVT.
683 bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
684 EVT &MemVT, unsigned ShAmt = 0);
685
686 /// Used by BackwardsPropagateMask to find suitable loads.
687 bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
688 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
689 ConstantSDNode *Mask, SDNode *&NodeToMask);
690 /// Attempt to propagate a given AND node back to load leaves so that they
691 /// can be combined into narrow loads.
692 bool BackwardsPropagateMask(SDNode *N);
693
694 /// Helper function for mergeConsecutiveStores which merges the component
695 /// store chains.
696 SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
697 unsigned NumStores);
698
699 /// This is a helper function for mergeConsecutiveStores. When the source
700 /// elements of the consecutive stores are all constants or all extracted
701 /// vector elements, try to merge them into one larger store introducing
702 /// bitcasts if necessary. \return True if a merged store was created.
703 bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
704 EVT MemVT, unsigned NumStores,
705 bool IsConstantSrc, bool UseVector,
706 bool UseTrunc);
707
708 /// This is a helper function for mergeConsecutiveStores. Stores that
709 /// potentially may be merged with St are placed in StoreNodes. RootNode is
710 /// a chain predecessor to all store candidates.
711 void getStoreMergeCandidates(StoreSDNode *St,
712 SmallVectorImpl<MemOpLink> &StoreNodes,
713 SDNode *&Root);
714
715 /// Helper function for mergeConsecutiveStores. Checks if candidate stores
716 /// have indirect dependency through their operands. RootNode is the
717 /// predecessor to all stores calculated by getStoreMergeCandidates and is
718 /// used to prune the dependency check. \return True if safe to merge.
719 bool checkMergeStoreCandidatesForDependencies(
720 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
721 SDNode *RootNode);
722
723 /// This is a helper function for mergeConsecutiveStores. Given a list of
724 /// store candidates, find the first N that are consecutive in memory.
725 /// Returns 0 if there are not at least 2 consecutive stores to try merging.
726 unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
727 int64_t ElementSizeBytes) const;
728
729 /// This is a helper function for mergeConsecutiveStores. It is used for
730 /// store chains that are composed entirely of constant values.
731 bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
732 unsigned NumConsecutiveStores,
733 EVT MemVT, SDNode *Root, bool AllowVectors);
734
735 /// This is a helper function for mergeConsecutiveStores. It is used for
736 /// store chains that are composed entirely of extracted vector elements.
737 /// When extracting multiple vector elements, try to store them in one
738 /// vector store rather than a sequence of scalar stores.
739 bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
740 unsigned NumConsecutiveStores, EVT MemVT,
741 SDNode *Root);
742
743 /// This is a helper function for mergeConsecutiveStores. It is used for
744 /// store chains that are composed entirely of loaded values.
745 bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
746 unsigned NumConsecutiveStores, EVT MemVT,
747 SDNode *Root, bool AllowVectors,
748 bool IsNonTemporalStore, bool IsNonTemporalLoad);
749
750 /// Merge consecutive store operations into a wide store.
751 /// This optimization uses wide integers or vectors when possible.
752 /// \return true if stores were merged.
753 bool mergeConsecutiveStores(StoreSDNode *St);
754
755 /// Try to transform a truncation where C is a constant:
756 /// (trunc (and X, C)) -> (and (trunc X), (trunc C))
757 ///
758 /// \p N needs to be a truncation and its first operand an AND. Other
759 /// requirements are checked by the function (e.g. that trunc is
760 /// single-use) and if missed an empty SDValue is returned.
761 SDValue distributeTruncateThroughAnd(SDNode *N);
762
763 /// Helper function to determine whether the target supports operation
764 /// given by \p Opcode for type \p VT, that is, whether the operation
765 /// is legal or custom before legalizing operations, and whether is
766 /// legal (but not custom) after legalization.
hasOperation(unsigned Opcode,EVT VT)767 bool hasOperation(unsigned Opcode, EVT VT) {
768 return TLI.isOperationLegalOrCustom(Opcode, VT, LegalOperations);
769 }
770
771 public:
772 /// Runs the dag combiner on all nodes in the work list
773 void Run(CombineLevel AtLevel);
774
getDAG() const775 SelectionDAG &getDAG() const { return DAG; }
776
777 /// Returns a type large enough to hold any valid shift amount - before type
778 /// legalization these can be huge.
getShiftAmountTy(EVT LHSTy)779 EVT getShiftAmountTy(EVT LHSTy) {
780 assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
781 return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout(), LegalTypes);
782 }
783
784 /// This method returns true if we are running before type legalization or
785 /// if the specified VT is legal.
isTypeLegal(const EVT & VT)786 bool isTypeLegal(const EVT &VT) {
787 if (!LegalTypes) return true;
788 return TLI.isTypeLegal(VT);
789 }
790
791 /// Convenience wrapper around TargetLowering::getSetCCResultType
getSetCCResultType(EVT VT) const792 EVT getSetCCResultType(EVT VT) const {
793 return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
794 }
795
796 void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
797 SDValue OrigLoad, SDValue ExtLoad,
798 ISD::NodeType ExtType);
799 };
800
801 /// This class is a DAGUpdateListener that removes any deleted
802 /// nodes from the worklist.
803 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
804 DAGCombiner &DC;
805
806 public:
WorklistRemover(DAGCombiner & dc)807 explicit WorklistRemover(DAGCombiner &dc)
808 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
809
NodeDeleted(SDNode * N,SDNode * E)810 void NodeDeleted(SDNode *N, SDNode *E) override {
811 DC.removeFromWorklist(N);
812 }
813 };
814
815 class WorklistInserter : public SelectionDAG::DAGUpdateListener {
816 DAGCombiner &DC;
817
818 public:
WorklistInserter(DAGCombiner & dc)819 explicit WorklistInserter(DAGCombiner &dc)
820 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
821
822 // FIXME: Ideally we could add N to the worklist, but this causes exponential
823 // compile time costs in large DAGs, e.g. Halide.
NodeInserted(SDNode * N)824 void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
825 };
826
827 } // end anonymous namespace
828
829 //===----------------------------------------------------------------------===//
830 // TargetLowering::DAGCombinerInfo implementation
831 //===----------------------------------------------------------------------===//
832
AddToWorklist(SDNode * N)833 void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
834 ((DAGCombiner*)DC)->AddToWorklist(N);
835 }
836
837 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,ArrayRef<SDValue> To,bool AddTo)838 CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
839 return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
840 }
841
842 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res,bool AddTo)843 CombineTo(SDNode *N, SDValue Res, bool AddTo) {
844 return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
845 }
846
847 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo)848 CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
849 return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
850 }
851
852 bool TargetLowering::DAGCombinerInfo::
recursivelyDeleteUnusedNodes(SDNode * N)853 recursivelyDeleteUnusedNodes(SDNode *N) {
854 return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
855 }
856
857 void TargetLowering::DAGCombinerInfo::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)858 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
859 return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
860 }
861
862 //===----------------------------------------------------------------------===//
863 // Helper Functions
864 //===----------------------------------------------------------------------===//
865
deleteAndRecombine(SDNode * N)866 void DAGCombiner::deleteAndRecombine(SDNode *N) {
867 removeFromWorklist(N);
868
869 // If the operands of this node are only used by the node, they will now be
870 // dead. Make sure to re-visit them and recursively delete dead nodes.
871 for (const SDValue &Op : N->ops())
872 // For an operand generating multiple values, one of the values may
873 // become dead allowing further simplification (e.g. split index
874 // arithmetic from an indexed load).
875 if (Op->hasOneUse() || Op->getNumValues() > 1)
876 AddToWorklist(Op.getNode());
877
878 DAG.DeleteNode(N);
879 }
880
881 // APInts must be the same size for most operations, this helper
882 // function zero extends the shorter of the pair so that they match.
883 // We provide an Offset so that we can create bitwidths that won't overflow.
zeroExtendToMatch(APInt & LHS,APInt & RHS,unsigned Offset=0)884 static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
885 unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
886 LHS = LHS.zext(Bits);
887 RHS = RHS.zext(Bits);
888 }
889
890 // Return true if this node is a setcc, or is a select_cc
891 // that selects between the target values used for true and false, making it
892 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
893 // the appropriate nodes based on the type of node we are checking. This
894 // simplifies life a bit for the callers.
isSetCCEquivalent(SDValue N,SDValue & LHS,SDValue & RHS,SDValue & CC,bool MatchStrict) const895 bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
896 SDValue &CC, bool MatchStrict) const {
897 if (N.getOpcode() == ISD::SETCC) {
898 LHS = N.getOperand(0);
899 RHS = N.getOperand(1);
900 CC = N.getOperand(2);
901 return true;
902 }
903
904 if (MatchStrict &&
905 (N.getOpcode() == ISD::STRICT_FSETCC ||
906 N.getOpcode() == ISD::STRICT_FSETCCS)) {
907 LHS = N.getOperand(1);
908 RHS = N.getOperand(2);
909 CC = N.getOperand(3);
910 return true;
911 }
912
913 if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N.getOperand(2)) ||
914 !TLI.isConstFalseVal(N.getOperand(3)))
915 return false;
916
917 if (TLI.getBooleanContents(N.getValueType()) ==
918 TargetLowering::UndefinedBooleanContent)
919 return false;
920
921 LHS = N.getOperand(0);
922 RHS = N.getOperand(1);
923 CC = N.getOperand(4);
924 return true;
925 }
926
927 /// Return true if this is a SetCC-equivalent operation with only one use.
928 /// If this is true, it allows the users to invert the operation for free when
929 /// it is profitable to do so.
isOneUseSetCC(SDValue N) const930 bool DAGCombiner::isOneUseSetCC(SDValue N) const {
931 SDValue N0, N1, N2;
932 if (isSetCCEquivalent(N, N0, N1, N2) && N->hasOneUse())
933 return true;
934 return false;
935 }
936
isConstantSplatVectorMaskForType(SDNode * N,EVT ScalarTy)937 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
938 if (!ScalarTy.isSimple())
939 return false;
940
941 uint64_t MaskForTy = 0ULL;
942 switch (ScalarTy.getSimpleVT().SimpleTy) {
943 case MVT::i8:
944 MaskForTy = 0xFFULL;
945 break;
946 case MVT::i16:
947 MaskForTy = 0xFFFFULL;
948 break;
949 case MVT::i32:
950 MaskForTy = 0xFFFFFFFFULL;
951 break;
952 default:
953 return false;
954 break;
955 }
956
957 APInt Val;
958 if (ISD::isConstantSplatVector(N, Val))
959 return Val.getLimitedValue() == MaskForTy;
960
961 return false;
962 }
963
964 // Determines if it is a constant integer or a splat/build vector of constant
965 // integers (and undefs).
966 // Do not permit build vector implicit truncation.
isConstantOrConstantVector(SDValue N,bool NoOpaques=false)967 static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
968 if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
969 return !(Const->isOpaque() && NoOpaques);
970 if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
971 return false;
972 unsigned BitWidth = N.getScalarValueSizeInBits();
973 for (const SDValue &Op : N->op_values()) {
974 if (Op.isUndef())
975 continue;
976 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
977 if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
978 (Const->isOpaque() && NoOpaques))
979 return false;
980 }
981 return true;
982 }
983
984 // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
985 // undef's.
isAnyConstantBuildVector(SDValue V,bool NoOpaques=false)986 static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
987 if (V.getOpcode() != ISD::BUILD_VECTOR)
988 return false;
989 return isConstantOrConstantVector(V, NoOpaques) ||
990 ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
991 }
992
993 // Determine if this an indexed load with an opaque target constant index.
canSplitIdx(LoadSDNode * LD)994 static bool canSplitIdx(LoadSDNode *LD) {
995 return MaySplitLoadIndex &&
996 (LD->getOperand(2).getOpcode() != ISD::TargetConstant ||
997 !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
998 }
999
reassociationCanBreakAddressingModePattern(unsigned Opc,const SDLoc & DL,SDNode * N,SDValue N0,SDValue N1)1000 bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1001 const SDLoc &DL,
1002 SDNode *N,
1003 SDValue N0,
1004 SDValue N1) {
1005 // Currently this only tries to ensure we don't undo the GEP splits done by
1006 // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1007 // we check if the following transformation would be problematic:
1008 // (load/store (add, (add, x, offset1), offset2)) ->
1009 // (load/store (add, x, offset1+offset2)).
1010
1011 // (load/store (add, (add, x, y), offset2)) ->
1012 // (load/store (add, (add, x, offset2), y)).
1013
1014 if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
1015 return false;
1016
1017 auto *C2 = dyn_cast<ConstantSDNode>(N1);
1018 if (!C2)
1019 return false;
1020
1021 const APInt &C2APIntVal = C2->getAPIntValue();
1022 if (C2APIntVal.getSignificantBits() > 64)
1023 return false;
1024
1025 if (auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
1026 if (N0.hasOneUse())
1027 return false;
1028
1029 const APInt &C1APIntVal = C1->getAPIntValue();
1030 const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1031 if (CombinedValueIntVal.getSignificantBits() > 64)
1032 return false;
1033 const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1034
1035 for (SDNode *Node : N->uses()) {
1036 if (auto *LoadStore = dyn_cast<MemSDNode>(Node)) {
1037 // Is x[offset2] already not a legal addressing mode? If so then
1038 // reassociating the constants breaks nothing (we test offset2 because
1039 // that's the one we hope to fold into the load or store).
1040 TargetLoweringBase::AddrMode AM;
1041 AM.HasBaseReg = true;
1042 AM.BaseOffs = C2APIntVal.getSExtValue();
1043 EVT VT = LoadStore->getMemoryVT();
1044 unsigned AS = LoadStore->getAddressSpace();
1045 Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1046 if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1047 continue;
1048
1049 // Would x[offset1+offset2] still be a legal addressing mode?
1050 AM.BaseOffs = CombinedValue;
1051 if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1052 return true;
1053 }
1054 }
1055 } else {
1056 if (auto *GA = dyn_cast<GlobalAddressSDNode>(N0.getOperand(1)))
1057 if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1058 return false;
1059
1060 for (SDNode *Node : N->uses()) {
1061 auto *LoadStore = dyn_cast<MemSDNode>(Node);
1062 if (!LoadStore)
1063 return false;
1064
1065 // Is x[offset2] a legal addressing mode? If so then
1066 // reassociating the constants breaks address pattern
1067 TargetLoweringBase::AddrMode AM;
1068 AM.HasBaseReg = true;
1069 AM.BaseOffs = C2APIntVal.getSExtValue();
1070 EVT VT = LoadStore->getMemoryVT();
1071 unsigned AS = LoadStore->getAddressSpace();
1072 Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1073 if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1074 return false;
1075 }
1076 return true;
1077 }
1078
1079 return false;
1080 }
1081
1082 // Helper for DAGCombiner::reassociateOps. Try to reassociate an expression
1083 // such as (Opc N0, N1), if \p N0 is the same kind of operation as \p Opc.
reassociateOpsCommutative(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1)1084 SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1085 SDValue N0, SDValue N1) {
1086 EVT VT = N0.getValueType();
1087
1088 if (N0.getOpcode() != Opc)
1089 return SDValue();
1090
1091 SDValue N00 = N0.getOperand(0);
1092 SDValue N01 = N0.getOperand(1);
1093
1094 if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(N01))) {
1095 if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(N1))) {
1096 // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1097 if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, {N01, N1}))
1098 return DAG.getNode(Opc, DL, VT, N00, OpNode);
1099 return SDValue();
1100 }
1101 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1102 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1103 // iff (op x, c1) has one use
1104 SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1);
1105 return DAG.getNode(Opc, DL, VT, OpNode, N01);
1106 }
1107 }
1108
1109 // Check for repeated operand logic simplifications.
1110 if (Opc == ISD::AND || Opc == ISD::OR) {
1111 // (N00 & N01) & N00 --> N00 & N01
1112 // (N00 & N01) & N01 --> N00 & N01
1113 // (N00 | N01) | N00 --> N00 | N01
1114 // (N00 | N01) | N01 --> N00 | N01
1115 if (N1 == N00 || N1 == N01)
1116 return N0;
1117 }
1118 if (Opc == ISD::XOR) {
1119 // (N00 ^ N01) ^ N00 --> N01
1120 if (N1 == N00)
1121 return N01;
1122 // (N00 ^ N01) ^ N01 --> N00
1123 if (N1 == N01)
1124 return N00;
1125 }
1126
1127 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1128 if (N1 != N01) {
1129 // Reassociate if (op N00, N1) already exist
1130 if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N00, N1})) {
1131 // if Op (Op N00, N1), N01 already exist
1132 // we need to stop reassciate to avoid dead loop
1133 if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N01}))
1134 return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N01);
1135 }
1136 }
1137
1138 if (N1 != N00) {
1139 // Reassociate if (op N01, N1) already exist
1140 if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N01, N1})) {
1141 // if Op (Op N01, N1), N00 already exist
1142 // we need to stop reassciate to avoid dead loop
1143 if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N00}))
1144 return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N00);
1145 }
1146 }
1147 }
1148
1149 return SDValue();
1150 }
1151
1152 // Try to reassociate commutative binops.
reassociateOps(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1153 SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1154 SDValue N1, SDNodeFlags Flags) {
1155 assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1156
1157 // Floating-point reassociation is not allowed without loose FP math.
1158 if (N0.getValueType().isFloatingPoint() ||
1159 N1.getValueType().isFloatingPoint())
1160 if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1161 return SDValue();
1162
1163 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1))
1164 return Combined;
1165 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0))
1166 return Combined;
1167 return SDValue();
1168 }
1169
CombineTo(SDNode * N,const SDValue * To,unsigned NumTo,bool AddTo)1170 SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1171 bool AddTo) {
1172 assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1173 ++NodesCombined;
1174 LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1175 To[0].dump(&DAG);
1176 dbgs() << " and " << NumTo - 1 << " other values\n");
1177 for (unsigned i = 0, e = NumTo; i != e; ++i)
1178 assert((!To[i].getNode() ||
1179 N->getValueType(i) == To[i].getValueType()) &&
1180 "Cannot combine value to value of different type!");
1181
1182 WorklistRemover DeadNodes(*this);
1183 DAG.ReplaceAllUsesWith(N, To);
1184 if (AddTo) {
1185 // Push the new nodes and any users onto the worklist
1186 for (unsigned i = 0, e = NumTo; i != e; ++i) {
1187 if (To[i].getNode())
1188 AddToWorklistWithUsers(To[i].getNode());
1189 }
1190 }
1191
1192 // Finally, if the node is now dead, remove it from the graph. The node
1193 // may not be dead if the replacement process recursively simplified to
1194 // something else needing this node.
1195 if (N->use_empty())
1196 deleteAndRecombine(N);
1197 return SDValue(N, 0);
1198 }
1199
1200 void DAGCombiner::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1201 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1202 // Replace the old value with the new one.
1203 ++NodesCombined;
1204 LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1205 dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1206
1207 // Replace all uses. If any nodes become isomorphic to other nodes and
1208 // are deleted, make sure to remove them from our worklist.
1209 WorklistRemover DeadNodes(*this);
1210 DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1211
1212 // Push the new node and any (possibly new) users onto the worklist.
1213 AddToWorklistWithUsers(TLO.New.getNode());
1214
1215 // Finally, if the node is now dead, remove it from the graph. The node
1216 // may not be dead if the replacement process recursively simplified to
1217 // something else needing this node.
1218 if (TLO.Old->use_empty())
1219 deleteAndRecombine(TLO.Old.getNode());
1220 }
1221
1222 /// Check the specified integer node value to see if it can be simplified or if
1223 /// things it uses can be simplified by bit propagation. If so, return true.
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,bool AssumeSingleUse)1224 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1225 const APInt &DemandedElts,
1226 bool AssumeSingleUse) {
1227 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1228 KnownBits Known;
1229 if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, 0,
1230 AssumeSingleUse))
1231 return false;
1232
1233 // Revisit the node.
1234 AddToWorklist(Op.getNode());
1235
1236 CommitTargetLoweringOpt(TLO);
1237 return true;
1238 }
1239
1240 /// Check the specified vector node value to see if it can be simplified or
1241 /// if things it uses can be simplified as it only uses some of the elements.
1242 /// If so, return true.
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,bool AssumeSingleUse)1243 bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1244 const APInt &DemandedElts,
1245 bool AssumeSingleUse) {
1246 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1247 APInt KnownUndef, KnownZero;
1248 if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1249 TLO, 0, AssumeSingleUse))
1250 return false;
1251
1252 // Revisit the node.
1253 AddToWorklist(Op.getNode());
1254
1255 CommitTargetLoweringOpt(TLO);
1256 return true;
1257 }
1258
ReplaceLoadWithPromotedLoad(SDNode * Load,SDNode * ExtLoad)1259 void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1260 SDLoc DL(Load);
1261 EVT VT = Load->getValueType(0);
1262 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1263
1264 LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1265 Trunc.dump(&DAG); dbgs() << '\n');
1266 WorklistRemover DeadNodes(*this);
1267 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1268 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1269 deleteAndRecombine(Load);
1270 AddToWorklist(Trunc.getNode());
1271 }
1272
PromoteOperand(SDValue Op,EVT PVT,bool & Replace)1273 SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1274 Replace = false;
1275 SDLoc DL(Op);
1276 if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1277 LoadSDNode *LD = cast<LoadSDNode>(Op);
1278 EVT MemVT = LD->getMemoryVT();
1279 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1280 : LD->getExtensionType();
1281 Replace = true;
1282 return DAG.getExtLoad(ExtType, DL, PVT,
1283 LD->getChain(), LD->getBasePtr(),
1284 MemVT, LD->getMemOperand());
1285 }
1286
1287 unsigned Opc = Op.getOpcode();
1288 switch (Opc) {
1289 default: break;
1290 case ISD::AssertSext:
1291 if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1292 return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1293 break;
1294 case ISD::AssertZext:
1295 if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1296 return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1297 break;
1298 case ISD::Constant: {
1299 unsigned ExtOpc =
1300 Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1301 return DAG.getNode(ExtOpc, DL, PVT, Op);
1302 }
1303 }
1304
1305 if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1306 return SDValue();
1307 return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1308 }
1309
SExtPromoteOperand(SDValue Op,EVT PVT)1310 SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1311 if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
1312 return SDValue();
1313 EVT OldVT = Op.getValueType();
1314 SDLoc DL(Op);
1315 bool Replace = false;
1316 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1317 if (!NewOp.getNode())
1318 return SDValue();
1319 AddToWorklist(NewOp.getNode());
1320
1321 if (Replace)
1322 ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1323 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1324 DAG.getValueType(OldVT));
1325 }
1326
ZExtPromoteOperand(SDValue Op,EVT PVT)1327 SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1328 EVT OldVT = Op.getValueType();
1329 SDLoc DL(Op);
1330 bool Replace = false;
1331 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1332 if (!NewOp.getNode())
1333 return SDValue();
1334 AddToWorklist(NewOp.getNode());
1335
1336 if (Replace)
1337 ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1338 return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1339 }
1340
1341 /// Promote the specified integer binary operation if the target indicates it is
1342 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1343 /// i32 since i16 instructions are longer.
PromoteIntBinOp(SDValue Op)1344 SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1345 if (!LegalOperations)
1346 return SDValue();
1347
1348 EVT VT = Op.getValueType();
1349 if (VT.isVector() || !VT.isInteger())
1350 return SDValue();
1351
1352 // If operation type is 'undesirable', e.g. i16 on x86, consider
1353 // promoting it.
1354 unsigned Opc = Op.getOpcode();
1355 if (TLI.isTypeDesirableForOp(Opc, VT))
1356 return SDValue();
1357
1358 EVT PVT = VT;
1359 // Consult target whether it is a good idea to promote this operation and
1360 // what's the right type to promote it to.
1361 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1362 assert(PVT != VT && "Don't know what type to promote to!");
1363
1364 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1365
1366 bool Replace0 = false;
1367 SDValue N0 = Op.getOperand(0);
1368 SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1369
1370 bool Replace1 = false;
1371 SDValue N1 = Op.getOperand(1);
1372 SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1373 SDLoc DL(Op);
1374
1375 SDValue RV =
1376 DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1377
1378 // We are always replacing N0/N1's use in N and only need additional
1379 // replacements if there are additional uses.
1380 // Note: We are checking uses of the *nodes* (SDNode) rather than values
1381 // (SDValue) here because the node may reference multiple values
1382 // (for example, the chain value of a load node).
1383 Replace0 &= !N0->hasOneUse();
1384 Replace1 &= (N0 != N1) && !N1->hasOneUse();
1385
1386 // Combine Op here so it is preserved past replacements.
1387 CombineTo(Op.getNode(), RV);
1388
1389 // If operands have a use ordering, make sure we deal with
1390 // predecessor first.
1391 if (Replace0 && Replace1 && N0->isPredecessorOf(N1.getNode())) {
1392 std::swap(N0, N1);
1393 std::swap(NN0, NN1);
1394 }
1395
1396 if (Replace0) {
1397 AddToWorklist(NN0.getNode());
1398 ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1399 }
1400 if (Replace1) {
1401 AddToWorklist(NN1.getNode());
1402 ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1403 }
1404 return Op;
1405 }
1406 return SDValue();
1407 }
1408
1409 /// Promote the specified integer shift operation if the target indicates it is
1410 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1411 /// i32 since i16 instructions are longer.
PromoteIntShiftOp(SDValue Op)1412 SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1413 if (!LegalOperations)
1414 return SDValue();
1415
1416 EVT VT = Op.getValueType();
1417 if (VT.isVector() || !VT.isInteger())
1418 return SDValue();
1419
1420 // If operation type is 'undesirable', e.g. i16 on x86, consider
1421 // promoting it.
1422 unsigned Opc = Op.getOpcode();
1423 if (TLI.isTypeDesirableForOp(Opc, VT))
1424 return SDValue();
1425
1426 EVT PVT = VT;
1427 // Consult target whether it is a good idea to promote this operation and
1428 // what's the right type to promote it to.
1429 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1430 assert(PVT != VT && "Don't know what type to promote to!");
1431
1432 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1433
1434 bool Replace = false;
1435 SDValue N0 = Op.getOperand(0);
1436 if (Opc == ISD::SRA)
1437 N0 = SExtPromoteOperand(N0, PVT);
1438 else if (Opc == ISD::SRL)
1439 N0 = ZExtPromoteOperand(N0, PVT);
1440 else
1441 N0 = PromoteOperand(N0, PVT, Replace);
1442
1443 if (!N0.getNode())
1444 return SDValue();
1445
1446 SDLoc DL(Op);
1447 SDValue N1 = Op.getOperand(1);
1448 SDValue RV =
1449 DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
1450
1451 if (Replace)
1452 ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1453
1454 // Deal with Op being deleted.
1455 if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1456 return RV;
1457 }
1458 return SDValue();
1459 }
1460
PromoteExtend(SDValue Op)1461 SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1462 if (!LegalOperations)
1463 return SDValue();
1464
1465 EVT VT = Op.getValueType();
1466 if (VT.isVector() || !VT.isInteger())
1467 return SDValue();
1468
1469 // If operation type is 'undesirable', e.g. i16 on x86, consider
1470 // promoting it.
1471 unsigned Opc = Op.getOpcode();
1472 if (TLI.isTypeDesirableForOp(Opc, VT))
1473 return SDValue();
1474
1475 EVT PVT = VT;
1476 // Consult target whether it is a good idea to promote this operation and
1477 // what's the right type to promote it to.
1478 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1479 assert(PVT != VT && "Don't know what type to promote to!");
1480 // fold (aext (aext x)) -> (aext x)
1481 // fold (aext (zext x)) -> (zext x)
1482 // fold (aext (sext x)) -> (sext x)
1483 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1484 return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1485 }
1486 return SDValue();
1487 }
1488
PromoteLoad(SDValue Op)1489 bool DAGCombiner::PromoteLoad(SDValue Op) {
1490 if (!LegalOperations)
1491 return false;
1492
1493 if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1494 return false;
1495
1496 EVT VT = Op.getValueType();
1497 if (VT.isVector() || !VT.isInteger())
1498 return false;
1499
1500 // If operation type is 'undesirable', e.g. i16 on x86, consider
1501 // promoting it.
1502 unsigned Opc = Op.getOpcode();
1503 if (TLI.isTypeDesirableForOp(Opc, VT))
1504 return false;
1505
1506 EVT PVT = VT;
1507 // Consult target whether it is a good idea to promote this operation and
1508 // what's the right type to promote it to.
1509 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1510 assert(PVT != VT && "Don't know what type to promote to!");
1511
1512 SDLoc DL(Op);
1513 SDNode *N = Op.getNode();
1514 LoadSDNode *LD = cast<LoadSDNode>(N);
1515 EVT MemVT = LD->getMemoryVT();
1516 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1517 : LD->getExtensionType();
1518 SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1519 LD->getChain(), LD->getBasePtr(),
1520 MemVT, LD->getMemOperand());
1521 SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1522
1523 LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1524 Result.dump(&DAG); dbgs() << '\n');
1525 WorklistRemover DeadNodes(*this);
1526 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1527 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1528 deleteAndRecombine(N);
1529 AddToWorklist(Result.getNode());
1530 return true;
1531 }
1532 return false;
1533 }
1534
1535 /// Recursively delete a node which has no uses and any operands for
1536 /// which it is the only use.
1537 ///
1538 /// Note that this both deletes the nodes and removes them from the worklist.
1539 /// It also adds any nodes who have had a user deleted to the worklist as they
1540 /// may now have only one use and subject to other combines.
recursivelyDeleteUnusedNodes(SDNode * N)1541 bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1542 if (!N->use_empty())
1543 return false;
1544
1545 SmallSetVector<SDNode *, 16> Nodes;
1546 Nodes.insert(N);
1547 do {
1548 N = Nodes.pop_back_val();
1549 if (!N)
1550 continue;
1551
1552 if (N->use_empty()) {
1553 for (const SDValue &ChildN : N->op_values())
1554 Nodes.insert(ChildN.getNode());
1555
1556 removeFromWorklist(N);
1557 DAG.DeleteNode(N);
1558 } else {
1559 AddToWorklist(N);
1560 }
1561 } while (!Nodes.empty());
1562 return true;
1563 }
1564
1565 //===----------------------------------------------------------------------===//
1566 // Main DAG Combiner implementation
1567 //===----------------------------------------------------------------------===//
1568
Run(CombineLevel AtLevel)1569 void DAGCombiner::Run(CombineLevel AtLevel) {
1570 // set the instance variables, so that the various visit routines may use it.
1571 Level = AtLevel;
1572 LegalDAG = Level >= AfterLegalizeDAG;
1573 LegalOperations = Level >= AfterLegalizeVectorOps;
1574 LegalTypes = Level >= AfterLegalizeTypes;
1575
1576 WorklistInserter AddNodes(*this);
1577
1578 // Add all the dag nodes to the worklist.
1579 for (SDNode &Node : DAG.allnodes())
1580 AddToWorklist(&Node);
1581
1582 // Create a dummy node (which is not added to allnodes), that adds a reference
1583 // to the root node, preventing it from being deleted, and tracking any
1584 // changes of the root.
1585 HandleSDNode Dummy(DAG.getRoot());
1586
1587 // While we have a valid worklist entry node, try to combine it.
1588 while (SDNode *N = getNextWorklistEntry()) {
1589 // If N has no uses, it is dead. Make sure to revisit all N's operands once
1590 // N is deleted from the DAG, since they too may now be dead or may have a
1591 // reduced number of uses, allowing other xforms.
1592 if (recursivelyDeleteUnusedNodes(N))
1593 continue;
1594
1595 WorklistRemover DeadNodes(*this);
1596
1597 // If this combine is running after legalizing the DAG, re-legalize any
1598 // nodes pulled off the worklist.
1599 if (LegalDAG) {
1600 SmallSetVector<SDNode *, 16> UpdatedNodes;
1601 bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1602
1603 for (SDNode *LN : UpdatedNodes)
1604 AddToWorklistWithUsers(LN);
1605
1606 if (!NIsValid)
1607 continue;
1608 }
1609
1610 LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1611
1612 // Add any operands of the new node which have not yet been combined to the
1613 // worklist as well. Because the worklist uniques things already, this
1614 // won't repeatedly process the same operand.
1615 CombinedNodes.insert(N);
1616 for (const SDValue &ChildN : N->op_values())
1617 if (!CombinedNodes.count(ChildN.getNode()))
1618 AddToWorklist(ChildN.getNode());
1619
1620 SDValue RV = combine(N);
1621
1622 if (!RV.getNode())
1623 continue;
1624
1625 ++NodesCombined;
1626
1627 // If we get back the same node we passed in, rather than a new node or
1628 // zero, we know that the node must have defined multiple values and
1629 // CombineTo was used. Since CombineTo takes care of the worklist
1630 // mechanics for us, we have no work to do in this case.
1631 if (RV.getNode() == N)
1632 continue;
1633
1634 assert(N->getOpcode() != ISD::DELETED_NODE &&
1635 RV.getOpcode() != ISD::DELETED_NODE &&
1636 "Node was deleted but visit returned new node!");
1637
1638 LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1639
1640 if (N->getNumValues() == RV->getNumValues())
1641 DAG.ReplaceAllUsesWith(N, RV.getNode());
1642 else {
1643 assert(N->getValueType(0) == RV.getValueType() &&
1644 N->getNumValues() == 1 && "Type mismatch");
1645 DAG.ReplaceAllUsesWith(N, &RV);
1646 }
1647
1648 // Push the new node and any users onto the worklist. Omit this if the
1649 // new node is the EntryToken (e.g. if a store managed to get optimized
1650 // out), because re-visiting the EntryToken and its users will not uncover
1651 // any additional opportunities, but there may be a large number of such
1652 // users, potentially causing compile time explosion.
1653 if (RV.getOpcode() != ISD::EntryToken) {
1654 AddToWorklist(RV.getNode());
1655 AddUsersToWorklist(RV.getNode());
1656 }
1657
1658 // Finally, if the node is now dead, remove it from the graph. The node
1659 // may not be dead if the replacement process recursively simplified to
1660 // something else needing this node. This will also take care of adding any
1661 // operands which have lost a user to the worklist.
1662 recursivelyDeleteUnusedNodes(N);
1663 }
1664
1665 // If the root changed (e.g. it was a dead load, update the root).
1666 DAG.setRoot(Dummy.getValue());
1667 DAG.RemoveDeadNodes();
1668 }
1669
visit(SDNode * N)1670 SDValue DAGCombiner::visit(SDNode *N) {
1671 switch (N->getOpcode()) {
1672 default: break;
1673 case ISD::TokenFactor: return visitTokenFactor(N);
1674 case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
1675 case ISD::ADD: return visitADD(N);
1676 case ISD::SUB: return visitSUB(N);
1677 case ISD::SADDSAT:
1678 case ISD::UADDSAT: return visitADDSAT(N);
1679 case ISD::SSUBSAT:
1680 case ISD::USUBSAT: return visitSUBSAT(N);
1681 case ISD::ADDC: return visitADDC(N);
1682 case ISD::SADDO:
1683 case ISD::UADDO: return visitADDO(N);
1684 case ISD::SUBC: return visitSUBC(N);
1685 case ISD::SSUBO:
1686 case ISD::USUBO: return visitSUBO(N);
1687 case ISD::ADDE: return visitADDE(N);
1688 case ISD::ADDCARRY: return visitADDCARRY(N);
1689 case ISD::SADDO_CARRY: return visitSADDO_CARRY(N);
1690 case ISD::SUBE: return visitSUBE(N);
1691 case ISD::SUBCARRY: return visitSUBCARRY(N);
1692 case ISD::SSUBO_CARRY: return visitSSUBO_CARRY(N);
1693 case ISD::SMULFIX:
1694 case ISD::SMULFIXSAT:
1695 case ISD::UMULFIX:
1696 case ISD::UMULFIXSAT: return visitMULFIX(N);
1697 case ISD::MUL: return visitMUL(N);
1698 case ISD::SDIV: return visitSDIV(N);
1699 case ISD::UDIV: return visitUDIV(N);
1700 case ISD::SREM:
1701 case ISD::UREM: return visitREM(N);
1702 case ISD::MULHU: return visitMULHU(N);
1703 case ISD::MULHS: return visitMULHS(N);
1704 case ISD::AVGFLOORS:
1705 case ISD::AVGFLOORU:
1706 case ISD::AVGCEILS:
1707 case ISD::AVGCEILU: return visitAVG(N);
1708 case ISD::SMUL_LOHI: return visitSMUL_LOHI(N);
1709 case ISD::UMUL_LOHI: return visitUMUL_LOHI(N);
1710 case ISD::SMULO:
1711 case ISD::UMULO: return visitMULO(N);
1712 case ISD::SMIN:
1713 case ISD::SMAX:
1714 case ISD::UMIN:
1715 case ISD::UMAX: return visitIMINMAX(N);
1716 case ISD::AND: return visitAND(N);
1717 case ISD::OR: return visitOR(N);
1718 case ISD::XOR: return visitXOR(N);
1719 case ISD::SHL: return visitSHL(N);
1720 case ISD::SRA: return visitSRA(N);
1721 case ISD::SRL: return visitSRL(N);
1722 case ISD::ROTR:
1723 case ISD::ROTL: return visitRotate(N);
1724 case ISD::FSHL:
1725 case ISD::FSHR: return visitFunnelShift(N);
1726 case ISD::SSHLSAT:
1727 case ISD::USHLSAT: return visitSHLSAT(N);
1728 case ISD::ABS: return visitABS(N);
1729 case ISD::BSWAP: return visitBSWAP(N);
1730 case ISD::BITREVERSE: return visitBITREVERSE(N);
1731 case ISD::CTLZ: return visitCTLZ(N);
1732 case ISD::CTLZ_ZERO_UNDEF: return visitCTLZ_ZERO_UNDEF(N);
1733 case ISD::CTTZ: return visitCTTZ(N);
1734 case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N);
1735 case ISD::CTPOP: return visitCTPOP(N);
1736 case ISD::SELECT: return visitSELECT(N);
1737 case ISD::VSELECT: return visitVSELECT(N);
1738 case ISD::SELECT_CC: return visitSELECT_CC(N);
1739 case ISD::SETCC: return visitSETCC(N);
1740 case ISD::SETCCCARRY: return visitSETCCCARRY(N);
1741 case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N);
1742 case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N);
1743 case ISD::ANY_EXTEND: return visitANY_EXTEND(N);
1744 case ISD::AssertSext:
1745 case ISD::AssertZext: return visitAssertExt(N);
1746 case ISD::AssertAlign: return visitAssertAlign(N);
1747 case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N);
1748 case ISD::SIGN_EXTEND_VECTOR_INREG:
1749 case ISD::ZERO_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
1750 case ISD::TRUNCATE: return visitTRUNCATE(N);
1751 case ISD::BITCAST: return visitBITCAST(N);
1752 case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
1753 case ISD::FADD: return visitFADD(N);
1754 case ISD::STRICT_FADD: return visitSTRICT_FADD(N);
1755 case ISD::FSUB: return visitFSUB(N);
1756 case ISD::FMUL: return visitFMUL(N);
1757 case ISD::FMA: return visitFMA(N);
1758 case ISD::FDIV: return visitFDIV(N);
1759 case ISD::FREM: return visitFREM(N);
1760 case ISD::FSQRT: return visitFSQRT(N);
1761 case ISD::FCOPYSIGN: return visitFCOPYSIGN(N);
1762 case ISD::FPOW: return visitFPOW(N);
1763 case ISD::SINT_TO_FP: return visitSINT_TO_FP(N);
1764 case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
1765 case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
1766 case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
1767 case ISD::FP_ROUND: return visitFP_ROUND(N);
1768 case ISD::FP_EXTEND: return visitFP_EXTEND(N);
1769 case ISD::FNEG: return visitFNEG(N);
1770 case ISD::FABS: return visitFABS(N);
1771 case ISD::FFLOOR: return visitFFLOOR(N);
1772 case ISD::FMINNUM:
1773 case ISD::FMAXNUM:
1774 case ISD::FMINIMUM:
1775 case ISD::FMAXIMUM: return visitFMinMax(N);
1776 case ISD::FCEIL: return visitFCEIL(N);
1777 case ISD::FTRUNC: return visitFTRUNC(N);
1778 case ISD::BRCOND: return visitBRCOND(N);
1779 case ISD::BR_CC: return visitBR_CC(N);
1780 case ISD::LOAD: return visitLOAD(N);
1781 case ISD::STORE: return visitSTORE(N);
1782 case ISD::INSERT_VECTOR_ELT: return visitINSERT_VECTOR_ELT(N);
1783 case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
1784 case ISD::BUILD_VECTOR: return visitBUILD_VECTOR(N);
1785 case ISD::CONCAT_VECTORS: return visitCONCAT_VECTORS(N);
1786 case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N);
1787 case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N);
1788 case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N);
1789 case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N);
1790 case ISD::MGATHER: return visitMGATHER(N);
1791 case ISD::MLOAD: return visitMLOAD(N);
1792 case ISD::MSCATTER: return visitMSCATTER(N);
1793 case ISD::MSTORE: return visitMSTORE(N);
1794 case ISD::LIFETIME_END: return visitLIFETIME_END(N);
1795 case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
1796 case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
1797 case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
1798 case ISD::FREEZE: return visitFREEZE(N);
1799 case ISD::VECREDUCE_FADD:
1800 case ISD::VECREDUCE_FMUL:
1801 case ISD::VECREDUCE_ADD:
1802 case ISD::VECREDUCE_MUL:
1803 case ISD::VECREDUCE_AND:
1804 case ISD::VECREDUCE_OR:
1805 case ISD::VECREDUCE_XOR:
1806 case ISD::VECREDUCE_SMAX:
1807 case ISD::VECREDUCE_SMIN:
1808 case ISD::VECREDUCE_UMAX:
1809 case ISD::VECREDUCE_UMIN:
1810 case ISD::VECREDUCE_FMAX:
1811 case ISD::VECREDUCE_FMIN: return visitVECREDUCE(N);
1812 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
1813 #include "llvm/IR/VPIntrinsics.def"
1814 return visitVPOp(N);
1815 }
1816 return SDValue();
1817 }
1818
combine(SDNode * N)1819 SDValue DAGCombiner::combine(SDNode *N) {
1820 SDValue RV;
1821 if (!DisableGenericCombines)
1822 RV = visit(N);
1823
1824 // If nothing happened, try a target-specific DAG combine.
1825 if (!RV.getNode()) {
1826 assert(N->getOpcode() != ISD::DELETED_NODE &&
1827 "Node was deleted but visit returned NULL!");
1828
1829 if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
1830 TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
1831
1832 // Expose the DAG combiner to the target combiner impls.
1833 TargetLowering::DAGCombinerInfo
1834 DagCombineInfo(DAG, Level, false, this);
1835
1836 RV = TLI.PerformDAGCombine(N, DagCombineInfo);
1837 }
1838 }
1839
1840 // If nothing happened still, try promoting the operation.
1841 if (!RV.getNode()) {
1842 switch (N->getOpcode()) {
1843 default: break;
1844 case ISD::ADD:
1845 case ISD::SUB:
1846 case ISD::MUL:
1847 case ISD::AND:
1848 case ISD::OR:
1849 case ISD::XOR:
1850 RV = PromoteIntBinOp(SDValue(N, 0));
1851 break;
1852 case ISD::SHL:
1853 case ISD::SRA:
1854 case ISD::SRL:
1855 RV = PromoteIntShiftOp(SDValue(N, 0));
1856 break;
1857 case ISD::SIGN_EXTEND:
1858 case ISD::ZERO_EXTEND:
1859 case ISD::ANY_EXTEND:
1860 RV = PromoteExtend(SDValue(N, 0));
1861 break;
1862 case ISD::LOAD:
1863 if (PromoteLoad(SDValue(N, 0)))
1864 RV = SDValue(N, 0);
1865 break;
1866 }
1867 }
1868
1869 // If N is a commutative binary node, try to eliminate it if the commuted
1870 // version is already present in the DAG.
1871 if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode())) {
1872 SDValue N0 = N->getOperand(0);
1873 SDValue N1 = N->getOperand(1);
1874
1875 // Constant operands are canonicalized to RHS.
1876 if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
1877 SDValue Ops[] = {N1, N0};
1878 SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
1879 N->getFlags());
1880 if (CSENode)
1881 return SDValue(CSENode, 0);
1882 }
1883 }
1884
1885 return RV;
1886 }
1887
1888 /// Given a node, return its input chain if it has one, otherwise return a null
1889 /// sd operand.
getInputChainForNode(SDNode * N)1890 static SDValue getInputChainForNode(SDNode *N) {
1891 if (unsigned NumOps = N->getNumOperands()) {
1892 if (N->getOperand(0).getValueType() == MVT::Other)
1893 return N->getOperand(0);
1894 if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
1895 return N->getOperand(NumOps-1);
1896 for (unsigned i = 1; i < NumOps-1; ++i)
1897 if (N->getOperand(i).getValueType() == MVT::Other)
1898 return N->getOperand(i);
1899 }
1900 return SDValue();
1901 }
1902
visitTokenFactor(SDNode * N)1903 SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
1904 // If N has two operands, where one has an input chain equal to the other,
1905 // the 'other' chain is redundant.
1906 if (N->getNumOperands() == 2) {
1907 if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
1908 return N->getOperand(0);
1909 if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
1910 return N->getOperand(1);
1911 }
1912
1913 // Don't simplify token factors if optnone.
1914 if (OptLevel == CodeGenOpt::None)
1915 return SDValue();
1916
1917 // Don't simplify the token factor if the node itself has too many operands.
1918 if (N->getNumOperands() > TokenFactorInlineLimit)
1919 return SDValue();
1920
1921 // If the sole user is a token factor, we should make sure we have a
1922 // chance to merge them together. This prevents TF chains from inhibiting
1923 // optimizations.
1924 if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
1925 AddToWorklist(*(N->use_begin()));
1926
1927 SmallVector<SDNode *, 8> TFs; // List of token factors to visit.
1928 SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
1929 SmallPtrSet<SDNode*, 16> SeenOps;
1930 bool Changed = false; // If we should replace this token factor.
1931
1932 // Start out with this token factor.
1933 TFs.push_back(N);
1934
1935 // Iterate through token factors. The TFs grows when new token factors are
1936 // encountered.
1937 for (unsigned i = 0; i < TFs.size(); ++i) {
1938 // Limit number of nodes to inline, to avoid quadratic compile times.
1939 // We have to add the outstanding Token Factors to Ops, otherwise we might
1940 // drop Ops from the resulting Token Factors.
1941 if (Ops.size() > TokenFactorInlineLimit) {
1942 for (unsigned j = i; j < TFs.size(); j++)
1943 Ops.emplace_back(TFs[j], 0);
1944 // Drop unprocessed Token Factors from TFs, so we do not add them to the
1945 // combiner worklist later.
1946 TFs.resize(i);
1947 break;
1948 }
1949
1950 SDNode *TF = TFs[i];
1951 // Check each of the operands.
1952 for (const SDValue &Op : TF->op_values()) {
1953 switch (Op.getOpcode()) {
1954 case ISD::EntryToken:
1955 // Entry tokens don't need to be added to the list. They are
1956 // redundant.
1957 Changed = true;
1958 break;
1959
1960 case ISD::TokenFactor:
1961 if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
1962 // Queue up for processing.
1963 TFs.push_back(Op.getNode());
1964 Changed = true;
1965 break;
1966 }
1967 LLVM_FALLTHROUGH;
1968
1969 default:
1970 // Only add if it isn't already in the list.
1971 if (SeenOps.insert(Op.getNode()).second)
1972 Ops.push_back(Op);
1973 else
1974 Changed = true;
1975 break;
1976 }
1977 }
1978 }
1979
1980 // Re-visit inlined Token Factors, to clean them up in case they have been
1981 // removed. Skip the first Token Factor, as this is the current node.
1982 for (unsigned i = 1, e = TFs.size(); i < e; i++)
1983 AddToWorklist(TFs[i]);
1984
1985 // Remove Nodes that are chained to another node in the list. Do so
1986 // by walking up chains breath-first stopping when we've seen
1987 // another operand. In general we must climb to the EntryNode, but we can exit
1988 // early if we find all remaining work is associated with just one operand as
1989 // no further pruning is possible.
1990
1991 // List of nodes to search through and original Ops from which they originate.
1992 SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
1993 SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
1994 SmallPtrSet<SDNode *, 16> SeenChains;
1995 bool DidPruneOps = false;
1996
1997 unsigned NumLeftToConsider = 0;
1998 for (const SDValue &Op : Ops) {
1999 Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
2000 OpWorkCount.push_back(1);
2001 }
2002
2003 auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2004 // If this is an Op, we can remove the op from the list. Remark any
2005 // search associated with it as from the current OpNumber.
2006 if (SeenOps.contains(Op)) {
2007 Changed = true;
2008 DidPruneOps = true;
2009 unsigned OrigOpNumber = 0;
2010 while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2011 OrigOpNumber++;
2012 assert((OrigOpNumber != Ops.size()) &&
2013 "expected to find TokenFactor Operand");
2014 // Re-mark worklist from OrigOpNumber to OpNumber
2015 for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2016 if (Worklist[i].second == OrigOpNumber) {
2017 Worklist[i].second = OpNumber;
2018 }
2019 }
2020 OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2021 OpWorkCount[OrigOpNumber] = 0;
2022 NumLeftToConsider--;
2023 }
2024 // Add if it's a new chain
2025 if (SeenChains.insert(Op).second) {
2026 OpWorkCount[OpNumber]++;
2027 Worklist.push_back(std::make_pair(Op, OpNumber));
2028 }
2029 };
2030
2031 for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2032 // We need at least be consider at least 2 Ops to prune.
2033 if (NumLeftToConsider <= 1)
2034 break;
2035 auto CurNode = Worklist[i].first;
2036 auto CurOpNumber = Worklist[i].second;
2037 assert((OpWorkCount[CurOpNumber] > 0) &&
2038 "Node should not appear in worklist");
2039 switch (CurNode->getOpcode()) {
2040 case ISD::EntryToken:
2041 // Hitting EntryToken is the only way for the search to terminate without
2042 // hitting
2043 // another operand's search. Prevent us from marking this operand
2044 // considered.
2045 NumLeftToConsider++;
2046 break;
2047 case ISD::TokenFactor:
2048 for (const SDValue &Op : CurNode->op_values())
2049 AddToWorklist(i, Op.getNode(), CurOpNumber);
2050 break;
2051 case ISD::LIFETIME_START:
2052 case ISD::LIFETIME_END:
2053 case ISD::CopyFromReg:
2054 case ISD::CopyToReg:
2055 AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
2056 break;
2057 default:
2058 if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
2059 AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2060 break;
2061 }
2062 OpWorkCount[CurOpNumber]--;
2063 if (OpWorkCount[CurOpNumber] == 0)
2064 NumLeftToConsider--;
2065 }
2066
2067 // If we've changed things around then replace token factor.
2068 if (Changed) {
2069 SDValue Result;
2070 if (Ops.empty()) {
2071 // The entry token is the only possible outcome.
2072 Result = DAG.getEntryNode();
2073 } else {
2074 if (DidPruneOps) {
2075 SmallVector<SDValue, 8> PrunedOps;
2076 //
2077 for (const SDValue &Op : Ops) {
2078 if (SeenChains.count(Op.getNode()) == 0)
2079 PrunedOps.push_back(Op);
2080 }
2081 Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
2082 } else {
2083 Result = DAG.getTokenFactor(SDLoc(N), Ops);
2084 }
2085 }
2086 return Result;
2087 }
2088 return SDValue();
2089 }
2090
2091 /// MERGE_VALUES can always be eliminated.
visitMERGE_VALUES(SDNode * N)2092 SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2093 WorklistRemover DeadNodes(*this);
2094 // Replacing results may cause a different MERGE_VALUES to suddenly
2095 // be CSE'd with N, and carry its uses with it. Iterate until no
2096 // uses remain, to ensure that the node can be safely deleted.
2097 // First add the users of this node to the work list so that they
2098 // can be tried again once they have new operands.
2099 AddUsersToWorklist(N);
2100 do {
2101 // Do as a single replacement to avoid rewalking use lists.
2102 SmallVector<SDValue, 8> Ops;
2103 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2104 Ops.push_back(N->getOperand(i));
2105 DAG.ReplaceAllUsesWith(N, Ops.data());
2106 } while (!N->use_empty());
2107 deleteAndRecombine(N);
2108 return SDValue(N, 0); // Return N so it doesn't get rechecked!
2109 }
2110
2111 /// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2112 /// ConstantSDNode pointer else nullptr.
getAsNonOpaqueConstant(SDValue N)2113 static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2114 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
2115 return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2116 }
2117
2118 /// Return true if 'Use' is a load or a store that uses N as its base pointer
2119 /// and that N may be folded in the load / store addressing mode.
canFoldInAddressingMode(SDNode * N,SDNode * Use,SelectionDAG & DAG,const TargetLowering & TLI)2120 static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2121 const TargetLowering &TLI) {
2122 EVT VT;
2123 unsigned AS;
2124
2125 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
2126 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2127 return false;
2128 VT = LD->getMemoryVT();
2129 AS = LD->getAddressSpace();
2130 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
2131 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2132 return false;
2133 VT = ST->getMemoryVT();
2134 AS = ST->getAddressSpace();
2135 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Use)) {
2136 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2137 return false;
2138 VT = LD->getMemoryVT();
2139 AS = LD->getAddressSpace();
2140 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Use)) {
2141 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2142 return false;
2143 VT = ST->getMemoryVT();
2144 AS = ST->getAddressSpace();
2145 } else {
2146 return false;
2147 }
2148
2149 TargetLowering::AddrMode AM;
2150 if (N->getOpcode() == ISD::ADD) {
2151 AM.HasBaseReg = true;
2152 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2153 if (Offset)
2154 // [reg +/- imm]
2155 AM.BaseOffs = Offset->getSExtValue();
2156 else
2157 // [reg +/- reg]
2158 AM.Scale = 1;
2159 } else if (N->getOpcode() == ISD::SUB) {
2160 AM.HasBaseReg = true;
2161 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2162 if (Offset)
2163 // [reg +/- imm]
2164 AM.BaseOffs = -Offset->getSExtValue();
2165 else
2166 // [reg +/- reg]
2167 AM.Scale = 1;
2168 } else {
2169 return false;
2170 }
2171
2172 return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
2173 VT.getTypeForEVT(*DAG.getContext()), AS);
2174 }
2175
2176 /// This inverts a canonicalization in IR that replaces a variable select arm
2177 /// with an identity constant. Codegen improves if we re-use the variable
2178 /// operand rather than load a constant. This can also be converted into a
2179 /// masked vector operation if the target supports it.
foldSelectWithIdentityConstant(SDNode * N,SelectionDAG & DAG,bool ShouldCommuteOperands)2180 static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2181 bool ShouldCommuteOperands) {
2182 // Match a select as operand 1. The identity constant that we are looking for
2183 // is only valid as operand 1 of a non-commutative binop.
2184 SDValue N0 = N->getOperand(0);
2185 SDValue N1 = N->getOperand(1);
2186 if (ShouldCommuteOperands)
2187 std::swap(N0, N1);
2188
2189 // TODO: Should this apply to scalar select too?
2190 if (!N1.hasOneUse() || N1.getOpcode() != ISD::VSELECT)
2191 return SDValue();
2192
2193 unsigned Opcode = N->getOpcode();
2194 EVT VT = N->getValueType(0);
2195 SDValue Cond = N1.getOperand(0);
2196 SDValue TVal = N1.getOperand(1);
2197 SDValue FVal = N1.getOperand(2);
2198
2199 // TODO: The cases should match with IR's ConstantExpr::getBinOpIdentity().
2200 // TODO: Target-specific opcodes could be added. Ex: "isCommutativeBinOp()".
2201 // TODO: With fast-math (NSZ), allow the opposite-sign form of zero?
2202 auto isIdentityConstantForOpcode = [](unsigned Opcode, SDValue V) {
2203 if (ConstantFPSDNode *C = isConstOrConstSplatFP(V)) {
2204 switch (Opcode) {
2205 case ISD::FADD: // X + -0.0 --> X
2206 return C->isZero() && C->isNegative();
2207 case ISD::FSUB: // X - 0.0 --> X
2208 return C->isZero() && !C->isNegative();
2209 case ISD::FMUL: // X * 1.0 --> X
2210 case ISD::FDIV: // X / 1.0 --> X
2211 return C->isExactlyValue(1.0);
2212 }
2213 }
2214 if (ConstantSDNode *C = isConstOrConstSplat(V)) {
2215 switch (Opcode) {
2216 case ISD::ADD: // X + 0 --> X
2217 case ISD::SUB: // X - 0 --> X
2218 case ISD::SHL: // X << 0 --> X
2219 case ISD::SRA: // X s>> 0 --> X
2220 case ISD::SRL: // X u>> 0 --> X
2221 return C->isZero();
2222 case ISD::MUL: // X * 1 --> X
2223 return C->isOne();
2224 }
2225 }
2226 return false;
2227 };
2228
2229 // This transform increases uses of N0, so freeze it to be safe.
2230 // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2231 if (isIdentityConstantForOpcode(Opcode, TVal)) {
2232 SDValue F0 = DAG.getFreeze(N0);
2233 SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
2234 return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
2235 }
2236 // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2237 if (isIdentityConstantForOpcode(Opcode, FVal)) {
2238 SDValue F0 = DAG.getFreeze(N0);
2239 SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
2240 return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
2241 }
2242
2243 return SDValue();
2244 }
2245
foldBinOpIntoSelect(SDNode * BO)2246 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2247 assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2248 "Unexpected binary operator");
2249
2250 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2251 auto BinOpcode = BO->getOpcode();
2252 EVT VT = BO->getValueType(0);
2253 if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
2254 if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2255 return Sel;
2256
2257 if (TLI.isCommutativeBinOp(BO->getOpcode()))
2258 if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2259 return Sel;
2260 }
2261
2262 // Don't do this unless the old select is going away. We want to eliminate the
2263 // binary operator, not replace a binop with a select.
2264 // TODO: Handle ISD::SELECT_CC.
2265 unsigned SelOpNo = 0;
2266 SDValue Sel = BO->getOperand(0);
2267 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2268 SelOpNo = 1;
2269 Sel = BO->getOperand(1);
2270 }
2271
2272 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2273 return SDValue();
2274
2275 SDValue CT = Sel.getOperand(1);
2276 if (!isConstantOrConstantVector(CT, true) &&
2277 !DAG.isConstantFPBuildVectorOrConstantFP(CT))
2278 return SDValue();
2279
2280 SDValue CF = Sel.getOperand(2);
2281 if (!isConstantOrConstantVector(CF, true) &&
2282 !DAG.isConstantFPBuildVectorOrConstantFP(CF))
2283 return SDValue();
2284
2285 // Bail out if any constants are opaque because we can't constant fold those.
2286 // The exception is "and" and "or" with either 0 or -1 in which case we can
2287 // propagate non constant operands into select. I.e.:
2288 // and (select Cond, 0, -1), X --> select Cond, 0, X
2289 // or X, (select Cond, -1, 0) --> select Cond, -1, X
2290 bool CanFoldNonConst =
2291 (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2292 (isNullOrNullSplat(CT) || isAllOnesOrAllOnesSplat(CT)) &&
2293 (isNullOrNullSplat(CF) || isAllOnesOrAllOnesSplat(CF));
2294
2295 SDValue CBO = BO->getOperand(SelOpNo ^ 1);
2296 if (!CanFoldNonConst &&
2297 !isConstantOrConstantVector(CBO, true) &&
2298 !DAG.isConstantFPBuildVectorOrConstantFP(CBO))
2299 return SDValue();
2300
2301 // We have a select-of-constants followed by a binary operator with a
2302 // constant. Eliminate the binop by pulling the constant math into the select.
2303 // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
2304 SDLoc DL(Sel);
2305 SDValue NewCT = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CT)
2306 : DAG.getNode(BinOpcode, DL, VT, CT, CBO);
2307 if (!CanFoldNonConst && !NewCT.isUndef() &&
2308 !isConstantOrConstantVector(NewCT, true) &&
2309 !DAG.isConstantFPBuildVectorOrConstantFP(NewCT))
2310 return SDValue();
2311
2312 SDValue NewCF = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CF)
2313 : DAG.getNode(BinOpcode, DL, VT, CF, CBO);
2314 if (!CanFoldNonConst && !NewCF.isUndef() &&
2315 !isConstantOrConstantVector(NewCF, true) &&
2316 !DAG.isConstantFPBuildVectorOrConstantFP(NewCF))
2317 return SDValue();
2318
2319 SDValue SelectOp = DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
2320 SelectOp->setFlags(BO->getFlags());
2321 return SelectOp;
2322 }
2323
foldAddSubBoolOfMaskedVal(SDNode * N,SelectionDAG & DAG)2324 static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) {
2325 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2326 "Expecting add or sub");
2327
2328 // Match a constant operand and a zext operand for the math instruction:
2329 // add Z, C
2330 // sub C, Z
2331 bool IsAdd = N->getOpcode() == ISD::ADD;
2332 SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2333 SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2334 auto *CN = dyn_cast<ConstantSDNode>(C);
2335 if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2336 return SDValue();
2337
2338 // Match the zext operand as a setcc of a boolean.
2339 if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
2340 Z.getOperand(0).getValueType() != MVT::i1)
2341 return SDValue();
2342
2343 // Match the compare as: setcc (X & 1), 0, eq.
2344 SDValue SetCC = Z.getOperand(0);
2345 ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
2346 if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
2347 SetCC.getOperand(0).getOpcode() != ISD::AND ||
2348 !isOneConstant(SetCC.getOperand(0).getOperand(1)))
2349 return SDValue();
2350
2351 // We are adding/subtracting a constant and an inverted low bit. Turn that
2352 // into a subtract/add of the low bit with incremented/decremented constant:
2353 // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2354 // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2355 EVT VT = C.getValueType();
2356 SDLoc DL(N);
2357 SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
2358 SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
2359 DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2360 return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2361 }
2362
2363 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2364 /// a shift and add with a different constant.
foldAddSubOfSignBit(SDNode * N,SelectionDAG & DAG)2365 static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
2366 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2367 "Expecting add or sub");
2368
2369 // We need a constant operand for the add/sub, and the other operand is a
2370 // logical shift right: add (srl), C or sub C, (srl).
2371 bool IsAdd = N->getOpcode() == ISD::ADD;
2372 SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2373 SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2374 if (!DAG.isConstantIntBuildVectorOrConstantInt(ConstantOp) ||
2375 ShiftOp.getOpcode() != ISD::SRL)
2376 return SDValue();
2377
2378 // The shift must be of a 'not' value.
2379 SDValue Not = ShiftOp.getOperand(0);
2380 if (!Not.hasOneUse() || !isBitwiseNot(Not))
2381 return SDValue();
2382
2383 // The shift must be moving the sign bit to the least-significant-bit.
2384 EVT VT = ShiftOp.getValueType();
2385 SDValue ShAmt = ShiftOp.getOperand(1);
2386 ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2387 if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2388 return SDValue();
2389
2390 // Eliminate the 'not' by adjusting the shift and add/sub constant:
2391 // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2392 // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2393 SDLoc DL(N);
2394 if (SDValue NewC = DAG.FoldConstantArithmetic(
2395 IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2396 {ConstantOp, DAG.getConstant(1, DL, VT)})) {
2397 SDValue NewShift = DAG.getNode(IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2398 Not.getOperand(0), ShAmt);
2399 return DAG.getNode(ISD::ADD, DL, VT, NewShift, NewC);
2400 }
2401
2402 return SDValue();
2403 }
2404
isADDLike(SDValue V,const SelectionDAG & DAG)2405 static bool isADDLike(SDValue V, const SelectionDAG &DAG) {
2406 unsigned Opcode = V.getOpcode();
2407 if (Opcode == ISD::OR)
2408 return DAG.haveNoCommonBitsSet(V.getOperand(0), V.getOperand(1));
2409 if (Opcode == ISD::XOR)
2410 return isMinSignedConstant(V.getOperand(1));
2411 return false;
2412 }
2413
2414 /// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2415 /// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2416 /// are no common bits set in the operands).
visitADDLike(SDNode * N)2417 SDValue DAGCombiner::visitADDLike(SDNode *N) {
2418 SDValue N0 = N->getOperand(0);
2419 SDValue N1 = N->getOperand(1);
2420 EVT VT = N0.getValueType();
2421 SDLoc DL(N);
2422
2423 // fold (add x, undef) -> undef
2424 if (N0.isUndef())
2425 return N0;
2426 if (N1.isUndef())
2427 return N1;
2428
2429 // fold (add c1, c2) -> c1+c2
2430 if (SDValue C = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1}))
2431 return C;
2432
2433 // canonicalize constant to RHS
2434 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2435 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2436 return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2437
2438 // fold vector ops
2439 if (VT.isVector()) {
2440 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2441 return FoldedVOp;
2442
2443 // fold (add x, 0) -> x, vector edition
2444 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2445 return N0;
2446 }
2447
2448 // fold (add x, 0) -> x
2449 if (isNullConstant(N1))
2450 return N0;
2451
2452 if (N0.getOpcode() == ISD::SUB) {
2453 SDValue N00 = N0.getOperand(0);
2454 SDValue N01 = N0.getOperand(1);
2455
2456 // fold ((A-c1)+c2) -> (A+(c2-c1))
2457 if (SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N01}))
2458 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2459
2460 // fold ((c1-A)+c2) -> (c1+c2)-A
2461 if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N00}))
2462 return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2463 }
2464
2465 // add (sext i1 X), 1 -> zext (not i1 X)
2466 // We don't transform this pattern:
2467 // add (zext i1 X), -1 -> sext (not i1 X)
2468 // because most (?) targets generate better code for the zext form.
2469 if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2470 isOneOrOneSplat(N1)) {
2471 SDValue X = N0.getOperand(0);
2472 if ((!LegalOperations ||
2473 (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2474 TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
2475 X.getScalarValueSizeInBits() == 1) {
2476 SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2477 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2478 }
2479 }
2480
2481 // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2482 // iff (or x, c0) is equivalent to (add x, c0).
2483 // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2484 // iff (xor x, c0) is equivalent to (add x, c0).
2485 if (isADDLike(N0, DAG)) {
2486 SDValue N01 = N0.getOperand(1);
2487 if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N01}))
2488 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add);
2489 }
2490
2491 if (SDValue NewSel = foldBinOpIntoSelect(N))
2492 return NewSel;
2493
2494 // reassociate add
2495 if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N, N0, N1)) {
2496 if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2497 return RADD;
2498
2499 // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2500 // equivalent to (add x, c).
2501 // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2502 // equivalent to (add x, c).
2503 auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2504 if (isADDLike(N0, DAG) && N0.hasOneUse() &&
2505 isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2506 return DAG.getNode(ISD::ADD, DL, VT,
2507 DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
2508 N0.getOperand(1));
2509 }
2510 return SDValue();
2511 };
2512 if (SDValue Add = ReassociateAddOr(N0, N1))
2513 return Add;
2514 if (SDValue Add = ReassociateAddOr(N1, N0))
2515 return Add;
2516 }
2517 // fold ((0-A) + B) -> B-A
2518 if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
2519 return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2520
2521 // fold (A + (0-B)) -> A-B
2522 if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
2523 return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
2524
2525 // fold (A+(B-A)) -> B
2526 if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
2527 return N1.getOperand(0);
2528
2529 // fold ((B-A)+A) -> B
2530 if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
2531 return N0.getOperand(0);
2532
2533 // fold ((A-B)+(C-A)) -> (C-B)
2534 if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2535 N0.getOperand(0) == N1.getOperand(1))
2536 return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2537 N0.getOperand(1));
2538
2539 // fold ((A-B)+(B-C)) -> (A-C)
2540 if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2541 N0.getOperand(1) == N1.getOperand(0))
2542 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
2543 N1.getOperand(1));
2544
2545 // fold (A+(B-(A+C))) to (B-C)
2546 if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2547 N0 == N1.getOperand(1).getOperand(0))
2548 return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2549 N1.getOperand(1).getOperand(1));
2550
2551 // fold (A+(B-(C+A))) to (B-C)
2552 if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2553 N0 == N1.getOperand(1).getOperand(1))
2554 return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2555 N1.getOperand(1).getOperand(0));
2556
2557 // fold (A+((B-A)+or-C)) to (B+or-C)
2558 if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
2559 N1.getOperand(0).getOpcode() == ISD::SUB &&
2560 N0 == N1.getOperand(0).getOperand(1))
2561 return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
2562 N1.getOperand(1));
2563
2564 // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2565 if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2566 N0->hasOneUse() && N1->hasOneUse()) {
2567 SDValue N00 = N0.getOperand(0);
2568 SDValue N01 = N0.getOperand(1);
2569 SDValue N10 = N1.getOperand(0);
2570 SDValue N11 = N1.getOperand(1);
2571
2572 if (isConstantOrConstantVector(N00) || isConstantOrConstantVector(N10))
2573 return DAG.getNode(ISD::SUB, DL, VT,
2574 DAG.getNode(ISD::ADD, SDLoc(N0), VT, N00, N10),
2575 DAG.getNode(ISD::ADD, SDLoc(N1), VT, N01, N11));
2576 }
2577
2578 // fold (add (umax X, C), -C) --> (usubsat X, C)
2579 if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
2580 auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2581 return (!Max && !Op) ||
2582 (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2583 };
2584 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
2585 /*AllowUndefs*/ true))
2586 return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
2587 N0.getOperand(1));
2588 }
2589
2590 if (SimplifyDemandedBits(SDValue(N, 0)))
2591 return SDValue(N, 0);
2592
2593 if (isOneOrOneSplat(N1)) {
2594 // fold (add (xor a, -1), 1) -> (sub 0, a)
2595 if (isBitwiseNot(N0))
2596 return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
2597 N0.getOperand(0));
2598
2599 // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2600 if (N0.getOpcode() == ISD::ADD) {
2601 SDValue A, Xor;
2602
2603 if (isBitwiseNot(N0.getOperand(0))) {
2604 A = N0.getOperand(1);
2605 Xor = N0.getOperand(0);
2606 } else if (isBitwiseNot(N0.getOperand(1))) {
2607 A = N0.getOperand(0);
2608 Xor = N0.getOperand(1);
2609 }
2610
2611 if (Xor)
2612 return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
2613 }
2614
2615 // Look for:
2616 // add (add x, y), 1
2617 // And if the target does not like this form then turn into:
2618 // sub y, (xor x, -1)
2619 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2620 N0.hasOneUse()) {
2621 SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2622 DAG.getAllOnesConstant(DL, VT));
2623 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
2624 }
2625 }
2626
2627 // (x - y) + -1 -> add (xor y, -1), x
2628 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
2629 isAllOnesOrAllOnesSplat(N1)) {
2630 SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1), N1);
2631 return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
2632 }
2633
2634 if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
2635 return Combined;
2636
2637 if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
2638 return Combined;
2639
2640 return SDValue();
2641 }
2642
visitADD(SDNode * N)2643 SDValue DAGCombiner::visitADD(SDNode *N) {
2644 SDValue N0 = N->getOperand(0);
2645 SDValue N1 = N->getOperand(1);
2646 EVT VT = N0.getValueType();
2647 SDLoc DL(N);
2648
2649 if (SDValue Combined = visitADDLike(N))
2650 return Combined;
2651
2652 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
2653 return V;
2654
2655 if (SDValue V = foldAddSubOfSignBit(N, DAG))
2656 return V;
2657
2658 // fold (a+b) -> (a|b) iff a and b share no bits.
2659 if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
2660 DAG.haveNoCommonBitsSet(N0, N1))
2661 return DAG.getNode(ISD::OR, DL, VT, N0, N1);
2662
2663 // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
2664 if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
2665 const APInt &C0 = N0->getConstantOperandAPInt(0);
2666 const APInt &C1 = N1->getConstantOperandAPInt(0);
2667 return DAG.getVScale(DL, VT, C0 + C1);
2668 }
2669
2670 // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
2671 if ((N0.getOpcode() == ISD::ADD) &&
2672 (N0.getOperand(1).getOpcode() == ISD::VSCALE) &&
2673 (N1.getOpcode() == ISD::VSCALE)) {
2674 const APInt &VS0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2675 const APInt &VS1 = N1->getConstantOperandAPInt(0);
2676 SDValue VS = DAG.getVScale(DL, VT, VS0 + VS1);
2677 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
2678 }
2679
2680 // Fold (add step_vector(c1), step_vector(c2) to step_vector(c1+c2))
2681 if (N0.getOpcode() == ISD::STEP_VECTOR &&
2682 N1.getOpcode() == ISD::STEP_VECTOR) {
2683 const APInt &C0 = N0->getConstantOperandAPInt(0);
2684 const APInt &C1 = N1->getConstantOperandAPInt(0);
2685 APInt NewStep = C0 + C1;
2686 return DAG.getStepVector(DL, VT, NewStep);
2687 }
2688
2689 // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
2690 if ((N0.getOpcode() == ISD::ADD) &&
2691 (N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR) &&
2692 (N1.getOpcode() == ISD::STEP_VECTOR)) {
2693 const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2694 const APInt &SV1 = N1->getConstantOperandAPInt(0);
2695 APInt NewStep = SV0 + SV1;
2696 SDValue SV = DAG.getStepVector(DL, VT, NewStep);
2697 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
2698 }
2699
2700 return SDValue();
2701 }
2702
visitADDSAT(SDNode * N)2703 SDValue DAGCombiner::visitADDSAT(SDNode *N) {
2704 unsigned Opcode = N->getOpcode();
2705 SDValue N0 = N->getOperand(0);
2706 SDValue N1 = N->getOperand(1);
2707 EVT VT = N0.getValueType();
2708 SDLoc DL(N);
2709
2710 // fold (add_sat x, undef) -> -1
2711 if (N0.isUndef() || N1.isUndef())
2712 return DAG.getAllOnesConstant(DL, VT);
2713
2714 // fold (add_sat c1, c2) -> c3
2715 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
2716 return C;
2717
2718 // canonicalize constant to RHS
2719 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2720 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2721 return DAG.getNode(Opcode, DL, VT, N1, N0);
2722
2723 // fold vector ops
2724 if (VT.isVector()) {
2725 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2726 return FoldedVOp;
2727
2728 // fold (add_sat x, 0) -> x, vector edition
2729 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2730 return N0;
2731 }
2732
2733 // fold (add_sat x, 0) -> x
2734 if (isNullConstant(N1))
2735 return N0;
2736
2737 // If it cannot overflow, transform into an add.
2738 if (Opcode == ISD::UADDSAT)
2739 if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2740 return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
2741
2742 return SDValue();
2743 }
2744
getAsCarry(const TargetLowering & TLI,SDValue V)2745 static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) {
2746 bool Masked = false;
2747
2748 // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
2749 while (true) {
2750 if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
2751 V = V.getOperand(0);
2752 continue;
2753 }
2754
2755 if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
2756 Masked = true;
2757 V = V.getOperand(0);
2758 continue;
2759 }
2760
2761 break;
2762 }
2763
2764 // If this is not a carry, return.
2765 if (V.getResNo() != 1)
2766 return SDValue();
2767
2768 if (V.getOpcode() != ISD::ADDCARRY && V.getOpcode() != ISD::SUBCARRY &&
2769 V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
2770 return SDValue();
2771
2772 EVT VT = V->getValueType(0);
2773 if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
2774 return SDValue();
2775
2776 // If the result is masked, then no matter what kind of bool it is we can
2777 // return. If it isn't, then we need to make sure the bool type is either 0 or
2778 // 1 and not other values.
2779 if (Masked ||
2780 TLI.getBooleanContents(V.getValueType()) ==
2781 TargetLoweringBase::ZeroOrOneBooleanContent)
2782 return V;
2783
2784 return SDValue();
2785 }
2786
2787 /// Given the operands of an add/sub operation, see if the 2nd operand is a
2788 /// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
2789 /// the opcode and bypass the mask operation.
foldAddSubMasked1(bool IsAdd,SDValue N0,SDValue N1,SelectionDAG & DAG,const SDLoc & DL)2790 static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
2791 SelectionDAG &DAG, const SDLoc &DL) {
2792 if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
2793 return SDValue();
2794
2795 EVT VT = N0.getValueType();
2796 if (DAG.ComputeNumSignBits(N1.getOperand(0)) != VT.getScalarSizeInBits())
2797 return SDValue();
2798
2799 // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
2800 // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
2801 return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N1.getOperand(0));
2802 }
2803
2804 /// Helper for doing combines based on N0 and N1 being added to each other.
visitADDLikeCommutative(SDValue N0,SDValue N1,SDNode * LocReference)2805 SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
2806 SDNode *LocReference) {
2807 EVT VT = N0.getValueType();
2808 SDLoc DL(LocReference);
2809
2810 // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
2811 if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB &&
2812 isNullOrNullSplat(N1.getOperand(0).getOperand(0)))
2813 return DAG.getNode(ISD::SUB, DL, VT, N0,
2814 DAG.getNode(ISD::SHL, DL, VT,
2815 N1.getOperand(0).getOperand(1),
2816 N1.getOperand(1)));
2817
2818 if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
2819 return V;
2820
2821 // Look for:
2822 // add (add x, 1), y
2823 // And if the target does not like this form then turn into:
2824 // sub y, (xor x, -1)
2825 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2826 N0.hasOneUse() && isOneOrOneSplat(N0.getOperand(1))) {
2827 SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2828 DAG.getAllOnesConstant(DL, VT));
2829 return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
2830 }
2831
2832 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
2833 // Hoist one-use subtraction by non-opaque constant:
2834 // (x - C) + y -> (x + y) - C
2835 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
2836 if (isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
2837 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
2838 return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2839 }
2840 // Hoist one-use subtraction from non-opaque constant:
2841 // (C - x) + y -> (y - x) + C
2842 if (isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
2843 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2844 return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
2845 }
2846 }
2847
2848 // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
2849 // rather than 'add 0/-1' (the zext should get folded).
2850 // add (sext i1 Y), X --> sub X, (zext i1 Y)
2851 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
2852 N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
2853 TLI.getBooleanContents(VT) == TargetLowering::ZeroOrOneBooleanContent) {
2854 SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
2855 return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
2856 }
2857
2858 // add X, (sextinreg Y i1) -> sub X, (and Y 1)
2859 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
2860 VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
2861 if (TN->getVT() == MVT::i1) {
2862 SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
2863 DAG.getConstant(1, DL, VT));
2864 return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
2865 }
2866 }
2867
2868 // (add X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
2869 if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1)) &&
2870 N1.getResNo() == 0)
2871 return DAG.getNode(ISD::ADDCARRY, DL, N1->getVTList(),
2872 N0, N1.getOperand(0), N1.getOperand(2));
2873
2874 // (add X, Carry) -> (addcarry X, 0, Carry)
2875 if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
2876 if (SDValue Carry = getAsCarry(TLI, N1))
2877 return DAG.getNode(ISD::ADDCARRY, DL,
2878 DAG.getVTList(VT, Carry.getValueType()), N0,
2879 DAG.getConstant(0, DL, VT), Carry);
2880
2881 return SDValue();
2882 }
2883
visitADDC(SDNode * N)2884 SDValue DAGCombiner::visitADDC(SDNode *N) {
2885 SDValue N0 = N->getOperand(0);
2886 SDValue N1 = N->getOperand(1);
2887 EVT VT = N0.getValueType();
2888 SDLoc DL(N);
2889
2890 // If the flag result is dead, turn this into an ADD.
2891 if (!N->hasAnyUseOfValue(1))
2892 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2893 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2894
2895 // canonicalize constant to RHS.
2896 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2897 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2898 if (N0C && !N1C)
2899 return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
2900
2901 // fold (addc x, 0) -> x + no carry out
2902 if (isNullConstant(N1))
2903 return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
2904 DL, MVT::Glue));
2905
2906 // If it cannot overflow, transform into an add.
2907 if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2908 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2909 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2910
2911 return SDValue();
2912 }
2913
2914 /**
2915 * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
2916 * then the flip also occurs if computing the inverse is the same cost.
2917 * This function returns an empty SDValue in case it cannot flip the boolean
2918 * without increasing the cost of the computation. If you want to flip a boolean
2919 * no matter what, use DAG.getLogicalNOT.
2920 */
extractBooleanFlip(SDValue V,SelectionDAG & DAG,const TargetLowering & TLI,bool Force)2921 static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
2922 const TargetLowering &TLI,
2923 bool Force) {
2924 if (Force && isa<ConstantSDNode>(V))
2925 return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
2926
2927 if (V.getOpcode() != ISD::XOR)
2928 return SDValue();
2929
2930 ConstantSDNode *Const = isConstOrConstSplat(V.getOperand(1), false);
2931 if (!Const)
2932 return SDValue();
2933
2934 EVT VT = V.getValueType();
2935
2936 bool IsFlip = false;
2937 switch(TLI.getBooleanContents(VT)) {
2938 case TargetLowering::ZeroOrOneBooleanContent:
2939 IsFlip = Const->isOne();
2940 break;
2941 case TargetLowering::ZeroOrNegativeOneBooleanContent:
2942 IsFlip = Const->isAllOnes();
2943 break;
2944 case TargetLowering::UndefinedBooleanContent:
2945 IsFlip = (Const->getAPIntValue() & 0x01) == 1;
2946 break;
2947 }
2948
2949 if (IsFlip)
2950 return V.getOperand(0);
2951 if (Force)
2952 return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
2953 return SDValue();
2954 }
2955
visitADDO(SDNode * N)2956 SDValue DAGCombiner::visitADDO(SDNode *N) {
2957 SDValue N0 = N->getOperand(0);
2958 SDValue N1 = N->getOperand(1);
2959 EVT VT = N0.getValueType();
2960 bool IsSigned = (ISD::SADDO == N->getOpcode());
2961
2962 EVT CarryVT = N->getValueType(1);
2963 SDLoc DL(N);
2964
2965 // If the flag result is dead, turn this into an ADD.
2966 if (!N->hasAnyUseOfValue(1))
2967 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2968 DAG.getUNDEF(CarryVT));
2969
2970 // canonicalize constant to RHS.
2971 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2972 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2973 return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
2974
2975 // fold (addo x, 0) -> x + no carry out
2976 if (isNullOrNullSplat(N1))
2977 return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
2978
2979 if (!IsSigned) {
2980 // If it cannot overflow, transform into an add.
2981 if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2982 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2983 DAG.getConstant(0, DL, CarryVT));
2984
2985 // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
2986 if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
2987 SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
2988 DAG.getConstant(0, DL, VT), N0.getOperand(0));
2989 return CombineTo(
2990 N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
2991 }
2992
2993 if (SDValue Combined = visitUADDOLike(N0, N1, N))
2994 return Combined;
2995
2996 if (SDValue Combined = visitUADDOLike(N1, N0, N))
2997 return Combined;
2998 }
2999
3000 return SDValue();
3001 }
3002
visitUADDOLike(SDValue N0,SDValue N1,SDNode * N)3003 SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3004 EVT VT = N0.getValueType();
3005 if (VT.isVector())
3006 return SDValue();
3007
3008 // (uaddo X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
3009 // If Y + 1 cannot overflow.
3010 if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1))) {
3011 SDValue Y = N1.getOperand(0);
3012 SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
3013 if (DAG.computeOverflowKind(Y, One) == SelectionDAG::OFK_Never)
3014 return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, Y,
3015 N1.getOperand(2));
3016 }
3017
3018 // (uaddo X, Carry) -> (addcarry X, 0, Carry)
3019 if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
3020 if (SDValue Carry = getAsCarry(TLI, N1))
3021 return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0,
3022 DAG.getConstant(0, SDLoc(N), VT), Carry);
3023
3024 return SDValue();
3025 }
3026
visitADDE(SDNode * N)3027 SDValue DAGCombiner::visitADDE(SDNode *N) {
3028 SDValue N0 = N->getOperand(0);
3029 SDValue N1 = N->getOperand(1);
3030 SDValue CarryIn = N->getOperand(2);
3031
3032 // canonicalize constant to RHS
3033 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3034 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3035 if (N0C && !N1C)
3036 return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
3037 N1, N0, CarryIn);
3038
3039 // fold (adde x, y, false) -> (addc x, y)
3040 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3041 return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
3042
3043 return SDValue();
3044 }
3045
visitADDCARRY(SDNode * N)3046 SDValue DAGCombiner::visitADDCARRY(SDNode *N) {
3047 SDValue N0 = N->getOperand(0);
3048 SDValue N1 = N->getOperand(1);
3049 SDValue CarryIn = N->getOperand(2);
3050 SDLoc DL(N);
3051
3052 // canonicalize constant to RHS
3053 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3054 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3055 if (N0C && !N1C)
3056 return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn);
3057
3058 // fold (addcarry x, y, false) -> (uaddo x, y)
3059 if (isNullConstant(CarryIn)) {
3060 if (!LegalOperations ||
3061 TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
3062 return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
3063 }
3064
3065 // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3066 if (isNullConstant(N0) && isNullConstant(N1)) {
3067 EVT VT = N0.getValueType();
3068 EVT CarryVT = CarryIn.getValueType();
3069 SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
3070 AddToWorklist(CarryExt.getNode());
3071 return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
3072 DAG.getConstant(1, DL, VT)),
3073 DAG.getConstant(0, DL, CarryVT));
3074 }
3075
3076 if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N))
3077 return Combined;
3078
3079 if (SDValue Combined = visitADDCARRYLike(N1, N0, CarryIn, N))
3080 return Combined;
3081
3082 return SDValue();
3083 }
3084
visitSADDO_CARRY(SDNode * N)3085 SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3086 SDValue N0 = N->getOperand(0);
3087 SDValue N1 = N->getOperand(1);
3088 SDValue CarryIn = N->getOperand(2);
3089 SDLoc DL(N);
3090
3091 // canonicalize constant to RHS
3092 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3093 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3094 if (N0C && !N1C)
3095 return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3096
3097 // fold (saddo_carry x, y, false) -> (saddo x, y)
3098 if (isNullConstant(CarryIn)) {
3099 if (!LegalOperations ||
3100 TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0)))
3101 return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1);
3102 }
3103
3104 return SDValue();
3105 }
3106
3107 /**
3108 * If we are facing some sort of diamond carry propapagtion pattern try to
3109 * break it up to generate something like:
3110 * (addcarry X, 0, (addcarry A, B, Z):Carry)
3111 *
3112 * The end result is usually an increase in operation required, but because the
3113 * carry is now linearized, other tranforms can kick in and optimize the DAG.
3114 *
3115 * Patterns typically look something like
3116 * (uaddo A, B)
3117 * / \
3118 * Carry Sum
3119 * | \
3120 * | (addcarry *, 0, Z)
3121 * | /
3122 * \ Carry
3123 * | /
3124 * (addcarry X, *, *)
3125 *
3126 * But numerous variation exist. Our goal is to identify A, B, X and Z and
3127 * produce a combine with a single path for carry propagation.
3128 */
combineADDCARRYDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,SDValue X,SDValue Carry0,SDValue Carry1,SDNode * N)3129 static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG,
3130 SDValue X, SDValue Carry0, SDValue Carry1,
3131 SDNode *N) {
3132 if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3133 return SDValue();
3134 if (Carry1.getOpcode() != ISD::UADDO)
3135 return SDValue();
3136
3137 SDValue Z;
3138
3139 /**
3140 * First look for a suitable Z. It will present itself in the form of
3141 * (addcarry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3142 */
3143 if (Carry0.getOpcode() == ISD::ADDCARRY &&
3144 isNullConstant(Carry0.getOperand(1))) {
3145 Z = Carry0.getOperand(2);
3146 } else if (Carry0.getOpcode() == ISD::UADDO &&
3147 isOneConstant(Carry0.getOperand(1))) {
3148 EVT VT = Combiner.getSetCCResultType(Carry0.getValueType());
3149 Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
3150 } else {
3151 // We couldn't find a suitable Z.
3152 return SDValue();
3153 }
3154
3155
3156 auto cancelDiamond = [&](SDValue A,SDValue B) {
3157 SDLoc DL(N);
3158 SDValue NewY = DAG.getNode(ISD::ADDCARRY, DL, Carry0->getVTList(), A, B, Z);
3159 Combiner.AddToWorklist(NewY.getNode());
3160 return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), X,
3161 DAG.getConstant(0, DL, X.getValueType()),
3162 NewY.getValue(1));
3163 };
3164
3165 /**
3166 * (uaddo A, B)
3167 * |
3168 * Sum
3169 * |
3170 * (addcarry *, 0, Z)
3171 */
3172 if (Carry0.getOperand(0) == Carry1.getValue(0)) {
3173 return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
3174 }
3175
3176 /**
3177 * (addcarry A, 0, Z)
3178 * |
3179 * Sum
3180 * |
3181 * (uaddo *, B)
3182 */
3183 if (Carry1.getOperand(0) == Carry0.getValue(0)) {
3184 return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
3185 }
3186
3187 if (Carry1.getOperand(1) == Carry0.getValue(0)) {
3188 return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
3189 }
3190
3191 return SDValue();
3192 }
3193
3194 // If we are facing some sort of diamond carry/borrow in/out pattern try to
3195 // match patterns like:
3196 //
3197 // (uaddo A, B) CarryIn
3198 // | \ |
3199 // | \ |
3200 // PartialSum PartialCarryOutX /
3201 // | | /
3202 // | ____|____________/
3203 // | / |
3204 // (uaddo *, *) \________
3205 // | \ \
3206 // | \ |
3207 // | PartialCarryOutY |
3208 // | \ |
3209 // | \ /
3210 // AddCarrySum | ______/
3211 // | /
3212 // CarryOut = (or *, *)
3213 //
3214 // And generate ADDCARRY (or SUBCARRY) with two result values:
3215 //
3216 // {AddCarrySum, CarryOut} = (addcarry A, B, CarryIn)
3217 //
3218 // Our goal is to identify A, B, and CarryIn and produce ADDCARRY/SUBCARRY with
3219 // a single path for carry/borrow out propagation:
combineCarryDiamond(SelectionDAG & DAG,const TargetLowering & TLI,SDValue N0,SDValue N1,SDNode * N)3220 static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3221 SDValue N0, SDValue N1, SDNode *N) {
3222 SDValue Carry0 = getAsCarry(TLI, N0);
3223 if (!Carry0)
3224 return SDValue();
3225 SDValue Carry1 = getAsCarry(TLI, N1);
3226 if (!Carry1)
3227 return SDValue();
3228
3229 unsigned Opcode = Carry0.getOpcode();
3230 if (Opcode != Carry1.getOpcode())
3231 return SDValue();
3232 if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3233 return SDValue();
3234
3235 // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3236 // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3237 if (Carry1.getNode()->isOperandOf(Carry0.getNode()))
3238 std::swap(Carry0, Carry1);
3239
3240 // Check if nodes are connected in expected way.
3241 if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3242 Carry1.getOperand(1) != Carry0.getValue(0))
3243 return SDValue();
3244
3245 // The carry in value must be on the righthand side for subtraction.
3246 unsigned CarryInOperandNum =
3247 Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
3248 if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3249 return SDValue();
3250 SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
3251
3252 unsigned NewOp = Opcode == ISD::UADDO ? ISD::ADDCARRY : ISD::SUBCARRY;
3253 if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
3254 return SDValue();
3255
3256 // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3257 // TODO: make getAsCarry() aware of how partial carries are merged.
3258 if (CarryIn.getOpcode() != ISD::ZERO_EXTEND)
3259 return SDValue();
3260 CarryIn = CarryIn.getOperand(0);
3261 if (CarryIn.getValueType() != MVT::i1)
3262 return SDValue();
3263
3264 SDLoc DL(N);
3265 SDValue Merged =
3266 DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
3267 Carry0.getOperand(1), CarryIn);
3268
3269 // Please note that because we have proven that the result of the UADDO/USUBO
3270 // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3271 // therefore prove that if the first UADDO/USUBO overflows, the second
3272 // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3273 // maximum value.
3274 //
3275 // 0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3276 // 0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3277 //
3278 // This is important because it means that OR and XOR can be used to merge
3279 // carry flags; and that AND can return a constant zero.
3280 //
3281 // TODO: match other operations that can merge flags (ADD, etc)
3282 DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
3283 if (N->getOpcode() == ISD::AND)
3284 return DAG.getConstant(0, DL, MVT::i1);
3285 return Merged.getValue(1);
3286 }
3287
visitADDCARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3288 SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
3289 SDNode *N) {
3290 // fold (addcarry (xor a, -1), b, c) -> (subcarry b, a, !c) and flip carry.
3291 if (isBitwiseNot(N0))
3292 if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
3293 SDLoc DL(N);
3294 SDValue Sub = DAG.getNode(ISD::SUBCARRY, DL, N->getVTList(), N1,
3295 N0.getOperand(0), NotC);
3296 return CombineTo(
3297 N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3298 }
3299
3300 // Iff the flag result is dead:
3301 // (addcarry (add|uaddo X, Y), 0, Carry) -> (addcarry X, Y, Carry)
3302 // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3303 // or the dependency between the instructions.
3304 if ((N0.getOpcode() == ISD::ADD ||
3305 (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3306 N0.getValue(1) != CarryIn)) &&
3307 isNullConstant(N1) && !N->hasAnyUseOfValue(1))
3308 return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(),
3309 N0.getOperand(0), N0.getOperand(1), CarryIn);
3310
3311 /**
3312 * When one of the addcarry argument is itself a carry, we may be facing
3313 * a diamond carry propagation. In which case we try to transform the DAG
3314 * to ensure linear carry propagation if that is possible.
3315 */
3316 if (auto Y = getAsCarry(TLI, N1)) {
3317 // Because both are carries, Y and Z can be swapped.
3318 if (auto R = combineADDCARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
3319 return R;
3320 if (auto R = combineADDCARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
3321 return R;
3322 }
3323
3324 return SDValue();
3325 }
3326
3327 // Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3328 // clamp/truncation if necessary.
getTruncatedUSUBSAT(EVT DstVT,EVT SrcVT,SDValue LHS,SDValue RHS,SelectionDAG & DAG,const SDLoc & DL)3329 static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3330 SDValue RHS, SelectionDAG &DAG,
3331 const SDLoc &DL) {
3332 assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3333 "Illegal truncation");
3334
3335 if (DstVT == SrcVT)
3336 return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3337
3338 // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3339 // clamping RHS.
3340 APInt UpperBits = APInt::getBitsSetFrom(SrcVT.getScalarSizeInBits(),
3341 DstVT.getScalarSizeInBits());
3342 if (!DAG.MaskedValueIsZero(LHS, UpperBits))
3343 return SDValue();
3344
3345 SDValue SatLimit =
3346 DAG.getConstant(APInt::getLowBitsSet(SrcVT.getScalarSizeInBits(),
3347 DstVT.getScalarSizeInBits()),
3348 DL, SrcVT);
3349 RHS = DAG.getNode(ISD::UMIN, DL, SrcVT, RHS, SatLimit);
3350 RHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, RHS);
3351 LHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, LHS);
3352 return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3353 }
3354
3355 // Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3356 // usubsat(a,b), optionally as a truncated type.
foldSubToUSubSat(EVT DstVT,SDNode * N)3357 SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N) {
3358 if (N->getOpcode() != ISD::SUB ||
3359 !(!LegalOperations || hasOperation(ISD::USUBSAT, DstVT)))
3360 return SDValue();
3361
3362 EVT SubVT = N->getValueType(0);
3363 SDValue Op0 = N->getOperand(0);
3364 SDValue Op1 = N->getOperand(1);
3365
3366 // Try to find umax(a,b) - b or a - umin(a,b) patterns
3367 // they may be converted to usubsat(a,b).
3368 if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3369 SDValue MaxLHS = Op0.getOperand(0);
3370 SDValue MaxRHS = Op0.getOperand(1);
3371 if (MaxLHS == Op1)
3372 return getTruncatedUSUBSAT(DstVT, SubVT, MaxRHS, Op1, DAG, SDLoc(N));
3373 if (MaxRHS == Op1)
3374 return getTruncatedUSUBSAT(DstVT, SubVT, MaxLHS, Op1, DAG, SDLoc(N));
3375 }
3376
3377 if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3378 SDValue MinLHS = Op1.getOperand(0);
3379 SDValue MinRHS = Op1.getOperand(1);
3380 if (MinLHS == Op0)
3381 return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinRHS, DAG, SDLoc(N));
3382 if (MinRHS == Op0)
3383 return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinLHS, DAG, SDLoc(N));
3384 }
3385
3386 // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3387 if (Op1.getOpcode() == ISD::TRUNCATE &&
3388 Op1.getOperand(0).getOpcode() == ISD::UMIN &&
3389 Op1.getOperand(0).hasOneUse()) {
3390 SDValue MinLHS = Op1.getOperand(0).getOperand(0);
3391 SDValue MinRHS = Op1.getOperand(0).getOperand(1);
3392 if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(0) == Op0)
3393 return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinLHS, MinRHS,
3394 DAG, SDLoc(N));
3395 if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(0) == Op0)
3396 return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinRHS, MinLHS,
3397 DAG, SDLoc(N));
3398 }
3399
3400 return SDValue();
3401 }
3402
3403 // Since it may not be valid to emit a fold to zero for vector initializers
3404 // check if we can before folding.
tryFoldToZero(const SDLoc & DL,const TargetLowering & TLI,EVT VT,SelectionDAG & DAG,bool LegalOperations)3405 static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
3406 SelectionDAG &DAG, bool LegalOperations) {
3407 if (!VT.isVector())
3408 return DAG.getConstant(0, DL, VT);
3409 if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
3410 return DAG.getConstant(0, DL, VT);
3411 return SDValue();
3412 }
3413
visitSUB(SDNode * N)3414 SDValue DAGCombiner::visitSUB(SDNode *N) {
3415 SDValue N0 = N->getOperand(0);
3416 SDValue N1 = N->getOperand(1);
3417 EVT VT = N0.getValueType();
3418 SDLoc DL(N);
3419
3420 auto PeekThroughFreeze = [](SDValue N) {
3421 if (N->getOpcode() == ISD::FREEZE && N.hasOneUse())
3422 return N->getOperand(0);
3423 return N;
3424 };
3425
3426 // fold (sub x, x) -> 0
3427 // FIXME: Refactor this and xor and other similar operations together.
3428 if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
3429 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
3430
3431 // fold (sub c1, c2) -> c3
3432 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
3433 return C;
3434
3435 // fold vector ops
3436 if (VT.isVector()) {
3437 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3438 return FoldedVOp;
3439
3440 // fold (sub x, 0) -> x, vector edition
3441 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3442 return N0;
3443 }
3444
3445 if (SDValue NewSel = foldBinOpIntoSelect(N))
3446 return NewSel;
3447
3448 ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3449
3450 // fold (sub x, c) -> (add x, -c)
3451 if (N1C) {
3452 return DAG.getNode(ISD::ADD, DL, VT, N0,
3453 DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3454 }
3455
3456 if (isNullOrNullSplat(N0)) {
3457 unsigned BitWidth = VT.getScalarSizeInBits();
3458 // Right-shifting everything out but the sign bit followed by negation is
3459 // the same as flipping arithmetic/logical shift type without the negation:
3460 // -(X >>u 31) -> (X >>s 31)
3461 // -(X >>s 31) -> (X >>u 31)
3462 if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
3463 ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
3464 if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
3465 auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
3466 if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
3467 return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
3468 }
3469 }
3470
3471 // 0 - X --> 0 if the sub is NUW.
3472 if (N->getFlags().hasNoUnsignedWrap())
3473 return N0;
3474
3475 if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
3476 // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3477 // N1 must be 0 because negating the minimum signed value is undefined.
3478 if (N->getFlags().hasNoSignedWrap())
3479 return N0;
3480
3481 // 0 - X --> X if X is 0 or the minimum signed value.
3482 return N1;
3483 }
3484
3485 // Convert 0 - abs(x).
3486 if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
3487 !TLI.isOperationLegalOrCustom(ISD::ABS, VT))
3488 if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true))
3489 return Result;
3490
3491 // Fold neg(splat(neg(x)) -> splat(x)
3492 if (VT.isVector()) {
3493 SDValue N1S = DAG.getSplatValue(N1, true);
3494 if (N1S && N1S.getOpcode() == ISD::SUB &&
3495 isNullConstant(N1S.getOperand(0))) {
3496 if (VT.isScalableVector())
3497 return DAG.getSplatVector(VT, DL, N1S.getOperand(1));
3498 return DAG.getSplatBuildVector(VT, DL, N1S.getOperand(1));
3499 }
3500 }
3501 }
3502
3503 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3504 if (isAllOnesOrAllOnesSplat(N0))
3505 return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3506
3507 // fold (A - (0-B)) -> A+B
3508 if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
3509 return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
3510
3511 // fold A-(A-B) -> B
3512 if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
3513 return N1.getOperand(1);
3514
3515 // fold (A+B)-A -> B
3516 if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
3517 return N0.getOperand(1);
3518
3519 // fold (A+B)-B -> A
3520 if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
3521 return N0.getOperand(0);
3522
3523 // fold (A+C1)-C2 -> A+(C1-C2)
3524 if (N0.getOpcode() == ISD::ADD) {
3525 SDValue N01 = N0.getOperand(1);
3526 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N01, N1}))
3527 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
3528 }
3529
3530 // fold C2-(A+C1) -> (C2-C1)-A
3531 if (N1.getOpcode() == ISD::ADD) {
3532 SDValue N11 = N1.getOperand(1);
3533 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11}))
3534 return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
3535 }
3536
3537 // fold (A-C1)-C2 -> A-(C1+C2)
3538 if (N0.getOpcode() == ISD::SUB) {
3539 SDValue N01 = N0.getOperand(1);
3540 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N01, N1}))
3541 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
3542 }
3543
3544 // fold (c1-A)-c2 -> (c1-c2)-A
3545 if (N0.getOpcode() == ISD::SUB) {
3546 SDValue N00 = N0.getOperand(0);
3547 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N00, N1}))
3548 return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
3549 }
3550
3551 // fold ((A+(B+or-C))-B) -> A+or-C
3552 if (N0.getOpcode() == ISD::ADD &&
3553 (N0.getOperand(1).getOpcode() == ISD::SUB ||
3554 N0.getOperand(1).getOpcode() == ISD::ADD) &&
3555 N0.getOperand(1).getOperand(0) == N1)
3556 return DAG.getNode(N0.getOperand(1).getOpcode(), DL, VT, N0.getOperand(0),
3557 N0.getOperand(1).getOperand(1));
3558
3559 // fold ((A+(C+B))-B) -> A+C
3560 if (N0.getOpcode() == ISD::ADD && N0.getOperand(1).getOpcode() == ISD::ADD &&
3561 N0.getOperand(1).getOperand(1) == N1)
3562 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0),
3563 N0.getOperand(1).getOperand(0));
3564
3565 // fold ((A-(B-C))-C) -> A-B
3566 if (N0.getOpcode() == ISD::SUB && N0.getOperand(1).getOpcode() == ISD::SUB &&
3567 N0.getOperand(1).getOperand(1) == N1)
3568 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
3569 N0.getOperand(1).getOperand(0));
3570
3571 // fold (A-(B-C)) -> A+(C-B)
3572 if (N1.getOpcode() == ISD::SUB && N1.hasOneUse())
3573 return DAG.getNode(ISD::ADD, DL, VT, N0,
3574 DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1),
3575 N1.getOperand(0)));
3576
3577 // A - (A & B) -> A & (~B)
3578 if (N1.getOpcode() == ISD::AND) {
3579 SDValue A = N1.getOperand(0);
3580 SDValue B = N1.getOperand(1);
3581 if (A != N0)
3582 std::swap(A, B);
3583 if (A == N0 &&
3584 (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true))) {
3585 SDValue InvB =
3586 DAG.getNode(ISD::XOR, DL, VT, B, DAG.getAllOnesConstant(DL, VT));
3587 return DAG.getNode(ISD::AND, DL, VT, A, InvB);
3588 }
3589 }
3590
3591 // fold (X - (-Y * Z)) -> (X + (Y * Z))
3592 if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {
3593 if (N1.getOperand(0).getOpcode() == ISD::SUB &&
3594 isNullOrNullSplat(N1.getOperand(0).getOperand(0))) {
3595 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3596 N1.getOperand(0).getOperand(1),
3597 N1.getOperand(1));
3598 return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3599 }
3600 if (N1.getOperand(1).getOpcode() == ISD::SUB &&
3601 isNullOrNullSplat(N1.getOperand(1).getOperand(0))) {
3602 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3603 N1.getOperand(0),
3604 N1.getOperand(1).getOperand(1));
3605 return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3606 }
3607 }
3608
3609 // If either operand of a sub is undef, the result is undef
3610 if (N0.isUndef())
3611 return N0;
3612 if (N1.isUndef())
3613 return N1;
3614
3615 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
3616 return V;
3617
3618 if (SDValue V = foldAddSubOfSignBit(N, DAG))
3619 return V;
3620
3621 if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, SDLoc(N)))
3622 return V;
3623
3624 if (SDValue V = foldSubToUSubSat(VT, N))
3625 return V;
3626
3627 // (x - y) - 1 -> add (xor y, -1), x
3628 if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB && isOneOrOneSplat(N1)) {
3629 SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1),
3630 DAG.getAllOnesConstant(DL, VT));
3631 return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
3632 }
3633
3634 // Look for:
3635 // sub y, (xor x, -1)
3636 // And if the target does not like this form then turn into:
3637 // add (add x, y), 1
3638 if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
3639 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
3640 return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
3641 }
3642
3643 // Hoist one-use addition by non-opaque constant:
3644 // (x + C) - y -> (x - y) + C
3645 if (N0.hasOneUse() && N0.getOpcode() == ISD::ADD &&
3646 isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3647 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3648 return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
3649 }
3650 // y - (x + C) -> (y - x) - C
3651 if (N1.hasOneUse() && N1.getOpcode() == ISD::ADD &&
3652 isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
3653 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
3654 return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
3655 }
3656 // (x - C) - y -> (x - y) - C
3657 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3658 if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
3659 isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3660 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3661 return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
3662 }
3663 // (C - x) - y -> C - (x + y)
3664 if (N0.hasOneUse() && N0.getOpcode() == ISD::SUB &&
3665 isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3666 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
3667 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
3668 }
3669
3670 // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
3671 // rather than 'sub 0/1' (the sext should get folded).
3672 // sub X, (zext i1 Y) --> add X, (sext i1 Y)
3673 if (N1.getOpcode() == ISD::ZERO_EXTEND &&
3674 N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
3675 TLI.getBooleanContents(VT) ==
3676 TargetLowering::ZeroOrNegativeOneBooleanContent) {
3677 SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
3678 return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
3679 }
3680
3681 // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X)
3682 if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
3683 if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) {
3684 SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1);
3685 SDValue S0 = N1.getOperand(0);
3686 if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0))
3687 if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1)))
3688 if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
3689 return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0);
3690 }
3691 }
3692
3693 // If the relocation model supports it, consider symbol offsets.
3694 if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
3695 if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
3696 // fold (sub Sym, c) -> Sym-c
3697 if (N1C && GA->getOpcode() == ISD::GlobalAddress)
3698 return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT,
3699 GA->getOffset() -
3700 (uint64_t)N1C->getSExtValue());
3701 // fold (sub Sym+c1, Sym+c2) -> c1-c2
3702 if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
3703 if (GA->getGlobal() == GB->getGlobal())
3704 return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
3705 DL, VT);
3706 }
3707
3708 // sub X, (sextinreg Y i1) -> add X, (and Y 1)
3709 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3710 VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3711 if (TN->getVT() == MVT::i1) {
3712 SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3713 DAG.getConstant(1, DL, VT));
3714 return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
3715 }
3716 }
3717
3718 // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
3719 if (N1.getOpcode() == ISD::VSCALE) {
3720 const APInt &IntVal = N1.getConstantOperandAPInt(0);
3721 return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal));
3722 }
3723
3724 // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
3725 if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
3726 APInt NewStep = -N1.getConstantOperandAPInt(0);
3727 return DAG.getNode(ISD::ADD, DL, VT, N0,
3728 DAG.getStepVector(DL, VT, NewStep));
3729 }
3730
3731 // Prefer an add for more folding potential and possibly better codegen:
3732 // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
3733 if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
3734 SDValue ShAmt = N1.getOperand(1);
3735 ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
3736 if (ShAmtC &&
3737 ShAmtC->getAPIntValue() == (N1.getScalarValueSizeInBits() - 1)) {
3738 SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
3739 return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
3740 }
3741 }
3742
3743 // As with the previous fold, prefer add for more folding potential.
3744 // Subtracting SMIN/0 is the same as adding SMIN/0:
3745 // N0 - (X << BW-1) --> N0 + (X << BW-1)
3746 if (N1.getOpcode() == ISD::SHL) {
3747 ConstantSDNode *ShlC = isConstOrConstSplat(N1.getOperand(1));
3748 if (ShlC && ShlC->getAPIntValue() == VT.getScalarSizeInBits() - 1)
3749 return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
3750 }
3751
3752 if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) {
3753 // (sub Carry, X) -> (addcarry (sub 0, X), 0, Carry)
3754 if (SDValue Carry = getAsCarry(TLI, N0)) {
3755 SDValue X = N1;
3756 SDValue Zero = DAG.getConstant(0, DL, VT);
3757 SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
3758 return DAG.getNode(ISD::ADDCARRY, DL,
3759 DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
3760 Carry);
3761 }
3762 }
3763
3764 // If there's no chance of borrowing from adjacent bits, then sub is xor:
3765 // sub C0, X --> xor X, C0
3766 if (ConstantSDNode *C0 = isConstOrConstSplat(N0)) {
3767 if (!C0->isOpaque()) {
3768 const APInt &C0Val = C0->getAPIntValue();
3769 const APInt &MaybeOnes = ~DAG.computeKnownBits(N1).Zero;
3770 if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
3771 return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3772 }
3773 }
3774
3775 return SDValue();
3776 }
3777
visitSUBSAT(SDNode * N)3778 SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
3779 SDValue N0 = N->getOperand(0);
3780 SDValue N1 = N->getOperand(1);
3781 EVT VT = N0.getValueType();
3782 SDLoc DL(N);
3783
3784 // fold (sub_sat x, undef) -> 0
3785 if (N0.isUndef() || N1.isUndef())
3786 return DAG.getConstant(0, DL, VT);
3787
3788 // fold (sub_sat x, x) -> 0
3789 if (N0 == N1)
3790 return DAG.getConstant(0, DL, VT);
3791
3792 // fold (sub_sat c1, c2) -> c3
3793 if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1}))
3794 return C;
3795
3796 // fold vector ops
3797 if (VT.isVector()) {
3798 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3799 return FoldedVOp;
3800
3801 // fold (sub_sat x, 0) -> x, vector edition
3802 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3803 return N0;
3804 }
3805
3806 // fold (sub_sat x, 0) -> x
3807 if (isNullConstant(N1))
3808 return N0;
3809
3810 return SDValue();
3811 }
3812
visitSUBC(SDNode * N)3813 SDValue DAGCombiner::visitSUBC(SDNode *N) {
3814 SDValue N0 = N->getOperand(0);
3815 SDValue N1 = N->getOperand(1);
3816 EVT VT = N0.getValueType();
3817 SDLoc DL(N);
3818
3819 // If the flag result is dead, turn this into an SUB.
3820 if (!N->hasAnyUseOfValue(1))
3821 return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3822 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3823
3824 // fold (subc x, x) -> 0 + no borrow
3825 if (N0 == N1)
3826 return CombineTo(N, DAG.getConstant(0, DL, VT),
3827 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3828
3829 // fold (subc x, 0) -> x + no borrow
3830 if (isNullConstant(N1))
3831 return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3832
3833 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3834 if (isAllOnesConstant(N0))
3835 return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3836 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3837
3838 return SDValue();
3839 }
3840
visitSUBO(SDNode * N)3841 SDValue DAGCombiner::visitSUBO(SDNode *N) {
3842 SDValue N0 = N->getOperand(0);
3843 SDValue N1 = N->getOperand(1);
3844 EVT VT = N0.getValueType();
3845 bool IsSigned = (ISD::SSUBO == N->getOpcode());
3846
3847 EVT CarryVT = N->getValueType(1);
3848 SDLoc DL(N);
3849
3850 // If the flag result is dead, turn this into an SUB.
3851 if (!N->hasAnyUseOfValue(1))
3852 return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3853 DAG.getUNDEF(CarryVT));
3854
3855 // fold (subo x, x) -> 0 + no borrow
3856 if (N0 == N1)
3857 return CombineTo(N, DAG.getConstant(0, DL, VT),
3858 DAG.getConstant(0, DL, CarryVT));
3859
3860 ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3861
3862 // fold (subox, c) -> (addo x, -c)
3863 if (IsSigned && N1C && !N1C->getAPIntValue().isMinSignedValue()) {
3864 return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
3865 DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3866 }
3867
3868 // fold (subo x, 0) -> x + no borrow
3869 if (isNullOrNullSplat(N1))
3870 return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3871
3872 // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3873 if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
3874 return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3875 DAG.getConstant(0, DL, CarryVT));
3876
3877 return SDValue();
3878 }
3879
visitSUBE(SDNode * N)3880 SDValue DAGCombiner::visitSUBE(SDNode *N) {
3881 SDValue N0 = N->getOperand(0);
3882 SDValue N1 = N->getOperand(1);
3883 SDValue CarryIn = N->getOperand(2);
3884
3885 // fold (sube x, y, false) -> (subc x, y)
3886 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3887 return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
3888
3889 return SDValue();
3890 }
3891
visitSUBCARRY(SDNode * N)3892 SDValue DAGCombiner::visitSUBCARRY(SDNode *N) {
3893 SDValue N0 = N->getOperand(0);
3894 SDValue N1 = N->getOperand(1);
3895 SDValue CarryIn = N->getOperand(2);
3896
3897 // fold (subcarry x, y, false) -> (usubo x, y)
3898 if (isNullConstant(CarryIn)) {
3899 if (!LegalOperations ||
3900 TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
3901 return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
3902 }
3903
3904 return SDValue();
3905 }
3906
visitSSUBO_CARRY(SDNode * N)3907 SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
3908 SDValue N0 = N->getOperand(0);
3909 SDValue N1 = N->getOperand(1);
3910 SDValue CarryIn = N->getOperand(2);
3911
3912 // fold (ssubo_carry x, y, false) -> (ssubo x, y)
3913 if (isNullConstant(CarryIn)) {
3914 if (!LegalOperations ||
3915 TLI.isOperationLegalOrCustom(ISD::SSUBO, N->getValueType(0)))
3916 return DAG.getNode(ISD::SSUBO, SDLoc(N), N->getVTList(), N0, N1);
3917 }
3918
3919 return SDValue();
3920 }
3921
3922 // Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
3923 // UMULFIXSAT here.
visitMULFIX(SDNode * N)3924 SDValue DAGCombiner::visitMULFIX(SDNode *N) {
3925 SDValue N0 = N->getOperand(0);
3926 SDValue N1 = N->getOperand(1);
3927 SDValue Scale = N->getOperand(2);
3928 EVT VT = N0.getValueType();
3929
3930 // fold (mulfix x, undef, scale) -> 0
3931 if (N0.isUndef() || N1.isUndef())
3932 return DAG.getConstant(0, SDLoc(N), VT);
3933
3934 // Canonicalize constant to RHS (vector doesn't have to splat)
3935 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3936 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3937 return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
3938
3939 // fold (mulfix x, 0, scale) -> 0
3940 if (isNullConstant(N1))
3941 return DAG.getConstant(0, SDLoc(N), VT);
3942
3943 return SDValue();
3944 }
3945
visitMUL(SDNode * N)3946 SDValue DAGCombiner::visitMUL(SDNode *N) {
3947 SDValue N0 = N->getOperand(0);
3948 SDValue N1 = N->getOperand(1);
3949 EVT VT = N0.getValueType();
3950 SDLoc DL(N);
3951
3952 // fold (mul x, undef) -> 0
3953 if (N0.isUndef() || N1.isUndef())
3954 return DAG.getConstant(0, DL, VT);
3955
3956 // fold (mul c1, c2) -> c1*c2
3957 if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}))
3958 return C;
3959
3960 // canonicalize constant to RHS (vector doesn't have to splat)
3961 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3962 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3963 return DAG.getNode(ISD::MUL, DL, VT, N1, N0);
3964
3965 bool N1IsConst = false;
3966 bool N1IsOpaqueConst = false;
3967 APInt ConstValue1;
3968
3969 // fold vector ops
3970 if (VT.isVector()) {
3971 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3972 return FoldedVOp;
3973
3974 N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
3975 assert((!N1IsConst ||
3976 ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
3977 "Splat APInt should be element width");
3978 } else {
3979 N1IsConst = isa<ConstantSDNode>(N1);
3980 if (N1IsConst) {
3981 ConstValue1 = cast<ConstantSDNode>(N1)->getAPIntValue();
3982 N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
3983 }
3984 }
3985
3986 // fold (mul x, 0) -> 0
3987 if (N1IsConst && ConstValue1.isZero())
3988 return N1;
3989
3990 // fold (mul x, 1) -> x
3991 if (N1IsConst && ConstValue1.isOne())
3992 return N0;
3993
3994 if (SDValue NewSel = foldBinOpIntoSelect(N))
3995 return NewSel;
3996
3997 // fold (mul x, -1) -> 0-x
3998 if (N1IsConst && ConstValue1.isAllOnes())
3999 return DAG.getNode(ISD::SUB, DL, VT,
4000 DAG.getConstant(0, DL, VT), N0);
4001
4002 // fold (mul x, (1 << c)) -> x << c
4003 if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4004 DAG.isKnownToBeAPowerOfTwo(N1) &&
4005 (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4006 SDValue LogBase2 = BuildLogBase2(N1, DL);
4007 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4008 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4009 return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
4010 }
4011
4012 // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4013 if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4014 unsigned Log2Val = (-ConstValue1).logBase2();
4015 // FIXME: If the input is something that is easily negated (e.g. a
4016 // single-use add), we should put the negate there.
4017 return DAG.getNode(ISD::SUB, DL, VT,
4018 DAG.getConstant(0, DL, VT),
4019 DAG.getNode(ISD::SHL, DL, VT, N0,
4020 DAG.getConstant(Log2Val, DL,
4021 getShiftAmountTy(N0.getValueType()))));
4022 }
4023
4024 // Try to transform:
4025 // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4026 // mul x, (2^N + 1) --> add (shl x, N), x
4027 // mul x, (2^N - 1) --> sub (shl x, N), x
4028 // Examples: x * 33 --> (x << 5) + x
4029 // x * 15 --> (x << 4) - x
4030 // x * -33 --> -((x << 5) + x)
4031 // x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4032 // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4033 // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4034 // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4035 // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4036 // x * 0xf800 --> (x << 16) - (x << 11)
4037 // x * -0x8800 --> -((x << 15) + (x << 11))
4038 // x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4039 if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4040 // TODO: We could handle more general decomposition of any constant by
4041 // having the target set a limit on number of ops and making a
4042 // callback to determine that sequence (similar to sqrt expansion).
4043 unsigned MathOp = ISD::DELETED_NODE;
4044 APInt MulC = ConstValue1.abs();
4045 // The constant `2` should be treated as (2^0 + 1).
4046 unsigned TZeros = MulC == 2 ? 0 : MulC.countTrailingZeros();
4047 MulC.lshrInPlace(TZeros);
4048 if ((MulC - 1).isPowerOf2())
4049 MathOp = ISD::ADD;
4050 else if ((MulC + 1).isPowerOf2())
4051 MathOp = ISD::SUB;
4052
4053 if (MathOp != ISD::DELETED_NODE) {
4054 unsigned ShAmt =
4055 MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4056 ShAmt += TZeros;
4057 assert(ShAmt < VT.getScalarSizeInBits() &&
4058 "multiply-by-constant generated out of bounds shift");
4059 SDValue Shl =
4060 DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
4061 SDValue R =
4062 TZeros ? DAG.getNode(MathOp, DL, VT, Shl,
4063 DAG.getNode(ISD::SHL, DL, VT, N0,
4064 DAG.getConstant(TZeros, DL, VT)))
4065 : DAG.getNode(MathOp, DL, VT, Shl, N0);
4066 if (ConstValue1.isNegative())
4067 R = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), R);
4068 return R;
4069 }
4070 }
4071
4072 // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4073 if (N0.getOpcode() == ISD::SHL) {
4074 SDValue N01 = N0.getOperand(1);
4075 if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
4076 return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
4077 }
4078
4079 // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4080 // use.
4081 {
4082 SDValue Sh, Y;
4083
4084 // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4085 if (N0.getOpcode() == ISD::SHL &&
4086 isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse()) {
4087 Sh = N0; Y = N1;
4088 } else if (N1.getOpcode() == ISD::SHL &&
4089 isConstantOrConstantVector(N1.getOperand(1)) &&
4090 N1->hasOneUse()) {
4091 Sh = N1; Y = N0;
4092 }
4093
4094 if (Sh.getNode()) {
4095 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4096 return DAG.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4097 }
4098 }
4099
4100 // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4101 if (DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
4102 N0.getOpcode() == ISD::ADD &&
4103 DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
4104 isMulAddWithConstProfitable(N, N0, N1))
4105 return DAG.getNode(
4106 ISD::ADD, DL, VT,
4107 DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4108 DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4109
4110 // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4111 if (N0.getOpcode() == ISD::VSCALE)
4112 if (ConstantSDNode *NC1 = isConstOrConstSplat(N1)) {
4113 const APInt &C0 = N0.getConstantOperandAPInt(0);
4114 const APInt &C1 = NC1->getAPIntValue();
4115 return DAG.getVScale(DL, VT, C0 * C1);
4116 }
4117
4118 // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4119 APInt MulVal;
4120 if (N0.getOpcode() == ISD::STEP_VECTOR)
4121 if (ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
4122 const APInt &C0 = N0.getConstantOperandAPInt(0);
4123 APInt NewStep = C0 * MulVal;
4124 return DAG.getStepVector(DL, VT, NewStep);
4125 }
4126
4127 // Fold ((mul x, 0/undef) -> 0,
4128 // (mul x, 1) -> x) -> x)
4129 // -> and(x, mask)
4130 // We can replace vectors with '0' and '1' factors with a clearing mask.
4131 if (VT.isFixedLengthVector()) {
4132 unsigned NumElts = VT.getVectorNumElements();
4133 SmallBitVector ClearMask;
4134 ClearMask.reserve(NumElts);
4135 auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4136 if (!V || V->isZero()) {
4137 ClearMask.push_back(true);
4138 return true;
4139 }
4140 ClearMask.push_back(false);
4141 return V->isOne();
4142 };
4143 if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) &&
4144 ISD::matchUnaryPredicate(N1, IsClearMask, /*AllowUndefs*/ true)) {
4145 assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4146 EVT LegalSVT = N1.getOperand(0).getValueType();
4147 SDValue Zero = DAG.getConstant(0, DL, LegalSVT);
4148 SDValue AllOnes = DAG.getAllOnesConstant(DL, LegalSVT);
4149 SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4150 for (unsigned I = 0; I != NumElts; ++I)
4151 if (ClearMask[I])
4152 Mask[I] = Zero;
4153 return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getBuildVector(VT, DL, Mask));
4154 }
4155 }
4156
4157 // reassociate mul
4158 if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4159 return RMUL;
4160
4161 // Simplify the operands using demanded-bits information.
4162 if (SimplifyDemandedBits(SDValue(N, 0)))
4163 return SDValue(N, 0);
4164
4165 return SDValue();
4166 }
4167
4168 /// Return true if divmod libcall is available.
isDivRemLibcallAvailable(SDNode * Node,bool isSigned,const TargetLowering & TLI)4169 static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4170 const TargetLowering &TLI) {
4171 RTLIB::Libcall LC;
4172 EVT NodeType = Node->getValueType(0);
4173 if (!NodeType.isSimple())
4174 return false;
4175 switch (NodeType.getSimpleVT().SimpleTy) {
4176 default: return false; // No libcall for vector types.
4177 case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break;
4178 case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4179 case MVT::i32: LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4180 case MVT::i64: LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4181 case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4182 }
4183
4184 return TLI.getLibcallName(LC) != nullptr;
4185 }
4186
4187 /// Issue divrem if both quotient and remainder are needed.
useDivRem(SDNode * Node)4188 SDValue DAGCombiner::useDivRem(SDNode *Node) {
4189 if (Node->use_empty())
4190 return SDValue(); // This is a dead node, leave it alone.
4191
4192 unsigned Opcode = Node->getOpcode();
4193 bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4194 unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4195
4196 // DivMod lib calls can still work on non-legal types if using lib-calls.
4197 EVT VT = Node->getValueType(0);
4198 if (VT.isVector() || !VT.isInteger())
4199 return SDValue();
4200
4201 if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
4202 return SDValue();
4203
4204 // If DIVREM is going to get expanded into a libcall,
4205 // but there is no libcall available, then don't combine.
4206 if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
4207 !isDivRemLibcallAvailable(Node, isSigned, TLI))
4208 return SDValue();
4209
4210 // If div is legal, it's better to do the normal expansion
4211 unsigned OtherOpcode = 0;
4212 if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4213 OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4214 if (TLI.isOperationLegalOrCustom(Opcode, VT))
4215 return SDValue();
4216 } else {
4217 OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4218 if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
4219 return SDValue();
4220 }
4221
4222 SDValue Op0 = Node->getOperand(0);
4223 SDValue Op1 = Node->getOperand(1);
4224 SDValue combined;
4225 for (SDNode *User : Op0->uses()) {
4226 if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4227 User->use_empty())
4228 continue;
4229 // Convert the other matching node(s), too;
4230 // otherwise, the DIVREM may get target-legalized into something
4231 // target-specific that we won't be able to recognize.
4232 unsigned UserOpc = User->getOpcode();
4233 if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4234 User->getOperand(0) == Op0 &&
4235 User->getOperand(1) == Op1) {
4236 if (!combined) {
4237 if (UserOpc == OtherOpcode) {
4238 SDVTList VTs = DAG.getVTList(VT, VT);
4239 combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
4240 } else if (UserOpc == DivRemOpc) {
4241 combined = SDValue(User, 0);
4242 } else {
4243 assert(UserOpc == Opcode);
4244 continue;
4245 }
4246 }
4247 if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4248 CombineTo(User, combined);
4249 else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4250 CombineTo(User, combined.getValue(1));
4251 }
4252 }
4253 return combined;
4254 }
4255
simplifyDivRem(SDNode * N,SelectionDAG & DAG)4256 static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4257 SDValue N0 = N->getOperand(0);
4258 SDValue N1 = N->getOperand(1);
4259 EVT VT = N->getValueType(0);
4260 SDLoc DL(N);
4261
4262 unsigned Opc = N->getOpcode();
4263 bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4264 ConstantSDNode *N1C = isConstOrConstSplat(N1);
4265
4266 // X / undef -> undef
4267 // X % undef -> undef
4268 // X / 0 -> undef
4269 // X % 0 -> undef
4270 // NOTE: This includes vectors where any divisor element is zero/undef.
4271 if (DAG.isUndef(Opc, {N0, N1}))
4272 return DAG.getUNDEF(VT);
4273
4274 // undef / X -> 0
4275 // undef % X -> 0
4276 if (N0.isUndef())
4277 return DAG.getConstant(0, DL, VT);
4278
4279 // 0 / X -> 0
4280 // 0 % X -> 0
4281 ConstantSDNode *N0C = isConstOrConstSplat(N0);
4282 if (N0C && N0C->isZero())
4283 return N0;
4284
4285 // X / X -> 1
4286 // X % X -> 0
4287 if (N0 == N1)
4288 return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
4289
4290 // X / 1 -> X
4291 // X % 1 -> 0
4292 // If this is a boolean op (single-bit element type), we can't have
4293 // division-by-zero or remainder-by-zero, so assume the divisor is 1.
4294 // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
4295 // it's a 1.
4296 if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
4297 return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
4298
4299 return SDValue();
4300 }
4301
visitSDIV(SDNode * N)4302 SDValue DAGCombiner::visitSDIV(SDNode *N) {
4303 SDValue N0 = N->getOperand(0);
4304 SDValue N1 = N->getOperand(1);
4305 EVT VT = N->getValueType(0);
4306 EVT CCVT = getSetCCResultType(VT);
4307 SDLoc DL(N);
4308
4309 // fold (sdiv c1, c2) -> c1/c2
4310 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
4311 return C;
4312
4313 // fold vector ops
4314 if (VT.isVector())
4315 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4316 return FoldedVOp;
4317
4318 // fold (sdiv X, -1) -> 0-X
4319 ConstantSDNode *N1C = isConstOrConstSplat(N1);
4320 if (N1C && N1C->isAllOnes())
4321 return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
4322
4323 // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
4324 if (N1C && N1C->getAPIntValue().isMinSignedValue())
4325 return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4326 DAG.getConstant(1, DL, VT),
4327 DAG.getConstant(0, DL, VT));
4328
4329 if (SDValue V = simplifyDivRem(N, DAG))
4330 return V;
4331
4332 if (SDValue NewSel = foldBinOpIntoSelect(N))
4333 return NewSel;
4334
4335 // If we know the sign bits of both operands are zero, strength reduce to a
4336 // udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
4337 if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4338 return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
4339
4340 if (SDValue V = visitSDIVLike(N0, N1, N)) {
4341 // If the corresponding remainder node exists, update its users with
4342 // (Dividend - (Quotient * Divisor).
4343 if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
4344 { N0, N1 })) {
4345 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4346 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4347 AddToWorklist(Mul.getNode());
4348 AddToWorklist(Sub.getNode());
4349 CombineTo(RemNode, Sub);
4350 }
4351 return V;
4352 }
4353
4354 // sdiv, srem -> sdivrem
4355 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4356 // true. Otherwise, we break the simplification logic in visitREM().
4357 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4358 if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4359 if (SDValue DivRem = useDivRem(N))
4360 return DivRem;
4361
4362 return SDValue();
4363 }
4364
isDivisorPowerOfTwo(SDValue Divisor)4365 static bool isDivisorPowerOfTwo(SDValue Divisor) {
4366 // Helper for determining whether a value is a power-2 constant scalar or a
4367 // vector of such elements.
4368 auto IsPowerOfTwo = [](ConstantSDNode *C) {
4369 if (C->isZero() || C->isOpaque())
4370 return false;
4371 if (C->getAPIntValue().isPowerOf2())
4372 return true;
4373 if (C->getAPIntValue().isNegatedPowerOf2())
4374 return true;
4375 return false;
4376 };
4377
4378 return ISD::matchUnaryPredicate(Divisor, IsPowerOfTwo);
4379 }
4380
visitSDIVLike(SDValue N0,SDValue N1,SDNode * N)4381 SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4382 SDLoc DL(N);
4383 EVT VT = N->getValueType(0);
4384 EVT CCVT = getSetCCResultType(VT);
4385 unsigned BitWidth = VT.getScalarSizeInBits();
4386
4387 // fold (sdiv X, pow2) -> simple ops after legalize
4388 // FIXME: We check for the exact bit here because the generic lowering gives
4389 // better results in that case. The target-specific lowering should learn how
4390 // to handle exact sdivs efficiently.
4391 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1)) {
4392 // Target-specific implementation of sdiv x, pow2.
4393 if (SDValue Res = BuildSDIVPow2(N))
4394 return Res;
4395
4396 // Create constants that are functions of the shift amount value.
4397 EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
4398 SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
4399 SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
4400 C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
4401 SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
4402 if (!isConstantOrConstantVector(Inexact))
4403 return SDValue();
4404
4405 // Splat the sign bit into the register
4406 SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
4407 DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
4408 AddToWorklist(Sign.getNode());
4409
4410 // Add (N0 < 0) ? abs2 - 1 : 0;
4411 SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
4412 AddToWorklist(Srl.getNode());
4413 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
4414 AddToWorklist(Add.getNode());
4415 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
4416 AddToWorklist(Sra.getNode());
4417
4418 // Special case: (sdiv X, 1) -> X
4419 // Special Case: (sdiv X, -1) -> 0-X
4420 SDValue One = DAG.getConstant(1, DL, VT);
4421 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
4422 SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
4423 SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
4424 SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
4425 Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
4426
4427 // If dividing by a positive value, we're done. Otherwise, the result must
4428 // be negated.
4429 SDValue Zero = DAG.getConstant(0, DL, VT);
4430 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
4431
4432 // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
4433 SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
4434 SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
4435 return Res;
4436 }
4437
4438 // If integer divide is expensive and we satisfy the requirements, emit an
4439 // alternate sequence. Targets may check function attributes for size/speed
4440 // trade-offs.
4441 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4442 if (isConstantOrConstantVector(N1) &&
4443 !TLI.isIntDivCheap(N->getValueType(0), Attr))
4444 if (SDValue Op = BuildSDIV(N))
4445 return Op;
4446
4447 return SDValue();
4448 }
4449
visitUDIV(SDNode * N)4450 SDValue DAGCombiner::visitUDIV(SDNode *N) {
4451 SDValue N0 = N->getOperand(0);
4452 SDValue N1 = N->getOperand(1);
4453 EVT VT = N->getValueType(0);
4454 EVT CCVT = getSetCCResultType(VT);
4455 SDLoc DL(N);
4456
4457 // fold (udiv c1, c2) -> c1/c2
4458 if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
4459 return C;
4460
4461 // fold vector ops
4462 if (VT.isVector())
4463 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4464 return FoldedVOp;
4465
4466 // fold (udiv X, -1) -> select(X == -1, 1, 0)
4467 ConstantSDNode *N1C = isConstOrConstSplat(N1);
4468 if (N1C && N1C->isAllOnes())
4469 return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4470 DAG.getConstant(1, DL, VT),
4471 DAG.getConstant(0, DL, VT));
4472
4473 if (SDValue V = simplifyDivRem(N, DAG))
4474 return V;
4475
4476 if (SDValue NewSel = foldBinOpIntoSelect(N))
4477 return NewSel;
4478
4479 if (SDValue V = visitUDIVLike(N0, N1, N)) {
4480 // If the corresponding remainder node exists, update its users with
4481 // (Dividend - (Quotient * Divisor).
4482 if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
4483 { N0, N1 })) {
4484 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4485 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4486 AddToWorklist(Mul.getNode());
4487 AddToWorklist(Sub.getNode());
4488 CombineTo(RemNode, Sub);
4489 }
4490 return V;
4491 }
4492
4493 // sdiv, srem -> sdivrem
4494 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4495 // true. Otherwise, we break the simplification logic in visitREM().
4496 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4497 if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4498 if (SDValue DivRem = useDivRem(N))
4499 return DivRem;
4500
4501 return SDValue();
4502 }
4503
visitUDIVLike(SDValue N0,SDValue N1,SDNode * N)4504 SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4505 SDLoc DL(N);
4506 EVT VT = N->getValueType(0);
4507
4508 // fold (udiv x, (1 << c)) -> x >>u c
4509 if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4510 DAG.isKnownToBeAPowerOfTwo(N1)) {
4511 SDValue LogBase2 = BuildLogBase2(N1, DL);
4512 AddToWorklist(LogBase2.getNode());
4513
4514 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4515 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4516 AddToWorklist(Trunc.getNode());
4517 return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4518 }
4519
4520 // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
4521 if (N1.getOpcode() == ISD::SHL) {
4522 SDValue N10 = N1.getOperand(0);
4523 if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) &&
4524 DAG.isKnownToBeAPowerOfTwo(N10)) {
4525 SDValue LogBase2 = BuildLogBase2(N10, DL);
4526 AddToWorklist(LogBase2.getNode());
4527
4528 EVT ADDVT = N1.getOperand(1).getValueType();
4529 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
4530 AddToWorklist(Trunc.getNode());
4531 SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
4532 AddToWorklist(Add.getNode());
4533 return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
4534 }
4535 }
4536
4537 // fold (udiv x, c) -> alternate
4538 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4539 if (isConstantOrConstantVector(N1) &&
4540 !TLI.isIntDivCheap(N->getValueType(0), Attr))
4541 if (SDValue Op = BuildUDIV(N))
4542 return Op;
4543
4544 return SDValue();
4545 }
4546
buildOptimizedSREM(SDValue N0,SDValue N1,SDNode * N)4547 SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
4548 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1) &&
4549 !DAG.doesNodeExist(ISD::SDIV, N->getVTList(), {N0, N1})) {
4550 // Target-specific implementation of srem x, pow2.
4551 if (SDValue Res = BuildSREMPow2(N))
4552 return Res;
4553 }
4554 return SDValue();
4555 }
4556
4557 // handles ISD::SREM and ISD::UREM
visitREM(SDNode * N)4558 SDValue DAGCombiner::visitREM(SDNode *N) {
4559 unsigned Opcode = N->getOpcode();
4560 SDValue N0 = N->getOperand(0);
4561 SDValue N1 = N->getOperand(1);
4562 EVT VT = N->getValueType(0);
4563 EVT CCVT = getSetCCResultType(VT);
4564
4565 bool isSigned = (Opcode == ISD::SREM);
4566 SDLoc DL(N);
4567
4568 // fold (rem c1, c2) -> c1%c2
4569 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4570 return C;
4571
4572 // fold (urem X, -1) -> select(FX == -1, 0, FX)
4573 // Freeze the numerator to avoid a miscompile with an undefined value.
4574 if (!isSigned && llvm::isAllOnesOrAllOnesSplat(N1, /*AllowUndefs*/ false)) {
4575 SDValue F0 = DAG.getFreeze(N0);
4576 SDValue EqualsNeg1 = DAG.getSetCC(DL, CCVT, F0, N1, ISD::SETEQ);
4577 return DAG.getSelect(DL, VT, EqualsNeg1, DAG.getConstant(0, DL, VT), F0);
4578 }
4579
4580 if (SDValue V = simplifyDivRem(N, DAG))
4581 return V;
4582
4583 if (SDValue NewSel = foldBinOpIntoSelect(N))
4584 return NewSel;
4585
4586 if (isSigned) {
4587 // If we know the sign bits of both operands are zero, strength reduce to a
4588 // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15
4589 if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4590 return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
4591 } else {
4592 if (DAG.isKnownToBeAPowerOfTwo(N1)) {
4593 // fold (urem x, pow2) -> (and x, pow2-1)
4594 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4595 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4596 AddToWorklist(Add.getNode());
4597 return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4598 }
4599 // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
4600 // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
4601 // TODO: We should sink the following into isKnownToBePowerOfTwo
4602 // using a OrZero parameter analogous to our handling in ValueTracking.
4603 if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
4604 DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
4605 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4606 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4607 AddToWorklist(Add.getNode());
4608 return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4609 }
4610 }
4611
4612 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4613
4614 // If X/C can be simplified by the division-by-constant logic, lower
4615 // X%C to the equivalent of X-X/C*C.
4616 // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
4617 // speculative DIV must not cause a DIVREM conversion. We guard against this
4618 // by skipping the simplification if isIntDivCheap(). When div is not cheap,
4619 // combine will not return a DIVREM. Regardless, checking cheapness here
4620 // makes sense since the simplification results in fatter code.
4621 if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
4622 if (isSigned) {
4623 // check if we can build faster implementation for srem
4624 if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
4625 return OptimizedRem;
4626 }
4627
4628 SDValue OptimizedDiv =
4629 isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
4630 if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
4631 // If the equivalent Div node also exists, update its users.
4632 unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4633 if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
4634 { N0, N1 }))
4635 CombineTo(DivNode, OptimizedDiv);
4636 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
4637 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4638 AddToWorklist(OptimizedDiv.getNode());
4639 AddToWorklist(Mul.getNode());
4640 return Sub;
4641 }
4642 }
4643
4644 // sdiv, srem -> sdivrem
4645 if (SDValue DivRem = useDivRem(N))
4646 return DivRem.getValue(1);
4647
4648 return SDValue();
4649 }
4650
visitMULHS(SDNode * N)4651 SDValue DAGCombiner::visitMULHS(SDNode *N) {
4652 SDValue N0 = N->getOperand(0);
4653 SDValue N1 = N->getOperand(1);
4654 EVT VT = N->getValueType(0);
4655 SDLoc DL(N);
4656
4657 // fold (mulhs c1, c2)
4658 if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHS, DL, VT, {N0, N1}))
4659 return C;
4660
4661 // canonicalize constant to RHS.
4662 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4663 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4664 return DAG.getNode(ISD::MULHS, DL, N->getVTList(), N1, N0);
4665
4666 if (VT.isVector()) {
4667 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4668 return FoldedVOp;
4669
4670 // fold (mulhs x, 0) -> 0
4671 // do not return N1, because undef node may exist.
4672 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4673 return DAG.getConstant(0, DL, VT);
4674 }
4675
4676 // fold (mulhs x, 0) -> 0
4677 if (isNullConstant(N1))
4678 return N1;
4679
4680 // fold (mulhs x, 1) -> (sra x, size(x)-1)
4681 if (isOneConstant(N1))
4682 return DAG.getNode(ISD::SRA, DL, N0.getValueType(), N0,
4683 DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
4684 getShiftAmountTy(N0.getValueType())));
4685
4686 // fold (mulhs x, undef) -> 0
4687 if (N0.isUndef() || N1.isUndef())
4688 return DAG.getConstant(0, DL, VT);
4689
4690 // If the type twice as wide is legal, transform the mulhs to a wider multiply
4691 // plus a shift.
4692 if (!TLI.isOperationLegalOrCustom(ISD::MULHS, VT) && VT.isSimple() &&
4693 !VT.isVector()) {
4694 MVT Simple = VT.getSimpleVT();
4695 unsigned SimpleSize = Simple.getSizeInBits();
4696 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4697 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4698 N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
4699 N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
4700 N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4701 N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4702 DAG.getConstant(SimpleSize, DL,
4703 getShiftAmountTy(N1.getValueType())));
4704 return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4705 }
4706 }
4707
4708 return SDValue();
4709 }
4710
visitMULHU(SDNode * N)4711 SDValue DAGCombiner::visitMULHU(SDNode *N) {
4712 SDValue N0 = N->getOperand(0);
4713 SDValue N1 = N->getOperand(1);
4714 EVT VT = N->getValueType(0);
4715 SDLoc DL(N);
4716
4717 // fold (mulhu c1, c2)
4718 if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHU, DL, VT, {N0, N1}))
4719 return C;
4720
4721 // canonicalize constant to RHS.
4722 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4723 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4724 return DAG.getNode(ISD::MULHU, DL, N->getVTList(), N1, N0);
4725
4726 if (VT.isVector()) {
4727 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4728 return FoldedVOp;
4729
4730 // fold (mulhu x, 0) -> 0
4731 // do not return N1, because undef node may exist.
4732 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4733 return DAG.getConstant(0, DL, VT);
4734 }
4735
4736 // fold (mulhu x, 0) -> 0
4737 if (isNullConstant(N1))
4738 return N1;
4739
4740 // fold (mulhu x, 1) -> 0
4741 if (isOneConstant(N1))
4742 return DAG.getConstant(0, DL, N0.getValueType());
4743
4744 // fold (mulhu x, undef) -> 0
4745 if (N0.isUndef() || N1.isUndef())
4746 return DAG.getConstant(0, DL, VT);
4747
4748 // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
4749 if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4750 DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) {
4751 unsigned NumEltBits = VT.getScalarSizeInBits();
4752 SDValue LogBase2 = BuildLogBase2(N1, DL);
4753 SDValue SRLAmt = DAG.getNode(
4754 ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
4755 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4756 SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
4757 return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4758 }
4759
4760 // If the type twice as wide is legal, transform the mulhu to a wider multiply
4761 // plus a shift.
4762 if (!TLI.isOperationLegalOrCustom(ISD::MULHU, VT) && VT.isSimple() &&
4763 !VT.isVector()) {
4764 MVT Simple = VT.getSimpleVT();
4765 unsigned SimpleSize = Simple.getSizeInBits();
4766 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4767 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4768 N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
4769 N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
4770 N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4771 N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4772 DAG.getConstant(SimpleSize, DL,
4773 getShiftAmountTy(N1.getValueType())));
4774 return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4775 }
4776 }
4777
4778 // Simplify the operands using demanded-bits information.
4779 // We don't have demanded bits support for MULHU so this just enables constant
4780 // folding based on known bits.
4781 if (SimplifyDemandedBits(SDValue(N, 0)))
4782 return SDValue(N, 0);
4783
4784 return SDValue();
4785 }
4786
visitAVG(SDNode * N)4787 SDValue DAGCombiner::visitAVG(SDNode *N) {
4788 unsigned Opcode = N->getOpcode();
4789 SDValue N0 = N->getOperand(0);
4790 SDValue N1 = N->getOperand(1);
4791 EVT VT = N->getValueType(0);
4792 SDLoc DL(N);
4793
4794 // fold (avg c1, c2)
4795 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4796 return C;
4797
4798 // canonicalize constant to RHS.
4799 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4800 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4801 return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
4802
4803 if (VT.isVector()) {
4804 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4805 return FoldedVOp;
4806
4807 // fold (avgfloor x, 0) -> x >> 1
4808 if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
4809 if (Opcode == ISD::AVGFLOORS)
4810 return DAG.getNode(ISD::SRA, DL, VT, N0, DAG.getConstant(1, DL, VT));
4811 if (Opcode == ISD::AVGFLOORU)
4812 return DAG.getNode(ISD::SRL, DL, VT, N0, DAG.getConstant(1, DL, VT));
4813 }
4814 }
4815
4816 // fold (avg x, undef) -> x
4817 if (N0.isUndef())
4818 return N1;
4819 if (N1.isUndef())
4820 return N0;
4821
4822 // TODO If we use avg for scalars anywhere, we can add (avgfl x, 0) -> x >> 1
4823
4824 return SDValue();
4825 }
4826
4827 /// Perform optimizations common to nodes that compute two values. LoOp and HiOp
4828 /// give the opcodes for the two computations that are being performed. Return
4829 /// true if a simplification was made.
SimplifyNodeWithTwoResults(SDNode * N,unsigned LoOp,unsigned HiOp)4830 SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
4831 unsigned HiOp) {
4832 // If the high half is not needed, just compute the low half.
4833 bool HiExists = N->hasAnyUseOfValue(1);
4834 if (!HiExists && (!LegalOperations ||
4835 TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
4836 SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4837 return CombineTo(N, Res, Res);
4838 }
4839
4840 // If the low half is not needed, just compute the high half.
4841 bool LoExists = N->hasAnyUseOfValue(0);
4842 if (!LoExists && (!LegalOperations ||
4843 TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
4844 SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4845 return CombineTo(N, Res, Res);
4846 }
4847
4848 // If both halves are used, return as it is.
4849 if (LoExists && HiExists)
4850 return SDValue();
4851
4852 // If the two computed results can be simplified separately, separate them.
4853 if (LoExists) {
4854 SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4855 AddToWorklist(Lo.getNode());
4856 SDValue LoOpt = combine(Lo.getNode());
4857 if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
4858 (!LegalOperations ||
4859 TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
4860 return CombineTo(N, LoOpt, LoOpt);
4861 }
4862
4863 if (HiExists) {
4864 SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4865 AddToWorklist(Hi.getNode());
4866 SDValue HiOpt = combine(Hi.getNode());
4867 if (HiOpt.getNode() && HiOpt != Hi &&
4868 (!LegalOperations ||
4869 TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
4870 return CombineTo(N, HiOpt, HiOpt);
4871 }
4872
4873 return SDValue();
4874 }
4875
visitSMUL_LOHI(SDNode * N)4876 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
4877 if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
4878 return Res;
4879
4880 SDValue N0 = N->getOperand(0);
4881 SDValue N1 = N->getOperand(1);
4882 EVT VT = N->getValueType(0);
4883 SDLoc DL(N);
4884
4885 // canonicalize constant to RHS (vector doesn't have to splat)
4886 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4887 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4888 return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N1, N0);
4889
4890 // If the type is twice as wide is legal, transform the mulhu to a wider
4891 // multiply plus a shift.
4892 if (VT.isSimple() && !VT.isVector()) {
4893 MVT Simple = VT.getSimpleVT();
4894 unsigned SimpleSize = Simple.getSizeInBits();
4895 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4896 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4897 SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
4898 SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
4899 Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4900 // Compute the high part as N1.
4901 Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4902 DAG.getConstant(SimpleSize, DL,
4903 getShiftAmountTy(Lo.getValueType())));
4904 Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4905 // Compute the low part as N0.
4906 Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4907 return CombineTo(N, Lo, Hi);
4908 }
4909 }
4910
4911 return SDValue();
4912 }
4913
visitUMUL_LOHI(SDNode * N)4914 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
4915 if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
4916 return Res;
4917
4918 SDValue N0 = N->getOperand(0);
4919 SDValue N1 = N->getOperand(1);
4920 EVT VT = N->getValueType(0);
4921 SDLoc DL(N);
4922
4923 // canonicalize constant to RHS (vector doesn't have to splat)
4924 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4925 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4926 return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N1, N0);
4927
4928 // (umul_lohi N0, 0) -> (0, 0)
4929 if (isNullConstant(N1)) {
4930 SDValue Zero = DAG.getConstant(0, DL, VT);
4931 return CombineTo(N, Zero, Zero);
4932 }
4933
4934 // (umul_lohi N0, 1) -> (N0, 0)
4935 if (isOneConstant(N1)) {
4936 SDValue Zero = DAG.getConstant(0, DL, VT);
4937 return CombineTo(N, N0, Zero);
4938 }
4939
4940 // If the type is twice as wide is legal, transform the mulhu to a wider
4941 // multiply plus a shift.
4942 if (VT.isSimple() && !VT.isVector()) {
4943 MVT Simple = VT.getSimpleVT();
4944 unsigned SimpleSize = Simple.getSizeInBits();
4945 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4946 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4947 SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
4948 SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
4949 Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4950 // Compute the high part as N1.
4951 Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4952 DAG.getConstant(SimpleSize, DL,
4953 getShiftAmountTy(Lo.getValueType())));
4954 Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4955 // Compute the low part as N0.
4956 Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4957 return CombineTo(N, Lo, Hi);
4958 }
4959 }
4960
4961 return SDValue();
4962 }
4963
visitMULO(SDNode * N)4964 SDValue DAGCombiner::visitMULO(SDNode *N) {
4965 SDValue N0 = N->getOperand(0);
4966 SDValue N1 = N->getOperand(1);
4967 EVT VT = N0.getValueType();
4968 bool IsSigned = (ISD::SMULO == N->getOpcode());
4969
4970 EVT CarryVT = N->getValueType(1);
4971 SDLoc DL(N);
4972
4973 ConstantSDNode *N0C = isConstOrConstSplat(N0);
4974 ConstantSDNode *N1C = isConstOrConstSplat(N1);
4975
4976 // fold operation with constant operands.
4977 // TODO: Move this to FoldConstantArithmetic when it supports nodes with
4978 // multiple results.
4979 if (N0C && N1C) {
4980 bool Overflow;
4981 APInt Result =
4982 IsSigned ? N0C->getAPIntValue().smul_ov(N1C->getAPIntValue(), Overflow)
4983 : N0C->getAPIntValue().umul_ov(N1C->getAPIntValue(), Overflow);
4984 return CombineTo(N, DAG.getConstant(Result, DL, VT),
4985 DAG.getBoolConstant(Overflow, DL, CarryVT, CarryVT));
4986 }
4987
4988 // canonicalize constant to RHS.
4989 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4990 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4991 return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
4992
4993 // fold (mulo x, 0) -> 0 + no carry out
4994 if (isNullOrNullSplat(N1))
4995 return CombineTo(N, DAG.getConstant(0, DL, VT),
4996 DAG.getConstant(0, DL, CarryVT));
4997
4998 // (mulo x, 2) -> (addo x, x)
4999 // FIXME: This needs a freeze.
5000 if (N1C && N1C->getAPIntValue() == 2 &&
5001 (!IsSigned || VT.getScalarSizeInBits() > 2))
5002 return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
5003 N->getVTList(), N0, N0);
5004
5005 if (IsSigned) {
5006 // A 1 bit SMULO overflows if both inputs are 1.
5007 if (VT.getScalarSizeInBits() == 1) {
5008 SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1);
5009 return CombineTo(N, And,
5010 DAG.getSetCC(DL, CarryVT, And,
5011 DAG.getConstant(0, DL, VT), ISD::SETNE));
5012 }
5013
5014 // Multiplying n * m significant bits yields a result of n + m significant
5015 // bits. If the total number of significant bits does not exceed the
5016 // result bit width (minus 1), there is no overflow.
5017 unsigned SignBits = DAG.ComputeNumSignBits(N0);
5018 if (SignBits > 1)
5019 SignBits += DAG.ComputeNumSignBits(N1);
5020 if (SignBits > VT.getScalarSizeInBits() + 1)
5021 return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
5022 DAG.getConstant(0, DL, CarryVT));
5023 } else {
5024 KnownBits N1Known = DAG.computeKnownBits(N1);
5025 KnownBits N0Known = DAG.computeKnownBits(N0);
5026 bool Overflow;
5027 (void)N0Known.getMaxValue().umul_ov(N1Known.getMaxValue(), Overflow);
5028 if (!Overflow)
5029 return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
5030 DAG.getConstant(0, DL, CarryVT));
5031 }
5032
5033 return SDValue();
5034 }
5035
5036 // Function to calculate whether the Min/Max pair of SDNodes (potentially
5037 // swapped around) make a signed saturate pattern, clamping to between a signed
5038 // saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
5039 // Returns the node being clamped and the bitwidth of the clamp in BW. Should
5040 // work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
5041 // same as SimplifySelectCC. N0<N1 ? N2 : N3.
isSaturatingMinMax(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,unsigned & BW,bool & Unsigned)5042 static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
5043 SDValue N3, ISD::CondCode CC, unsigned &BW,
5044 bool &Unsigned) {
5045 auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
5046 ISD::CondCode CC) {
5047 // The compare and select operand should be the same or the select operands
5048 // should be truncated versions of the comparison.
5049 if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0)))
5050 return 0;
5051 // The constants need to be the same or a truncated version of each other.
5052 ConstantSDNode *N1C = isConstOrConstSplat(N1);
5053 ConstantSDNode *N3C = isConstOrConstSplat(N3);
5054 if (!N1C || !N3C)
5055 return 0;
5056 const APInt &C1 = N1C->getAPIntValue();
5057 const APInt &C2 = N3C->getAPIntValue();
5058 if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(C1.getBitWidth()))
5059 return 0;
5060 return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
5061 };
5062
5063 // Check the initial value is a SMIN/SMAX equivalent.
5064 unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
5065 if (!Opcode0)
5066 return SDValue();
5067
5068 SDValue N00, N01, N02, N03;
5069 ISD::CondCode N0CC;
5070 switch (N0.getOpcode()) {
5071 case ISD::SMIN:
5072 case ISD::SMAX:
5073 N00 = N02 = N0.getOperand(0);
5074 N01 = N03 = N0.getOperand(1);
5075 N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
5076 break;
5077 case ISD::SELECT_CC:
5078 N00 = N0.getOperand(0);
5079 N01 = N0.getOperand(1);
5080 N02 = N0.getOperand(2);
5081 N03 = N0.getOperand(3);
5082 N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
5083 break;
5084 case ISD::SELECT:
5085 case ISD::VSELECT:
5086 if (N0.getOperand(0).getOpcode() != ISD::SETCC)
5087 return SDValue();
5088 N00 = N0.getOperand(0).getOperand(0);
5089 N01 = N0.getOperand(0).getOperand(1);
5090 N02 = N0.getOperand(1);
5091 N03 = N0.getOperand(2);
5092 N0CC = cast<CondCodeSDNode>(N0.getOperand(0).getOperand(2))->get();
5093 break;
5094 default:
5095 return SDValue();
5096 }
5097
5098 unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
5099 if (!Opcode1 || Opcode0 == Opcode1)
5100 return SDValue();
5101
5102 ConstantSDNode *MinCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N1 : N01);
5103 ConstantSDNode *MaxCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N01 : N1);
5104 if (!MinCOp || !MaxCOp || MinCOp->getValueType(0) != MaxCOp->getValueType(0))
5105 return SDValue();
5106
5107 const APInt &MinC = MinCOp->getAPIntValue();
5108 const APInt &MaxC = MaxCOp->getAPIntValue();
5109 APInt MinCPlus1 = MinC + 1;
5110 if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
5111 BW = MinCPlus1.exactLogBase2() + 1;
5112 Unsigned = false;
5113 return N02;
5114 }
5115
5116 if (MaxC == 0 && MinCPlus1.isPowerOf2()) {
5117 BW = MinCPlus1.exactLogBase2();
5118 Unsigned = true;
5119 return N02;
5120 }
5121
5122 return SDValue();
5123 }
5124
PerformMinMaxFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5125 static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5126 SDValue N3, ISD::CondCode CC,
5127 SelectionDAG &DAG) {
5128 unsigned BW;
5129 bool Unsigned;
5130 SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned);
5131 if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
5132 return SDValue();
5133 EVT FPVT = Fp.getOperand(0).getValueType();
5134 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
5135 if (FPVT.isVector())
5136 NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
5137 FPVT.getVectorElementCount());
5138 unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
5139 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(NewOpc, FPVT, NewVT))
5140 return SDValue();
5141 SDLoc DL(Fp);
5142 SDValue Sat = DAG.getNode(NewOpc, DL, NewVT, Fp.getOperand(0),
5143 DAG.getValueType(NewVT.getScalarType()));
5144 return Unsigned ? DAG.getZExtOrTrunc(Sat, DL, N2->getValueType(0))
5145 : DAG.getSExtOrTrunc(Sat, DL, N2->getValueType(0));
5146 }
5147
PerformUMinFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5148 static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5149 SDValue N3, ISD::CondCode CC,
5150 SelectionDAG &DAG) {
5151 // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
5152 // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
5153 // be truncated versions of the the setcc (N0/N1).
5154 if ((N0 != N2 &&
5155 (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0))) ||
5156 N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
5157 return SDValue();
5158 ConstantSDNode *N1C = isConstOrConstSplat(N1);
5159 ConstantSDNode *N3C = isConstOrConstSplat(N3);
5160 if (!N1C || !N3C)
5161 return SDValue();
5162 const APInt &C1 = N1C->getAPIntValue();
5163 const APInt &C3 = N3C->getAPIntValue();
5164 if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
5165 C1 != C3.zext(C1.getBitWidth()))
5166 return SDValue();
5167
5168 unsigned BW = (C1 + 1).exactLogBase2();
5169 EVT FPVT = N0.getOperand(0).getValueType();
5170 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
5171 if (FPVT.isVector())
5172 NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
5173 FPVT.getVectorElementCount());
5174 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
5175 FPVT, NewVT))
5176 return SDValue();
5177
5178 SDValue Sat =
5179 DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(N0), NewVT, N0.getOperand(0),
5180 DAG.getValueType(NewVT.getScalarType()));
5181 return DAG.getZExtOrTrunc(Sat, SDLoc(N0), N3.getValueType());
5182 }
5183
visitIMINMAX(SDNode * N)5184 SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
5185 SDValue N0 = N->getOperand(0);
5186 SDValue N1 = N->getOperand(1);
5187 EVT VT = N0.getValueType();
5188 unsigned Opcode = N->getOpcode();
5189 SDLoc DL(N);
5190
5191 // fold operation with constant operands.
5192 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5193 return C;
5194
5195 // If the operands are the same, this is a no-op.
5196 if (N0 == N1)
5197 return N0;
5198
5199 // canonicalize constant to RHS
5200 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5201 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5202 return DAG.getNode(Opcode, DL, VT, N1, N0);
5203
5204 // fold vector ops
5205 if (VT.isVector())
5206 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5207 return FoldedVOp;
5208
5209 // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
5210 // Only do this if the current op isn't legal and the flipped is.
5211 if (!TLI.isOperationLegal(Opcode, VT) &&
5212 (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
5213 (N1.isUndef() || DAG.SignBitIsZero(N1))) {
5214 unsigned AltOpcode;
5215 switch (Opcode) {
5216 case ISD::SMIN: AltOpcode = ISD::UMIN; break;
5217 case ISD::SMAX: AltOpcode = ISD::UMAX; break;
5218 case ISD::UMIN: AltOpcode = ISD::SMIN; break;
5219 case ISD::UMAX: AltOpcode = ISD::SMAX; break;
5220 default: llvm_unreachable("Unknown MINMAX opcode");
5221 }
5222 if (TLI.isOperationLegal(AltOpcode, VT))
5223 return DAG.getNode(AltOpcode, DL, VT, N0, N1);
5224 }
5225
5226 if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
5227 if (SDValue S = PerformMinMaxFpToSatCombine(
5228 N0, N1, N0, N1, Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
5229 return S;
5230 if (Opcode == ISD::UMIN)
5231 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG))
5232 return S;
5233
5234 // Simplify the operands using demanded-bits information.
5235 if (SimplifyDemandedBits(SDValue(N, 0)))
5236 return SDValue(N, 0);
5237
5238 return SDValue();
5239 }
5240
5241 /// If this is a bitwise logic instruction and both operands have the same
5242 /// opcode, try to sink the other opcode after the logic instruction.
hoistLogicOpWithSameOpcodeHands(SDNode * N)5243 SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
5244 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
5245 EVT VT = N0.getValueType();
5246 unsigned LogicOpcode = N->getOpcode();
5247 unsigned HandOpcode = N0.getOpcode();
5248 assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
5249 LogicOpcode == ISD::XOR) && "Expected logic opcode");
5250 assert(HandOpcode == N1.getOpcode() && "Bad input!");
5251
5252 // Bail early if none of these transforms apply.
5253 if (N0.getNumOperands() == 0)
5254 return SDValue();
5255
5256 // FIXME: We should check number of uses of the operands to not increase
5257 // the instruction count for all transforms.
5258
5259 // Handle size-changing casts.
5260 SDValue X = N0.getOperand(0);
5261 SDValue Y = N1.getOperand(0);
5262 EVT XVT = X.getValueType();
5263 SDLoc DL(N);
5264 if (HandOpcode == ISD::ANY_EXTEND || HandOpcode == ISD::ZERO_EXTEND ||
5265 HandOpcode == ISD::SIGN_EXTEND) {
5266 // If both operands have other uses, this transform would create extra
5267 // instructions without eliminating anything.
5268 if (!N0.hasOneUse() && !N1.hasOneUse())
5269 return SDValue();
5270 // We need matching integer source types.
5271 if (XVT != Y.getValueType())
5272 return SDValue();
5273 // Don't create an illegal op during or after legalization. Don't ever
5274 // create an unsupported vector op.
5275 if ((VT.isVector() || LegalOperations) &&
5276 !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
5277 return SDValue();
5278 // Avoid infinite looping with PromoteIntBinOp.
5279 // TODO: Should we apply desirable/legal constraints to all opcodes?
5280 if (HandOpcode == ISD::ANY_EXTEND && LegalTypes &&
5281 !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
5282 return SDValue();
5283 // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
5284 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5285 return DAG.getNode(HandOpcode, DL, VT, Logic);
5286 }
5287
5288 // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
5289 if (HandOpcode == ISD::TRUNCATE) {
5290 // If both operands have other uses, this transform would create extra
5291 // instructions without eliminating anything.
5292 if (!N0.hasOneUse() && !N1.hasOneUse())
5293 return SDValue();
5294 // We need matching source types.
5295 if (XVT != Y.getValueType())
5296 return SDValue();
5297 // Don't create an illegal op during or after legalization.
5298 if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
5299 return SDValue();
5300 // Be extra careful sinking truncate. If it's free, there's no benefit in
5301 // widening a binop. Also, don't create a logic op on an illegal type.
5302 if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
5303 return SDValue();
5304 if (!TLI.isTypeLegal(XVT))
5305 return SDValue();
5306 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5307 return DAG.getNode(HandOpcode, DL, VT, Logic);
5308 }
5309
5310 // For binops SHL/SRL/SRA/AND:
5311 // logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
5312 if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
5313 HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
5314 N0.getOperand(1) == N1.getOperand(1)) {
5315 // If either operand has other uses, this transform is not an improvement.
5316 if (!N0.hasOneUse() || !N1.hasOneUse())
5317 return SDValue();
5318 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5319 return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
5320 }
5321
5322 // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
5323 if (HandOpcode == ISD::BSWAP) {
5324 // If either operand has other uses, this transform is not an improvement.
5325 if (!N0.hasOneUse() || !N1.hasOneUse())
5326 return SDValue();
5327 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5328 return DAG.getNode(HandOpcode, DL, VT, Logic);
5329 }
5330
5331 // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
5332 // Only perform this optimization up until type legalization, before
5333 // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
5334 // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
5335 // we don't want to undo this promotion.
5336 // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
5337 // on scalars.
5338 if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
5339 Level <= AfterLegalizeTypes) {
5340 // Input types must be integer and the same.
5341 if (XVT.isInteger() && XVT == Y.getValueType() &&
5342 !(VT.isVector() && TLI.isTypeLegal(VT) &&
5343 !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
5344 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5345 return DAG.getNode(HandOpcode, DL, VT, Logic);
5346 }
5347 }
5348
5349 // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
5350 // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
5351 // If both shuffles use the same mask, and both shuffle within a single
5352 // vector, then it is worthwhile to move the swizzle after the operation.
5353 // The type-legalizer generates this pattern when loading illegal
5354 // vector types from memory. In many cases this allows additional shuffle
5355 // optimizations.
5356 // There are other cases where moving the shuffle after the xor/and/or
5357 // is profitable even if shuffles don't perform a swizzle.
5358 // If both shuffles use the same mask, and both shuffles have the same first
5359 // or second operand, then it might still be profitable to move the shuffle
5360 // after the xor/and/or operation.
5361 if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
5362 auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
5363 auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
5364 assert(X.getValueType() == Y.getValueType() &&
5365 "Inputs to shuffles are not the same type");
5366
5367 // Check that both shuffles use the same mask. The masks are known to be of
5368 // the same length because the result vector type is the same.
5369 // Check also that shuffles have only one use to avoid introducing extra
5370 // instructions.
5371 if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
5372 !SVN0->getMask().equals(SVN1->getMask()))
5373 return SDValue();
5374
5375 // Don't try to fold this node if it requires introducing a
5376 // build vector of all zeros that might be illegal at this stage.
5377 SDValue ShOp = N0.getOperand(1);
5378 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5379 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5380
5381 // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
5382 if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
5383 SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
5384 N0.getOperand(0), N1.getOperand(0));
5385 return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
5386 }
5387
5388 // Don't try to fold this node if it requires introducing a
5389 // build vector of all zeros that might be illegal at this stage.
5390 ShOp = N0.getOperand(0);
5391 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5392 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5393
5394 // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
5395 if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
5396 SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
5397 N1.getOperand(1));
5398 return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
5399 }
5400 }
5401
5402 return SDValue();
5403 }
5404
5405 /// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
foldLogicOfSetCCs(bool IsAnd,SDValue N0,SDValue N1,const SDLoc & DL)5406 SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
5407 const SDLoc &DL) {
5408 SDValue LL, LR, RL, RR, N0CC, N1CC;
5409 if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
5410 !isSetCCEquivalent(N1, RL, RR, N1CC))
5411 return SDValue();
5412
5413 assert(N0.getValueType() == N1.getValueType() &&
5414 "Unexpected operand types for bitwise logic op");
5415 assert(LL.getValueType() == LR.getValueType() &&
5416 RL.getValueType() == RR.getValueType() &&
5417 "Unexpected operand types for setcc");
5418
5419 // If we're here post-legalization or the logic op type is not i1, the logic
5420 // op type must match a setcc result type. Also, all folds require new
5421 // operations on the left and right operands, so those types must match.
5422 EVT VT = N0.getValueType();
5423 EVT OpVT = LL.getValueType();
5424 if (LegalOperations || VT.getScalarType() != MVT::i1)
5425 if (VT != getSetCCResultType(OpVT))
5426 return SDValue();
5427 if (OpVT != RL.getValueType())
5428 return SDValue();
5429
5430 ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
5431 ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
5432 bool IsInteger = OpVT.isInteger();
5433 if (LR == RR && CC0 == CC1 && IsInteger) {
5434 bool IsZero = isNullOrNullSplat(LR);
5435 bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
5436
5437 // All bits clear?
5438 bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
5439 // All sign bits clear?
5440 bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
5441 // Any bits set?
5442 bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
5443 // Any sign bits set?
5444 bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
5445
5446 // (and (seteq X, 0), (seteq Y, 0)) --> (seteq (or X, Y), 0)
5447 // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
5448 // (or (setne X, 0), (setne Y, 0)) --> (setne (or X, Y), 0)
5449 // (or (setlt X, 0), (setlt Y, 0)) --> (setlt (or X, Y), 0)
5450 if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
5451 SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
5452 AddToWorklist(Or.getNode());
5453 return DAG.getSetCC(DL, VT, Or, LR, CC1);
5454 }
5455
5456 // All bits set?
5457 bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
5458 // All sign bits set?
5459 bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
5460 // Any bits clear?
5461 bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
5462 // Any sign bits clear?
5463 bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
5464
5465 // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
5466 // (and (setlt X, 0), (setlt Y, 0)) --> (setlt (and X, Y), 0)
5467 // (or (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
5468 // (or (setgt X, -1), (setgt Y -1)) --> (setgt (and X, Y), -1)
5469 if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
5470 SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
5471 AddToWorklist(And.getNode());
5472 return DAG.getSetCC(DL, VT, And, LR, CC1);
5473 }
5474 }
5475
5476 // TODO: What is the 'or' equivalent of this fold?
5477 // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
5478 if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
5479 IsInteger && CC0 == ISD::SETNE &&
5480 ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
5481 (isAllOnesConstant(LR) && isNullConstant(RR)))) {
5482 SDValue One = DAG.getConstant(1, DL, OpVT);
5483 SDValue Two = DAG.getConstant(2, DL, OpVT);
5484 SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
5485 AddToWorklist(Add.getNode());
5486 return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
5487 }
5488
5489 // Try more general transforms if the predicates match and the only user of
5490 // the compares is the 'and' or 'or'.
5491 if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
5492 N0.hasOneUse() && N1.hasOneUse()) {
5493 // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
5494 // or (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
5495 if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
5496 SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
5497 SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
5498 SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
5499 SDValue Zero = DAG.getConstant(0, DL, OpVT);
5500 return DAG.getSetCC(DL, VT, Or, Zero, CC1);
5501 }
5502
5503 // Turn compare of constants whose difference is 1 bit into add+and+setcc.
5504 if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
5505 // Match a shared variable operand and 2 non-opaque constant operands.
5506 auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
5507 // The difference of the constants must be a single bit.
5508 const APInt &CMax =
5509 APIntOps::umax(C0->getAPIntValue(), C1->getAPIntValue());
5510 const APInt &CMin =
5511 APIntOps::umin(C0->getAPIntValue(), C1->getAPIntValue());
5512 return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
5513 };
5514 if (LL == RL && ISD::matchBinaryPredicate(LR, RR, MatchDiffPow2)) {
5515 // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
5516 // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
5517 SDValue Max = DAG.getNode(ISD::UMAX, DL, OpVT, LR, RR);
5518 SDValue Min = DAG.getNode(ISD::UMIN, DL, OpVT, LR, RR);
5519 SDValue Offset = DAG.getNode(ISD::SUB, DL, OpVT, LL, Min);
5520 SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, Max, Min);
5521 SDValue Mask = DAG.getNOT(DL, Diff, OpVT);
5522 SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Offset, Mask);
5523 SDValue Zero = DAG.getConstant(0, DL, OpVT);
5524 return DAG.getSetCC(DL, VT, And, Zero, CC0);
5525 }
5526 }
5527 }
5528
5529 // Canonicalize equivalent operands to LL == RL.
5530 if (LL == RR && LR == RL) {
5531 CC1 = ISD::getSetCCSwappedOperands(CC1);
5532 std::swap(RL, RR);
5533 }
5534
5535 // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
5536 // (or (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
5537 if (LL == RL && LR == RR) {
5538 ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
5539 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
5540 if (NewCC != ISD::SETCC_INVALID &&
5541 (!LegalOperations ||
5542 (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
5543 TLI.isOperationLegal(ISD::SETCC, OpVT))))
5544 return DAG.getSetCC(DL, VT, LL, LR, NewCC);
5545 }
5546
5547 return SDValue();
5548 }
5549
5550 /// This contains all DAGCombine rules which reduce two values combined by
5551 /// an And operation to a single value. This makes them reusable in the context
5552 /// of visitSELECT(). Rules involving constants are not included as
5553 /// visitSELECT() already handles those cases.
visitANDLike(SDValue N0,SDValue N1,SDNode * N)5554 SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
5555 EVT VT = N1.getValueType();
5556 SDLoc DL(N);
5557
5558 // fold (and x, undef) -> 0
5559 if (N0.isUndef() || N1.isUndef())
5560 return DAG.getConstant(0, DL, VT);
5561
5562 if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
5563 return V;
5564
5565 // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
5566 if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
5567 VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
5568 if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5569 if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
5570 // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
5571 // immediate for an add, but it is legal if its top c2 bits are set,
5572 // transform the ADD so the immediate doesn't need to be materialized
5573 // in a register.
5574 APInt ADDC = ADDI->getAPIntValue();
5575 APInt SRLC = SRLI->getAPIntValue();
5576 if (ADDC.getMinSignedBits() <= 64 &&
5577 SRLC.ult(VT.getSizeInBits()) &&
5578 !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
5579 APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
5580 SRLC.getZExtValue());
5581 if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
5582 ADDC |= Mask;
5583 if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
5584 SDLoc DL0(N0);
5585 SDValue NewAdd =
5586 DAG.getNode(ISD::ADD, DL0, VT,
5587 N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
5588 CombineTo(N0.getNode(), NewAdd);
5589 // Return N so it doesn't get rechecked!
5590 return SDValue(N, 0);
5591 }
5592 }
5593 }
5594 }
5595 }
5596 }
5597
5598 // Reduce bit extract of low half of an integer to the narrower type.
5599 // (and (srl i64:x, K), KMask) ->
5600 // (i64 zero_extend (and (srl (i32 (trunc i64:x)), K)), KMask)
5601 if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
5602 if (ConstantSDNode *CAnd = dyn_cast<ConstantSDNode>(N1)) {
5603 if (ConstantSDNode *CShift = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5604 unsigned Size = VT.getSizeInBits();
5605 const APInt &AndMask = CAnd->getAPIntValue();
5606 unsigned ShiftBits = CShift->getZExtValue();
5607
5608 // Bail out, this node will probably disappear anyway.
5609 if (ShiftBits == 0)
5610 return SDValue();
5611
5612 unsigned MaskBits = AndMask.countTrailingOnes();
5613 EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), Size / 2);
5614
5615 if (AndMask.isMask() &&
5616 // Required bits must not span the two halves of the integer and
5617 // must fit in the half size type.
5618 (ShiftBits + MaskBits <= Size / 2) &&
5619 TLI.isNarrowingProfitable(VT, HalfVT) &&
5620 TLI.isTypeDesirableForOp(ISD::AND, HalfVT) &&
5621 TLI.isTypeDesirableForOp(ISD::SRL, HalfVT) &&
5622 TLI.isTruncateFree(VT, HalfVT) &&
5623 TLI.isZExtFree(HalfVT, VT)) {
5624 // The isNarrowingProfitable is to avoid regressions on PPC and
5625 // AArch64 which match a few 64-bit bit insert / bit extract patterns
5626 // on downstream users of this. Those patterns could probably be
5627 // extended to handle extensions mixed in.
5628
5629 SDValue SL(N0);
5630 assert(MaskBits <= Size);
5631
5632 // Extracting the highest bit of the low half.
5633 EVT ShiftVT = TLI.getShiftAmountTy(HalfVT, DAG.getDataLayout());
5634 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, HalfVT,
5635 N0.getOperand(0));
5636
5637 SDValue NewMask = DAG.getConstant(AndMask.trunc(Size / 2), SL, HalfVT);
5638 SDValue ShiftK = DAG.getConstant(ShiftBits, SL, ShiftVT);
5639 SDValue Shift = DAG.getNode(ISD::SRL, SL, HalfVT, Trunc, ShiftK);
5640 SDValue And = DAG.getNode(ISD::AND, SL, HalfVT, Shift, NewMask);
5641 return DAG.getNode(ISD::ZERO_EXTEND, SL, VT, And);
5642 }
5643 }
5644 }
5645 }
5646
5647 return SDValue();
5648 }
5649
isAndLoadExtLoad(ConstantSDNode * AndC,LoadSDNode * LoadN,EVT LoadResultTy,EVT & ExtVT)5650 bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
5651 EVT LoadResultTy, EVT &ExtVT) {
5652 if (!AndC->getAPIntValue().isMask())
5653 return false;
5654
5655 unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes();
5656
5657 ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
5658 EVT LoadedVT = LoadN->getMemoryVT();
5659
5660 if (ExtVT == LoadedVT &&
5661 (!LegalOperations ||
5662 TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
5663 // ZEXTLOAD will match without needing to change the size of the value being
5664 // loaded.
5665 return true;
5666 }
5667
5668 // Do not change the width of a volatile or atomic loads.
5669 if (!LoadN->isSimple())
5670 return false;
5671
5672 // Do not generate loads of non-round integer types since these can
5673 // be expensive (and would be wrong if the type is not byte sized).
5674 if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
5675 return false;
5676
5677 if (LegalOperations &&
5678 !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
5679 return false;
5680
5681 if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT))
5682 return false;
5683
5684 return true;
5685 }
5686
isLegalNarrowLdSt(LSBaseSDNode * LDST,ISD::LoadExtType ExtType,EVT & MemVT,unsigned ShAmt)5687 bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
5688 ISD::LoadExtType ExtType, EVT &MemVT,
5689 unsigned ShAmt) {
5690 if (!LDST)
5691 return false;
5692 // Only allow byte offsets.
5693 if (ShAmt % 8)
5694 return false;
5695
5696 // Do not generate loads of non-round integer types since these can
5697 // be expensive (and would be wrong if the type is not byte sized).
5698 if (!MemVT.isRound())
5699 return false;
5700
5701 // Don't change the width of a volatile or atomic loads.
5702 if (!LDST->isSimple())
5703 return false;
5704
5705 EVT LdStMemVT = LDST->getMemoryVT();
5706
5707 // Bail out when changing the scalable property, since we can't be sure that
5708 // we're actually narrowing here.
5709 if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
5710 return false;
5711
5712 // Verify that we are actually reducing a load width here.
5713 if (LdStMemVT.bitsLT(MemVT))
5714 return false;
5715
5716 // Ensure that this isn't going to produce an unsupported memory access.
5717 if (ShAmt) {
5718 assert(ShAmt % 8 == 0 && "ShAmt is byte offset");
5719 const unsigned ByteShAmt = ShAmt / 8;
5720 const Align LDSTAlign = LDST->getAlign();
5721 const Align NarrowAlign = commonAlignment(LDSTAlign, ByteShAmt);
5722 if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
5723 LDST->getAddressSpace(), NarrowAlign,
5724 LDST->getMemOperand()->getFlags()))
5725 return false;
5726 }
5727
5728 // It's not possible to generate a constant of extended or untyped type.
5729 EVT PtrType = LDST->getBasePtr().getValueType();
5730 if (PtrType == MVT::Untyped || PtrType.isExtended())
5731 return false;
5732
5733 if (isa<LoadSDNode>(LDST)) {
5734 LoadSDNode *Load = cast<LoadSDNode>(LDST);
5735 // Don't transform one with multiple uses, this would require adding a new
5736 // load.
5737 if (!SDValue(Load, 0).hasOneUse())
5738 return false;
5739
5740 if (LegalOperations &&
5741 !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
5742 return false;
5743
5744 // For the transform to be legal, the load must produce only two values
5745 // (the value loaded and the chain). Don't transform a pre-increment
5746 // load, for example, which produces an extra value. Otherwise the
5747 // transformation is not equivalent, and the downstream logic to replace
5748 // uses gets things wrong.
5749 if (Load->getNumValues() > 2)
5750 return false;
5751
5752 // If the load that we're shrinking is an extload and we're not just
5753 // discarding the extension we can't simply shrink the load. Bail.
5754 // TODO: It would be possible to merge the extensions in some cases.
5755 if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
5756 Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
5757 return false;
5758
5759 if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT))
5760 return false;
5761 } else {
5762 assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
5763 StoreSDNode *Store = cast<StoreSDNode>(LDST);
5764 // Can't write outside the original store
5765 if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
5766 return false;
5767
5768 if (LegalOperations &&
5769 !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
5770 return false;
5771 }
5772 return true;
5773 }
5774
SearchForAndLoads(SDNode * N,SmallVectorImpl<LoadSDNode * > & Loads,SmallPtrSetImpl<SDNode * > & NodesWithConsts,ConstantSDNode * Mask,SDNode * & NodeToMask)5775 bool DAGCombiner::SearchForAndLoads(SDNode *N,
5776 SmallVectorImpl<LoadSDNode*> &Loads,
5777 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
5778 ConstantSDNode *Mask,
5779 SDNode *&NodeToMask) {
5780 // Recursively search for the operands, looking for loads which can be
5781 // narrowed.
5782 for (SDValue Op : N->op_values()) {
5783 if (Op.getValueType().isVector())
5784 return false;
5785
5786 // Some constants may need fixing up later if they are too large.
5787 if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
5788 if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
5789 (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
5790 NodesWithConsts.insert(N);
5791 continue;
5792 }
5793
5794 if (!Op.hasOneUse())
5795 return false;
5796
5797 switch(Op.getOpcode()) {
5798 case ISD::LOAD: {
5799 auto *Load = cast<LoadSDNode>(Op);
5800 EVT ExtVT;
5801 if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
5802 isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
5803
5804 // ZEXTLOAD is already small enough.
5805 if (Load->getExtensionType() == ISD::ZEXTLOAD &&
5806 ExtVT.bitsGE(Load->getMemoryVT()))
5807 continue;
5808
5809 // Use LE to convert equal sized loads to zext.
5810 if (ExtVT.bitsLE(Load->getMemoryVT()))
5811 Loads.push_back(Load);
5812
5813 continue;
5814 }
5815 return false;
5816 }
5817 case ISD::ZERO_EXTEND:
5818 case ISD::AssertZext: {
5819 unsigned ActiveBits = Mask->getAPIntValue().countTrailingOnes();
5820 EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
5821 EVT VT = Op.getOpcode() == ISD::AssertZext ?
5822 cast<VTSDNode>(Op.getOperand(1))->getVT() :
5823 Op.getOperand(0).getValueType();
5824
5825 // We can accept extending nodes if the mask is wider or an equal
5826 // width to the original type.
5827 if (ExtVT.bitsGE(VT))
5828 continue;
5829 break;
5830 }
5831 case ISD::OR:
5832 case ISD::XOR:
5833 case ISD::AND:
5834 if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
5835 NodeToMask))
5836 return false;
5837 continue;
5838 }
5839
5840 // Allow one node which will masked along with any loads found.
5841 if (NodeToMask)
5842 return false;
5843
5844 // Also ensure that the node to be masked only produces one data result.
5845 NodeToMask = Op.getNode();
5846 if (NodeToMask->getNumValues() > 1) {
5847 bool HasValue = false;
5848 for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
5849 MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
5850 if (VT != MVT::Glue && VT != MVT::Other) {
5851 if (HasValue) {
5852 NodeToMask = nullptr;
5853 return false;
5854 }
5855 HasValue = true;
5856 }
5857 }
5858 assert(HasValue && "Node to be masked has no data result?");
5859 }
5860 }
5861 return true;
5862 }
5863
BackwardsPropagateMask(SDNode * N)5864 bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
5865 auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
5866 if (!Mask)
5867 return false;
5868
5869 if (!Mask->getAPIntValue().isMask())
5870 return false;
5871
5872 // No need to do anything if the and directly uses a load.
5873 if (isa<LoadSDNode>(N->getOperand(0)))
5874 return false;
5875
5876 SmallVector<LoadSDNode*, 8> Loads;
5877 SmallPtrSet<SDNode*, 2> NodesWithConsts;
5878 SDNode *FixupNode = nullptr;
5879 if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
5880 if (Loads.size() == 0)
5881 return false;
5882
5883 LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
5884 SDValue MaskOp = N->getOperand(1);
5885
5886 // If it exists, fixup the single node we allow in the tree that needs
5887 // masking.
5888 if (FixupNode) {
5889 LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
5890 SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
5891 FixupNode->getValueType(0),
5892 SDValue(FixupNode, 0), MaskOp);
5893 DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
5894 if (And.getOpcode() == ISD ::AND)
5895 DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
5896 }
5897
5898 // Narrow any constants that need it.
5899 for (auto *LogicN : NodesWithConsts) {
5900 SDValue Op0 = LogicN->getOperand(0);
5901 SDValue Op1 = LogicN->getOperand(1);
5902
5903 if (isa<ConstantSDNode>(Op0))
5904 std::swap(Op0, Op1);
5905
5906 SDValue And = DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(),
5907 Op1, MaskOp);
5908
5909 DAG.UpdateNodeOperands(LogicN, Op0, And);
5910 }
5911
5912 // Create narrow loads.
5913 for (auto *Load : Loads) {
5914 LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
5915 SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
5916 SDValue(Load, 0), MaskOp);
5917 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
5918 if (And.getOpcode() == ISD ::AND)
5919 And = SDValue(
5920 DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
5921 SDValue NewLoad = reduceLoadWidth(And.getNode());
5922 assert(NewLoad &&
5923 "Shouldn't be masking the load if it can't be narrowed");
5924 CombineTo(Load, NewLoad, NewLoad.getValue(1));
5925 }
5926 DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
5927 return true;
5928 }
5929 return false;
5930 }
5931
5932 // Unfold
5933 // x & (-1 'logical shift' y)
5934 // To
5935 // (x 'opposite logical shift' y) 'logical shift' y
5936 // if it is better for performance.
unfoldExtremeBitClearingToShifts(SDNode * N)5937 SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
5938 assert(N->getOpcode() == ISD::AND);
5939
5940 SDValue N0 = N->getOperand(0);
5941 SDValue N1 = N->getOperand(1);
5942
5943 // Do we actually prefer shifts over mask?
5944 if (!TLI.shouldFoldMaskToVariableShiftPair(N0))
5945 return SDValue();
5946
5947 // Try to match (-1 '[outer] logical shift' y)
5948 unsigned OuterShift;
5949 unsigned InnerShift; // The opposite direction to the OuterShift.
5950 SDValue Y; // Shift amount.
5951 auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
5952 if (!M.hasOneUse())
5953 return false;
5954 OuterShift = M->getOpcode();
5955 if (OuterShift == ISD::SHL)
5956 InnerShift = ISD::SRL;
5957 else if (OuterShift == ISD::SRL)
5958 InnerShift = ISD::SHL;
5959 else
5960 return false;
5961 if (!isAllOnesConstant(M->getOperand(0)))
5962 return false;
5963 Y = M->getOperand(1);
5964 return true;
5965 };
5966
5967 SDValue X;
5968 if (matchMask(N1))
5969 X = N0;
5970 else if (matchMask(N0))
5971 X = N1;
5972 else
5973 return SDValue();
5974
5975 SDLoc DL(N);
5976 EVT VT = N->getValueType(0);
5977
5978 // tmp = x 'opposite logical shift' y
5979 SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
5980 // ret = tmp 'logical shift' y
5981 SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
5982
5983 return T1;
5984 }
5985
5986 /// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
5987 /// For a target with a bit test, this is expected to become test + set and save
5988 /// at least 1 instruction.
combineShiftAnd1ToBitTest(SDNode * And,SelectionDAG & DAG)5989 static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
5990 assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
5991
5992 // This is probably not worthwhile without a supported type.
5993 EVT VT = And->getValueType(0);
5994 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5995 if (!TLI.isTypeLegal(VT))
5996 return SDValue();
5997
5998 // Look through an optional extension.
5999 SDValue And0 = And->getOperand(0), And1 = And->getOperand(1);
6000 if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
6001 And0 = And0.getOperand(0);
6002 if (!isOneConstant(And1) || !And0.hasOneUse())
6003 return SDValue();
6004
6005 SDValue Src = And0;
6006
6007 // Attempt to find a 'not' op.
6008 // TODO: Should we favor test+set even without the 'not' op?
6009 bool FoundNot = false;
6010 if (isBitwiseNot(Src)) {
6011 FoundNot = true;
6012 Src = Src.getOperand(0);
6013
6014 // Look though an optional truncation. The source operand may not be the
6015 // same type as the original 'and', but that is ok because we are masking
6016 // off everything but the low bit.
6017 if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
6018 Src = Src.getOperand(0);
6019 }
6020
6021 // Match a shift-right by constant.
6022 if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
6023 return SDValue();
6024
6025 // We might have looked through casts that make this transform invalid.
6026 // TODO: If the source type is wider than the result type, do the mask and
6027 // compare in the source type.
6028 unsigned VTBitWidth = VT.getScalarSizeInBits();
6029 SDValue ShiftAmt = Src.getOperand(1);
6030 auto *ShiftAmtC = dyn_cast<ConstantSDNode>(ShiftAmt);
6031 if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(VTBitWidth))
6032 return SDValue();
6033
6034 // Set source to shift source.
6035 Src = Src.getOperand(0);
6036
6037 // Try again to find a 'not' op.
6038 // TODO: Should we favor test+set even with two 'not' ops?
6039 if (!FoundNot) {
6040 if (!isBitwiseNot(Src))
6041 return SDValue();
6042 Src = Src.getOperand(0);
6043 }
6044
6045 if (!TLI.hasBitTest(Src, ShiftAmt))
6046 return SDValue();
6047
6048 // Turn this into a bit-test pattern using mask op + setcc:
6049 // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
6050 // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
6051 SDLoc DL(And);
6052 SDValue X = DAG.getZExtOrTrunc(Src, DL, VT);
6053 EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
6054 SDValue Mask = DAG.getConstant(
6055 APInt::getOneBitSet(VTBitWidth, ShiftAmtC->getZExtValue()), DL, VT);
6056 SDValue NewAnd = DAG.getNode(ISD::AND, DL, VT, X, Mask);
6057 SDValue Zero = DAG.getConstant(0, DL, VT);
6058 SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
6059 return DAG.getZExtOrTrunc(Setcc, DL, VT);
6060 }
6061
6062 /// For targets that support usubsat, match a bit-hack form of that operation
6063 /// that ends in 'and' and convert it.
foldAndToUsubsat(SDNode * N,SelectionDAG & DAG)6064 static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG) {
6065 SDValue N0 = N->getOperand(0);
6066 SDValue N1 = N->getOperand(1);
6067 EVT VT = N1.getValueType();
6068
6069 // Canonicalize SRA as operand 1.
6070 if (N0.getOpcode() == ISD::SRA)
6071 std::swap(N0, N1);
6072
6073 // xor/add with SMIN (signmask) are logically equivalent.
6074 if (N0.getOpcode() != ISD::XOR && N0.getOpcode() != ISD::ADD)
6075 return SDValue();
6076
6077 if (N1.getOpcode() != ISD::SRA || !N0.hasOneUse() || !N1.hasOneUse() ||
6078 N0.getOperand(0) != N1.getOperand(0))
6079 return SDValue();
6080
6081 unsigned BitWidth = VT.getScalarSizeInBits();
6082 ConstantSDNode *XorC = isConstOrConstSplat(N0.getOperand(1), true);
6083 ConstantSDNode *SraC = isConstOrConstSplat(N1.getOperand(1), true);
6084 if (!XorC || !XorC->getAPIntValue().isSignMask() ||
6085 !SraC || SraC->getAPIntValue() != BitWidth - 1)
6086 return SDValue();
6087
6088 // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
6089 // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
6090 SDLoc DL(N);
6091 SDValue SignMask = DAG.getConstant(XorC->getAPIntValue(), DL, VT);
6092 return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0), SignMask);
6093 }
6094
6095 /// Given a bitwise logic operation N with a matching bitwise logic operand,
6096 /// fold a pattern where 2 of the source operands are identically shifted
6097 /// values. For example:
6098 /// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
foldLogicOfShifts(SDNode * N,SDValue LogicOp,SDValue ShiftOp,SelectionDAG & DAG)6099 static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
6100 SelectionDAG &DAG) {
6101 unsigned LogicOpcode = N->getOpcode();
6102 assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
6103 LogicOpcode == ISD::XOR)
6104 && "Expected bitwise logic operation");
6105
6106 if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
6107 return SDValue();
6108
6109 // Match another bitwise logic op and a shift.
6110 unsigned ShiftOpcode = ShiftOp.getOpcode();
6111 if (LogicOp.getOpcode() != LogicOpcode ||
6112 !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
6113 ShiftOpcode == ISD::SRA))
6114 return SDValue();
6115
6116 // Match another shift op inside the first logic operand. Handle both commuted
6117 // possibilities.
6118 // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6119 // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6120 SDValue X1 = ShiftOp.getOperand(0);
6121 SDValue Y = ShiftOp.getOperand(1);
6122 SDValue X0, Z;
6123 if (LogicOp.getOperand(0).getOpcode() == ShiftOpcode &&
6124 LogicOp.getOperand(0).getOperand(1) == Y) {
6125 X0 = LogicOp.getOperand(0).getOperand(0);
6126 Z = LogicOp.getOperand(1);
6127 } else if (LogicOp.getOperand(1).getOpcode() == ShiftOpcode &&
6128 LogicOp.getOperand(1).getOperand(1) == Y) {
6129 X0 = LogicOp.getOperand(1).getOperand(0);
6130 Z = LogicOp.getOperand(0);
6131 } else {
6132 return SDValue();
6133 }
6134
6135 EVT VT = N->getValueType(0);
6136 SDLoc DL(N);
6137 SDValue LogicX = DAG.getNode(LogicOpcode, DL, VT, X0, X1);
6138 SDValue NewShift = DAG.getNode(ShiftOpcode, DL, VT, LogicX, Y);
6139 return DAG.getNode(LogicOpcode, DL, VT, NewShift, Z);
6140 }
6141
visitAND(SDNode * N)6142 SDValue DAGCombiner::visitAND(SDNode *N) {
6143 SDValue N0 = N->getOperand(0);
6144 SDValue N1 = N->getOperand(1);
6145 EVT VT = N1.getValueType();
6146
6147 // x & x --> x
6148 if (N0 == N1)
6149 return N0;
6150
6151 // fold (and c1, c2) -> c1&c2
6152 if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, {N0, N1}))
6153 return C;
6154
6155 // canonicalize constant to RHS
6156 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
6157 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
6158 return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0);
6159
6160 // fold vector ops
6161 if (VT.isVector()) {
6162 if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
6163 return FoldedVOp;
6164
6165 // fold (and x, 0) -> 0, vector edition
6166 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
6167 // do not return N1, because undef node may exist in N1
6168 return DAG.getConstant(APInt::getZero(N1.getScalarValueSizeInBits()),
6169 SDLoc(N), N1.getValueType());
6170
6171 // fold (and x, -1) -> x, vector edition
6172 if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
6173 return N0;
6174
6175 // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
6176 auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
6177 ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
6178 if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && N0.hasOneUse() &&
6179 Splat && N1.hasOneUse()) {
6180 EVT LoadVT = MLoad->getMemoryVT();
6181 EVT ExtVT = VT;
6182 if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
6183 // For this AND to be a zero extension of the masked load the elements
6184 // of the BuildVec must mask the bottom bits of the extended element
6185 // type
6186 uint64_t ElementSize =
6187 LoadVT.getVectorElementType().getScalarSizeInBits();
6188 if (Splat->getAPIntValue().isMask(ElementSize)) {
6189 return DAG.getMaskedLoad(
6190 ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(),
6191 MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
6192 LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
6193 ISD::ZEXTLOAD, MLoad->isExpandingLoad());
6194 }
6195 }
6196 }
6197 }
6198
6199 // fold (and x, -1) -> x
6200 if (isAllOnesConstant(N1))
6201 return N0;
6202
6203 // if (and x, c) is known to be zero, return 0
6204 unsigned BitWidth = VT.getScalarSizeInBits();
6205 ConstantSDNode *N1C = isConstOrConstSplat(N1);
6206 if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(BitWidth)))
6207 return DAG.getConstant(0, SDLoc(N), VT);
6208
6209 if (SDValue NewSel = foldBinOpIntoSelect(N))
6210 return NewSel;
6211
6212 // reassociate and
6213 if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
6214 return RAND;
6215
6216 // Try to convert a constant mask AND into a shuffle clear mask.
6217 if (VT.isVector())
6218 if (SDValue Shuffle = XformToShuffleWithZero(N))
6219 return Shuffle;
6220
6221 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
6222 return Combined;
6223
6224 // fold (and (or x, C), D) -> D if (C & D) == D
6225 auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
6226 return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
6227 };
6228 if (N0.getOpcode() == ISD::OR &&
6229 ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
6230 return N1;
6231 // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
6232 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
6233 SDValue N0Op0 = N0.getOperand(0);
6234 APInt Mask = ~N1C->getAPIntValue();
6235 Mask = Mask.trunc(N0Op0.getScalarValueSizeInBits());
6236 if (DAG.MaskedValueIsZero(N0Op0, Mask)) {
6237 SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
6238 N0.getValueType(), N0Op0);
6239
6240 // Replace uses of the AND with uses of the Zero extend node.
6241 CombineTo(N, Zext);
6242
6243 // We actually want to replace all uses of the any_extend with the
6244 // zero_extend, to avoid duplicating things. This will later cause this
6245 // AND to be folded.
6246 CombineTo(N0.getNode(), Zext);
6247 return SDValue(N, 0); // Return N so it doesn't get rechecked!
6248 }
6249 }
6250
6251 // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
6252 // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
6253 // already be zero by virtue of the width of the base type of the load.
6254 //
6255 // the 'X' node here can either be nothing or an extract_vector_elt to catch
6256 // more cases.
6257 if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
6258 N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
6259 N0.getOperand(0).getOpcode() == ISD::LOAD &&
6260 N0.getOperand(0).getResNo() == 0) ||
6261 (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
6262 LoadSDNode *Load = cast<LoadSDNode>( (N0.getOpcode() == ISD::LOAD) ?
6263 N0 : N0.getOperand(0) );
6264
6265 // Get the constant (if applicable) the zero'th operand is being ANDed with.
6266 // This can be a pure constant or a vector splat, in which case we treat the
6267 // vector as a scalar and use the splat value.
6268 APInt Constant = APInt::getZero(1);
6269 if (const ConstantSDNode *C = isConstOrConstSplat(
6270 N1, /*AllowUndef=*/false, /*AllowTruncation=*/true)) {
6271 Constant = C->getAPIntValue();
6272 } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
6273 APInt SplatValue, SplatUndef;
6274 unsigned SplatBitSize;
6275 bool HasAnyUndefs;
6276 bool IsSplat = Vector->isConstantSplat(SplatValue, SplatUndef,
6277 SplatBitSize, HasAnyUndefs);
6278 if (IsSplat) {
6279 // Undef bits can contribute to a possible optimisation if set, so
6280 // set them.
6281 SplatValue |= SplatUndef;
6282
6283 // The splat value may be something like "0x00FFFFFF", which means 0 for
6284 // the first vector value and FF for the rest, repeating. We need a mask
6285 // that will apply equally to all members of the vector, so AND all the
6286 // lanes of the constant together.
6287 unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
6288
6289 // If the splat value has been compressed to a bitlength lower
6290 // than the size of the vector lane, we need to re-expand it to
6291 // the lane size.
6292 if (EltBitWidth > SplatBitSize)
6293 for (SplatValue = SplatValue.zextOrTrunc(EltBitWidth);
6294 SplatBitSize < EltBitWidth; SplatBitSize = SplatBitSize * 2)
6295 SplatValue |= SplatValue.shl(SplatBitSize);
6296
6297 // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
6298 // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
6299 if ((SplatBitSize % EltBitWidth) == 0) {
6300 Constant = APInt::getAllOnes(EltBitWidth);
6301 for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
6302 Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
6303 }
6304 }
6305 }
6306
6307 // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
6308 // actually legal and isn't going to get expanded, else this is a false
6309 // optimisation.
6310 bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
6311 Load->getValueType(0),
6312 Load->getMemoryVT());
6313
6314 // Resize the constant to the same size as the original memory access before
6315 // extension. If it is still the AllOnesValue then this AND is completely
6316 // unneeded.
6317 Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
6318
6319 bool B;
6320 switch (Load->getExtensionType()) {
6321 default: B = false; break;
6322 case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
6323 case ISD::ZEXTLOAD:
6324 case ISD::NON_EXTLOAD: B = true; break;
6325 }
6326
6327 if (B && Constant.isAllOnes()) {
6328 // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
6329 // preserve semantics once we get rid of the AND.
6330 SDValue NewLoad(Load, 0);
6331
6332 // Fold the AND away. NewLoad may get replaced immediately.
6333 CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
6334
6335 if (Load->getExtensionType() == ISD::EXTLOAD) {
6336 NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
6337 Load->getValueType(0), SDLoc(Load),
6338 Load->getChain(), Load->getBasePtr(),
6339 Load->getOffset(), Load->getMemoryVT(),
6340 Load->getMemOperand());
6341 // Replace uses of the EXTLOAD with the new ZEXTLOAD.
6342 if (Load->getNumValues() == 3) {
6343 // PRE/POST_INC loads have 3 values.
6344 SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
6345 NewLoad.getValue(2) };
6346 CombineTo(Load, To, 3, true);
6347 } else {
6348 CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
6349 }
6350 }
6351
6352 return SDValue(N, 0); // Return N so it doesn't get rechecked!
6353 }
6354 }
6355
6356 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
6357 ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
6358 SDValue Ext = N0.getOperand(0);
6359 EVT ExtVT = Ext->getValueType(0);
6360 SDValue Extendee = Ext->getOperand(0);
6361
6362 unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
6363 if (N1C->getAPIntValue().isMask(ScalarWidth) &&
6364 (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, ExtVT))) {
6365 // (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
6366 // => (extract_subvector (iN_zeroext v))
6367 SDValue ZeroExtExtendee =
6368 DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), ExtVT, Extendee);
6369
6370 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, ZeroExtExtendee,
6371 N0.getOperand(1));
6372 }
6373 }
6374
6375 // fold (and (masked_gather x)) -> (zext_masked_gather x)
6376 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
6377 EVT MemVT = GN0->getMemoryVT();
6378 EVT ScalarVT = MemVT.getScalarType();
6379
6380 if (SDValue(GN0, 0).hasOneUse() &&
6381 isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) &&
6382 TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
6383 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
6384 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
6385
6386 SDValue ZExtLoad = DAG.getMaskedGather(
6387 DAG.getVTList(VT, MVT::Other), MemVT, SDLoc(N), Ops,
6388 GN0->getMemOperand(), GN0->getIndexType(), ISD::ZEXTLOAD);
6389
6390 CombineTo(N, ZExtLoad);
6391 AddToWorklist(ZExtLoad.getNode());
6392 // Avoid recheck of N.
6393 return SDValue(N, 0);
6394 }
6395 }
6396
6397 // fold (and (load x), 255) -> (zextload x, i8)
6398 // fold (and (extload x, i16), 255) -> (zextload x, i8)
6399 if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
6400 if (SDValue Res = reduceLoadWidth(N))
6401 return Res;
6402
6403 if (LegalTypes) {
6404 // Attempt to propagate the AND back up to the leaves which, if they're
6405 // loads, can be combined to narrow loads and the AND node can be removed.
6406 // Perform after legalization so that extend nodes will already be
6407 // combined into the loads.
6408 if (BackwardsPropagateMask(N))
6409 return SDValue(N, 0);
6410 }
6411
6412 if (SDValue Combined = visitANDLike(N0, N1, N))
6413 return Combined;
6414
6415 // Simplify: (and (op x...), (op y...)) -> (op (and x, y))
6416 if (N0.getOpcode() == N1.getOpcode())
6417 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
6418 return V;
6419
6420 if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
6421 return R;
6422 if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
6423 return R;
6424
6425 // Masking the negated extension of a boolean is just the zero-extended
6426 // boolean:
6427 // and (sub 0, zext(bool X)), 1 --> zext(bool X)
6428 // and (sub 0, sext(bool X)), 1 --> zext(bool X)
6429 //
6430 // Note: the SimplifyDemandedBits fold below can make an information-losing
6431 // transform, and then we have no way to find this better fold.
6432 if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
6433 if (isNullOrNullSplat(N0.getOperand(0))) {
6434 SDValue SubRHS = N0.getOperand(1);
6435 if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
6436 SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
6437 return SubRHS;
6438 if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
6439 SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
6440 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, SubRHS.getOperand(0));
6441 }
6442 }
6443
6444 // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
6445 // fold (and (sra)) -> (and (srl)) when possible.
6446 if (SimplifyDemandedBits(SDValue(N, 0)))
6447 return SDValue(N, 0);
6448
6449 // fold (zext_inreg (extload x)) -> (zextload x)
6450 // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
6451 if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
6452 (ISD::isEXTLoad(N0.getNode()) ||
6453 (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
6454 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
6455 EVT MemVT = LN0->getMemoryVT();
6456 // If we zero all the possible extended bits, then we can turn this into
6457 // a zextload if we are running before legalize or the operation is legal.
6458 unsigned ExtBitSize = N1.getScalarValueSizeInBits();
6459 unsigned MemBitSize = MemVT.getScalarSizeInBits();
6460 APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
6461 if (DAG.MaskedValueIsZero(N1, ExtBits) &&
6462 ((!LegalOperations && LN0->isSimple()) ||
6463 TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
6464 SDValue ExtLoad =
6465 DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
6466 LN0->getBasePtr(), MemVT, LN0->getMemOperand());
6467 AddToWorklist(N);
6468 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
6469 return SDValue(N, 0); // Return N so it doesn't get rechecked!
6470 }
6471 }
6472
6473 // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
6474 if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
6475 if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
6476 N0.getOperand(1), false))
6477 return BSwap;
6478 }
6479
6480 if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
6481 return Shifts;
6482
6483 if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
6484 return V;
6485
6486 // Recognize the following pattern:
6487 //
6488 // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
6489 //
6490 // where bitmask is a mask that clears the upper bits of AndVT. The
6491 // number of bits in bitmask must be a power of two.
6492 auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
6493 if (LHS->getOpcode() != ISD::SIGN_EXTEND)
6494 return false;
6495
6496 auto *C = dyn_cast<ConstantSDNode>(RHS);
6497 if (!C)
6498 return false;
6499
6500 if (!C->getAPIntValue().isMask(
6501 LHS.getOperand(0).getValueType().getFixedSizeInBits()))
6502 return false;
6503
6504 return true;
6505 };
6506
6507 // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
6508 if (IsAndZeroExtMask(N0, N1))
6509 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, N0.getOperand(0));
6510
6511 if (hasOperation(ISD::USUBSAT, VT))
6512 if (SDValue V = foldAndToUsubsat(N, DAG))
6513 return V;
6514
6515 return SDValue();
6516 }
6517
6518 /// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
MatchBSwapHWordLow(SDNode * N,SDValue N0,SDValue N1,bool DemandHighBits)6519 SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
6520 bool DemandHighBits) {
6521 if (!LegalOperations)
6522 return SDValue();
6523
6524 EVT VT = N->getValueType(0);
6525 if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
6526 return SDValue();
6527 if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
6528 return SDValue();
6529
6530 // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
6531 bool LookPassAnd0 = false;
6532 bool LookPassAnd1 = false;
6533 if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
6534 std::swap(N0, N1);
6535 if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
6536 std::swap(N0, N1);
6537 if (N0.getOpcode() == ISD::AND) {
6538 if (!N0->hasOneUse())
6539 return SDValue();
6540 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6541 // Also handle 0xffff since the LHS is guaranteed to have zeros there.
6542 // This is needed for X86.
6543 if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
6544 N01C->getZExtValue() != 0xFFFF))
6545 return SDValue();
6546 N0 = N0.getOperand(0);
6547 LookPassAnd0 = true;
6548 }
6549
6550 if (N1.getOpcode() == ISD::AND) {
6551 if (!N1->hasOneUse())
6552 return SDValue();
6553 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
6554 if (!N11C || N11C->getZExtValue() != 0xFF)
6555 return SDValue();
6556 N1 = N1.getOperand(0);
6557 LookPassAnd1 = true;
6558 }
6559
6560 if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
6561 std::swap(N0, N1);
6562 if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
6563 return SDValue();
6564 if (!N0->hasOneUse() || !N1->hasOneUse())
6565 return SDValue();
6566
6567 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6568 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
6569 if (!N01C || !N11C)
6570 return SDValue();
6571 if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
6572 return SDValue();
6573
6574 // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
6575 SDValue N00 = N0->getOperand(0);
6576 if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
6577 if (!N00->hasOneUse())
6578 return SDValue();
6579 ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
6580 if (!N001C || N001C->getZExtValue() != 0xFF)
6581 return SDValue();
6582 N00 = N00.getOperand(0);
6583 LookPassAnd0 = true;
6584 }
6585
6586 SDValue N10 = N1->getOperand(0);
6587 if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
6588 if (!N10->hasOneUse())
6589 return SDValue();
6590 ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
6591 // Also allow 0xFFFF since the bits will be shifted out. This is needed
6592 // for X86.
6593 if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
6594 N101C->getZExtValue() != 0xFFFF))
6595 return SDValue();
6596 N10 = N10.getOperand(0);
6597 LookPassAnd1 = true;
6598 }
6599
6600 if (N00 != N10)
6601 return SDValue();
6602
6603 // Make sure everything beyond the low halfword gets set to zero since the SRL
6604 // 16 will clear the top bits.
6605 unsigned OpSizeInBits = VT.getSizeInBits();
6606 if (OpSizeInBits > 16) {
6607 // If the left-shift isn't masked out then the only way this is a bswap is
6608 // if all bits beyond the low 8 are 0. In that case the entire pattern
6609 // reduces to a left shift anyway: leave it for other parts of the combiner.
6610 if (DemandHighBits && !LookPassAnd0)
6611 return SDValue();
6612
6613 // However, if the right shift isn't masked out then it might be because
6614 // it's not needed. See if we can spot that too. If the high bits aren't
6615 // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
6616 // upper bits to be zero.
6617 if (!LookPassAnd1) {
6618 unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
6619 if (!DAG.MaskedValueIsZero(N10,
6620 APInt::getBitsSet(OpSizeInBits, 16, HighBit)))
6621 return SDValue();
6622 }
6623 }
6624
6625 SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
6626 if (OpSizeInBits > 16) {
6627 SDLoc DL(N);
6628 Res = DAG.getNode(ISD::SRL, DL, VT, Res,
6629 DAG.getConstant(OpSizeInBits - 16, DL,
6630 getShiftAmountTy(VT)));
6631 }
6632 return Res;
6633 }
6634
6635 /// Return true if the specified node is an element that makes up a 32-bit
6636 /// packed halfword byteswap.
6637 /// ((x & 0x000000ff) << 8) |
6638 /// ((x & 0x0000ff00) >> 8) |
6639 /// ((x & 0x00ff0000) << 8) |
6640 /// ((x & 0xff000000) >> 8)
isBSwapHWordElement(SDValue N,MutableArrayRef<SDNode * > Parts)6641 static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
6642 if (!N->hasOneUse())
6643 return false;
6644
6645 unsigned Opc = N.getOpcode();
6646 if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
6647 return false;
6648
6649 SDValue N0 = N.getOperand(0);
6650 unsigned Opc0 = N0.getOpcode();
6651 if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
6652 return false;
6653
6654 ConstantSDNode *N1C = nullptr;
6655 // SHL or SRL: look upstream for AND mask operand
6656 if (Opc == ISD::AND)
6657 N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6658 else if (Opc0 == ISD::AND)
6659 N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6660 if (!N1C)
6661 return false;
6662
6663 unsigned MaskByteOffset;
6664 switch (N1C->getZExtValue()) {
6665 default:
6666 return false;
6667 case 0xFF: MaskByteOffset = 0; break;
6668 case 0xFF00: MaskByteOffset = 1; break;
6669 case 0xFFFF:
6670 // In case demanded bits didn't clear the bits that will be shifted out.
6671 // This is needed for X86.
6672 if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
6673 MaskByteOffset = 1;
6674 break;
6675 }
6676 return false;
6677 case 0xFF0000: MaskByteOffset = 2; break;
6678 case 0xFF000000: MaskByteOffset = 3; break;
6679 }
6680
6681 // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
6682 if (Opc == ISD::AND) {
6683 if (MaskByteOffset == 0 || MaskByteOffset == 2) {
6684 // (x >> 8) & 0xff
6685 // (x >> 8) & 0xff0000
6686 if (Opc0 != ISD::SRL)
6687 return false;
6688 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6689 if (!C || C->getZExtValue() != 8)
6690 return false;
6691 } else {
6692 // (x << 8) & 0xff00
6693 // (x << 8) & 0xff000000
6694 if (Opc0 != ISD::SHL)
6695 return false;
6696 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6697 if (!C || C->getZExtValue() != 8)
6698 return false;
6699 }
6700 } else if (Opc == ISD::SHL) {
6701 // (x & 0xff) << 8
6702 // (x & 0xff0000) << 8
6703 if (MaskByteOffset != 0 && MaskByteOffset != 2)
6704 return false;
6705 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6706 if (!C || C->getZExtValue() != 8)
6707 return false;
6708 } else { // Opc == ISD::SRL
6709 // (x & 0xff00) >> 8
6710 // (x & 0xff000000) >> 8
6711 if (MaskByteOffset != 1 && MaskByteOffset != 3)
6712 return false;
6713 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6714 if (!C || C->getZExtValue() != 8)
6715 return false;
6716 }
6717
6718 if (Parts[MaskByteOffset])
6719 return false;
6720
6721 Parts[MaskByteOffset] = N0.getOperand(0).getNode();
6722 return true;
6723 }
6724
6725 // Match 2 elements of a packed halfword bswap.
isBSwapHWordPair(SDValue N,MutableArrayRef<SDNode * > Parts)6726 static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
6727 if (N.getOpcode() == ISD::OR)
6728 return isBSwapHWordElement(N.getOperand(0), Parts) &&
6729 isBSwapHWordElement(N.getOperand(1), Parts);
6730
6731 if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
6732 ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
6733 if (!C || C->getAPIntValue() != 16)
6734 return false;
6735 Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
6736 return true;
6737 }
6738
6739 return false;
6740 }
6741
6742 // Match this pattern:
6743 // (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
6744 // And rewrite this to:
6745 // (rotr (bswap A), 16)
matchBSwapHWordOrAndAnd(const TargetLowering & TLI,SelectionDAG & DAG,SDNode * N,SDValue N0,SDValue N1,EVT VT,EVT ShiftAmountTy)6746 static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
6747 SelectionDAG &DAG, SDNode *N, SDValue N0,
6748 SDValue N1, EVT VT, EVT ShiftAmountTy) {
6749 assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
6750 "MatchBSwapHWordOrAndAnd: expecting i32");
6751 if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
6752 return SDValue();
6753 if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
6754 return SDValue();
6755 // TODO: this is too restrictive; lifting this restriction requires more tests
6756 if (!N0->hasOneUse() || !N1->hasOneUse())
6757 return SDValue();
6758 ConstantSDNode *Mask0 = isConstOrConstSplat(N0.getOperand(1));
6759 ConstantSDNode *Mask1 = isConstOrConstSplat(N1.getOperand(1));
6760 if (!Mask0 || !Mask1)
6761 return SDValue();
6762 if (Mask0->getAPIntValue() != 0xff00ff00 ||
6763 Mask1->getAPIntValue() != 0x00ff00ff)
6764 return SDValue();
6765 SDValue Shift0 = N0.getOperand(0);
6766 SDValue Shift1 = N1.getOperand(0);
6767 if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
6768 return SDValue();
6769 ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(Shift0.getOperand(1));
6770 ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(Shift1.getOperand(1));
6771 if (!ShiftAmt0 || !ShiftAmt1)
6772 return SDValue();
6773 if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
6774 return SDValue();
6775 if (Shift0.getOperand(0) != Shift1.getOperand(0))
6776 return SDValue();
6777
6778 SDLoc DL(N);
6779 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
6780 SDValue ShAmt = DAG.getConstant(16, DL, ShiftAmountTy);
6781 return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
6782 }
6783
6784 /// Match a 32-bit packed halfword bswap. That is
6785 /// ((x & 0x000000ff) << 8) |
6786 /// ((x & 0x0000ff00) >> 8) |
6787 /// ((x & 0x00ff0000) << 8) |
6788 /// ((x & 0xff000000) >> 8)
6789 /// => (rotl (bswap x), 16)
MatchBSwapHWord(SDNode * N,SDValue N0,SDValue N1)6790 SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
6791 if (!LegalOperations)
6792 return SDValue();
6793
6794 EVT VT = N->getValueType(0);
6795 if (VT != MVT::i32)
6796 return SDValue();
6797 if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
6798 return SDValue();
6799
6800 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
6801 getShiftAmountTy(VT)))
6802 return BSwap;
6803
6804 // Try again with commuted operands.
6805 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT,
6806 getShiftAmountTy(VT)))
6807 return BSwap;
6808
6809
6810 // Look for either
6811 // (or (bswaphpair), (bswaphpair))
6812 // (or (or (bswaphpair), (and)), (and))
6813 // (or (or (and), (bswaphpair)), (and))
6814 SDNode *Parts[4] = {};
6815
6816 if (isBSwapHWordPair(N0, Parts)) {
6817 // (or (or (and), (and)), (or (and), (and)))
6818 if (!isBSwapHWordPair(N1, Parts))
6819 return SDValue();
6820 } else if (N0.getOpcode() == ISD::OR) {
6821 // (or (or (or (and), (and)), (and)), (and))
6822 if (!isBSwapHWordElement(N1, Parts))
6823 return SDValue();
6824 SDValue N00 = N0.getOperand(0);
6825 SDValue N01 = N0.getOperand(1);
6826 if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
6827 !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
6828 return SDValue();
6829 } else {
6830 return SDValue();
6831 }
6832
6833 // Make sure the parts are all coming from the same node.
6834 if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
6835 return SDValue();
6836
6837 SDLoc DL(N);
6838 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
6839 SDValue(Parts[0], 0));
6840
6841 // Result of the bswap should be rotated by 16. If it's not legal, then
6842 // do (x << 16) | (x >> 16).
6843 SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
6844 if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
6845 return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
6846 if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
6847 return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
6848 return DAG.getNode(ISD::OR, DL, VT,
6849 DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
6850 DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
6851 }
6852
6853 /// This contains all DAGCombine rules which reduce two values combined by
6854 /// an Or operation to a single value \see visitANDLike().
visitORLike(SDValue N0,SDValue N1,SDNode * N)6855 SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
6856 EVT VT = N1.getValueType();
6857 SDLoc DL(N);
6858
6859 // fold (or x, undef) -> -1
6860 if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
6861 return DAG.getAllOnesConstant(DL, VT);
6862
6863 if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
6864 return V;
6865
6866 // (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible.
6867 if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
6868 // Don't increase # computations.
6869 (N0->hasOneUse() || N1->hasOneUse())) {
6870 // We can only do this xform if we know that bits from X that are set in C2
6871 // but not in C1 are already zero. Likewise for Y.
6872 if (const ConstantSDNode *N0O1C =
6873 getAsNonOpaqueConstant(N0.getOperand(1))) {
6874 if (const ConstantSDNode *N1O1C =
6875 getAsNonOpaqueConstant(N1.getOperand(1))) {
6876 // We can only do this xform if we know that bits from X that are set in
6877 // C2 but not in C1 are already zero. Likewise for Y.
6878 const APInt &LHSMask = N0O1C->getAPIntValue();
6879 const APInt &RHSMask = N1O1C->getAPIntValue();
6880
6881 if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
6882 DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
6883 SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
6884 N0.getOperand(0), N1.getOperand(0));
6885 return DAG.getNode(ISD::AND, DL, VT, X,
6886 DAG.getConstant(LHSMask | RHSMask, DL, VT));
6887 }
6888 }
6889 }
6890 }
6891
6892 // (or (and X, M), (and X, N)) -> (and X, (or M, N))
6893 if (N0.getOpcode() == ISD::AND &&
6894 N1.getOpcode() == ISD::AND &&
6895 N0.getOperand(0) == N1.getOperand(0) &&
6896 // Don't increase # computations.
6897 (N0->hasOneUse() || N1->hasOneUse())) {
6898 SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
6899 N0.getOperand(1), N1.getOperand(1));
6900 return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
6901 }
6902
6903 return SDValue();
6904 }
6905
6906 /// OR combines for which the commuted variant will be tried as well.
visitORCommutative(SelectionDAG & DAG,SDValue N0,SDValue N1,SDNode * N)6907 static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
6908 SDNode *N) {
6909 EVT VT = N0.getValueType();
6910 if (N0.getOpcode() == ISD::AND) {
6911 SDValue N00 = N0.getOperand(0);
6912 SDValue N01 = N0.getOperand(1);
6913
6914 // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
6915 // TODO: Set AllowUndefs = true.
6916 if (getBitwiseNotOperand(N01, N00,
6917 /* AllowUndefs */ false) == N1)
6918 return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N1);
6919
6920 // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
6921 if (getBitwiseNotOperand(N00, N01,
6922 /* AllowUndefs */ false) == N1)
6923 return DAG.getNode(ISD::OR, SDLoc(N), VT, N01, N1);
6924 }
6925
6926 if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
6927 return R;
6928
6929 auto peekThroughZext = [](SDValue V) {
6930 if (V->getOpcode() == ISD::ZERO_EXTEND)
6931 return V->getOperand(0);
6932 return V;
6933 };
6934
6935 // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
6936 if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
6937 N0.getOperand(0) == N1.getOperand(0) &&
6938 peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
6939 return N0;
6940
6941 // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
6942 if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
6943 N0.getOperand(1) == N1.getOperand(0) &&
6944 peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
6945 return N0;
6946
6947 return SDValue();
6948 }
6949
visitOR(SDNode * N)6950 SDValue DAGCombiner::visitOR(SDNode *N) {
6951 SDValue N0 = N->getOperand(0);
6952 SDValue N1 = N->getOperand(1);
6953 EVT VT = N1.getValueType();
6954
6955 // x | x --> x
6956 if (N0 == N1)
6957 return N0;
6958
6959 // fold (or c1, c2) -> c1|c2
6960 if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, {N0, N1}))
6961 return C;
6962
6963 // canonicalize constant to RHS
6964 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
6965 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
6966 return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0);
6967
6968 // fold vector ops
6969 if (VT.isVector()) {
6970 if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
6971 return FoldedVOp;
6972
6973 // fold (or x, 0) -> x, vector edition
6974 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
6975 return N0;
6976
6977 // fold (or x, -1) -> -1, vector edition
6978 if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
6979 // do not return N1, because undef node may exist in N1
6980 return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType());
6981
6982 // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
6983 // Do this only if the resulting type / shuffle is legal.
6984 auto *SV0 = dyn_cast<ShuffleVectorSDNode>(N0);
6985 auto *SV1 = dyn_cast<ShuffleVectorSDNode>(N1);
6986 if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
6987 bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
6988 bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
6989 bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
6990 bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
6991 // Ensure both shuffles have a zero input.
6992 if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
6993 assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
6994 assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
6995 bool CanFold = true;
6996 int NumElts = VT.getVectorNumElements();
6997 SmallVector<int, 4> Mask(NumElts, -1);
6998
6999 for (int i = 0; i != NumElts; ++i) {
7000 int M0 = SV0->getMaskElt(i);
7001 int M1 = SV1->getMaskElt(i);
7002
7003 // Determine if either index is pointing to a zero vector.
7004 bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
7005 bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
7006
7007 // If one element is zero and the otherside is undef, keep undef.
7008 // This also handles the case that both are undef.
7009 if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
7010 continue;
7011
7012 // Make sure only one of the elements is zero.
7013 if (M0Zero == M1Zero) {
7014 CanFold = false;
7015 break;
7016 }
7017
7018 assert((M0 >= 0 || M1 >= 0) && "Undef index!");
7019
7020 // We have a zero and non-zero element. If the non-zero came from
7021 // SV0 make the index a LHS index. If it came from SV1, make it
7022 // a RHS index. We need to mod by NumElts because we don't care
7023 // which operand it came from in the original shuffles.
7024 Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
7025 }
7026
7027 if (CanFold) {
7028 SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
7029 SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
7030
7031 SDValue LegalShuffle =
7032 TLI.buildLegalVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS,
7033 Mask, DAG);
7034 if (LegalShuffle)
7035 return LegalShuffle;
7036 }
7037 }
7038 }
7039 }
7040
7041 // fold (or x, 0) -> x
7042 if (isNullConstant(N1))
7043 return N0;
7044
7045 // fold (or x, -1) -> -1
7046 if (isAllOnesConstant(N1))
7047 return N1;
7048
7049 if (SDValue NewSel = foldBinOpIntoSelect(N))
7050 return NewSel;
7051
7052 // fold (or x, c) -> c iff (x & ~c) == 0
7053 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
7054 if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
7055 return N1;
7056
7057 if (SDValue Combined = visitORLike(N0, N1, N))
7058 return Combined;
7059
7060 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7061 return Combined;
7062
7063 // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
7064 if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
7065 return BSwap;
7066 if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
7067 return BSwap;
7068
7069 // reassociate or
7070 if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
7071 return ROR;
7072
7073 // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
7074 // iff (c1 & c2) != 0 or c1/c2 are undef.
7075 auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
7076 return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
7077 };
7078 if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
7079 ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
7080 if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
7081 {N1, N0.getOperand(1)})) {
7082 SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
7083 AddToWorklist(IOR.getNode());
7084 return DAG.getNode(ISD::AND, SDLoc(N), VT, COR, IOR);
7085 }
7086 }
7087
7088 if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
7089 return Combined;
7090 if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
7091 return Combined;
7092
7093 // Simplify: (or (op x...), (op y...)) -> (op (or x, y))
7094 if (N0.getOpcode() == N1.getOpcode())
7095 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7096 return V;
7097
7098 // See if this is some rotate idiom.
7099 if (SDValue Rot = MatchRotate(N0, N1, SDLoc(N)))
7100 return Rot;
7101
7102 if (SDValue Load = MatchLoadCombine(N))
7103 return Load;
7104
7105 // Simplify the operands using demanded-bits information.
7106 if (SimplifyDemandedBits(SDValue(N, 0)))
7107 return SDValue(N, 0);
7108
7109 // If OR can be rewritten into ADD, try combines based on ADD.
7110 if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
7111 DAG.haveNoCommonBitsSet(N0, N1))
7112 if (SDValue Combined = visitADDLike(N))
7113 return Combined;
7114
7115 return SDValue();
7116 }
7117
stripConstantMask(SelectionDAG & DAG,SDValue Op,SDValue & Mask)7118 static SDValue stripConstantMask(SelectionDAG &DAG, SDValue Op, SDValue &Mask) {
7119 if (Op.getOpcode() == ISD::AND &&
7120 DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
7121 Mask = Op.getOperand(1);
7122 return Op.getOperand(0);
7123 }
7124 return Op;
7125 }
7126
7127 /// Match "(X shl/srl V1) & V2" where V2 may not be present.
matchRotateHalf(SelectionDAG & DAG,SDValue Op,SDValue & Shift,SDValue & Mask)7128 static bool matchRotateHalf(SelectionDAG &DAG, SDValue Op, SDValue &Shift,
7129 SDValue &Mask) {
7130 Op = stripConstantMask(DAG, Op, Mask);
7131 if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
7132 Shift = Op;
7133 return true;
7134 }
7135 return false;
7136 }
7137
7138 /// Helper function for visitOR to extract the needed side of a rotate idiom
7139 /// from a shl/srl/mul/udiv. This is meant to handle cases where
7140 /// InstCombine merged some outside op with one of the shifts from
7141 /// the rotate pattern.
7142 /// \returns An empty \c SDValue if the needed shift couldn't be extracted.
7143 /// Otherwise, returns an expansion of \p ExtractFrom based on the following
7144 /// patterns:
7145 ///
7146 /// (or (add v v) (shrl v bitwidth-1)):
7147 /// expands (add v v) -> (shl v 1)
7148 ///
7149 /// (or (mul v c0) (shrl (mul v c1) c2)):
7150 /// expands (mul v c0) -> (shl (mul v c1) c3)
7151 ///
7152 /// (or (udiv v c0) (shl (udiv v c1) c2)):
7153 /// expands (udiv v c0) -> (shrl (udiv v c1) c3)
7154 ///
7155 /// (or (shl v c0) (shrl (shl v c1) c2)):
7156 /// expands (shl v c0) -> (shl (shl v c1) c3)
7157 ///
7158 /// (or (shrl v c0) (shl (shrl v c1) c2)):
7159 /// expands (shrl v c0) -> (shrl (shrl v c1) c3)
7160 ///
7161 /// Such that in all cases, c3+c2==bitwidth(op v c1).
extractShiftForRotate(SelectionDAG & DAG,SDValue OppShift,SDValue ExtractFrom,SDValue & Mask,const SDLoc & DL)7162 static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
7163 SDValue ExtractFrom, SDValue &Mask,
7164 const SDLoc &DL) {
7165 assert(OppShift && ExtractFrom && "Empty SDValue");
7166 if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
7167 return SDValue();
7168
7169 ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
7170
7171 // Value and Type of the shift.
7172 SDValue OppShiftLHS = OppShift.getOperand(0);
7173 EVT ShiftedVT = OppShiftLHS.getValueType();
7174
7175 // Amount of the existing shift.
7176 ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
7177
7178 // (add v v) -> (shl v 1)
7179 // TODO: Should this be a general DAG canonicalization?
7180 if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
7181 ExtractFrom.getOpcode() == ISD::ADD &&
7182 ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
7183 ExtractFrom.getOperand(0) == OppShiftLHS &&
7184 OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
7185 return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
7186 DAG.getShiftAmountConstant(1, ShiftedVT, DL));
7187
7188 // Preconditions:
7189 // (or (op0 v c0) (shiftl/r (op0 v c1) c2))
7190 //
7191 // Find opcode of the needed shift to be extracted from (op0 v c0).
7192 unsigned Opcode = ISD::DELETED_NODE;
7193 bool IsMulOrDiv = false;
7194 // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
7195 // opcode or its arithmetic (mul or udiv) variant.
7196 auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
7197 IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
7198 if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
7199 return false;
7200 Opcode = NeededShift;
7201 return true;
7202 };
7203 // op0 must be either the needed shift opcode or the mul/udiv equivalent
7204 // that the needed shift can be extracted from.
7205 if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
7206 (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
7207 return SDValue();
7208
7209 // op0 must be the same opcode on both sides, have the same LHS argument,
7210 // and produce the same value type.
7211 if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
7212 OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
7213 ShiftedVT != ExtractFrom.getValueType())
7214 return SDValue();
7215
7216 // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
7217 ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
7218 // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
7219 ConstantSDNode *ExtractFromCst =
7220 isConstOrConstSplat(ExtractFrom.getOperand(1));
7221 // TODO: We should be able to handle non-uniform constant vectors for these values
7222 // Check that we have constant values.
7223 if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
7224 !OppLHSCst || !OppLHSCst->getAPIntValue() ||
7225 !ExtractFromCst || !ExtractFromCst->getAPIntValue())
7226 return SDValue();
7227
7228 // Compute the shift amount we need to extract to complete the rotate.
7229 const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
7230 if (OppShiftCst->getAPIntValue().ugt(VTWidth))
7231 return SDValue();
7232 APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
7233 // Normalize the bitwidth of the two mul/udiv/shift constant operands.
7234 APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
7235 APInt OppLHSAmt = OppLHSCst->getAPIntValue();
7236 zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
7237
7238 // Now try extract the needed shift from the ExtractFrom op and see if the
7239 // result matches up with the existing shift's LHS op.
7240 if (IsMulOrDiv) {
7241 // Op to extract from is a mul or udiv by a constant.
7242 // Check:
7243 // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
7244 // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
7245 const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
7246 NeededShiftAmt.getZExtValue());
7247 APInt ResultAmt;
7248 APInt Rem;
7249 APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
7250 if (Rem != 0 || ResultAmt != OppLHSAmt)
7251 return SDValue();
7252 } else {
7253 // Op to extract from is a shift by a constant.
7254 // Check:
7255 // c2 - (bitwidth(op0 v c0) - c1) == c0
7256 if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
7257 ExtractFromAmt.getBitWidth()))
7258 return SDValue();
7259 }
7260
7261 // Return the expanded shift op that should allow a rotate to be formed.
7262 EVT ShiftVT = OppShift.getOperand(1).getValueType();
7263 EVT ResVT = ExtractFrom.getValueType();
7264 SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
7265 return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
7266 }
7267
7268 // Return true if we can prove that, whenever Neg and Pos are both in the
7269 // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that
7270 // for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
7271 //
7272 // (or (shift1 X, Neg), (shift2 X, Pos))
7273 //
7274 // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
7275 // in direction shift1 by Neg. The range [0, EltSize) means that we only need
7276 // to consider shift amounts with defined behavior.
7277 //
7278 // The IsRotate flag should be set when the LHS of both shifts is the same.
7279 // Otherwise if matching a general funnel shift, it should be clear.
matchRotateSub(SDValue Pos,SDValue Neg,unsigned EltSize,SelectionDAG & DAG,bool IsRotate)7280 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
7281 SelectionDAG &DAG, bool IsRotate) {
7282 const auto &TLI = DAG.getTargetLoweringInfo();
7283 // If EltSize is a power of 2 then:
7284 //
7285 // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
7286 // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
7287 //
7288 // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
7289 // for the stronger condition:
7290 //
7291 // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A]
7292 //
7293 // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
7294 // we can just replace Neg with Neg' for the rest of the function.
7295 //
7296 // In other cases we check for the even stronger condition:
7297 //
7298 // Neg == EltSize - Pos [B]
7299 //
7300 // for all Neg and Pos. Note that the (or ...) then invokes undefined
7301 // behavior if Pos == 0 (and consequently Neg == EltSize).
7302 //
7303 // We could actually use [A] whenever EltSize is a power of 2, but the
7304 // only extra cases that it would match are those uninteresting ones
7305 // where Neg and Pos are never in range at the same time. E.g. for
7306 // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
7307 // as well as (sub 32, Pos), but:
7308 //
7309 // (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
7310 //
7311 // always invokes undefined behavior for 32-bit X.
7312 //
7313 // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
7314 // This allows us to peek through any operations that only affect Mask's
7315 // un-demanded bits.
7316 //
7317 // NOTE: We can only do this when matching operations which won't modify the
7318 // least Log2(EltSize) significant bits and not a general funnel shift.
7319 unsigned MaskLoBits = 0;
7320 if (IsRotate && isPowerOf2_64(EltSize)) {
7321 unsigned Bits = Log2_64(EltSize);
7322 unsigned NegBits = Neg.getScalarValueSizeInBits();
7323 if (NegBits >= Bits) {
7324 APInt DemandedBits = APInt::getLowBitsSet(NegBits, Bits);
7325 if (SDValue Inner =
7326 TLI.SimplifyMultipleUseDemandedBits(Neg, DemandedBits, DAG)) {
7327 Neg = Inner;
7328 MaskLoBits = Bits;
7329 }
7330 }
7331 }
7332
7333 // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
7334 if (Neg.getOpcode() != ISD::SUB)
7335 return false;
7336 ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
7337 if (!NegC)
7338 return false;
7339 SDValue NegOp1 = Neg.getOperand(1);
7340
7341 // On the RHS of [A], if Pos is the result of operation on Pos' that won't
7342 // affect Mask's demanded bits, just replace Pos with Pos'. These operations
7343 // are redundant for the purpose of the equality.
7344 if (MaskLoBits) {
7345 unsigned PosBits = Pos.getScalarValueSizeInBits();
7346 if (PosBits >= MaskLoBits) {
7347 APInt DemandedBits = APInt::getLowBitsSet(PosBits, MaskLoBits);
7348 if (SDValue Inner =
7349 TLI.SimplifyMultipleUseDemandedBits(Pos, DemandedBits, DAG)) {
7350 Pos = Inner;
7351 }
7352 }
7353 }
7354
7355 // The condition we need is now:
7356 //
7357 // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
7358 //
7359 // If NegOp1 == Pos then we need:
7360 //
7361 // EltSize & Mask == NegC & Mask
7362 //
7363 // (because "x & Mask" is a truncation and distributes through subtraction).
7364 //
7365 // We also need to account for a potential truncation of NegOp1 if the amount
7366 // has already been legalized to a shift amount type.
7367 APInt Width;
7368 if ((Pos == NegOp1) ||
7369 (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(0)))
7370 Width = NegC->getAPIntValue();
7371
7372 // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
7373 // Then the condition we want to prove becomes:
7374 //
7375 // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
7376 //
7377 // which, again because "x & Mask" is a truncation, becomes:
7378 //
7379 // NegC & Mask == (EltSize - PosC) & Mask
7380 // EltSize & Mask == (NegC + PosC) & Mask
7381 else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
7382 if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
7383 Width = PosC->getAPIntValue() + NegC->getAPIntValue();
7384 else
7385 return false;
7386 } else
7387 return false;
7388
7389 // Now we just need to check that EltSize & Mask == Width & Mask.
7390 if (MaskLoBits)
7391 // EltSize & Mask is 0 since Mask is EltSize - 1.
7392 return Width.getLoBits(MaskLoBits) == 0;
7393 return Width == EltSize;
7394 }
7395
7396 // A subroutine of MatchRotate used once we have found an OR of two opposite
7397 // shifts of Shifted. If Neg == <operand size> - Pos then the OR reduces
7398 // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
7399 // former being preferred if supported. InnerPos and InnerNeg are Pos and
7400 // Neg with outer conversions stripped away.
MatchRotatePosNeg(SDValue Shifted,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)7401 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
7402 SDValue Neg, SDValue InnerPos,
7403 SDValue InnerNeg, bool HasPos,
7404 unsigned PosOpcode, unsigned NegOpcode,
7405 const SDLoc &DL) {
7406 // fold (or (shl x, (*ext y)),
7407 // (srl x, (*ext (sub 32, y)))) ->
7408 // (rotl x, y) or (rotr x, (sub 32, y))
7409 //
7410 // fold (or (shl x, (*ext (sub 32, y))),
7411 // (srl x, (*ext y))) ->
7412 // (rotr x, y) or (rotl x, (sub 32, y))
7413 EVT VT = Shifted.getValueType();
7414 if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
7415 /*IsRotate*/ true)) {
7416 return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
7417 HasPos ? Pos : Neg);
7418 }
7419
7420 return SDValue();
7421 }
7422
7423 // A subroutine of MatchRotate used once we have found an OR of two opposite
7424 // shifts of N0 + N1. If Neg == <operand size> - Pos then the OR reduces
7425 // to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
7426 // former being preferred if supported. InnerPos and InnerNeg are Pos and
7427 // Neg with outer conversions stripped away.
7428 // TODO: Merge with MatchRotatePosNeg.
MatchFunnelPosNeg(SDValue N0,SDValue N1,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)7429 SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
7430 SDValue Neg, SDValue InnerPos,
7431 SDValue InnerNeg, bool HasPos,
7432 unsigned PosOpcode, unsigned NegOpcode,
7433 const SDLoc &DL) {
7434 EVT VT = N0.getValueType();
7435 unsigned EltBits = VT.getScalarSizeInBits();
7436
7437 // fold (or (shl x0, (*ext y)),
7438 // (srl x1, (*ext (sub 32, y)))) ->
7439 // (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
7440 //
7441 // fold (or (shl x0, (*ext (sub 32, y))),
7442 // (srl x1, (*ext y))) ->
7443 // (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
7444 if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) {
7445 return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
7446 HasPos ? Pos : Neg);
7447 }
7448
7449 // Matching the shift+xor cases, we can't easily use the xor'd shift amount
7450 // so for now just use the PosOpcode case if its legal.
7451 // TODO: When can we use the NegOpcode case?
7452 if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
7453 auto IsBinOpImm = [](SDValue Op, unsigned BinOpc, unsigned Imm) {
7454 if (Op.getOpcode() != BinOpc)
7455 return false;
7456 ConstantSDNode *Cst = isConstOrConstSplat(Op.getOperand(1));
7457 return Cst && (Cst->getAPIntValue() == Imm);
7458 };
7459
7460 // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
7461 // -> (fshl x0, x1, y)
7462 if (IsBinOpImm(N1, ISD::SRL, 1) &&
7463 IsBinOpImm(InnerNeg, ISD::XOR, EltBits - 1) &&
7464 InnerPos == InnerNeg.getOperand(0) &&
7465 TLI.isOperationLegalOrCustom(ISD::FSHL, VT)) {
7466 return DAG.getNode(ISD::FSHL, DL, VT, N0, N1.getOperand(0), Pos);
7467 }
7468
7469 // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
7470 // -> (fshr x0, x1, y)
7471 if (IsBinOpImm(N0, ISD::SHL, 1) &&
7472 IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
7473 InnerNeg == InnerPos.getOperand(0) &&
7474 TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
7475 return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
7476 }
7477
7478 // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
7479 // -> (fshr x0, x1, y)
7480 // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
7481 if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N0.getOperand(1) &&
7482 IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
7483 InnerNeg == InnerPos.getOperand(0) &&
7484 TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
7485 return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
7486 }
7487 }
7488
7489 return SDValue();
7490 }
7491
7492 // MatchRotate - Handle an 'or' of two operands. If this is one of the many
7493 // idioms for rotate, and if the target supports rotation instructions, generate
7494 // a rot[lr]. This also matches funnel shift patterns, similar to rotation but
7495 // with different shifted sources.
MatchRotate(SDValue LHS,SDValue RHS,const SDLoc & DL)7496 SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
7497 EVT VT = LHS.getValueType();
7498
7499 // The target must have at least one rotate/funnel flavor.
7500 // We still try to match rotate by constant pre-legalization.
7501 // TODO: Support pre-legalization funnel-shift by constant.
7502 bool HasROTL = hasOperation(ISD::ROTL, VT);
7503 bool HasROTR = hasOperation(ISD::ROTR, VT);
7504 bool HasFSHL = hasOperation(ISD::FSHL, VT);
7505 bool HasFSHR = hasOperation(ISD::FSHR, VT);
7506
7507 // If the type is going to be promoted and the target has enabled custom
7508 // lowering for rotate, allow matching rotate by non-constants. Only allow
7509 // this for scalar types.
7510 if (VT.isScalarInteger() && TLI.getTypeAction(*DAG.getContext(), VT) ==
7511 TargetLowering::TypePromoteInteger) {
7512 HasROTL |= TLI.getOperationAction(ISD::ROTL, VT) == TargetLowering::Custom;
7513 HasROTR |= TLI.getOperationAction(ISD::ROTR, VT) == TargetLowering::Custom;
7514 }
7515
7516 if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
7517 return SDValue();
7518
7519 // Check for truncated rotate.
7520 if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
7521 LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
7522 assert(LHS.getValueType() == RHS.getValueType());
7523 if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
7524 return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
7525 }
7526 }
7527
7528 // Match "(X shl/srl V1) & V2" where V2 may not be present.
7529 SDValue LHSShift; // The shift.
7530 SDValue LHSMask; // AND value if any.
7531 matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
7532
7533 SDValue RHSShift; // The shift.
7534 SDValue RHSMask; // AND value if any.
7535 matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
7536
7537 // If neither side matched a rotate half, bail
7538 if (!LHSShift && !RHSShift)
7539 return SDValue();
7540
7541 // InstCombine may have combined a constant shl, srl, mul, or udiv with one
7542 // side of the rotate, so try to handle that here. In all cases we need to
7543 // pass the matched shift from the opposite side to compute the opcode and
7544 // needed shift amount to extract. We still want to do this if both sides
7545 // matched a rotate half because one half may be a potential overshift that
7546 // can be broken down (ie if InstCombine merged two shl or srl ops into a
7547 // single one).
7548
7549 // Have LHS side of the rotate, try to extract the needed shift from the RHS.
7550 if (LHSShift)
7551 if (SDValue NewRHSShift =
7552 extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
7553 RHSShift = NewRHSShift;
7554 // Have RHS side of the rotate, try to extract the needed shift from the LHS.
7555 if (RHSShift)
7556 if (SDValue NewLHSShift =
7557 extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
7558 LHSShift = NewLHSShift;
7559
7560 // If a side is still missing, nothing else we can do.
7561 if (!RHSShift || !LHSShift)
7562 return SDValue();
7563
7564 // At this point we've matched or extracted a shift op on each side.
7565
7566 if (LHSShift.getOpcode() == RHSShift.getOpcode())
7567 return SDValue(); // Shifts must disagree.
7568
7569 // Canonicalize shl to left side in a shl/srl pair.
7570 if (RHSShift.getOpcode() == ISD::SHL) {
7571 std::swap(LHS, RHS);
7572 std::swap(LHSShift, RHSShift);
7573 std::swap(LHSMask, RHSMask);
7574 }
7575
7576 // Something has gone wrong - we've lost the shl/srl pair - bail.
7577 if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
7578 return SDValue();
7579
7580 unsigned EltSizeInBits = VT.getScalarSizeInBits();
7581 SDValue LHSShiftArg = LHSShift.getOperand(0);
7582 SDValue LHSShiftAmt = LHSShift.getOperand(1);
7583 SDValue RHSShiftArg = RHSShift.getOperand(0);
7584 SDValue RHSShiftAmt = RHSShift.getOperand(1);
7585
7586 auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
7587 ConstantSDNode *RHS) {
7588 return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
7589 };
7590
7591 auto ApplyMasks = [&](SDValue Res) {
7592 // If there is an AND of either shifted operand, apply it to the result.
7593 if (LHSMask.getNode() || RHSMask.getNode()) {
7594 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
7595 SDValue Mask = AllOnes;
7596
7597 if (LHSMask.getNode()) {
7598 SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
7599 Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
7600 DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
7601 }
7602 if (RHSMask.getNode()) {
7603 SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
7604 Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
7605 DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
7606 }
7607
7608 Res = DAG.getNode(ISD::AND, DL, VT, Res, Mask);
7609 }
7610
7611 return Res;
7612 };
7613
7614 // TODO: Support pre-legalization funnel-shift by constant.
7615 bool IsRotate = LHSShift.getOperand(0) == RHSShift.getOperand(0);
7616 if (!IsRotate && !(HasFSHL || HasFSHR)) {
7617 if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
7618 ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
7619 // Look for a disguised rotate by constant.
7620 // The common shifted operand X may be hidden inside another 'or'.
7621 SDValue X, Y;
7622 auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
7623 if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
7624 return false;
7625 if (CommonOp == Or.getOperand(0)) {
7626 X = CommonOp;
7627 Y = Or.getOperand(1);
7628 return true;
7629 }
7630 if (CommonOp == Or.getOperand(1)) {
7631 X = CommonOp;
7632 Y = Or.getOperand(0);
7633 return true;
7634 }
7635 return false;
7636 };
7637
7638 SDValue Res;
7639 if (matchOr(LHSShiftArg, RHSShiftArg)) {
7640 // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
7641 SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
7642 SDValue ShlY = DAG.getNode(ISD::SHL, DL, VT, Y, LHSShiftAmt);
7643 Res = DAG.getNode(ISD::OR, DL, VT, RotX, ShlY);
7644 } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
7645 // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
7646 SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
7647 SDValue SrlY = DAG.getNode(ISD::SRL, DL, VT, Y, RHSShiftAmt);
7648 Res = DAG.getNode(ISD::OR, DL, VT, RotX, SrlY);
7649 } else {
7650 return SDValue();
7651 }
7652
7653 return ApplyMasks(Res);
7654 }
7655
7656 return SDValue(); // Requires funnel shift support.
7657 }
7658
7659 // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
7660 // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
7661 // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
7662 // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
7663 // iff C1+C2 == EltSizeInBits
7664 if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
7665 SDValue Res;
7666 if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
7667 bool UseROTL = !LegalOperations || HasROTL;
7668 Res = DAG.getNode(UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
7669 UseROTL ? LHSShiftAmt : RHSShiftAmt);
7670 } else {
7671 bool UseFSHL = !LegalOperations || HasFSHL;
7672 Res = DAG.getNode(UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg,
7673 RHSShiftArg, UseFSHL ? LHSShiftAmt : RHSShiftAmt);
7674 }
7675
7676 return ApplyMasks(Res);
7677 }
7678
7679 // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
7680 // shift.
7681 if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
7682 return SDValue();
7683
7684 // If there is a mask here, and we have a variable shift, we can't be sure
7685 // that we're masking out the right stuff.
7686 if (LHSMask.getNode() || RHSMask.getNode())
7687 return SDValue();
7688
7689 // If the shift amount is sign/zext/any-extended just peel it off.
7690 SDValue LExtOp0 = LHSShiftAmt;
7691 SDValue RExtOp0 = RHSShiftAmt;
7692 if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
7693 LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
7694 LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
7695 LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
7696 (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
7697 RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
7698 RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
7699 RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
7700 LExtOp0 = LHSShiftAmt.getOperand(0);
7701 RExtOp0 = RHSShiftAmt.getOperand(0);
7702 }
7703
7704 if (IsRotate && (HasROTL || HasROTR)) {
7705 SDValue TryL =
7706 MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0,
7707 RExtOp0, HasROTL, ISD::ROTL, ISD::ROTR, DL);
7708 if (TryL)
7709 return TryL;
7710
7711 SDValue TryR =
7712 MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0,
7713 LExtOp0, HasROTR, ISD::ROTR, ISD::ROTL, DL);
7714 if (TryR)
7715 return TryR;
7716 }
7717
7718 SDValue TryL =
7719 MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt,
7720 LExtOp0, RExtOp0, HasFSHL, ISD::FSHL, ISD::FSHR, DL);
7721 if (TryL)
7722 return TryL;
7723
7724 SDValue TryR =
7725 MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
7726 RExtOp0, LExtOp0, HasFSHR, ISD::FSHR, ISD::FSHL, DL);
7727 if (TryR)
7728 return TryR;
7729
7730 return SDValue();
7731 }
7732
7733 namespace {
7734
7735 /// Represents known origin of an individual byte in load combine pattern. The
7736 /// value of the byte is either constant zero or comes from memory.
7737 struct ByteProvider {
7738 // For constant zero providers Load is set to nullptr. For memory providers
7739 // Load represents the node which loads the byte from memory.
7740 // ByteOffset is the offset of the byte in the value produced by the load.
7741 LoadSDNode *Load = nullptr;
7742 unsigned ByteOffset = 0;
7743
7744 ByteProvider() = default;
7745
getMemory__anon54f00e401611::ByteProvider7746 static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) {
7747 return ByteProvider(Load, ByteOffset);
7748 }
7749
getConstantZero__anon54f00e401611::ByteProvider7750 static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0); }
7751
isConstantZero__anon54f00e401611::ByteProvider7752 bool isConstantZero() const { return !Load; }
isMemory__anon54f00e401611::ByteProvider7753 bool isMemory() const { return Load; }
7754
operator ==__anon54f00e401611::ByteProvider7755 bool operator==(const ByteProvider &Other) const {
7756 return Other.Load == Load && Other.ByteOffset == ByteOffset;
7757 }
7758
7759 private:
ByteProvider__anon54f00e401611::ByteProvider7760 ByteProvider(LoadSDNode *Load, unsigned ByteOffset)
7761 : Load(Load), ByteOffset(ByteOffset) {}
7762 };
7763
7764 } // end anonymous namespace
7765
7766 /// Recursively traverses the expression calculating the origin of the requested
7767 /// byte of the given value. Returns None if the provider can't be calculated.
7768 ///
7769 /// For all the values except the root of the expression verifies that the value
7770 /// has exactly one use and if it's not true return None. This way if the origin
7771 /// of the byte is returned it's guaranteed that the values which contribute to
7772 /// the byte are not used outside of this expression.
7773 ///
7774 /// Because the parts of the expression are not allowed to have more than one
7775 /// use this function iterates over trees, not DAGs. So it never visits the same
7776 /// node more than once.
7777 static const Optional<ByteProvider>
calculateByteProvider(SDValue Op,unsigned Index,unsigned Depth,bool Root=false)7778 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
7779 bool Root = false) {
7780 // Typical i64 by i8 pattern requires recursion up to 8 calls depth
7781 if (Depth == 10)
7782 return None;
7783
7784 if (!Root && !Op.hasOneUse())
7785 return None;
7786
7787 assert(Op.getValueType().isScalarInteger() && "can't handle other types");
7788 unsigned BitWidth = Op.getValueSizeInBits();
7789 if (BitWidth % 8 != 0)
7790 return None;
7791 unsigned ByteWidth = BitWidth / 8;
7792 assert(Index < ByteWidth && "invalid index requested");
7793 (void) ByteWidth;
7794
7795 switch (Op.getOpcode()) {
7796 case ISD::OR: {
7797 auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1);
7798 if (!LHS)
7799 return None;
7800 auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1);
7801 if (!RHS)
7802 return None;
7803
7804 if (LHS->isConstantZero())
7805 return RHS;
7806 if (RHS->isConstantZero())
7807 return LHS;
7808 return None;
7809 }
7810 case ISD::SHL: {
7811 auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
7812 if (!ShiftOp)
7813 return None;
7814
7815 uint64_t BitShift = ShiftOp->getZExtValue();
7816 if (BitShift % 8 != 0)
7817 return None;
7818 uint64_t ByteShift = BitShift / 8;
7819
7820 return Index < ByteShift
7821 ? ByteProvider::getConstantZero()
7822 : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
7823 Depth + 1);
7824 }
7825 case ISD::ANY_EXTEND:
7826 case ISD::SIGN_EXTEND:
7827 case ISD::ZERO_EXTEND: {
7828 SDValue NarrowOp = Op->getOperand(0);
7829 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
7830 if (NarrowBitWidth % 8 != 0)
7831 return None;
7832 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
7833
7834 if (Index >= NarrowByteWidth)
7835 return Op.getOpcode() == ISD::ZERO_EXTEND
7836 ? Optional<ByteProvider>(ByteProvider::getConstantZero())
7837 : None;
7838 return calculateByteProvider(NarrowOp, Index, Depth + 1);
7839 }
7840 case ISD::BSWAP:
7841 return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
7842 Depth + 1);
7843 case ISD::LOAD: {
7844 auto L = cast<LoadSDNode>(Op.getNode());
7845 if (!L->isSimple() || L->isIndexed())
7846 return None;
7847
7848 unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
7849 if (NarrowBitWidth % 8 != 0)
7850 return None;
7851 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
7852
7853 if (Index >= NarrowByteWidth)
7854 return L->getExtensionType() == ISD::ZEXTLOAD
7855 ? Optional<ByteProvider>(ByteProvider::getConstantZero())
7856 : None;
7857 return ByteProvider::getMemory(L, Index);
7858 }
7859 }
7860
7861 return None;
7862 }
7863
littleEndianByteAt(unsigned BW,unsigned i)7864 static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
7865 return i;
7866 }
7867
bigEndianByteAt(unsigned BW,unsigned i)7868 static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
7869 return BW - i - 1;
7870 }
7871
7872 // Check if the bytes offsets we are looking at match with either big or
7873 // little endian value loaded. Return true for big endian, false for little
7874 // endian, and None if match failed.
isBigEndian(const ArrayRef<int64_t> ByteOffsets,int64_t FirstOffset)7875 static Optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
7876 int64_t FirstOffset) {
7877 // The endian can be decided only when it is 2 bytes at least.
7878 unsigned Width = ByteOffsets.size();
7879 if (Width < 2)
7880 return None;
7881
7882 bool BigEndian = true, LittleEndian = true;
7883 for (unsigned i = 0; i < Width; i++) {
7884 int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
7885 LittleEndian &= CurrentByteOffset == littleEndianByteAt(Width, i);
7886 BigEndian &= CurrentByteOffset == bigEndianByteAt(Width, i);
7887 if (!BigEndian && !LittleEndian)
7888 return None;
7889 }
7890
7891 assert((BigEndian != LittleEndian) && "It should be either big endian or"
7892 "little endian");
7893 return BigEndian;
7894 }
7895
stripTruncAndExt(SDValue Value)7896 static SDValue stripTruncAndExt(SDValue Value) {
7897 switch (Value.getOpcode()) {
7898 case ISD::TRUNCATE:
7899 case ISD::ZERO_EXTEND:
7900 case ISD::SIGN_EXTEND:
7901 case ISD::ANY_EXTEND:
7902 return stripTruncAndExt(Value.getOperand(0));
7903 }
7904 return Value;
7905 }
7906
7907 /// Match a pattern where a wide type scalar value is stored by several narrow
7908 /// stores. Fold it into a single store or a BSWAP and a store if the targets
7909 /// supports it.
7910 ///
7911 /// Assuming little endian target:
7912 /// i8 *p = ...
7913 /// i32 val = ...
7914 /// p[0] = (val >> 0) & 0xFF;
7915 /// p[1] = (val >> 8) & 0xFF;
7916 /// p[2] = (val >> 16) & 0xFF;
7917 /// p[3] = (val >> 24) & 0xFF;
7918 /// =>
7919 /// *((i32)p) = val;
7920 ///
7921 /// i8 *p = ...
7922 /// i32 val = ...
7923 /// p[0] = (val >> 24) & 0xFF;
7924 /// p[1] = (val >> 16) & 0xFF;
7925 /// p[2] = (val >> 8) & 0xFF;
7926 /// p[3] = (val >> 0) & 0xFF;
7927 /// =>
7928 /// *((i32)p) = BSWAP(val);
mergeTruncStores(StoreSDNode * N)7929 SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
7930 // The matching looks for "store (trunc x)" patterns that appear early but are
7931 // likely to be replaced by truncating store nodes during combining.
7932 // TODO: If there is evidence that running this later would help, this
7933 // limitation could be removed. Legality checks may need to be added
7934 // for the created store and optional bswap/rotate.
7935 if (LegalOperations || OptLevel == CodeGenOpt::None)
7936 return SDValue();
7937
7938 // We only handle merging simple stores of 1-4 bytes.
7939 // TODO: Allow unordered atomics when wider type is legal (see D66309)
7940 EVT MemVT = N->getMemoryVT();
7941 if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
7942 !N->isSimple() || N->isIndexed())
7943 return SDValue();
7944
7945 // Collect all of the stores in the chain.
7946 SDValue Chain = N->getChain();
7947 SmallVector<StoreSDNode *, 8> Stores = {N};
7948 while (auto *Store = dyn_cast<StoreSDNode>(Chain)) {
7949 // All stores must be the same size to ensure that we are writing all of the
7950 // bytes in the wide value.
7951 // TODO: We could allow multiple sizes by tracking each stored byte.
7952 if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
7953 Store->isIndexed())
7954 return SDValue();
7955 Stores.push_back(Store);
7956 Chain = Store->getChain();
7957 }
7958 // There is no reason to continue if we do not have at least a pair of stores.
7959 if (Stores.size() < 2)
7960 return SDValue();
7961
7962 // Handle simple types only.
7963 LLVMContext &Context = *DAG.getContext();
7964 unsigned NumStores = Stores.size();
7965 unsigned NarrowNumBits = N->getMemoryVT().getScalarSizeInBits();
7966 unsigned WideNumBits = NumStores * NarrowNumBits;
7967 EVT WideVT = EVT::getIntegerVT(Context, WideNumBits);
7968 if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
7969 return SDValue();
7970
7971 // Check if all bytes of the source value that we are looking at are stored
7972 // to the same base address. Collect offsets from Base address into OffsetMap.
7973 SDValue SourceValue;
7974 SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
7975 int64_t FirstOffset = INT64_MAX;
7976 StoreSDNode *FirstStore = nullptr;
7977 Optional<BaseIndexOffset> Base;
7978 for (auto *Store : Stores) {
7979 // All the stores store different parts of the CombinedValue. A truncate is
7980 // required to get the partial value.
7981 SDValue Trunc = Store->getValue();
7982 if (Trunc.getOpcode() != ISD::TRUNCATE)
7983 return SDValue();
7984 // Other than the first/last part, a shift operation is required to get the
7985 // offset.
7986 int64_t Offset = 0;
7987 SDValue WideVal = Trunc.getOperand(0);
7988 if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
7989 isa<ConstantSDNode>(WideVal.getOperand(1))) {
7990 // The shift amount must be a constant multiple of the narrow type.
7991 // It is translated to the offset address in the wide source value "y".
7992 //
7993 // x = srl y, ShiftAmtC
7994 // i8 z = trunc x
7995 // store z, ...
7996 uint64_t ShiftAmtC = WideVal.getConstantOperandVal(1);
7997 if (ShiftAmtC % NarrowNumBits != 0)
7998 return SDValue();
7999
8000 Offset = ShiftAmtC / NarrowNumBits;
8001 WideVal = WideVal.getOperand(0);
8002 }
8003
8004 // Stores must share the same source value with different offsets.
8005 // Truncate and extends should be stripped to get the single source value.
8006 if (!SourceValue)
8007 SourceValue = WideVal;
8008 else if (stripTruncAndExt(SourceValue) != stripTruncAndExt(WideVal))
8009 return SDValue();
8010 else if (SourceValue.getValueType() != WideVT) {
8011 if (WideVal.getValueType() == WideVT ||
8012 WideVal.getScalarValueSizeInBits() >
8013 SourceValue.getScalarValueSizeInBits())
8014 SourceValue = WideVal;
8015 // Give up if the source value type is smaller than the store size.
8016 if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
8017 return SDValue();
8018 }
8019
8020 // Stores must share the same base address.
8021 BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
8022 int64_t ByteOffsetFromBase = 0;
8023 if (!Base)
8024 Base = Ptr;
8025 else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
8026 return SDValue();
8027
8028 // Remember the first store.
8029 if (ByteOffsetFromBase < FirstOffset) {
8030 FirstStore = Store;
8031 FirstOffset = ByteOffsetFromBase;
8032 }
8033 // Map the offset in the store and the offset in the combined value, and
8034 // early return if it has been set before.
8035 if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
8036 return SDValue();
8037 OffsetMap[Offset] = ByteOffsetFromBase;
8038 }
8039
8040 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
8041 assert(FirstStore && "First store must be set");
8042
8043 // Check that a store of the wide type is both allowed and fast on the target
8044 const DataLayout &Layout = DAG.getDataLayout();
8045 bool Fast = false;
8046 bool Allowed = TLI.allowsMemoryAccess(Context, Layout, WideVT,
8047 *FirstStore->getMemOperand(), &Fast);
8048 if (!Allowed || !Fast)
8049 return SDValue();
8050
8051 // Check if the pieces of the value are going to the expected places in memory
8052 // to merge the stores.
8053 auto checkOffsets = [&](bool MatchLittleEndian) {
8054 if (MatchLittleEndian) {
8055 for (unsigned i = 0; i != NumStores; ++i)
8056 if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
8057 return false;
8058 } else { // MatchBigEndian by reversing loop counter.
8059 for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
8060 if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
8061 return false;
8062 }
8063 return true;
8064 };
8065
8066 // Check if the offsets line up for the native data layout of this target.
8067 bool NeedBswap = false;
8068 bool NeedRotate = false;
8069 if (!checkOffsets(Layout.isLittleEndian())) {
8070 // Special-case: check if byte offsets line up for the opposite endian.
8071 if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
8072 NeedBswap = true;
8073 else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
8074 NeedRotate = true;
8075 else
8076 return SDValue();
8077 }
8078
8079 SDLoc DL(N);
8080 if (WideVT != SourceValue.getValueType()) {
8081 assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
8082 "Unexpected store value to merge");
8083 SourceValue = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SourceValue);
8084 }
8085
8086 // Before legalize we can introduce illegal bswaps/rotates which will be later
8087 // converted to an explicit bswap sequence. This way we end up with a single
8088 // store and byte shuffling instead of several stores and byte shuffling.
8089 if (NeedBswap) {
8090 SourceValue = DAG.getNode(ISD::BSWAP, DL, WideVT, SourceValue);
8091 } else if (NeedRotate) {
8092 assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
8093 SDValue RotAmt = DAG.getConstant(WideNumBits / 2, DL, WideVT);
8094 SourceValue = DAG.getNode(ISD::ROTR, DL, WideVT, SourceValue, RotAmt);
8095 }
8096
8097 SDValue NewStore =
8098 DAG.getStore(Chain, DL, SourceValue, FirstStore->getBasePtr(),
8099 FirstStore->getPointerInfo(), FirstStore->getAlign());
8100
8101 // Rely on other DAG combine rules to remove the other individual stores.
8102 DAG.ReplaceAllUsesWith(N, NewStore.getNode());
8103 return NewStore;
8104 }
8105
8106 /// Match a pattern where a wide type scalar value is loaded by several narrow
8107 /// loads and combined by shifts and ors. Fold it into a single load or a load
8108 /// and a BSWAP if the targets supports it.
8109 ///
8110 /// Assuming little endian target:
8111 /// i8 *a = ...
8112 /// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
8113 /// =>
8114 /// i32 val = *((i32)a)
8115 ///
8116 /// i8 *a = ...
8117 /// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
8118 /// =>
8119 /// i32 val = BSWAP(*((i32)a))
8120 ///
8121 /// TODO: This rule matches complex patterns with OR node roots and doesn't
8122 /// interact well with the worklist mechanism. When a part of the pattern is
8123 /// updated (e.g. one of the loads) its direct users are put into the worklist,
8124 /// but the root node of the pattern which triggers the load combine is not
8125 /// necessarily a direct user of the changed node. For example, once the address
8126 /// of t28 load is reassociated load combine won't be triggered:
8127 /// t25: i32 = add t4, Constant:i32<2>
8128 /// t26: i64 = sign_extend t25
8129 /// t27: i64 = add t2, t26
8130 /// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
8131 /// t29: i32 = zero_extend t28
8132 /// t32: i32 = shl t29, Constant:i8<8>
8133 /// t33: i32 = or t23, t32
8134 /// As a possible fix visitLoad can check if the load can be a part of a load
8135 /// combine pattern and add corresponding OR roots to the worklist.
MatchLoadCombine(SDNode * N)8136 SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
8137 assert(N->getOpcode() == ISD::OR &&
8138 "Can only match load combining against OR nodes");
8139
8140 // Handles simple types only
8141 EVT VT = N->getValueType(0);
8142 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
8143 return SDValue();
8144 unsigned ByteWidth = VT.getSizeInBits() / 8;
8145
8146 bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
8147 auto MemoryByteOffset = [&] (ByteProvider P) {
8148 assert(P.isMemory() && "Must be a memory byte provider");
8149 unsigned LoadBitWidth = P.Load->getMemoryVT().getSizeInBits();
8150 assert(LoadBitWidth % 8 == 0 &&
8151 "can only analyze providers for individual bytes not bit");
8152 unsigned LoadByteWidth = LoadBitWidth / 8;
8153 return IsBigEndianTarget
8154 ? bigEndianByteAt(LoadByteWidth, P.ByteOffset)
8155 : littleEndianByteAt(LoadByteWidth, P.ByteOffset);
8156 };
8157
8158 Optional<BaseIndexOffset> Base;
8159 SDValue Chain;
8160
8161 SmallPtrSet<LoadSDNode *, 8> Loads;
8162 Optional<ByteProvider> FirstByteProvider;
8163 int64_t FirstOffset = INT64_MAX;
8164
8165 // Check if all the bytes of the OR we are looking at are loaded from the same
8166 // base address. Collect bytes offsets from Base address in ByteOffsets.
8167 SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
8168 unsigned ZeroExtendedBytes = 0;
8169 for (int i = ByteWidth - 1; i >= 0; --i) {
8170 auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*Root=*/true);
8171 if (!P)
8172 return SDValue();
8173
8174 if (P->isConstantZero()) {
8175 // It's OK for the N most significant bytes to be 0, we can just
8176 // zero-extend the load.
8177 if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
8178 return SDValue();
8179 continue;
8180 }
8181 assert(P->isMemory() && "provenance should either be memory or zero");
8182
8183 LoadSDNode *L = P->Load;
8184 assert(L->hasNUsesOfValue(1, 0) && L->isSimple() &&
8185 !L->isIndexed() &&
8186 "Must be enforced by calculateByteProvider");
8187 assert(L->getOffset().isUndef() && "Unindexed load must have undef offset");
8188
8189 // All loads must share the same chain
8190 SDValue LChain = L->getChain();
8191 if (!Chain)
8192 Chain = LChain;
8193 else if (Chain != LChain)
8194 return SDValue();
8195
8196 // Loads must share the same base address
8197 BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
8198 int64_t ByteOffsetFromBase = 0;
8199 if (!Base)
8200 Base = Ptr;
8201 else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
8202 return SDValue();
8203
8204 // Calculate the offset of the current byte from the base address
8205 ByteOffsetFromBase += MemoryByteOffset(*P);
8206 ByteOffsets[i] = ByteOffsetFromBase;
8207
8208 // Remember the first byte load
8209 if (ByteOffsetFromBase < FirstOffset) {
8210 FirstByteProvider = P;
8211 FirstOffset = ByteOffsetFromBase;
8212 }
8213
8214 Loads.insert(L);
8215 }
8216 assert(!Loads.empty() && "All the bytes of the value must be loaded from "
8217 "memory, so there must be at least one load which produces the value");
8218 assert(Base && "Base address of the accessed memory location must be set");
8219 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
8220
8221 bool NeedsZext = ZeroExtendedBytes > 0;
8222
8223 EVT MemVT =
8224 EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
8225
8226 if (!MemVT.isSimple())
8227 return SDValue();
8228
8229 // Before legalize we can introduce too wide illegal loads which will be later
8230 // split into legal sized loads. This enables us to combine i64 load by i8
8231 // patterns to a couple of i32 loads on 32 bit targets.
8232 if (LegalOperations &&
8233 !TLI.isOperationLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
8234 MemVT))
8235 return SDValue();
8236
8237 // Check if the bytes of the OR we are looking at match with either big or
8238 // little endian value load
8239 Optional<bool> IsBigEndian = isBigEndian(
8240 makeArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
8241 if (!IsBigEndian)
8242 return SDValue();
8243
8244 assert(FirstByteProvider && "must be set");
8245
8246 // Ensure that the first byte is loaded from zero offset of the first load.
8247 // So the combined value can be loaded from the first load address.
8248 if (MemoryByteOffset(*FirstByteProvider) != 0)
8249 return SDValue();
8250 LoadSDNode *FirstLoad = FirstByteProvider->Load;
8251
8252 // The node we are looking at matches with the pattern, check if we can
8253 // replace it with a single (possibly zero-extended) load and bswap + shift if
8254 // needed.
8255
8256 // If the load needs byte swap check if the target supports it
8257 bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
8258
8259 // Before legalize we can introduce illegal bswaps which will be later
8260 // converted to an explicit bswap sequence. This way we end up with a single
8261 // load and byte shuffling instead of several loads and byte shuffling.
8262 // We do not introduce illegal bswaps when zero-extending as this tends to
8263 // introduce too many arithmetic instructions.
8264 if (NeedsBswap && (LegalOperations || NeedsZext) &&
8265 !TLI.isOperationLegal(ISD::BSWAP, VT))
8266 return SDValue();
8267
8268 // If we need to bswap and zero extend, we have to insert a shift. Check that
8269 // it is legal.
8270 if (NeedsBswap && NeedsZext && LegalOperations &&
8271 !TLI.isOperationLegal(ISD::SHL, VT))
8272 return SDValue();
8273
8274 // Check that a load of the wide type is both allowed and fast on the target
8275 bool Fast = false;
8276 bool Allowed =
8277 TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
8278 *FirstLoad->getMemOperand(), &Fast);
8279 if (!Allowed || !Fast)
8280 return SDValue();
8281
8282 SDValue NewLoad =
8283 DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
8284 Chain, FirstLoad->getBasePtr(),
8285 FirstLoad->getPointerInfo(), MemVT, FirstLoad->getAlign());
8286
8287 // Transfer chain users from old loads to the new load.
8288 for (LoadSDNode *L : Loads)
8289 DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
8290
8291 if (!NeedsBswap)
8292 return NewLoad;
8293
8294 SDValue ShiftedLoad =
8295 NeedsZext
8296 ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
8297 DAG.getShiftAmountConstant(ZeroExtendedBytes * 8, VT,
8298 SDLoc(N), LegalOperations))
8299 : NewLoad;
8300 return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
8301 }
8302
8303 // If the target has andn, bsl, or a similar bit-select instruction,
8304 // we want to unfold masked merge, with canonical pattern of:
8305 // | A | |B|
8306 // ((x ^ y) & m) ^ y
8307 // | D |
8308 // Into:
8309 // (x & m) | (y & ~m)
8310 // If y is a constant, m is not a 'not', and the 'andn' does not work with
8311 // immediates, we unfold into a different pattern:
8312 // ~(~x & m) & (m | y)
8313 // If x is a constant, m is a 'not', and the 'andn' does not work with
8314 // immediates, we unfold into a different pattern:
8315 // (x | ~m) & ~(~m & ~y)
8316 // NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
8317 // the very least that breaks andnpd / andnps patterns, and because those
8318 // patterns are simplified in IR and shouldn't be created in the DAG
unfoldMaskedMerge(SDNode * N)8319 SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
8320 assert(N->getOpcode() == ISD::XOR);
8321
8322 // Don't touch 'not' (i.e. where y = -1).
8323 if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
8324 return SDValue();
8325
8326 EVT VT = N->getValueType(0);
8327
8328 // There are 3 commutable operators in the pattern,
8329 // so we have to deal with 8 possible variants of the basic pattern.
8330 SDValue X, Y, M;
8331 auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
8332 if (And.getOpcode() != ISD::AND || !And.hasOneUse())
8333 return false;
8334 SDValue Xor = And.getOperand(XorIdx);
8335 if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
8336 return false;
8337 SDValue Xor0 = Xor.getOperand(0);
8338 SDValue Xor1 = Xor.getOperand(1);
8339 // Don't touch 'not' (i.e. where y = -1).
8340 if (isAllOnesOrAllOnesSplat(Xor1))
8341 return false;
8342 if (Other == Xor0)
8343 std::swap(Xor0, Xor1);
8344 if (Other != Xor1)
8345 return false;
8346 X = Xor0;
8347 Y = Xor1;
8348 M = And.getOperand(XorIdx ? 0 : 1);
8349 return true;
8350 };
8351
8352 SDValue N0 = N->getOperand(0);
8353 SDValue N1 = N->getOperand(1);
8354 if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
8355 !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
8356 return SDValue();
8357
8358 // Don't do anything if the mask is constant. This should not be reachable.
8359 // InstCombine should have already unfolded this pattern, and DAGCombiner
8360 // probably shouldn't produce it, too.
8361 if (isa<ConstantSDNode>(M.getNode()))
8362 return SDValue();
8363
8364 // We can transform if the target has AndNot
8365 if (!TLI.hasAndNot(M))
8366 return SDValue();
8367
8368 SDLoc DL(N);
8369
8370 // If Y is a constant, check that 'andn' works with immediates. Unless M is
8371 // a bitwise not that would already allow ANDN to be used.
8372 if (!TLI.hasAndNot(Y) && !isBitwiseNot(M)) {
8373 assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
8374 // If not, we need to do a bit more work to make sure andn is still used.
8375 SDValue NotX = DAG.getNOT(DL, X, VT);
8376 SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
8377 SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
8378 SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
8379 return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
8380 }
8381
8382 // If X is a constant and M is a bitwise not, check that 'andn' works with
8383 // immediates.
8384 if (!TLI.hasAndNot(X) && isBitwiseNot(M)) {
8385 assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
8386 // If not, we need to do a bit more work to make sure andn is still used.
8387 SDValue NotM = M.getOperand(0);
8388 SDValue LHS = DAG.getNode(ISD::OR, DL, VT, X, NotM);
8389 SDValue NotY = DAG.getNOT(DL, Y, VT);
8390 SDValue RHS = DAG.getNode(ISD::AND, DL, VT, NotM, NotY);
8391 SDValue NotRHS = DAG.getNOT(DL, RHS, VT);
8392 return DAG.getNode(ISD::AND, DL, VT, LHS, NotRHS);
8393 }
8394
8395 SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
8396 SDValue NotM = DAG.getNOT(DL, M, VT);
8397 SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
8398
8399 return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
8400 }
8401
visitXOR(SDNode * N)8402 SDValue DAGCombiner::visitXOR(SDNode *N) {
8403 SDValue N0 = N->getOperand(0);
8404 SDValue N1 = N->getOperand(1);
8405 EVT VT = N0.getValueType();
8406 SDLoc DL(N);
8407
8408 // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
8409 if (N0.isUndef() && N1.isUndef())
8410 return DAG.getConstant(0, DL, VT);
8411
8412 // fold (xor x, undef) -> undef
8413 if (N0.isUndef())
8414 return N0;
8415 if (N1.isUndef())
8416 return N1;
8417
8418 // fold (xor c1, c2) -> c1^c2
8419 if (SDValue C = DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, {N0, N1}))
8420 return C;
8421
8422 // canonicalize constant to RHS
8423 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
8424 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
8425 return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
8426
8427 // fold vector ops
8428 if (VT.isVector()) {
8429 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
8430 return FoldedVOp;
8431
8432 // fold (xor x, 0) -> x, vector edition
8433 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
8434 return N0;
8435 }
8436
8437 // fold (xor x, 0) -> x
8438 if (isNullConstant(N1))
8439 return N0;
8440
8441 if (SDValue NewSel = foldBinOpIntoSelect(N))
8442 return NewSel;
8443
8444 // reassociate xor
8445 if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
8446 return RXOR;
8447
8448 // look for 'add-like' folds:
8449 // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
8450 if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
8451 isMinSignedConstant(N1))
8452 if (SDValue Combined = visitADDLike(N))
8453 return Combined;
8454
8455 // fold !(x cc y) -> (x !cc y)
8456 unsigned N0Opcode = N0.getOpcode();
8457 SDValue LHS, RHS, CC;
8458 if (TLI.isConstTrueVal(N1) &&
8459 isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true)) {
8460 ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
8461 LHS.getValueType());
8462 if (!LegalOperations ||
8463 TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
8464 switch (N0Opcode) {
8465 default:
8466 llvm_unreachable("Unhandled SetCC Equivalent!");
8467 case ISD::SETCC:
8468 return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
8469 case ISD::SELECT_CC:
8470 return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
8471 N0.getOperand(3), NotCC);
8472 case ISD::STRICT_FSETCC:
8473 case ISD::STRICT_FSETCCS: {
8474 if (N0.hasOneUse()) {
8475 // FIXME Can we handle multiple uses? Could we token factor the chain
8476 // results from the new/old setcc?
8477 SDValue SetCC =
8478 DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC,
8479 N0.getOperand(0), N0Opcode == ISD::STRICT_FSETCCS);
8480 CombineTo(N, SetCC);
8481 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), SetCC.getValue(1));
8482 recursivelyDeleteUnusedNodes(N0.getNode());
8483 return SDValue(N, 0); // Return N so it doesn't get rechecked!
8484 }
8485 break;
8486 }
8487 }
8488 }
8489 }
8490
8491 // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
8492 if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
8493 isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
8494 SDValue V = N0.getOperand(0);
8495 SDLoc DL0(N0);
8496 V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
8497 DAG.getConstant(1, DL0, V.getValueType()));
8498 AddToWorklist(V.getNode());
8499 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
8500 }
8501
8502 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
8503 if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
8504 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
8505 SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
8506 if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
8507 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
8508 N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
8509 N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
8510 AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
8511 return DAG.getNode(NewOpcode, DL, VT, N00, N01);
8512 }
8513 }
8514 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
8515 if (isAllOnesConstant(N1) && N0.hasOneUse() &&
8516 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
8517 SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
8518 if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
8519 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
8520 N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
8521 N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
8522 AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
8523 return DAG.getNode(NewOpcode, DL, VT, N00, N01);
8524 }
8525 }
8526
8527 // fold (not (neg x)) -> (add X, -1)
8528 // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
8529 // Y is a constant or the subtract has a single use.
8530 if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::SUB &&
8531 isNullConstant(N0.getOperand(0))) {
8532 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
8533 DAG.getAllOnesConstant(DL, VT));
8534 }
8535
8536 // fold (not (add X, -1)) -> (neg X)
8537 if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::ADD &&
8538 isAllOnesOrAllOnesSplat(N0.getOperand(1))) {
8539 return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
8540 N0.getOperand(0));
8541 }
8542
8543 // fold (xor (and x, y), y) -> (and (not x), y)
8544 if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
8545 SDValue X = N0.getOperand(0);
8546 SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
8547 AddToWorklist(NotX.getNode());
8548 return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
8549 }
8550
8551 // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
8552 if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
8553 SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
8554 SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
8555 if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
8556 SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
8557 SDValue S0 = S.getOperand(0);
8558 if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
8559 if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
8560 if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
8561 return DAG.getNode(ISD::ABS, DL, VT, S0);
8562 }
8563 }
8564
8565 // fold (xor x, x) -> 0
8566 if (N0 == N1)
8567 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
8568
8569 // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
8570 // Here is a concrete example of this equivalence:
8571 // i16 x == 14
8572 // i16 shl == 1 << 14 == 16384 == 0b0100000000000000
8573 // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
8574 //
8575 // =>
8576 //
8577 // i16 ~1 == 0b1111111111111110
8578 // i16 rol(~1, 14) == 0b1011111111111111
8579 //
8580 // Some additional tips to help conceptualize this transform:
8581 // - Try to see the operation as placing a single zero in a value of all ones.
8582 // - There exists no value for x which would allow the result to contain zero.
8583 // - Values of x larger than the bitwidth are undefined and do not require a
8584 // consistent result.
8585 // - Pushing the zero left requires shifting one bits in from the right.
8586 // A rotate left of ~1 is a nice way of achieving the desired result.
8587 if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
8588 isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
8589 return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT),
8590 N0.getOperand(1));
8591 }
8592
8593 // Simplify: xor (op x...), (op y...) -> (op (xor x, y))
8594 if (N0Opcode == N1.getOpcode())
8595 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
8596 return V;
8597
8598 if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
8599 return R;
8600 if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
8601 return R;
8602
8603 // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable
8604 if (SDValue MM = unfoldMaskedMerge(N))
8605 return MM;
8606
8607 // Simplify the expression using non-local knowledge.
8608 if (SimplifyDemandedBits(SDValue(N, 0)))
8609 return SDValue(N, 0);
8610
8611 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
8612 return Combined;
8613
8614 return SDValue();
8615 }
8616
8617 /// If we have a shift-by-constant of a bitwise logic op that itself has a
8618 /// shift-by-constant operand with identical opcode, we may be able to convert
8619 /// that into 2 independent shifts followed by the logic op. This is a
8620 /// throughput improvement.
combineShiftOfShiftedLogic(SDNode * Shift,SelectionDAG & DAG)8621 static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
8622 // Match a one-use bitwise logic op.
8623 SDValue LogicOp = Shift->getOperand(0);
8624 if (!LogicOp.hasOneUse())
8625 return SDValue();
8626
8627 unsigned LogicOpcode = LogicOp.getOpcode();
8628 if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
8629 LogicOpcode != ISD::XOR)
8630 return SDValue();
8631
8632 // Find a matching one-use shift by constant.
8633 unsigned ShiftOpcode = Shift->getOpcode();
8634 SDValue C1 = Shift->getOperand(1);
8635 ConstantSDNode *C1Node = isConstOrConstSplat(C1);
8636 assert(C1Node && "Expected a shift with constant operand");
8637 const APInt &C1Val = C1Node->getAPIntValue();
8638 auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
8639 const APInt *&ShiftAmtVal) {
8640 if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
8641 return false;
8642
8643 ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
8644 if (!ShiftCNode)
8645 return false;
8646
8647 // Capture the shifted operand and shift amount value.
8648 ShiftOp = V.getOperand(0);
8649 ShiftAmtVal = &ShiftCNode->getAPIntValue();
8650
8651 // Shift amount types do not have to match their operand type, so check that
8652 // the constants are the same width.
8653 if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
8654 return false;
8655
8656 // The fold is not valid if the sum of the shift values exceeds bitwidth.
8657 if ((*ShiftAmtVal + C1Val).uge(V.getScalarValueSizeInBits()))
8658 return false;
8659
8660 return true;
8661 };
8662
8663 // Logic ops are commutative, so check each operand for a match.
8664 SDValue X, Y;
8665 const APInt *C0Val;
8666 if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
8667 Y = LogicOp.getOperand(1);
8668 else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
8669 Y = LogicOp.getOperand(0);
8670 else
8671 return SDValue();
8672
8673 // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
8674 SDLoc DL(Shift);
8675 EVT VT = Shift->getValueType(0);
8676 EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
8677 SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
8678 SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
8679 SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
8680 return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2);
8681 }
8682
8683 /// Handle transforms common to the three shifts, when the shift amount is a
8684 /// constant.
8685 /// We are looking for: (shift being one of shl/sra/srl)
8686 /// shift (binop X, C0), C1
8687 /// And want to transform into:
8688 /// binop (shift X, C1), (shift C0, C1)
visitShiftByConstant(SDNode * N)8689 SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
8690 assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
8691
8692 // Do not turn a 'not' into a regular xor.
8693 if (isBitwiseNot(N->getOperand(0)))
8694 return SDValue();
8695
8696 // The inner binop must be one-use, since we want to replace it.
8697 SDValue LHS = N->getOperand(0);
8698 if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
8699 return SDValue();
8700
8701 // TODO: This is limited to early combining because it may reveal regressions
8702 // otherwise. But since we just checked a target hook to see if this is
8703 // desirable, that should have filtered out cases where this interferes
8704 // with some other pattern matching.
8705 if (!LegalTypes)
8706 if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
8707 return R;
8708
8709 // We want to pull some binops through shifts, so that we have (and (shift))
8710 // instead of (shift (and)), likewise for add, or, xor, etc. This sort of
8711 // thing happens with address calculations, so it's important to canonicalize
8712 // it.
8713 switch (LHS.getOpcode()) {
8714 default:
8715 return SDValue();
8716 case ISD::OR:
8717 case ISD::XOR:
8718 case ISD::AND:
8719 break;
8720 case ISD::ADD:
8721 if (N->getOpcode() != ISD::SHL)
8722 return SDValue(); // only shl(add) not sr[al](add).
8723 break;
8724 }
8725
8726 // We require the RHS of the binop to be a constant and not opaque as well.
8727 ConstantSDNode *BinOpCst = getAsNonOpaqueConstant(LHS.getOperand(1));
8728 if (!BinOpCst)
8729 return SDValue();
8730
8731 // FIXME: disable this unless the input to the binop is a shift by a constant
8732 // or is copy/select. Enable this in other cases when figure out it's exactly
8733 // profitable.
8734 SDValue BinOpLHSVal = LHS.getOperand(0);
8735 bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
8736 BinOpLHSVal.getOpcode() == ISD::SRA ||
8737 BinOpLHSVal.getOpcode() == ISD::SRL) &&
8738 isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
8739 bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
8740 BinOpLHSVal.getOpcode() == ISD::SELECT;
8741
8742 if (!IsShiftByConstant && !IsCopyOrSelect)
8743 return SDValue();
8744
8745 if (IsCopyOrSelect && N->hasOneUse())
8746 return SDValue();
8747
8748 // Fold the constants, shifting the binop RHS by the shift amount.
8749 SDLoc DL(N);
8750 EVT VT = N->getValueType(0);
8751 SDValue NewRHS = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(1),
8752 N->getOperand(1));
8753 assert(isa<ConstantSDNode>(NewRHS) && "Folding was not successful!");
8754
8755 SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
8756 N->getOperand(1));
8757 return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
8758 }
8759
distributeTruncateThroughAnd(SDNode * N)8760 SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
8761 assert(N->getOpcode() == ISD::TRUNCATE);
8762 assert(N->getOperand(0).getOpcode() == ISD::AND);
8763
8764 // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
8765 EVT TruncVT = N->getValueType(0);
8766 if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
8767 TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
8768 SDValue N01 = N->getOperand(0).getOperand(1);
8769 if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
8770 SDLoc DL(N);
8771 SDValue N00 = N->getOperand(0).getOperand(0);
8772 SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
8773 SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
8774 AddToWorklist(Trunc00.getNode());
8775 AddToWorklist(Trunc01.getNode());
8776 return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
8777 }
8778 }
8779
8780 return SDValue();
8781 }
8782
visitRotate(SDNode * N)8783 SDValue DAGCombiner::visitRotate(SDNode *N) {
8784 SDLoc dl(N);
8785 SDValue N0 = N->getOperand(0);
8786 SDValue N1 = N->getOperand(1);
8787 EVT VT = N->getValueType(0);
8788 unsigned Bitsize = VT.getScalarSizeInBits();
8789
8790 // fold (rot x, 0) -> x
8791 if (isNullOrNullSplat(N1))
8792 return N0;
8793
8794 // fold (rot x, c) -> x iff (c % BitSize) == 0
8795 if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
8796 APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
8797 if (DAG.MaskedValueIsZero(N1, ModuloMask))
8798 return N0;
8799 }
8800
8801 // fold (rot x, c) -> (rot x, c % BitSize)
8802 bool OutOfRange = false;
8803 auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
8804 OutOfRange |= C->getAPIntValue().uge(Bitsize);
8805 return true;
8806 };
8807 if (ISD::matchUnaryPredicate(N1, MatchOutOfRange) && OutOfRange) {
8808 EVT AmtVT = N1.getValueType();
8809 SDValue Bits = DAG.getConstant(Bitsize, dl, AmtVT);
8810 if (SDValue Amt =
8811 DAG.FoldConstantArithmetic(ISD::UREM, dl, AmtVT, {N1, Bits}))
8812 return DAG.getNode(N->getOpcode(), dl, VT, N0, Amt);
8813 }
8814
8815 // rot i16 X, 8 --> bswap X
8816 auto *RotAmtC = isConstOrConstSplat(N1);
8817 if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
8818 VT.getScalarSizeInBits() == 16 && hasOperation(ISD::BSWAP, VT))
8819 return DAG.getNode(ISD::BSWAP, dl, VT, N0);
8820
8821 // Simplify the operands using demanded-bits information.
8822 if (SimplifyDemandedBits(SDValue(N, 0)))
8823 return SDValue(N, 0);
8824
8825 // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
8826 if (N1.getOpcode() == ISD::TRUNCATE &&
8827 N1.getOperand(0).getOpcode() == ISD::AND) {
8828 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8829 return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
8830 }
8831
8832 unsigned NextOp = N0.getOpcode();
8833
8834 // fold (rot* (rot* x, c2), c1)
8835 // -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize)) % bitsize)
8836 if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
8837 SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
8838 SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
8839 if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) {
8840 EVT ShiftVT = C1->getValueType(0);
8841 bool SameSide = (N->getOpcode() == NextOp);
8842 unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
8843 SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
8844 SDValue Norm1 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
8845 {N1, BitsizeC});
8846 SDValue Norm2 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
8847 {N0.getOperand(1), BitsizeC});
8848 if (Norm1 && Norm2)
8849 if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
8850 CombineOp, dl, ShiftVT, {Norm1, Norm2})) {
8851 SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
8852 ISD::UREM, dl, ShiftVT, {CombinedShift, BitsizeC});
8853 return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
8854 CombinedShiftNorm);
8855 }
8856 }
8857 }
8858 return SDValue();
8859 }
8860
visitSHL(SDNode * N)8861 SDValue DAGCombiner::visitSHL(SDNode *N) {
8862 SDValue N0 = N->getOperand(0);
8863 SDValue N1 = N->getOperand(1);
8864 if (SDValue V = DAG.simplifyShift(N0, N1))
8865 return V;
8866
8867 EVT VT = N0.getValueType();
8868 EVT ShiftVT = N1.getValueType();
8869 unsigned OpSizeInBits = VT.getScalarSizeInBits();
8870
8871 // fold (shl c1, c2) -> c1<<c2
8872 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N0, N1}))
8873 return C;
8874
8875 // fold vector ops
8876 if (VT.isVector()) {
8877 if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
8878 return FoldedVOp;
8879
8880 BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
8881 // If setcc produces all-one true value then:
8882 // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
8883 if (N1CV && N1CV->isConstant()) {
8884 if (N0.getOpcode() == ISD::AND) {
8885 SDValue N00 = N0->getOperand(0);
8886 SDValue N01 = N0->getOperand(1);
8887 BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
8888
8889 if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
8890 TLI.getBooleanContents(N00.getOperand(0).getValueType()) ==
8891 TargetLowering::ZeroOrNegativeOneBooleanContent) {
8892 if (SDValue C =
8893 DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N01, N1}))
8894 return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C);
8895 }
8896 }
8897 }
8898 }
8899
8900 if (SDValue NewSel = foldBinOpIntoSelect(N))
8901 return NewSel;
8902
8903 // if (shl x, c) is known to be zero, return 0
8904 if (DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
8905 return DAG.getConstant(0, SDLoc(N), VT);
8906
8907 // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
8908 if (N1.getOpcode() == ISD::TRUNCATE &&
8909 N1.getOperand(0).getOpcode() == ISD::AND) {
8910 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
8911 return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1);
8912 }
8913
8914 if (SimplifyDemandedBits(SDValue(N, 0)))
8915 return SDValue(N, 0);
8916
8917 // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
8918 if (N0.getOpcode() == ISD::SHL) {
8919 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
8920 ConstantSDNode *RHS) {
8921 APInt c1 = LHS->getAPIntValue();
8922 APInt c2 = RHS->getAPIntValue();
8923 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8924 return (c1 + c2).uge(OpSizeInBits);
8925 };
8926 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
8927 return DAG.getConstant(0, SDLoc(N), VT);
8928
8929 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
8930 ConstantSDNode *RHS) {
8931 APInt c1 = LHS->getAPIntValue();
8932 APInt c2 = RHS->getAPIntValue();
8933 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8934 return (c1 + c2).ult(OpSizeInBits);
8935 };
8936 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
8937 SDLoc DL(N);
8938 SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
8939 return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
8940 }
8941 }
8942
8943 // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
8944 // For this to be valid, the second form must not preserve any of the bits
8945 // that are shifted out by the inner shift in the first form. This means
8946 // the outer shift size must be >= the number of bits added by the ext.
8947 // As a corollary, we don't care what kind of ext it is.
8948 if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
8949 N0.getOpcode() == ISD::ANY_EXTEND ||
8950 N0.getOpcode() == ISD::SIGN_EXTEND) &&
8951 N0.getOperand(0).getOpcode() == ISD::SHL) {
8952 SDValue N0Op0 = N0.getOperand(0);
8953 SDValue InnerShiftAmt = N0Op0.getOperand(1);
8954 EVT InnerVT = N0Op0.getValueType();
8955 uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
8956
8957 auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
8958 ConstantSDNode *RHS) {
8959 APInt c1 = LHS->getAPIntValue();
8960 APInt c2 = RHS->getAPIntValue();
8961 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8962 return c2.uge(OpSizeInBits - InnerBitwidth) &&
8963 (c1 + c2).uge(OpSizeInBits);
8964 };
8965 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
8966 /*AllowUndefs*/ false,
8967 /*AllowTypeMismatch*/ true))
8968 return DAG.getConstant(0, SDLoc(N), VT);
8969
8970 auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
8971 ConstantSDNode *RHS) {
8972 APInt c1 = LHS->getAPIntValue();
8973 APInt c2 = RHS->getAPIntValue();
8974 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
8975 return c2.uge(OpSizeInBits - InnerBitwidth) &&
8976 (c1 + c2).ult(OpSizeInBits);
8977 };
8978 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
8979 /*AllowUndefs*/ false,
8980 /*AllowTypeMismatch*/ true)) {
8981 SDLoc DL(N);
8982 SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
8983 SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
8984 Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
8985 return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
8986 }
8987 }
8988
8989 // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
8990 // Only fold this if the inner zext has no other uses to avoid increasing
8991 // the total number of instructions.
8992 if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
8993 N0.getOperand(0).getOpcode() == ISD::SRL) {
8994 SDValue N0Op0 = N0.getOperand(0);
8995 SDValue InnerShiftAmt = N0Op0.getOperand(1);
8996
8997 auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
8998 APInt c1 = LHS->getAPIntValue();
8999 APInt c2 = RHS->getAPIntValue();
9000 zeroExtendToMatch(c1, c2);
9001 return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
9002 };
9003 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
9004 /*AllowUndefs*/ false,
9005 /*AllowTypeMismatch*/ true)) {
9006 SDLoc DL(N);
9007 EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
9008 SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
9009 NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
9010 AddToWorklist(NewSHL.getNode());
9011 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
9012 }
9013 }
9014
9015 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
9016 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
9017 ConstantSDNode *RHS) {
9018 const APInt &LHSC = LHS->getAPIntValue();
9019 const APInt &RHSC = RHS->getAPIntValue();
9020 return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
9021 LHSC.getZExtValue() <= RHSC.getZExtValue();
9022 };
9023
9024 SDLoc DL(N);
9025
9026 // fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2
9027 // fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
9028 if (N0->getFlags().hasExact()) {
9029 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
9030 /*AllowUndefs*/ false,
9031 /*AllowTypeMismatch*/ true)) {
9032 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9033 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
9034 return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
9035 }
9036 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
9037 /*AllowUndefs*/ false,
9038 /*AllowTypeMismatch*/ true)) {
9039 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9040 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
9041 return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Diff);
9042 }
9043 }
9044
9045 // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
9046 // (and (srl x, (sub c1, c2), MASK)
9047 // Only fold this if the inner shift has no other uses -- if it does,
9048 // folding this will increase the total number of instructions.
9049 if (N0.getOpcode() == ISD::SRL &&
9050 (N0.getOperand(1) == N1 || N0.hasOneUse()) &&
9051 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
9052 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
9053 /*AllowUndefs*/ false,
9054 /*AllowTypeMismatch*/ true)) {
9055 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9056 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
9057 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9058 Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N01);
9059 Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, Diff);
9060 SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
9061 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9062 }
9063 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
9064 /*AllowUndefs*/ false,
9065 /*AllowTypeMismatch*/ true)) {
9066 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9067 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
9068 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9069 Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N1);
9070 SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
9071 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9072 }
9073 }
9074 }
9075
9076 // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
9077 if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
9078 isConstantOrConstantVector(N1, /* No Opaques */ true)) {
9079 SDLoc DL(N);
9080 SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
9081 SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
9082 return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
9083 }
9084
9085 // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
9086 // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
9087 // Variant of version done on multiply, except mul by a power of 2 is turned
9088 // into a shift.
9089 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
9090 N0->hasOneUse() &&
9091 isConstantOrConstantVector(N1, /* No Opaques */ true) &&
9092 isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) &&
9093 TLI.isDesirableToCommuteWithShift(N, Level)) {
9094 SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
9095 SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
9096 AddToWorklist(Shl0.getNode());
9097 AddToWorklist(Shl1.getNode());
9098 return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Shl0, Shl1);
9099 }
9100
9101 // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
9102 if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
9103 SDValue N01 = N0.getOperand(1);
9104 if (SDValue Shl =
9105 DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1}))
9106 return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl);
9107 }
9108
9109 ConstantSDNode *N1C = isConstOrConstSplat(N1);
9110 if (N1C && !N1C->isOpaque())
9111 if (SDValue NewSHL = visitShiftByConstant(N))
9112 return NewSHL;
9113
9114 // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
9115 if (N0.getOpcode() == ISD::VSCALE)
9116 if (ConstantSDNode *NC1 = isConstOrConstSplat(N->getOperand(1))) {
9117 const APInt &C0 = N0.getConstantOperandAPInt(0);
9118 const APInt &C1 = NC1->getAPIntValue();
9119 return DAG.getVScale(SDLoc(N), VT, C0 << C1);
9120 }
9121
9122 // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
9123 APInt ShlVal;
9124 if (N0.getOpcode() == ISD::STEP_VECTOR)
9125 if (ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
9126 const APInt &C0 = N0.getConstantOperandAPInt(0);
9127 if (ShlVal.ult(C0.getBitWidth())) {
9128 APInt NewStep = C0 << ShlVal;
9129 return DAG.getStepVector(SDLoc(N), VT, NewStep);
9130 }
9131 }
9132
9133 return SDValue();
9134 }
9135
9136 // Transform a right shift of a multiply into a multiply-high.
9137 // Examples:
9138 // (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
9139 // (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
combineShiftToMULH(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)9140 static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG,
9141 const TargetLowering &TLI) {
9142 assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
9143 "SRL or SRA node is required here!");
9144
9145 // Check the shift amount. Proceed with the transformation if the shift
9146 // amount is constant.
9147 ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1));
9148 if (!ShiftAmtSrc)
9149 return SDValue();
9150
9151 SDLoc DL(N);
9152
9153 // The operation feeding into the shift must be a multiply.
9154 SDValue ShiftOperand = N->getOperand(0);
9155 if (ShiftOperand.getOpcode() != ISD::MUL)
9156 return SDValue();
9157
9158 // Both operands must be equivalent extend nodes.
9159 SDValue LeftOp = ShiftOperand.getOperand(0);
9160 SDValue RightOp = ShiftOperand.getOperand(1);
9161
9162 bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
9163 bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
9164
9165 if (!IsSignExt && !IsZeroExt)
9166 return SDValue();
9167
9168 EVT NarrowVT = LeftOp.getOperand(0).getValueType();
9169 unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
9170
9171 SDValue MulhRightOp;
9172 if (ConstantSDNode *Constant = isConstOrConstSplat(RightOp)) {
9173 unsigned ActiveBits = IsSignExt
9174 ? Constant->getAPIntValue().getMinSignedBits()
9175 : Constant->getAPIntValue().getActiveBits();
9176 if (ActiveBits > NarrowVTSize)
9177 return SDValue();
9178 MulhRightOp = DAG.getConstant(
9179 Constant->getAPIntValue().trunc(NarrowVT.getScalarSizeInBits()), DL,
9180 NarrowVT);
9181 } else {
9182 if (LeftOp.getOpcode() != RightOp.getOpcode())
9183 return SDValue();
9184 // Check that the two extend nodes are the same type.
9185 if (NarrowVT != RightOp.getOperand(0).getValueType())
9186 return SDValue();
9187 MulhRightOp = RightOp.getOperand(0);
9188 }
9189
9190 EVT WideVT = LeftOp.getValueType();
9191 // Proceed with the transformation if the wide types match.
9192 assert((WideVT == RightOp.getValueType()) &&
9193 "Cannot have a multiply node with two different operand types.");
9194
9195 // Proceed with the transformation if the wide type is twice as large
9196 // as the narrow type.
9197 if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
9198 return SDValue();
9199
9200 // Check the shift amount with the narrow type size.
9201 // Proceed with the transformation if the shift amount is the width
9202 // of the narrow type.
9203 unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
9204 if (ShiftAmt != NarrowVTSize)
9205 return SDValue();
9206
9207 // If the operation feeding into the MUL is a sign extend (sext),
9208 // we use mulhs. Othewise, zero extends (zext) use mulhu.
9209 unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
9210
9211 // Combine to mulh if mulh is legal/custom for the narrow type on the target.
9212 if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT))
9213 return SDValue();
9214
9215 SDValue Result =
9216 DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0), MulhRightOp);
9217 return (N->getOpcode() == ISD::SRA ? DAG.getSExtOrTrunc(Result, DL, WideVT)
9218 : DAG.getZExtOrTrunc(Result, DL, WideVT));
9219 }
9220
visitSRA(SDNode * N)9221 SDValue DAGCombiner::visitSRA(SDNode *N) {
9222 SDValue N0 = N->getOperand(0);
9223 SDValue N1 = N->getOperand(1);
9224 if (SDValue V = DAG.simplifyShift(N0, N1))
9225 return V;
9226
9227 EVT VT = N0.getValueType();
9228 unsigned OpSizeInBits = VT.getScalarSizeInBits();
9229
9230 // fold (sra c1, c2) -> (sra c1, c2)
9231 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, {N0, N1}))
9232 return C;
9233
9234 // Arithmetic shifting an all-sign-bit value is a no-op.
9235 // fold (sra 0, x) -> 0
9236 // fold (sra -1, x) -> -1
9237 if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
9238 return N0;
9239
9240 // fold vector ops
9241 if (VT.isVector())
9242 if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
9243 return FoldedVOp;
9244
9245 if (SDValue NewSel = foldBinOpIntoSelect(N))
9246 return NewSel;
9247
9248 // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
9249 // sext_inreg.
9250 ConstantSDNode *N1C = isConstOrConstSplat(N1);
9251 if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
9252 unsigned LowBits = OpSizeInBits - (unsigned)N1C->getZExtValue();
9253 EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), LowBits);
9254 if (VT.isVector())
9255 ExtVT = EVT::getVectorVT(*DAG.getContext(), ExtVT,
9256 VT.getVectorElementCount());
9257 if (!LegalOperations ||
9258 TLI.getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) ==
9259 TargetLowering::Legal)
9260 return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
9261 N0.getOperand(0), DAG.getValueType(ExtVT));
9262 // Even if we can't convert to sext_inreg, we might be able to remove
9263 // this shift pair if the input is already sign extended.
9264 if (DAG.ComputeNumSignBits(N0.getOperand(0)) > N1C->getZExtValue())
9265 return N0.getOperand(0);
9266 }
9267
9268 // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
9269 // clamp (add c1, c2) to max shift.
9270 if (N0.getOpcode() == ISD::SRA) {
9271 SDLoc DL(N);
9272 EVT ShiftVT = N1.getValueType();
9273 EVT ShiftSVT = ShiftVT.getScalarType();
9274 SmallVector<SDValue, 16> ShiftValues;
9275
9276 auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
9277 APInt c1 = LHS->getAPIntValue();
9278 APInt c2 = RHS->getAPIntValue();
9279 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9280 APInt Sum = c1 + c2;
9281 unsigned ShiftSum =
9282 Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
9283 ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
9284 return true;
9285 };
9286 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
9287 SDValue ShiftValue;
9288 if (N1.getOpcode() == ISD::BUILD_VECTOR)
9289 ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
9290 else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
9291 assert(ShiftValues.size() == 1 &&
9292 "Expected matchBinaryPredicate to return one element for "
9293 "SPLAT_VECTORs");
9294 ShiftValue = DAG.getSplatVector(ShiftVT, DL, ShiftValues[0]);
9295 } else
9296 ShiftValue = ShiftValues[0];
9297 return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
9298 }
9299 }
9300
9301 // fold (sra (shl X, m), (sub result_size, n))
9302 // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
9303 // result_size - n != m.
9304 // If truncate is free for the target sext(shl) is likely to result in better
9305 // code.
9306 if (N0.getOpcode() == ISD::SHL && N1C) {
9307 // Get the two constanst of the shifts, CN0 = m, CN = n.
9308 const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
9309 if (N01C) {
9310 LLVMContext &Ctx = *DAG.getContext();
9311 // Determine what the truncate's result bitsize and type would be.
9312 EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue());
9313
9314 if (VT.isVector())
9315 TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
9316
9317 // Determine the residual right-shift amount.
9318 int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
9319
9320 // If the shift is not a no-op (in which case this should be just a sign
9321 // extend already), the truncated to type is legal, sign_extend is legal
9322 // on that type, and the truncate to that type is both legal and free,
9323 // perform the transform.
9324 if ((ShiftAmt > 0) &&
9325 TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
9326 TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
9327 TLI.isTruncateFree(VT, TruncVT)) {
9328 SDLoc DL(N);
9329 SDValue Amt = DAG.getConstant(ShiftAmt, DL,
9330 getShiftAmountTy(N0.getOperand(0).getValueType()));
9331 SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
9332 N0.getOperand(0), Amt);
9333 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
9334 Shift);
9335 return DAG.getNode(ISD::SIGN_EXTEND, DL,
9336 N->getValueType(0), Trunc);
9337 }
9338 }
9339 }
9340
9341 // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
9342 // sra (add (shl X, N1C), AddC), N1C -->
9343 // sext (add (trunc X to (width - N1C)), AddC')
9344 // sra (sub AddC, (shl X, N1C)), N1C -->
9345 // sext (sub AddC1',(trunc X to (width - N1C)))
9346 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
9347 N0.hasOneUse()) {
9348 bool IsAdd = N0.getOpcode() == ISD::ADD;
9349 SDValue Shl = N0.getOperand(IsAdd ? 0 : 1);
9350 if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(1) == N1 &&
9351 Shl.hasOneUse()) {
9352 // TODO: AddC does not need to be a splat.
9353 if (ConstantSDNode *AddC =
9354 isConstOrConstSplat(N0.getOperand(IsAdd ? 1 : 0))) {
9355 // Determine what the truncate's type would be and ask the target if
9356 // that is a free operation.
9357 LLVMContext &Ctx = *DAG.getContext();
9358 unsigned ShiftAmt = N1C->getZExtValue();
9359 EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt);
9360 if (VT.isVector())
9361 TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
9362
9363 // TODO: The simple type check probably belongs in the default hook
9364 // implementation and/or target-specific overrides (because
9365 // non-simple types likely require masking when legalized), but
9366 // that restriction may conflict with other transforms.
9367 if (TruncVT.isSimple() && isTypeLegal(TruncVT) &&
9368 TLI.isTruncateFree(VT, TruncVT)) {
9369 SDLoc DL(N);
9370 SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
9371 SDValue ShiftC =
9372 DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).trunc(
9373 TruncVT.getScalarSizeInBits()),
9374 DL, TruncVT);
9375 SDValue Add;
9376 if (IsAdd)
9377 Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
9378 else
9379 Add = DAG.getNode(ISD::SUB, DL, TruncVT, ShiftC, Trunc);
9380 return DAG.getSExtOrTrunc(Add, DL, VT);
9381 }
9382 }
9383 }
9384 }
9385
9386 // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
9387 if (N1.getOpcode() == ISD::TRUNCATE &&
9388 N1.getOperand(0).getOpcode() == ISD::AND) {
9389 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9390 return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1);
9391 }
9392
9393 // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
9394 // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
9395 // if c1 is equal to the number of bits the trunc removes
9396 // TODO - support non-uniform vector shift amounts.
9397 if (N0.getOpcode() == ISD::TRUNCATE &&
9398 (N0.getOperand(0).getOpcode() == ISD::SRL ||
9399 N0.getOperand(0).getOpcode() == ISD::SRA) &&
9400 N0.getOperand(0).hasOneUse() &&
9401 N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
9402 SDValue N0Op0 = N0.getOperand(0);
9403 if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
9404 EVT LargeVT = N0Op0.getValueType();
9405 unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
9406 if (LargeShift->getAPIntValue() == TruncBits) {
9407 SDLoc DL(N);
9408 EVT LargeShiftVT = getShiftAmountTy(LargeVT);
9409 SDValue Amt = DAG.getZExtOrTrunc(N1, DL, LargeShiftVT);
9410 Amt = DAG.getNode(ISD::ADD, DL, LargeShiftVT, Amt,
9411 DAG.getConstant(TruncBits, DL, LargeShiftVT));
9412 SDValue SRA =
9413 DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
9414 return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
9415 }
9416 }
9417 }
9418
9419 // Simplify, based on bits shifted out of the LHS.
9420 if (SimplifyDemandedBits(SDValue(N, 0)))
9421 return SDValue(N, 0);
9422
9423 // If the sign bit is known to be zero, switch this to a SRL.
9424 if (DAG.SignBitIsZero(N0))
9425 return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1);
9426
9427 if (N1C && !N1C->isOpaque())
9428 if (SDValue NewSRA = visitShiftByConstant(N))
9429 return NewSRA;
9430
9431 // Try to transform this shift into a multiply-high if
9432 // it matches the appropriate pattern detected in combineShiftToMULH.
9433 if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
9434 return MULH;
9435
9436 // Attempt to convert a sra of a load into a narrower sign-extending load.
9437 if (SDValue NarrowLoad = reduceLoadWidth(N))
9438 return NarrowLoad;
9439
9440 return SDValue();
9441 }
9442
visitSRL(SDNode * N)9443 SDValue DAGCombiner::visitSRL(SDNode *N) {
9444 SDValue N0 = N->getOperand(0);
9445 SDValue N1 = N->getOperand(1);
9446 if (SDValue V = DAG.simplifyShift(N0, N1))
9447 return V;
9448
9449 EVT VT = N0.getValueType();
9450 EVT ShiftVT = N1.getValueType();
9451 unsigned OpSizeInBits = VT.getScalarSizeInBits();
9452
9453 // fold (srl c1, c2) -> c1 >>u c2
9454 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, {N0, N1}))
9455 return C;
9456
9457 // fold vector ops
9458 if (VT.isVector())
9459 if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
9460 return FoldedVOp;
9461
9462 if (SDValue NewSel = foldBinOpIntoSelect(N))
9463 return NewSel;
9464
9465 // if (srl x, c) is known to be zero, return 0
9466 ConstantSDNode *N1C = isConstOrConstSplat(N1);
9467 if (N1C &&
9468 DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
9469 return DAG.getConstant(0, SDLoc(N), VT);
9470
9471 // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
9472 if (N0.getOpcode() == ISD::SRL) {
9473 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
9474 ConstantSDNode *RHS) {
9475 APInt c1 = LHS->getAPIntValue();
9476 APInt c2 = RHS->getAPIntValue();
9477 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9478 return (c1 + c2).uge(OpSizeInBits);
9479 };
9480 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
9481 return DAG.getConstant(0, SDLoc(N), VT);
9482
9483 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
9484 ConstantSDNode *RHS) {
9485 APInt c1 = LHS->getAPIntValue();
9486 APInt c2 = RHS->getAPIntValue();
9487 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9488 return (c1 + c2).ult(OpSizeInBits);
9489 };
9490 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
9491 SDLoc DL(N);
9492 SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
9493 return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
9494 }
9495 }
9496
9497 if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
9498 N0.getOperand(0).getOpcode() == ISD::SRL) {
9499 SDValue InnerShift = N0.getOperand(0);
9500 // TODO - support non-uniform vector shift amounts.
9501 if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
9502 uint64_t c1 = N001C->getZExtValue();
9503 uint64_t c2 = N1C->getZExtValue();
9504 EVT InnerShiftVT = InnerShift.getValueType();
9505 EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
9506 uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
9507 // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
9508 // This is only valid if the OpSizeInBits + c1 = size of inner shift.
9509 if (c1 + OpSizeInBits == InnerShiftSize) {
9510 SDLoc DL(N);
9511 if (c1 + c2 >= InnerShiftSize)
9512 return DAG.getConstant(0, DL, VT);
9513 SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
9514 SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
9515 InnerShift.getOperand(0), NewShiftAmt);
9516 return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
9517 }
9518 // In the more general case, we can clear the high bits after the shift:
9519 // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
9520 if (N0.hasOneUse() && InnerShift.hasOneUse() &&
9521 c1 + c2 < InnerShiftSize) {
9522 SDLoc DL(N);
9523 SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
9524 SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
9525 InnerShift.getOperand(0), NewShiftAmt);
9526 SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
9527 OpSizeInBits - c2),
9528 DL, InnerShiftVT);
9529 SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
9530 return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
9531 }
9532 }
9533 }
9534
9535 // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
9536 // (and (srl x, (sub c2, c1), MASK)
9537 if (N0.getOpcode() == ISD::SHL &&
9538 (N0.getOperand(1) == N1 || N0->hasOneUse()) &&
9539 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
9540 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
9541 ConstantSDNode *RHS) {
9542 const APInt &LHSC = LHS->getAPIntValue();
9543 const APInt &RHSC = RHS->getAPIntValue();
9544 return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
9545 LHSC.getZExtValue() <= RHSC.getZExtValue();
9546 };
9547 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
9548 /*AllowUndefs*/ false,
9549 /*AllowTypeMismatch*/ true)) {
9550 SDLoc DL(N);
9551 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9552 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
9553 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9554 Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N01);
9555 Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, Diff);
9556 SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
9557 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9558 }
9559 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
9560 /*AllowUndefs*/ false,
9561 /*AllowTypeMismatch*/ true)) {
9562 SDLoc DL(N);
9563 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9564 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
9565 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9566 Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N1);
9567 SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
9568 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9569 }
9570 }
9571
9572 // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
9573 // TODO - support non-uniform vector shift amounts.
9574 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
9575 // Shifting in all undef bits?
9576 EVT SmallVT = N0.getOperand(0).getValueType();
9577 unsigned BitSize = SmallVT.getScalarSizeInBits();
9578 if (N1C->getAPIntValue().uge(BitSize))
9579 return DAG.getUNDEF(VT);
9580
9581 if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
9582 uint64_t ShiftAmt = N1C->getZExtValue();
9583 SDLoc DL0(N0);
9584 SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
9585 N0.getOperand(0),
9586 DAG.getConstant(ShiftAmt, DL0,
9587 getShiftAmountTy(SmallVT)));
9588 AddToWorklist(SmallShift.getNode());
9589 APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
9590 SDLoc DL(N);
9591 return DAG.getNode(ISD::AND, DL, VT,
9592 DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
9593 DAG.getConstant(Mask, DL, VT));
9594 }
9595 }
9596
9597 // fold (srl (sra X, Y), 31) -> (srl X, 31). This srl only looks at the sign
9598 // bit, which is unmodified by sra.
9599 if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
9600 if (N0.getOpcode() == ISD::SRA)
9601 return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1);
9602 }
9603
9604 // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit).
9605 if (N1C && N0.getOpcode() == ISD::CTLZ &&
9606 N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
9607 KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
9608
9609 // If any of the input bits are KnownOne, then the input couldn't be all
9610 // zeros, thus the result of the srl will always be zero.
9611 if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
9612
9613 // If all of the bits input the to ctlz node are known to be zero, then
9614 // the result of the ctlz is "32" and the result of the shift is one.
9615 APInt UnknownBits = ~Known.Zero;
9616 if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
9617
9618 // Otherwise, check to see if there is exactly one bit input to the ctlz.
9619 if (UnknownBits.isPowerOf2()) {
9620 // Okay, we know that only that the single bit specified by UnknownBits
9621 // could be set on input to the CTLZ node. If this bit is set, the SRL
9622 // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
9623 // to an SRL/XOR pair, which is likely to simplify more.
9624 unsigned ShAmt = UnknownBits.countTrailingZeros();
9625 SDValue Op = N0.getOperand(0);
9626
9627 if (ShAmt) {
9628 SDLoc DL(N0);
9629 Op = DAG.getNode(ISD::SRL, DL, VT, Op,
9630 DAG.getConstant(ShAmt, DL,
9631 getShiftAmountTy(Op.getValueType())));
9632 AddToWorklist(Op.getNode());
9633 }
9634
9635 SDLoc DL(N);
9636 return DAG.getNode(ISD::XOR, DL, VT,
9637 Op, DAG.getConstant(1, DL, VT));
9638 }
9639 }
9640
9641 // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
9642 if (N1.getOpcode() == ISD::TRUNCATE &&
9643 N1.getOperand(0).getOpcode() == ISD::AND) {
9644 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9645 return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, NewOp1);
9646 }
9647
9648 // fold operands of srl based on knowledge that the low bits are not
9649 // demanded.
9650 if (SimplifyDemandedBits(SDValue(N, 0)))
9651 return SDValue(N, 0);
9652
9653 if (N1C && !N1C->isOpaque())
9654 if (SDValue NewSRL = visitShiftByConstant(N))
9655 return NewSRL;
9656
9657 // Attempt to convert a srl of a load into a narrower zero-extending load.
9658 if (SDValue NarrowLoad = reduceLoadWidth(N))
9659 return NarrowLoad;
9660
9661 // Here is a common situation. We want to optimize:
9662 //
9663 // %a = ...
9664 // %b = and i32 %a, 2
9665 // %c = srl i32 %b, 1
9666 // brcond i32 %c ...
9667 //
9668 // into
9669 //
9670 // %a = ...
9671 // %b = and %a, 2
9672 // %c = setcc eq %b, 0
9673 // brcond %c ...
9674 //
9675 // However when after the source operand of SRL is optimized into AND, the SRL
9676 // itself may not be optimized further. Look for it and add the BRCOND into
9677 // the worklist.
9678 if (N->hasOneUse()) {
9679 SDNode *Use = *N->use_begin();
9680 if (Use->getOpcode() == ISD::BRCOND)
9681 AddToWorklist(Use);
9682 else if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse()) {
9683 // Also look pass the truncate.
9684 Use = *Use->use_begin();
9685 if (Use->getOpcode() == ISD::BRCOND)
9686 AddToWorklist(Use);
9687 }
9688 }
9689
9690 // Try to transform this shift into a multiply-high if
9691 // it matches the appropriate pattern detected in combineShiftToMULH.
9692 if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
9693 return MULH;
9694
9695 return SDValue();
9696 }
9697
visitFunnelShift(SDNode * N)9698 SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
9699 EVT VT = N->getValueType(0);
9700 SDValue N0 = N->getOperand(0);
9701 SDValue N1 = N->getOperand(1);
9702 SDValue N2 = N->getOperand(2);
9703 bool IsFSHL = N->getOpcode() == ISD::FSHL;
9704 unsigned BitWidth = VT.getScalarSizeInBits();
9705
9706 // fold (fshl N0, N1, 0) -> N0
9707 // fold (fshr N0, N1, 0) -> N1
9708 if (isPowerOf2_32(BitWidth))
9709 if (DAG.MaskedValueIsZero(
9710 N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
9711 return IsFSHL ? N0 : N1;
9712
9713 auto IsUndefOrZero = [](SDValue V) {
9714 return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
9715 };
9716
9717 // TODO - support non-uniform vector shift amounts.
9718 if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
9719 EVT ShAmtTy = N2.getValueType();
9720
9721 // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
9722 if (Cst->getAPIntValue().uge(BitWidth)) {
9723 uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
9724 return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1,
9725 DAG.getConstant(RotAmt, SDLoc(N), ShAmtTy));
9726 }
9727
9728 unsigned ShAmt = Cst->getZExtValue();
9729 if (ShAmt == 0)
9730 return IsFSHL ? N0 : N1;
9731
9732 // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
9733 // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
9734 // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
9735 // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
9736 if (IsUndefOrZero(N0))
9737 return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1,
9738 DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt,
9739 SDLoc(N), ShAmtTy));
9740 if (IsUndefOrZero(N1))
9741 return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
9742 DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt,
9743 SDLoc(N), ShAmtTy));
9744
9745 // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
9746 // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
9747 // TODO - bigendian support once we have test coverage.
9748 // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
9749 // TODO - permit LHS EXTLOAD if extensions are shifted out.
9750 if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
9751 !DAG.getDataLayout().isBigEndian()) {
9752 auto *LHS = dyn_cast<LoadSDNode>(N0);
9753 auto *RHS = dyn_cast<LoadSDNode>(N1);
9754 if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
9755 LHS->getAddressSpace() == RHS->getAddressSpace() &&
9756 (LHS->hasOneUse() || RHS->hasOneUse()) && ISD::isNON_EXTLoad(RHS) &&
9757 ISD::isNON_EXTLoad(LHS)) {
9758 if (DAG.areNonVolatileConsecutiveLoads(LHS, RHS, BitWidth / 8, 1)) {
9759 SDLoc DL(RHS);
9760 uint64_t PtrOff =
9761 IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
9762 Align NewAlign = commonAlignment(RHS->getAlign(), PtrOff);
9763 bool Fast = false;
9764 if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
9765 RHS->getAddressSpace(), NewAlign,
9766 RHS->getMemOperand()->getFlags(), &Fast) &&
9767 Fast) {
9768 SDValue NewPtr = DAG.getMemBasePlusOffset(
9769 RHS->getBasePtr(), TypeSize::Fixed(PtrOff), DL);
9770 AddToWorklist(NewPtr.getNode());
9771 SDValue Load = DAG.getLoad(
9772 VT, DL, RHS->getChain(), NewPtr,
9773 RHS->getPointerInfo().getWithOffset(PtrOff), NewAlign,
9774 RHS->getMemOperand()->getFlags(), RHS->getAAInfo());
9775 // Replace the old load's chain with the new load's chain.
9776 WorklistRemover DeadNodes(*this);
9777 DAG.ReplaceAllUsesOfValueWith(N1.getValue(1), Load.getValue(1));
9778 return Load;
9779 }
9780 }
9781 }
9782 }
9783 }
9784
9785 // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
9786 // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
9787 // iff We know the shift amount is in range.
9788 // TODO: when is it worth doing SUB(BW, N2) as well?
9789 if (isPowerOf2_32(BitWidth)) {
9790 APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
9791 if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
9792 return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1, N2);
9793 if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
9794 return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N2);
9795 }
9796
9797 // fold (fshl N0, N0, N2) -> (rotl N0, N2)
9798 // fold (fshr N0, N0, N2) -> (rotr N0, N2)
9799 // TODO: Investigate flipping this rotate if only one is legal, if funnel shift
9800 // is legal as well we might be better off avoiding non-constant (BW - N2).
9801 unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
9802 if (N0 == N1 && hasOperation(RotOpc, VT))
9803 return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2);
9804
9805 // Simplify, based on bits shifted out of N0/N1.
9806 if (SimplifyDemandedBits(SDValue(N, 0)))
9807 return SDValue(N, 0);
9808
9809 return SDValue();
9810 }
9811
visitSHLSAT(SDNode * N)9812 SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
9813 SDValue N0 = N->getOperand(0);
9814 SDValue N1 = N->getOperand(1);
9815 if (SDValue V = DAG.simplifyShift(N0, N1))
9816 return V;
9817
9818 EVT VT = N0.getValueType();
9819
9820 // fold (*shlsat c1, c2) -> c1<<c2
9821 if (SDValue C =
9822 DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, {N0, N1}))
9823 return C;
9824
9825 ConstantSDNode *N1C = isConstOrConstSplat(N1);
9826
9827 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) {
9828 // fold (sshlsat x, c) -> (shl x, c)
9829 if (N->getOpcode() == ISD::SSHLSAT && N1C &&
9830 N1C->getAPIntValue().ult(DAG.ComputeNumSignBits(N0)))
9831 return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1);
9832
9833 // fold (ushlsat x, c) -> (shl x, c)
9834 if (N->getOpcode() == ISD::USHLSAT && N1C &&
9835 N1C->getAPIntValue().ule(
9836 DAG.computeKnownBits(N0).countMinLeadingZeros()))
9837 return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1);
9838 }
9839
9840 return SDValue();
9841 }
9842
9843 // Given a ABS node, detect the following pattern:
9844 // (ABS (SUB (EXTEND a), (EXTEND b))).
9845 // Generates UABD/SABD instruction.
combineABSToABD(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)9846 static SDValue combineABSToABD(SDNode *N, SelectionDAG &DAG,
9847 const TargetLowering &TLI) {
9848 SDValue AbsOp1 = N->getOperand(0);
9849 SDValue Op0, Op1;
9850
9851 if (AbsOp1.getOpcode() != ISD::SUB)
9852 return SDValue();
9853
9854 Op0 = AbsOp1.getOperand(0);
9855 Op1 = AbsOp1.getOperand(1);
9856
9857 unsigned Opc0 = Op0.getOpcode();
9858 // Check if the operands of the sub are (zero|sign)-extended.
9859 if (Opc0 != Op1.getOpcode() ||
9860 (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
9861 return SDValue();
9862
9863 EVT VT = N->getValueType(0);
9864 EVT VT1 = Op0.getOperand(0).getValueType();
9865 EVT VT2 = Op1.getOperand(0).getValueType();
9866 unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU;
9867
9868 // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
9869 // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
9870 // NOTE: Extensions must be equivalent.
9871 if (VT1 == VT2 && TLI.isOperationLegalOrCustom(ABDOpcode, VT1)) {
9872 Op0 = Op0.getOperand(0);
9873 Op1 = Op1.getOperand(0);
9874 SDValue ABD = DAG.getNode(ABDOpcode, SDLoc(N), VT1, Op0, Op1);
9875 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, ABD);
9876 }
9877
9878 // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
9879 // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
9880 if (TLI.isOperationLegalOrCustom(ABDOpcode, VT))
9881 return DAG.getNode(ABDOpcode, SDLoc(N), VT, Op0, Op1);
9882
9883 return SDValue();
9884 }
9885
visitABS(SDNode * N)9886 SDValue DAGCombiner::visitABS(SDNode *N) {
9887 SDValue N0 = N->getOperand(0);
9888 EVT VT = N->getValueType(0);
9889
9890 // fold (abs c1) -> c2
9891 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9892 return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0);
9893 // fold (abs (abs x)) -> (abs x)
9894 if (N0.getOpcode() == ISD::ABS)
9895 return N0;
9896 // fold (abs x) -> x iff not-negative
9897 if (DAG.SignBitIsZero(N0))
9898 return N0;
9899
9900 if (SDValue ABD = combineABSToABD(N, DAG, TLI))
9901 return ABD;
9902
9903 return SDValue();
9904 }
9905
visitBSWAP(SDNode * N)9906 SDValue DAGCombiner::visitBSWAP(SDNode *N) {
9907 SDValue N0 = N->getOperand(0);
9908 EVT VT = N->getValueType(0);
9909 SDLoc DL(N);
9910
9911 // fold (bswap c1) -> c2
9912 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9913 return DAG.getNode(ISD::BSWAP, DL, VT, N0);
9914 // fold (bswap (bswap x)) -> x
9915 if (N0.getOpcode() == ISD::BSWAP)
9916 return N0.getOperand(0);
9917
9918 // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
9919 // isn't supported, it will be expanded to bswap followed by a manual reversal
9920 // of bits in each byte. By placing bswaps before bitreverse, we can remove
9921 // the two bswaps if the bitreverse gets expanded.
9922 if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
9923 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
9924 return DAG.getNode(ISD::BITREVERSE, DL, VT, BSwap);
9925 }
9926
9927 // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
9928 // iff x >= bw/2 (i.e. lower half is known zero)
9929 unsigned BW = VT.getScalarSizeInBits();
9930 if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
9931 auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
9932 EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), BW / 2);
9933 if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
9934 ShAmt->getZExtValue() >= (BW / 2) &&
9935 (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(HalfVT) &&
9936 TLI.isTruncateFree(VT, HalfVT) &&
9937 (!LegalOperations || hasOperation(ISD::BSWAP, HalfVT))) {
9938 SDValue Res = N0.getOperand(0);
9939 if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
9940 Res = DAG.getNode(ISD::SHL, DL, VT, Res,
9941 DAG.getConstant(NewShAmt, DL, getShiftAmountTy(VT)));
9942 Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
9943 Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
9944 return DAG.getZExtOrTrunc(Res, DL, VT);
9945 }
9946 }
9947
9948 // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
9949 // inverse-shift-of-bswap:
9950 // bswap (X u<< C) --> (bswap X) u>> C
9951 // bswap (X u>> C) --> (bswap X) u<< C
9952 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
9953 N0.hasOneUse()) {
9954 auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
9955 if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
9956 ShAmt->getZExtValue() % 8 == 0) {
9957 SDValue NewSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
9958 unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
9959 return DAG.getNode(InverseShift, DL, VT, NewSwap, N0.getOperand(1));
9960 }
9961 }
9962
9963 return SDValue();
9964 }
9965
visitBITREVERSE(SDNode * N)9966 SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
9967 SDValue N0 = N->getOperand(0);
9968 EVT VT = N->getValueType(0);
9969
9970 // fold (bitreverse c1) -> c2
9971 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9972 return DAG.getNode(ISD::BITREVERSE, SDLoc(N), VT, N0);
9973 // fold (bitreverse (bitreverse x)) -> x
9974 if (N0.getOpcode() == ISD::BITREVERSE)
9975 return N0.getOperand(0);
9976 return SDValue();
9977 }
9978
visitCTLZ(SDNode * N)9979 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
9980 SDValue N0 = N->getOperand(0);
9981 EVT VT = N->getValueType(0);
9982
9983 // fold (ctlz c1) -> c2
9984 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
9985 return DAG.getNode(ISD::CTLZ, SDLoc(N), VT, N0);
9986
9987 // If the value is known never to be zero, switch to the undef version.
9988 if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT)) {
9989 if (DAG.isKnownNeverZero(N0))
9990 return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
9991 }
9992
9993 return SDValue();
9994 }
9995
visitCTLZ_ZERO_UNDEF(SDNode * N)9996 SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
9997 SDValue N0 = N->getOperand(0);
9998 EVT VT = N->getValueType(0);
9999
10000 // fold (ctlz_zero_undef c1) -> c2
10001 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10002 return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10003 return SDValue();
10004 }
10005
visitCTTZ(SDNode * N)10006 SDValue DAGCombiner::visitCTTZ(SDNode *N) {
10007 SDValue N0 = N->getOperand(0);
10008 EVT VT = N->getValueType(0);
10009
10010 // fold (cttz c1) -> c2
10011 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10012 return DAG.getNode(ISD::CTTZ, SDLoc(N), VT, N0);
10013
10014 // If the value is known never to be zero, switch to the undef version.
10015 if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT)) {
10016 if (DAG.isKnownNeverZero(N0))
10017 return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10018 }
10019
10020 return SDValue();
10021 }
10022
visitCTTZ_ZERO_UNDEF(SDNode * N)10023 SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
10024 SDValue N0 = N->getOperand(0);
10025 EVT VT = N->getValueType(0);
10026
10027 // fold (cttz_zero_undef c1) -> c2
10028 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10029 return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10030 return SDValue();
10031 }
10032
visitCTPOP(SDNode * N)10033 SDValue DAGCombiner::visitCTPOP(SDNode *N) {
10034 SDValue N0 = N->getOperand(0);
10035 EVT VT = N->getValueType(0);
10036
10037 // fold (ctpop c1) -> c2
10038 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10039 return DAG.getNode(ISD::CTPOP, SDLoc(N), VT, N0);
10040 return SDValue();
10041 }
10042
10043 // FIXME: This should be checking for no signed zeros on individual operands, as
10044 // well as no nans.
isLegalToCombineMinNumMaxNum(SelectionDAG & DAG,SDValue LHS,SDValue RHS,const TargetLowering & TLI)10045 static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
10046 SDValue RHS,
10047 const TargetLowering &TLI) {
10048 const TargetOptions &Options = DAG.getTarget().Options;
10049 EVT VT = LHS.getValueType();
10050
10051 return Options.NoSignedZerosFPMath && VT.isFloatingPoint() &&
10052 TLI.isProfitableToCombineMinNumMaxNum(VT) &&
10053 DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS);
10054 }
10055
10056 /// Generate Min/Max node
combineMinNumMaxNum(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const TargetLowering & TLI,SelectionDAG & DAG)10057 static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
10058 SDValue RHS, SDValue True, SDValue False,
10059 ISD::CondCode CC, const TargetLowering &TLI,
10060 SelectionDAG &DAG) {
10061 if (!(LHS == True && RHS == False) && !(LHS == False && RHS == True))
10062 return SDValue();
10063
10064 EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
10065 switch (CC) {
10066 case ISD::SETOLT:
10067 case ISD::SETOLE:
10068 case ISD::SETLT:
10069 case ISD::SETLE:
10070 case ISD::SETULT:
10071 case ISD::SETULE: {
10072 // Since it's known never nan to get here already, either fminnum or
10073 // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
10074 // expanded in terms of it.
10075 unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
10076 if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
10077 return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
10078
10079 unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
10080 if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
10081 return DAG.getNode(Opcode, DL, VT, LHS, RHS);
10082 return SDValue();
10083 }
10084 case ISD::SETOGT:
10085 case ISD::SETOGE:
10086 case ISD::SETGT:
10087 case ISD::SETGE:
10088 case ISD::SETUGT:
10089 case ISD::SETUGE: {
10090 unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
10091 if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
10092 return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
10093
10094 unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
10095 if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
10096 return DAG.getNode(Opcode, DL, VT, LHS, RHS);
10097 return SDValue();
10098 }
10099 default:
10100 return SDValue();
10101 }
10102 }
10103
10104 /// If a (v)select has a condition value that is a sign-bit test, try to smear
10105 /// the condition operand sign-bit across the value width and use it as a mask.
foldSelectOfConstantsUsingSra(SDNode * N,SelectionDAG & DAG)10106 static SDValue foldSelectOfConstantsUsingSra(SDNode *N, SelectionDAG &DAG) {
10107 SDValue Cond = N->getOperand(0);
10108 SDValue C1 = N->getOperand(1);
10109 SDValue C2 = N->getOperand(2);
10110 if (!isConstantOrConstantVector(C1) || !isConstantOrConstantVector(C2))
10111 return SDValue();
10112
10113 EVT VT = N->getValueType(0);
10114 if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
10115 VT != Cond.getOperand(0).getValueType())
10116 return SDValue();
10117
10118 // The inverted-condition + commuted-select variants of these patterns are
10119 // canonicalized to these forms in IR.
10120 SDValue X = Cond.getOperand(0);
10121 SDValue CondC = Cond.getOperand(1);
10122 ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
10123 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
10124 isAllOnesOrAllOnesSplat(C2)) {
10125 // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
10126 SDLoc DL(N);
10127 SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
10128 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
10129 return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
10130 }
10131 if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
10132 // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
10133 SDLoc DL(N);
10134 SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
10135 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
10136 return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
10137 }
10138 return SDValue();
10139 }
10140
foldSelectOfConstants(SDNode * N)10141 SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
10142 SDValue Cond = N->getOperand(0);
10143 SDValue N1 = N->getOperand(1);
10144 SDValue N2 = N->getOperand(2);
10145 EVT VT = N->getValueType(0);
10146 EVT CondVT = Cond.getValueType();
10147 SDLoc DL(N);
10148
10149 if (!VT.isInteger())
10150 return SDValue();
10151
10152 auto *C1 = dyn_cast<ConstantSDNode>(N1);
10153 auto *C2 = dyn_cast<ConstantSDNode>(N2);
10154 if (!C1 || !C2)
10155 return SDValue();
10156
10157 // Only do this before legalization to avoid conflicting with target-specific
10158 // transforms in the other direction (create a select from a zext/sext). There
10159 // is also a target-independent combine here in DAGCombiner in the other
10160 // direction for (select Cond, -1, 0) when the condition is not i1.
10161 if (CondVT == MVT::i1 && !LegalOperations) {
10162 if (C1->isZero() && C2->isOne()) {
10163 // select Cond, 0, 1 --> zext (!Cond)
10164 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
10165 if (VT != MVT::i1)
10166 NotCond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotCond);
10167 return NotCond;
10168 }
10169 if (C1->isZero() && C2->isAllOnes()) {
10170 // select Cond, 0, -1 --> sext (!Cond)
10171 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
10172 if (VT != MVT::i1)
10173 NotCond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NotCond);
10174 return NotCond;
10175 }
10176 if (C1->isOne() && C2->isZero()) {
10177 // select Cond, 1, 0 --> zext (Cond)
10178 if (VT != MVT::i1)
10179 Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
10180 return Cond;
10181 }
10182 if (C1->isAllOnes() && C2->isZero()) {
10183 // select Cond, -1, 0 --> sext (Cond)
10184 if (VT != MVT::i1)
10185 Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
10186 return Cond;
10187 }
10188
10189 // Use a target hook because some targets may prefer to transform in the
10190 // other direction.
10191 if (TLI.convertSelectOfConstantsToMath(VT)) {
10192 // For any constants that differ by 1, we can transform the select into an
10193 // extend and add.
10194 const APInt &C1Val = C1->getAPIntValue();
10195 const APInt &C2Val = C2->getAPIntValue();
10196 if (C1Val - 1 == C2Val) {
10197 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
10198 if (VT != MVT::i1)
10199 Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
10200 return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
10201 }
10202 if (C1Val + 1 == C2Val) {
10203 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
10204 if (VT != MVT::i1)
10205 Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
10206 return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
10207 }
10208
10209 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
10210 if (C1Val.isPowerOf2() && C2Val.isZero()) {
10211 if (VT != MVT::i1)
10212 Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
10213 SDValue ShAmtC =
10214 DAG.getShiftAmountConstant(C1Val.exactLogBase2(), VT, DL);
10215 return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
10216 }
10217
10218 if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
10219 return V;
10220 }
10221
10222 return SDValue();
10223 }
10224
10225 // fold (select Cond, 0, 1) -> (xor Cond, 1)
10226 // We can't do this reliably if integer based booleans have different contents
10227 // to floating point based booleans. This is because we can't tell whether we
10228 // have an integer-based boolean or a floating-point-based boolean unless we
10229 // can find the SETCC that produced it and inspect its operands. This is
10230 // fairly easy if C is the SETCC node, but it can potentially be
10231 // undiscoverable (or not reasonably discoverable). For example, it could be
10232 // in another basic block or it could require searching a complicated
10233 // expression.
10234 if (CondVT.isInteger() &&
10235 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
10236 TargetLowering::ZeroOrOneBooleanContent &&
10237 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
10238 TargetLowering::ZeroOrOneBooleanContent &&
10239 C1->isZero() && C2->isOne()) {
10240 SDValue NotCond =
10241 DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
10242 if (VT.bitsEq(CondVT))
10243 return NotCond;
10244 return DAG.getZExtOrTrunc(NotCond, DL, VT);
10245 }
10246
10247 return SDValue();
10248 }
10249
foldBoolSelectToLogic(SDNode * N,SelectionDAG & DAG)10250 static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
10251 assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT) &&
10252 "Expected a (v)select");
10253 SDValue Cond = N->getOperand(0);
10254 SDValue T = N->getOperand(1), F = N->getOperand(2);
10255 EVT VT = N->getValueType(0);
10256 if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
10257 return SDValue();
10258
10259 // select Cond, Cond, F --> or Cond, F
10260 // select Cond, 1, F --> or Cond, F
10261 if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
10262 return DAG.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
10263
10264 // select Cond, T, Cond --> and Cond, T
10265 // select Cond, T, 0 --> and Cond, T
10266 if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
10267 return DAG.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
10268
10269 // select Cond, T, 1 --> or (not Cond), T
10270 if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
10271 SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
10272 return DAG.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
10273 }
10274
10275 // select Cond, 0, F --> and (not Cond), F
10276 if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
10277 SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
10278 return DAG.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
10279 }
10280
10281 return SDValue();
10282 }
10283
foldVSelectToSignBitSplatMask(SDNode * N,SelectionDAG & DAG)10284 static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
10285 SDValue N0 = N->getOperand(0);
10286 SDValue N1 = N->getOperand(1);
10287 SDValue N2 = N->getOperand(2);
10288 EVT VT = N->getValueType(0);
10289 if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
10290 return SDValue();
10291
10292 SDValue Cond0 = N0.getOperand(0);
10293 SDValue Cond1 = N0.getOperand(1);
10294 ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
10295 if (VT != Cond0.getValueType())
10296 return SDValue();
10297
10298 // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
10299 // compare is inverted from that pattern ("Cond0 s> -1").
10300 if (CC == ISD::SETLT && isNullOrNullSplat(Cond1))
10301 ; // This is the pattern we are looking for.
10302 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond1))
10303 std::swap(N1, N2);
10304 else
10305 return SDValue();
10306
10307 // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & N1
10308 if (isNullOrNullSplat(N2)) {
10309 SDLoc DL(N);
10310 SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
10311 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
10312 return DAG.getNode(ISD::AND, DL, VT, Sra, N1);
10313 }
10314
10315 // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | N2
10316 if (isAllOnesOrAllOnesSplat(N1)) {
10317 SDLoc DL(N);
10318 SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
10319 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
10320 return DAG.getNode(ISD::OR, DL, VT, Sra, N2);
10321 }
10322
10323 // If we have to invert the sign bit mask, only do that transform if the
10324 // target has a bitwise 'and not' instruction (the invert is free).
10325 // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & N2
10326 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
10327 if (isNullOrNullSplat(N1) && TLI.hasAndNot(N1)) {
10328 SDLoc DL(N);
10329 SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
10330 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
10331 SDValue Not = DAG.getNOT(DL, Sra, VT);
10332 return DAG.getNode(ISD::AND, DL, VT, Not, N2);
10333 }
10334
10335 // TODO: There's another pattern in this family, but it may require
10336 // implementing hasOrNot() to check for profitability:
10337 // (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | N2
10338
10339 return SDValue();
10340 }
10341
visitSELECT(SDNode * N)10342 SDValue DAGCombiner::visitSELECT(SDNode *N) {
10343 SDValue N0 = N->getOperand(0);
10344 SDValue N1 = N->getOperand(1);
10345 SDValue N2 = N->getOperand(2);
10346 EVT VT = N->getValueType(0);
10347 EVT VT0 = N0.getValueType();
10348 SDLoc DL(N);
10349 SDNodeFlags Flags = N->getFlags();
10350
10351 if (SDValue V = DAG.simplifySelect(N0, N1, N2))
10352 return V;
10353
10354 if (SDValue V = foldSelectOfConstants(N))
10355 return V;
10356
10357 if (SDValue V = foldBoolSelectToLogic(N, DAG))
10358 return V;
10359
10360 // If we can fold this based on the true/false value, do so.
10361 if (SimplifySelectOps(N, N1, N2))
10362 return SDValue(N, 0); // Don't revisit N.
10363
10364 if (VT0 == MVT::i1) {
10365 // The code in this block deals with the following 2 equivalences:
10366 // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
10367 // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
10368 // The target can specify its preferred form with the
10369 // shouldNormalizeToSelectSequence() callback. However we always transform
10370 // to the right anyway if we find the inner select exists in the DAG anyway
10371 // and we always transform to the left side if we know that we can further
10372 // optimize the combination of the conditions.
10373 bool normalizeToSequence =
10374 TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
10375 // select (and Cond0, Cond1), X, Y
10376 // -> select Cond0, (select Cond1, X, Y), Y
10377 if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
10378 SDValue Cond0 = N0->getOperand(0);
10379 SDValue Cond1 = N0->getOperand(1);
10380 SDValue InnerSelect =
10381 DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
10382 if (normalizeToSequence || !InnerSelect.use_empty())
10383 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
10384 InnerSelect, N2, Flags);
10385 // Cleanup on failure.
10386 if (InnerSelect.use_empty())
10387 recursivelyDeleteUnusedNodes(InnerSelect.getNode());
10388 }
10389 // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
10390 if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
10391 SDValue Cond0 = N0->getOperand(0);
10392 SDValue Cond1 = N0->getOperand(1);
10393 SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
10394 Cond1, N1, N2, Flags);
10395 if (normalizeToSequence || !InnerSelect.use_empty())
10396 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
10397 InnerSelect, Flags);
10398 // Cleanup on failure.
10399 if (InnerSelect.use_empty())
10400 recursivelyDeleteUnusedNodes(InnerSelect.getNode());
10401 }
10402
10403 // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
10404 if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
10405 SDValue N1_0 = N1->getOperand(0);
10406 SDValue N1_1 = N1->getOperand(1);
10407 SDValue N1_2 = N1->getOperand(2);
10408 if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
10409 // Create the actual and node if we can generate good code for it.
10410 if (!normalizeToSequence) {
10411 SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
10412 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
10413 N2, Flags);
10414 }
10415 // Otherwise see if we can optimize the "and" to a better pattern.
10416 if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
10417 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
10418 N2, Flags);
10419 }
10420 }
10421 }
10422 // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
10423 if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
10424 SDValue N2_0 = N2->getOperand(0);
10425 SDValue N2_1 = N2->getOperand(1);
10426 SDValue N2_2 = N2->getOperand(2);
10427 if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
10428 // Create the actual or node if we can generate good code for it.
10429 if (!normalizeToSequence) {
10430 SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
10431 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
10432 N2_2, Flags);
10433 }
10434 // Otherwise see if we can optimize to a better pattern.
10435 if (SDValue Combined = visitORLike(N0, N2_0, N))
10436 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
10437 N2_2, Flags);
10438 }
10439 }
10440 }
10441
10442 // select (not Cond), N1, N2 -> select Cond, N2, N1
10443 if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
10444 SDValue SelectOp = DAG.getSelect(DL, VT, F, N2, N1);
10445 SelectOp->setFlags(Flags);
10446 return SelectOp;
10447 }
10448
10449 // Fold selects based on a setcc into other things, such as min/max/abs.
10450 if (N0.getOpcode() == ISD::SETCC) {
10451 SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
10452 ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
10453
10454 // select (fcmp lt x, y), x, y -> fminnum x, y
10455 // select (fcmp gt x, y), x, y -> fmaxnum x, y
10456 //
10457 // This is OK if we don't care what happens if either operand is a NaN.
10458 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, TLI))
10459 if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2,
10460 CC, TLI, DAG))
10461 return FMinMax;
10462
10463 // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
10464 // This is conservatively limited to pre-legal-operations to give targets
10465 // a chance to reverse the transform if they want to do that. Also, it is
10466 // unlikely that the pattern would be formed late, so it's probably not
10467 // worth going through the other checks.
10468 if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
10469 CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
10470 N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
10471 auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
10472 auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
10473 if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
10474 // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
10475 // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
10476 //
10477 // The IR equivalent of this transform would have this form:
10478 // %a = add %x, C
10479 // %c = icmp ugt %x, ~C
10480 // %r = select %c, -1, %a
10481 // =>
10482 // %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
10483 // %u0 = extractvalue %u, 0
10484 // %u1 = extractvalue %u, 1
10485 // %r = select %u1, -1, %u0
10486 SDVTList VTs = DAG.getVTList(VT, VT0);
10487 SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
10488 return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
10489 }
10490 }
10491
10492 if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
10493 (!LegalOperations &&
10494 TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) {
10495 // Any flags available in a select/setcc fold will be on the setcc as they
10496 // migrated from fcmp
10497 Flags = N0->getFlags();
10498 SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1,
10499 N2, N0.getOperand(2));
10500 SelectNode->setFlags(Flags);
10501 return SelectNode;
10502 }
10503
10504 if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
10505 return NewSel;
10506 }
10507
10508 if (!VT.isVector())
10509 if (SDValue BinOp = foldSelectOfBinops(N))
10510 return BinOp;
10511
10512 return SDValue();
10513 }
10514
10515 // This function assumes all the vselect's arguments are CONCAT_VECTOR
10516 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
ConvertSelectToConcatVector(SDNode * N,SelectionDAG & DAG)10517 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
10518 SDLoc DL(N);
10519 SDValue Cond = N->getOperand(0);
10520 SDValue LHS = N->getOperand(1);
10521 SDValue RHS = N->getOperand(2);
10522 EVT VT = N->getValueType(0);
10523 int NumElems = VT.getVectorNumElements();
10524 assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
10525 RHS.getOpcode() == ISD::CONCAT_VECTORS &&
10526 Cond.getOpcode() == ISD::BUILD_VECTOR);
10527
10528 // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
10529 // binary ones here.
10530 if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
10531 return SDValue();
10532
10533 // We're sure we have an even number of elements due to the
10534 // concat_vectors we have as arguments to vselect.
10535 // Skip BV elements until we find one that's not an UNDEF
10536 // After we find an UNDEF element, keep looping until we get to half the
10537 // length of the BV and see if all the non-undef nodes are the same.
10538 ConstantSDNode *BottomHalf = nullptr;
10539 for (int i = 0; i < NumElems / 2; ++i) {
10540 if (Cond->getOperand(i)->isUndef())
10541 continue;
10542
10543 if (BottomHalf == nullptr)
10544 BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
10545 else if (Cond->getOperand(i).getNode() != BottomHalf)
10546 return SDValue();
10547 }
10548
10549 // Do the same for the second half of the BuildVector
10550 ConstantSDNode *TopHalf = nullptr;
10551 for (int i = NumElems / 2; i < NumElems; ++i) {
10552 if (Cond->getOperand(i)->isUndef())
10553 continue;
10554
10555 if (TopHalf == nullptr)
10556 TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
10557 else if (Cond->getOperand(i).getNode() != TopHalf)
10558 return SDValue();
10559 }
10560
10561 assert(TopHalf && BottomHalf &&
10562 "One half of the selector was all UNDEFs and the other was all the "
10563 "same value. This should have been addressed before this function.");
10564 return DAG.getNode(
10565 ISD::CONCAT_VECTORS, DL, VT,
10566 BottomHalf->isZero() ? RHS->getOperand(0) : LHS->getOperand(0),
10567 TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
10568 }
10569
refineUniformBase(SDValue & BasePtr,SDValue & Index,bool IndexIsScaled,SelectionDAG & DAG)10570 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
10571 SelectionDAG &DAG) {
10572 if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD)
10573 return false;
10574
10575 // Only perform the transformation when existing operands can be reused.
10576 if (IndexIsScaled)
10577 return false;
10578
10579 // For now we check only the LHS of the add.
10580 SDValue LHS = Index.getOperand(0);
10581 SDValue SplatVal = DAG.getSplatValue(LHS);
10582 if (!SplatVal || SplatVal.getValueType() != BasePtr.getValueType())
10583 return false;
10584
10585 BasePtr = SplatVal;
10586 Index = Index.getOperand(1);
10587 return true;
10588 }
10589
10590 // Fold sext/zext of index into index type.
refineIndexType(SDValue & Index,ISD::MemIndexType & IndexType,EVT DataVT,SelectionDAG & DAG)10591 bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
10592 SelectionDAG &DAG) {
10593 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
10594
10595 // It's always safe to look through zero extends.
10596 if (Index.getOpcode() == ISD::ZERO_EXTEND) {
10597 SDValue Op = Index.getOperand(0);
10598 if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType(), DataVT)) {
10599 IndexType = ISD::UNSIGNED_SCALED;
10600 Index = Op;
10601 return true;
10602 }
10603 if (ISD::isIndexTypeSigned(IndexType)) {
10604 IndexType = ISD::UNSIGNED_SCALED;
10605 return true;
10606 }
10607 }
10608
10609 // It's only safe to look through sign extends when Index is signed.
10610 if (Index.getOpcode() == ISD::SIGN_EXTEND &&
10611 ISD::isIndexTypeSigned(IndexType)) {
10612 SDValue Op = Index.getOperand(0);
10613 if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType(), DataVT)) {
10614 Index = Op;
10615 return true;
10616 }
10617 }
10618
10619 return false;
10620 }
10621
visitMSCATTER(SDNode * N)10622 SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
10623 MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
10624 SDValue Mask = MSC->getMask();
10625 SDValue Chain = MSC->getChain();
10626 SDValue Index = MSC->getIndex();
10627 SDValue Scale = MSC->getScale();
10628 SDValue StoreVal = MSC->getValue();
10629 SDValue BasePtr = MSC->getBasePtr();
10630 ISD::MemIndexType IndexType = MSC->getIndexType();
10631 SDLoc DL(N);
10632
10633 // Zap scatters with a zero mask.
10634 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
10635 return Chain;
10636
10637 if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) {
10638 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
10639 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
10640 DL, Ops, MSC->getMemOperand(), IndexType,
10641 MSC->isTruncatingStore());
10642 }
10643
10644 if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
10645 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
10646 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
10647 DL, Ops, MSC->getMemOperand(), IndexType,
10648 MSC->isTruncatingStore());
10649 }
10650
10651 return SDValue();
10652 }
10653
visitMSTORE(SDNode * N)10654 SDValue DAGCombiner::visitMSTORE(SDNode *N) {
10655 MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
10656 SDValue Mask = MST->getMask();
10657 SDValue Chain = MST->getChain();
10658 SDValue Value = MST->getValue();
10659 SDValue Ptr = MST->getBasePtr();
10660 SDLoc DL(N);
10661
10662 // Zap masked stores with a zero mask.
10663 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
10664 return Chain;
10665
10666 // If this is a masked load with an all ones mask, we can use a unmasked load.
10667 // FIXME: Can we do this for indexed, compressing, or truncating stores?
10668 if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MST->isUnindexed() &&
10669 !MST->isCompressingStore() && !MST->isTruncatingStore())
10670 return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
10671 MST->getBasePtr(), MST->getPointerInfo(),
10672 MST->getOriginalAlign(), MachineMemOperand::MOStore,
10673 MST->getAAInfo());
10674
10675 // Try transforming N to an indexed store.
10676 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
10677 return SDValue(N, 0);
10678
10679 if (MST->isTruncatingStore() && MST->isUnindexed() &&
10680 Value.getValueType().isInteger() &&
10681 (!isa<ConstantSDNode>(Value) ||
10682 !cast<ConstantSDNode>(Value)->isOpaque())) {
10683 APInt TruncDemandedBits =
10684 APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
10685 MST->getMemoryVT().getScalarSizeInBits());
10686
10687 // See if we can simplify the operation with
10688 // SimplifyDemandedBits, which only works if the value has a single use.
10689 if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
10690 // Re-visit the store if anything changed and the store hasn't been merged
10691 // with another node (N is deleted) SimplifyDemandedBits will add Value's
10692 // node back to the worklist if necessary, but we also need to re-visit
10693 // the Store node itself.
10694 if (N->getOpcode() != ISD::DELETED_NODE)
10695 AddToWorklist(N);
10696 return SDValue(N, 0);
10697 }
10698 }
10699
10700 // If this is a TRUNC followed by a masked store, fold this into a masked
10701 // truncating store. We can do this even if this is already a masked
10702 // truncstore.
10703 if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
10704 MST->isUnindexed() &&
10705 TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
10706 MST->getMemoryVT(), LegalOperations)) {
10707 auto Mask = TLI.promoteTargetBoolean(DAG, MST->getMask(),
10708 Value.getOperand(0).getValueType());
10709 return DAG.getMaskedStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
10710 MST->getOffset(), Mask, MST->getMemoryVT(),
10711 MST->getMemOperand(), MST->getAddressingMode(),
10712 /*IsTruncating=*/true);
10713 }
10714
10715 return SDValue();
10716 }
10717
visitMGATHER(SDNode * N)10718 SDValue DAGCombiner::visitMGATHER(SDNode *N) {
10719 MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
10720 SDValue Mask = MGT->getMask();
10721 SDValue Chain = MGT->getChain();
10722 SDValue Index = MGT->getIndex();
10723 SDValue Scale = MGT->getScale();
10724 SDValue PassThru = MGT->getPassThru();
10725 SDValue BasePtr = MGT->getBasePtr();
10726 ISD::MemIndexType IndexType = MGT->getIndexType();
10727 SDLoc DL(N);
10728
10729 // Zap gathers with a zero mask.
10730 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
10731 return CombineTo(N, PassThru, MGT->getChain());
10732
10733 if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) {
10734 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
10735 return DAG.getMaskedGather(
10736 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
10737 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
10738 }
10739
10740 if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
10741 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
10742 return DAG.getMaskedGather(
10743 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
10744 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
10745 }
10746
10747 return SDValue();
10748 }
10749
visitMLOAD(SDNode * N)10750 SDValue DAGCombiner::visitMLOAD(SDNode *N) {
10751 MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
10752 SDValue Mask = MLD->getMask();
10753 SDLoc DL(N);
10754
10755 // Zap masked loads with a zero mask.
10756 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
10757 return CombineTo(N, MLD->getPassThru(), MLD->getChain());
10758
10759 // If this is a masked load with an all ones mask, we can use a unmasked load.
10760 // FIXME: Can we do this for indexed, expanding, or extending loads?
10761 if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MLD->isUnindexed() &&
10762 !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
10763 SDValue NewLd = DAG.getLoad(
10764 N->getValueType(0), SDLoc(N), MLD->getChain(), MLD->getBasePtr(),
10765 MLD->getPointerInfo(), MLD->getOriginalAlign(),
10766 MachineMemOperand::MOLoad, MLD->getAAInfo(), MLD->getRanges());
10767 return CombineTo(N, NewLd, NewLd.getValue(1));
10768 }
10769
10770 // Try transforming N to an indexed load.
10771 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
10772 return SDValue(N, 0);
10773
10774 return SDValue();
10775 }
10776
10777 /// A vector select of 2 constant vectors can be simplified to math/logic to
10778 /// avoid a variable select instruction and possibly avoid constant loads.
foldVSelectOfConstants(SDNode * N)10779 SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
10780 SDValue Cond = N->getOperand(0);
10781 SDValue N1 = N->getOperand(1);
10782 SDValue N2 = N->getOperand(2);
10783 EVT VT = N->getValueType(0);
10784 if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
10785 !TLI.convertSelectOfConstantsToMath(VT) ||
10786 !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
10787 !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
10788 return SDValue();
10789
10790 // Check if we can use the condition value to increment/decrement a single
10791 // constant value. This simplifies a select to an add and removes a constant
10792 // load/materialization from the general case.
10793 bool AllAddOne = true;
10794 bool AllSubOne = true;
10795 unsigned Elts = VT.getVectorNumElements();
10796 for (unsigned i = 0; i != Elts; ++i) {
10797 SDValue N1Elt = N1.getOperand(i);
10798 SDValue N2Elt = N2.getOperand(i);
10799 if (N1Elt.isUndef() || N2Elt.isUndef())
10800 continue;
10801 if (N1Elt.getValueType() != N2Elt.getValueType())
10802 continue;
10803
10804 const APInt &C1 = cast<ConstantSDNode>(N1Elt)->getAPIntValue();
10805 const APInt &C2 = cast<ConstantSDNode>(N2Elt)->getAPIntValue();
10806 if (C1 != C2 + 1)
10807 AllAddOne = false;
10808 if (C1 != C2 - 1)
10809 AllSubOne = false;
10810 }
10811
10812 // Further simplifications for the extra-special cases where the constants are
10813 // all 0 or all -1 should be implemented as folds of these patterns.
10814 SDLoc DL(N);
10815 if (AllAddOne || AllSubOne) {
10816 // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
10817 // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
10818 auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
10819 SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
10820 return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
10821 }
10822
10823 // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
10824 APInt Pow2C;
10825 if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
10826 isNullOrNullSplat(N2)) {
10827 SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
10828 SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
10829 return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
10830 }
10831
10832 if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
10833 return V;
10834
10835 // The general case for select-of-constants:
10836 // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
10837 // ...but that only makes sense if a vselect is slower than 2 logic ops, so
10838 // leave that to a machine-specific pass.
10839 return SDValue();
10840 }
10841
visitVSELECT(SDNode * N)10842 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
10843 SDValue N0 = N->getOperand(0);
10844 SDValue N1 = N->getOperand(1);
10845 SDValue N2 = N->getOperand(2);
10846 EVT VT = N->getValueType(0);
10847 SDLoc DL(N);
10848
10849 if (SDValue V = DAG.simplifySelect(N0, N1, N2))
10850 return V;
10851
10852 if (SDValue V = foldBoolSelectToLogic(N, DAG))
10853 return V;
10854
10855 // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
10856 if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
10857 return DAG.getSelect(DL, VT, F, N2, N1);
10858
10859 // Canonicalize integer abs.
10860 // vselect (setg[te] X, 0), X, -X ->
10861 // vselect (setgt X, -1), X, -X ->
10862 // vselect (setl[te] X, 0), -X, X ->
10863 // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
10864 if (N0.getOpcode() == ISD::SETCC) {
10865 SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
10866 ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
10867 bool isAbs = false;
10868 bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
10869
10870 if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
10871 (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
10872 N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
10873 isAbs = ISD::isBuildVectorAllZeros(N2.getOperand(0).getNode());
10874 else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
10875 N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
10876 isAbs = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
10877
10878 if (isAbs) {
10879 if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
10880 return DAG.getNode(ISD::ABS, DL, VT, LHS);
10881
10882 SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
10883 DAG.getConstant(VT.getScalarSizeInBits() - 1,
10884 DL, getShiftAmountTy(VT)));
10885 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
10886 AddToWorklist(Shift.getNode());
10887 AddToWorklist(Add.getNode());
10888 return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
10889 }
10890
10891 // vselect x, y (fcmp lt x, y) -> fminnum x, y
10892 // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
10893 //
10894 // This is OK if we don't care about what happens if either operand is a
10895 // NaN.
10896 //
10897 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) {
10898 if (SDValue FMinMax =
10899 combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC, TLI, DAG))
10900 return FMinMax;
10901 }
10902
10903 if (SDValue S = PerformMinMaxFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
10904 return S;
10905 if (SDValue S = PerformUMinFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
10906 return S;
10907
10908 // If this select has a condition (setcc) with narrower operands than the
10909 // select, try to widen the compare to match the select width.
10910 // TODO: This should be extended to handle any constant.
10911 // TODO: This could be extended to handle non-loading patterns, but that
10912 // requires thorough testing to avoid regressions.
10913 if (isNullOrNullSplat(RHS)) {
10914 EVT NarrowVT = LHS.getValueType();
10915 EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
10916 EVT SetCCVT = getSetCCResultType(LHS.getValueType());
10917 unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
10918 unsigned WideWidth = WideVT.getScalarSizeInBits();
10919 bool IsSigned = isSignedIntSetCC(CC);
10920 auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
10921 if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
10922 SetCCWidth != 1 && SetCCWidth < WideWidth &&
10923 TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
10924 TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
10925 // Both compare operands can be widened for free. The LHS can use an
10926 // extended load, and the RHS is a constant:
10927 // vselect (ext (setcc load(X), C)), N1, N2 -->
10928 // vselect (setcc extload(X), C'), N1, N2
10929 auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
10930 SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
10931 SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
10932 EVT WideSetCCVT = getSetCCResultType(WideVT);
10933 SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
10934 return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
10935 }
10936 }
10937
10938 // Match VSELECTs into add with unsigned saturation.
10939 if (hasOperation(ISD::UADDSAT, VT)) {
10940 // Check if one of the arms of the VSELECT is vector with all bits set.
10941 // If it's on the left side invert the predicate to simplify logic below.
10942 SDValue Other;
10943 ISD::CondCode SatCC = CC;
10944 if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) {
10945 Other = N2;
10946 SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
10947 } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) {
10948 Other = N1;
10949 }
10950
10951 if (Other && Other.getOpcode() == ISD::ADD) {
10952 SDValue CondLHS = LHS, CondRHS = RHS;
10953 SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
10954
10955 // Canonicalize condition operands.
10956 if (SatCC == ISD::SETUGE) {
10957 std::swap(CondLHS, CondRHS);
10958 SatCC = ISD::SETULE;
10959 }
10960
10961 // We can test against either of the addition operands.
10962 // x <= x+y ? x+y : ~0 --> uaddsat x, y
10963 // x+y >= x ? x+y : ~0 --> uaddsat x, y
10964 if (SatCC == ISD::SETULE && Other == CondRHS &&
10965 (OpLHS == CondLHS || OpRHS == CondLHS))
10966 return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
10967
10968 if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
10969 (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
10970 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
10971 CondLHS == OpLHS) {
10972 // If the RHS is a constant we have to reverse the const
10973 // canonicalization.
10974 // x >= ~C ? x+C : ~0 --> uaddsat x, C
10975 auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
10976 return Cond->getAPIntValue() == ~Op->getAPIntValue();
10977 };
10978 if (SatCC == ISD::SETULE &&
10979 ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT))
10980 return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
10981 }
10982 }
10983 }
10984
10985 // Match VSELECTs into sub with unsigned saturation.
10986 if (hasOperation(ISD::USUBSAT, VT)) {
10987 // Check if one of the arms of the VSELECT is a zero vector. If it's on
10988 // the left side invert the predicate to simplify logic below.
10989 SDValue Other;
10990 ISD::CondCode SatCC = CC;
10991 if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
10992 Other = N2;
10993 SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
10994 } else if (ISD::isConstantSplatVectorAllZeros(N2.getNode())) {
10995 Other = N1;
10996 }
10997
10998 // zext(x) >= y ? trunc(zext(x) - y) : 0
10999 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
11000 // zext(x) > y ? trunc(zext(x) - y) : 0
11001 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
11002 if (Other && Other.getOpcode() == ISD::TRUNCATE &&
11003 Other.getOperand(0).getOpcode() == ISD::SUB &&
11004 (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
11005 SDValue OpLHS = Other.getOperand(0).getOperand(0);
11006 SDValue OpRHS = Other.getOperand(0).getOperand(1);
11007 if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
11008 if (SDValue R = getTruncatedUSUBSAT(VT, LHS.getValueType(), LHS, RHS,
11009 DAG, DL))
11010 return R;
11011 }
11012
11013 if (Other && Other.getNumOperands() == 2) {
11014 SDValue CondRHS = RHS;
11015 SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
11016
11017 if (OpLHS == LHS) {
11018 // Look for a general sub with unsigned saturation first.
11019 // x >= y ? x-y : 0 --> usubsat x, y
11020 // x > y ? x-y : 0 --> usubsat x, y
11021 if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
11022 Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
11023 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
11024
11025 if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
11026 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
11027 if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
11028 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
11029 // If the RHS is a constant we have to reverse the const
11030 // canonicalization.
11031 // x > C-1 ? x+-C : 0 --> usubsat x, C
11032 auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
11033 return (!Op && !Cond) ||
11034 (Op && Cond &&
11035 Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
11036 };
11037 if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
11038 ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
11039 /*AllowUndefs*/ true)) {
11040 OpRHS = DAG.getNode(ISD::SUB, DL, VT,
11041 DAG.getConstant(0, DL, VT), OpRHS);
11042 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
11043 }
11044
11045 // Another special case: If C was a sign bit, the sub has been
11046 // canonicalized into a xor.
11047 // FIXME: Would it be better to use computeKnownBits to
11048 // determine whether it's safe to decanonicalize the xor?
11049 // x s< 0 ? x^C : 0 --> usubsat x, C
11050 APInt SplatValue;
11051 if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
11052 ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) &&
11053 ISD::isConstantSplatVectorAllZeros(CondRHS.getNode()) &&
11054 SplatValue.isSignMask()) {
11055 // Note that we have to rebuild the RHS constant here to
11056 // ensure we don't rely on particular values of undef lanes.
11057 OpRHS = DAG.getConstant(SplatValue, DL, VT);
11058 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
11059 }
11060 }
11061 }
11062 }
11063 }
11064 }
11065 }
11066
11067 if (SimplifySelectOps(N, N1, N2))
11068 return SDValue(N, 0); // Don't revisit N.
11069
11070 // Fold (vselect all_ones, N1, N2) -> N1
11071 if (ISD::isConstantSplatVectorAllOnes(N0.getNode()))
11072 return N1;
11073 // Fold (vselect all_zeros, N1, N2) -> N2
11074 if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
11075 return N2;
11076
11077 // The ConvertSelectToConcatVector function is assuming both the above
11078 // checks for (vselect (build_vector all{ones,zeros) ...) have been made
11079 // and addressed.
11080 if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
11081 N2.getOpcode() == ISD::CONCAT_VECTORS &&
11082 ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
11083 if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
11084 return CV;
11085 }
11086
11087 if (SDValue V = foldVSelectOfConstants(N))
11088 return V;
11089
11090 if (hasOperation(ISD::SRA, VT))
11091 if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
11092 return V;
11093
11094 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
11095 return SDValue(N, 0);
11096
11097 return SDValue();
11098 }
11099
visitSELECT_CC(SDNode * N)11100 SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
11101 SDValue N0 = N->getOperand(0);
11102 SDValue N1 = N->getOperand(1);
11103 SDValue N2 = N->getOperand(2);
11104 SDValue N3 = N->getOperand(3);
11105 SDValue N4 = N->getOperand(4);
11106 ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
11107
11108 // fold select_cc lhs, rhs, x, x, cc -> x
11109 if (N2 == N3)
11110 return N2;
11111
11112 // Determine if the condition we're dealing with is constant
11113 if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
11114 CC, SDLoc(N), false)) {
11115 AddToWorklist(SCC.getNode());
11116
11117 // cond always true -> true val
11118 // cond always false -> false val
11119 if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode()))
11120 return SCCC->isZero() ? N3 : N2;
11121
11122 // When the condition is UNDEF, just return the first operand. This is
11123 // coherent the DAG creation, no setcc node is created in this case
11124 if (SCC->isUndef())
11125 return N2;
11126
11127 // Fold to a simpler select_cc
11128 if (SCC.getOpcode() == ISD::SETCC) {
11129 SDValue SelectOp = DAG.getNode(
11130 ISD::SELECT_CC, SDLoc(N), N2.getValueType(), SCC.getOperand(0),
11131 SCC.getOperand(1), N2, N3, SCC.getOperand(2));
11132 SelectOp->setFlags(SCC->getFlags());
11133 return SelectOp;
11134 }
11135 }
11136
11137 // If we can fold this based on the true/false value, do so.
11138 if (SimplifySelectOps(N, N2, N3))
11139 return SDValue(N, 0); // Don't revisit N.
11140
11141 // fold select_cc into other things, such as min/max/abs
11142 return SimplifySelectCC(SDLoc(N), N0, N1, N2, N3, CC);
11143 }
11144
visitSETCC(SDNode * N)11145 SDValue DAGCombiner::visitSETCC(SDNode *N) {
11146 // setcc is very commonly used as an argument to brcond. This pattern
11147 // also lend itself to numerous combines and, as a result, it is desired
11148 // we keep the argument to a brcond as a setcc as much as possible.
11149 bool PreferSetCC =
11150 N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
11151
11152 ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
11153 EVT VT = N->getValueType(0);
11154
11155 // SETCC(FREEZE(X), CONST, Cond)
11156 // =>
11157 // FREEZE(SETCC(X, CONST, Cond))
11158 // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
11159 // isn't equivalent to true or false.
11160 // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
11161 // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
11162 //
11163 // This transformation is beneficial because visitBRCOND can fold
11164 // BRCOND(FREEZE(X)) to BRCOND(X).
11165
11166 // Conservatively optimize integer comparisons only.
11167 if (PreferSetCC) {
11168 // Do this only when SETCC is going to be used by BRCOND.
11169
11170 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
11171 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
11172 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
11173 bool Updated = false;
11174
11175 // Is 'X Cond C' always true or false?
11176 auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
11177 bool False = (Cond == ISD::SETULT && C->isZero()) ||
11178 (Cond == ISD::SETLT && C->isMinSignedValue()) ||
11179 (Cond == ISD::SETUGT && C->isAllOnes()) ||
11180 (Cond == ISD::SETGT && C->isMaxSignedValue());
11181 bool True = (Cond == ISD::SETULE && C->isAllOnes()) ||
11182 (Cond == ISD::SETLE && C->isMaxSignedValue()) ||
11183 (Cond == ISD::SETUGE && C->isZero()) ||
11184 (Cond == ISD::SETGE && C->isMinSignedValue());
11185 return True || False;
11186 };
11187
11188 if (N0->getOpcode() == ISD::FREEZE && N0.hasOneUse() && N1C) {
11189 if (!IsAlwaysTrueOrFalse(Cond, N1C)) {
11190 N0 = N0->getOperand(0);
11191 Updated = true;
11192 }
11193 }
11194 if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse() && N0C) {
11195 if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Cond),
11196 N0C)) {
11197 N1 = N1->getOperand(0);
11198 Updated = true;
11199 }
11200 }
11201
11202 if (Updated)
11203 return DAG.getFreeze(DAG.getSetCC(SDLoc(N), VT, N0, N1, Cond));
11204 }
11205
11206 SDValue Combined = SimplifySetCC(VT, N->getOperand(0), N->getOperand(1), Cond,
11207 SDLoc(N), !PreferSetCC);
11208
11209 if (!Combined)
11210 return SDValue();
11211
11212 // If we prefer to have a setcc, and we don't, we'll try our best to
11213 // recreate one using rebuildSetCC.
11214 if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
11215 SDValue NewSetCC = rebuildSetCC(Combined);
11216
11217 // We don't have anything interesting to combine to.
11218 if (NewSetCC.getNode() == N)
11219 return SDValue();
11220
11221 if (NewSetCC)
11222 return NewSetCC;
11223 }
11224
11225 return Combined;
11226 }
11227
visitSETCCCARRY(SDNode * N)11228 SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
11229 SDValue LHS = N->getOperand(0);
11230 SDValue RHS = N->getOperand(1);
11231 SDValue Carry = N->getOperand(2);
11232 SDValue Cond = N->getOperand(3);
11233
11234 // If Carry is false, fold to a regular SETCC.
11235 if (isNullConstant(Carry))
11236 return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
11237
11238 return SDValue();
11239 }
11240
11241 /// Check if N satisfies:
11242 /// N is used once.
11243 /// N is a Load.
11244 /// The load is compatible with ExtOpcode. It means
11245 /// If load has explicit zero/sign extension, ExpOpcode must have the same
11246 /// extension.
11247 /// Otherwise returns true.
isCompatibleLoad(SDValue N,unsigned ExtOpcode)11248 static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
11249 if (!N.hasOneUse())
11250 return false;
11251
11252 if (!isa<LoadSDNode>(N))
11253 return false;
11254
11255 LoadSDNode *Load = cast<LoadSDNode>(N);
11256 ISD::LoadExtType LoadExt = Load->getExtensionType();
11257 if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
11258 return true;
11259
11260 // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
11261 // extension.
11262 if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
11263 (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
11264 return false;
11265
11266 return true;
11267 }
11268
11269 /// Fold
11270 /// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
11271 /// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
11272 /// (aext (select c, load x, load y)) -> (select c, extload x, extload y)
11273 /// This function is called by the DAGCombiner when visiting sext/zext/aext
11274 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
tryToFoldExtendSelectLoad(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG)11275 static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
11276 SelectionDAG &DAG) {
11277 unsigned Opcode = N->getOpcode();
11278 SDValue N0 = N->getOperand(0);
11279 EVT VT = N->getValueType(0);
11280 SDLoc DL(N);
11281
11282 assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
11283 Opcode == ISD::ANY_EXTEND) &&
11284 "Expected EXTEND dag node in input!");
11285
11286 if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
11287 !N0.hasOneUse())
11288 return SDValue();
11289
11290 SDValue Op1 = N0->getOperand(1);
11291 SDValue Op2 = N0->getOperand(2);
11292 if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
11293 return SDValue();
11294
11295 auto ExtLoadOpcode = ISD::EXTLOAD;
11296 if (Opcode == ISD::SIGN_EXTEND)
11297 ExtLoadOpcode = ISD::SEXTLOAD;
11298 else if (Opcode == ISD::ZERO_EXTEND)
11299 ExtLoadOpcode = ISD::ZEXTLOAD;
11300
11301 LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
11302 LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
11303 if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
11304 !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()))
11305 return SDValue();
11306
11307 SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
11308 SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
11309 return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
11310 }
11311
11312 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
11313 /// a build_vector of constants.
11314 /// This function is called by the DAGCombiner when visiting sext/zext/aext
11315 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
11316 /// Vector extends are not folded if operations are legal; this is to
11317 /// avoid introducing illegal build_vector dag nodes.
tryToFoldExtendOfConstant(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalTypes)11318 static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI,
11319 SelectionDAG &DAG, bool LegalTypes) {
11320 unsigned Opcode = N->getOpcode();
11321 SDValue N0 = N->getOperand(0);
11322 EVT VT = N->getValueType(0);
11323 SDLoc DL(N);
11324
11325 assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
11326 Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
11327 Opcode == ISD::ZERO_EXTEND_VECTOR_INREG)
11328 && "Expected EXTEND dag node in input!");
11329
11330 // fold (sext c1) -> c1
11331 // fold (zext c1) -> c1
11332 // fold (aext c1) -> c1
11333 if (isa<ConstantSDNode>(N0))
11334 return DAG.getNode(Opcode, DL, VT, N0);
11335
11336 // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
11337 // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
11338 // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
11339 if (N0->getOpcode() == ISD::SELECT) {
11340 SDValue Op1 = N0->getOperand(1);
11341 SDValue Op2 = N0->getOperand(2);
11342 if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
11343 (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
11344 // For any_extend, choose sign extension of the constants to allow a
11345 // possible further transform to sign_extend_inreg.i.e.
11346 //
11347 // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
11348 // t2: i64 = any_extend t1
11349 // -->
11350 // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
11351 // -->
11352 // t4: i64 = sign_extend_inreg t3
11353 unsigned FoldOpc = Opcode;
11354 if (FoldOpc == ISD::ANY_EXTEND)
11355 FoldOpc = ISD::SIGN_EXTEND;
11356 return DAG.getSelect(DL, VT, N0->getOperand(0),
11357 DAG.getNode(FoldOpc, DL, VT, Op1),
11358 DAG.getNode(FoldOpc, DL, VT, Op2));
11359 }
11360 }
11361
11362 // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
11363 // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
11364 // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
11365 EVT SVT = VT.getScalarType();
11366 if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
11367 ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
11368 return SDValue();
11369
11370 // We can fold this node into a build_vector.
11371 unsigned VTBits = SVT.getSizeInBits();
11372 unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
11373 SmallVector<SDValue, 8> Elts;
11374 unsigned NumElts = VT.getVectorNumElements();
11375
11376 // For zero-extensions, UNDEF elements still guarantee to have the upper
11377 // bits set to zero.
11378 bool IsZext =
11379 Opcode == ISD::ZERO_EXTEND || Opcode == ISD::ZERO_EXTEND_VECTOR_INREG;
11380
11381 for (unsigned i = 0; i != NumElts; ++i) {
11382 SDValue Op = N0.getOperand(i);
11383 if (Op.isUndef()) {
11384 Elts.push_back(IsZext ? DAG.getConstant(0, DL, SVT) : DAG.getUNDEF(SVT));
11385 continue;
11386 }
11387
11388 SDLoc DL(Op);
11389 // Get the constant value and if needed trunc it to the size of the type.
11390 // Nodes like build_vector might have constants wider than the scalar type.
11391 APInt C = cast<ConstantSDNode>(Op)->getAPIntValue().zextOrTrunc(EVTBits);
11392 if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
11393 Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
11394 else
11395 Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
11396 }
11397
11398 return DAG.getBuildVector(VT, DL, Elts);
11399 }
11400
11401 // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
11402 // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
11403 // transformation. Returns true if extension are possible and the above
11404 // mentioned transformation is profitable.
ExtendUsesToFormExtLoad(EVT VT,SDNode * N,SDValue N0,unsigned ExtOpc,SmallVectorImpl<SDNode * > & ExtendNodes,const TargetLowering & TLI)11405 static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
11406 unsigned ExtOpc,
11407 SmallVectorImpl<SDNode *> &ExtendNodes,
11408 const TargetLowering &TLI) {
11409 bool HasCopyToRegUses = false;
11410 bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
11411 for (SDNode::use_iterator UI = N0->use_begin(), UE = N0->use_end(); UI != UE;
11412 ++UI) {
11413 SDNode *User = *UI;
11414 if (User == N)
11415 continue;
11416 if (UI.getUse().getResNo() != N0.getResNo())
11417 continue;
11418 // FIXME: Only extend SETCC N, N and SETCC N, c for now.
11419 if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
11420 ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
11421 if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
11422 // Sign bits will be lost after a zext.
11423 return false;
11424 bool Add = false;
11425 for (unsigned i = 0; i != 2; ++i) {
11426 SDValue UseOp = User->getOperand(i);
11427 if (UseOp == N0)
11428 continue;
11429 if (!isa<ConstantSDNode>(UseOp))
11430 return false;
11431 Add = true;
11432 }
11433 if (Add)
11434 ExtendNodes.push_back(User);
11435 continue;
11436 }
11437 // If truncates aren't free and there are users we can't
11438 // extend, it isn't worthwhile.
11439 if (!isTruncFree)
11440 return false;
11441 // Remember if this value is live-out.
11442 if (User->getOpcode() == ISD::CopyToReg)
11443 HasCopyToRegUses = true;
11444 }
11445
11446 if (HasCopyToRegUses) {
11447 bool BothLiveOut = false;
11448 for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
11449 UI != UE; ++UI) {
11450 SDUse &Use = UI.getUse();
11451 if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
11452 BothLiveOut = true;
11453 break;
11454 }
11455 }
11456 if (BothLiveOut)
11457 // Both unextended and extended values are live out. There had better be
11458 // a good reason for the transformation.
11459 return ExtendNodes.size();
11460 }
11461 return true;
11462 }
11463
ExtendSetCCUses(const SmallVectorImpl<SDNode * > & SetCCs,SDValue OrigLoad,SDValue ExtLoad,ISD::NodeType ExtType)11464 void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
11465 SDValue OrigLoad, SDValue ExtLoad,
11466 ISD::NodeType ExtType) {
11467 // Extend SetCC uses if necessary.
11468 SDLoc DL(ExtLoad);
11469 for (SDNode *SetCC : SetCCs) {
11470 SmallVector<SDValue, 4> Ops;
11471
11472 for (unsigned j = 0; j != 2; ++j) {
11473 SDValue SOp = SetCC->getOperand(j);
11474 if (SOp == OrigLoad)
11475 Ops.push_back(ExtLoad);
11476 else
11477 Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
11478 }
11479
11480 Ops.push_back(SetCC->getOperand(2));
11481 CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
11482 }
11483 }
11484
11485 // FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
CombineExtLoad(SDNode * N)11486 SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
11487 SDValue N0 = N->getOperand(0);
11488 EVT DstVT = N->getValueType(0);
11489 EVT SrcVT = N0.getValueType();
11490
11491 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
11492 N->getOpcode() == ISD::ZERO_EXTEND) &&
11493 "Unexpected node type (not an extend)!");
11494
11495 // fold (sext (load x)) to multiple smaller sextloads; same for zext.
11496 // For example, on a target with legal v4i32, but illegal v8i32, turn:
11497 // (v8i32 (sext (v8i16 (load x))))
11498 // into:
11499 // (v8i32 (concat_vectors (v4i32 (sextload x)),
11500 // (v4i32 (sextload (x + 16)))))
11501 // Where uses of the original load, i.e.:
11502 // (v8i16 (load x))
11503 // are replaced with:
11504 // (v8i16 (truncate
11505 // (v8i32 (concat_vectors (v4i32 (sextload x)),
11506 // (v4i32 (sextload (x + 16)))))))
11507 //
11508 // This combine is only applicable to illegal, but splittable, vectors.
11509 // All legal types, and illegal non-vector types, are handled elsewhere.
11510 // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
11511 //
11512 if (N0->getOpcode() != ISD::LOAD)
11513 return SDValue();
11514
11515 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11516
11517 if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
11518 !N0.hasOneUse() || !LN0->isSimple() ||
11519 !DstVT.isVector() || !DstVT.isPow2VectorType() ||
11520 !TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
11521 return SDValue();
11522
11523 SmallVector<SDNode *, 4> SetCCs;
11524 if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
11525 return SDValue();
11526
11527 ISD::LoadExtType ExtType =
11528 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
11529
11530 // Try to split the vector types to get down to legal types.
11531 EVT SplitSrcVT = SrcVT;
11532 EVT SplitDstVT = DstVT;
11533 while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
11534 SplitSrcVT.getVectorNumElements() > 1) {
11535 SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
11536 SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
11537 }
11538
11539 if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
11540 return SDValue();
11541
11542 assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
11543
11544 SDLoc DL(N);
11545 const unsigned NumSplits =
11546 DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
11547 const unsigned Stride = SplitSrcVT.getStoreSize();
11548 SmallVector<SDValue, 4> Loads;
11549 SmallVector<SDValue, 4> Chains;
11550
11551 SDValue BasePtr = LN0->getBasePtr();
11552 for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
11553 const unsigned Offset = Idx * Stride;
11554 const Align Align = commonAlignment(LN0->getAlign(), Offset);
11555
11556 SDValue SplitLoad = DAG.getExtLoad(
11557 ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(), BasePtr,
11558 LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, Align,
11559 LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
11560
11561 BasePtr = DAG.getMemBasePlusOffset(BasePtr, TypeSize::Fixed(Stride), DL);
11562
11563 Loads.push_back(SplitLoad.getValue(0));
11564 Chains.push_back(SplitLoad.getValue(1));
11565 }
11566
11567 SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
11568 SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
11569
11570 // Simplify TF.
11571 AddToWorklist(NewChain.getNode());
11572
11573 CombineTo(N, NewValue);
11574
11575 // Replace uses of the original load (before extension)
11576 // with a truncate of the concatenated sextloaded vectors.
11577 SDValue Trunc =
11578 DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
11579 ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
11580 CombineTo(N0.getNode(), Trunc, NewChain);
11581 return SDValue(N, 0); // Return N so it doesn't get rechecked!
11582 }
11583
11584 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
11585 // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
CombineZExtLogicopShiftLoad(SDNode * N)11586 SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
11587 assert(N->getOpcode() == ISD::ZERO_EXTEND);
11588 EVT VT = N->getValueType(0);
11589 EVT OrigVT = N->getOperand(0).getValueType();
11590 if (TLI.isZExtFree(OrigVT, VT))
11591 return SDValue();
11592
11593 // and/or/xor
11594 SDValue N0 = N->getOperand(0);
11595 if (!(N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
11596 N0.getOpcode() == ISD::XOR) ||
11597 N0.getOperand(1).getOpcode() != ISD::Constant ||
11598 (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
11599 return SDValue();
11600
11601 // shl/shr
11602 SDValue N1 = N0->getOperand(0);
11603 if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
11604 N1.getOperand(1).getOpcode() != ISD::Constant ||
11605 (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
11606 return SDValue();
11607
11608 // load
11609 if (!isa<LoadSDNode>(N1.getOperand(0)))
11610 return SDValue();
11611 LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
11612 EVT MemVT = Load->getMemoryVT();
11613 if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
11614 Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
11615 return SDValue();
11616
11617
11618 // If the shift op is SHL, the logic op must be AND, otherwise the result
11619 // will be wrong.
11620 if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
11621 return SDValue();
11622
11623 if (!N0.hasOneUse() || !N1.hasOneUse())
11624 return SDValue();
11625
11626 SmallVector<SDNode*, 4> SetCCs;
11627 if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
11628 ISD::ZERO_EXTEND, SetCCs, TLI))
11629 return SDValue();
11630
11631 // Actually do the transformation.
11632 SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
11633 Load->getChain(), Load->getBasePtr(),
11634 Load->getMemoryVT(), Load->getMemOperand());
11635
11636 SDLoc DL1(N1);
11637 SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
11638 N1.getOperand(1));
11639
11640 APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
11641 SDLoc DL0(N0);
11642 SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
11643 DAG.getConstant(Mask, DL0, VT));
11644
11645 ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
11646 CombineTo(N, And);
11647 if (SDValue(Load, 0).hasOneUse()) {
11648 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
11649 } else {
11650 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
11651 Load->getValueType(0), ExtLoad);
11652 CombineTo(Load, Trunc, ExtLoad.getValue(1));
11653 }
11654
11655 // N0 is dead at this point.
11656 recursivelyDeleteUnusedNodes(N0.getNode());
11657
11658 return SDValue(N,0); // Return N so it doesn't get rechecked!
11659 }
11660
11661 /// If we're narrowing or widening the result of a vector select and the final
11662 /// size is the same size as a setcc (compare) feeding the select, then try to
11663 /// apply the cast operation to the select's operands because matching vector
11664 /// sizes for a select condition and other operands should be more efficient.
matchVSelectOpSizesWithSetCC(SDNode * Cast)11665 SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
11666 unsigned CastOpcode = Cast->getOpcode();
11667 assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
11668 CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
11669 CastOpcode == ISD::FP_ROUND) &&
11670 "Unexpected opcode for vector select narrowing/widening");
11671
11672 // We only do this transform before legal ops because the pattern may be
11673 // obfuscated by target-specific operations after legalization. Do not create
11674 // an illegal select op, however, because that may be difficult to lower.
11675 EVT VT = Cast->getValueType(0);
11676 if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
11677 return SDValue();
11678
11679 SDValue VSel = Cast->getOperand(0);
11680 if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
11681 VSel.getOperand(0).getOpcode() != ISD::SETCC)
11682 return SDValue();
11683
11684 // Does the setcc have the same vector size as the casted select?
11685 SDValue SetCC = VSel.getOperand(0);
11686 EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
11687 if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
11688 return SDValue();
11689
11690 // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
11691 SDValue A = VSel.getOperand(1);
11692 SDValue B = VSel.getOperand(2);
11693 SDValue CastA, CastB;
11694 SDLoc DL(Cast);
11695 if (CastOpcode == ISD::FP_ROUND) {
11696 // FP_ROUND (fptrunc) has an extra flag operand to pass along.
11697 CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
11698 CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
11699 } else {
11700 CastA = DAG.getNode(CastOpcode, DL, VT, A);
11701 CastB = DAG.getNode(CastOpcode, DL, VT, B);
11702 }
11703 return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
11704 }
11705
11706 // fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
11707 // fold ([s|z]ext ( extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
tryToFoldExtOfExtload(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType)11708 static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
11709 const TargetLowering &TLI, EVT VT,
11710 bool LegalOperations, SDNode *N,
11711 SDValue N0, ISD::LoadExtType ExtLoadType) {
11712 SDNode *N0Node = N0.getNode();
11713 bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
11714 : ISD::isZEXTLoad(N0Node);
11715 if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
11716 !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
11717 return SDValue();
11718
11719 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11720 EVT MemVT = LN0->getMemoryVT();
11721 if ((LegalOperations || !LN0->isSimple() ||
11722 VT.isVector()) &&
11723 !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
11724 return SDValue();
11725
11726 SDValue ExtLoad =
11727 DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
11728 LN0->getBasePtr(), MemVT, LN0->getMemOperand());
11729 Combiner.CombineTo(N, ExtLoad);
11730 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
11731 if (LN0->use_empty())
11732 Combiner.recursivelyDeleteUnusedNodes(LN0);
11733 return SDValue(N, 0); // Return N so it doesn't get rechecked!
11734 }
11735
11736 // fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
11737 // Only generate vector extloads when 1) they're legal, and 2) they are
11738 // deemed desirable by the target.
tryToFoldExtOfLoad(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)11739 static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
11740 const TargetLowering &TLI, EVT VT,
11741 bool LegalOperations, SDNode *N, SDValue N0,
11742 ISD::LoadExtType ExtLoadType,
11743 ISD::NodeType ExtOpc) {
11744 // TODO: isFixedLengthVector() should be removed and any negative effects on
11745 // code generation being the result of that target's implementation of
11746 // isVectorLoadExtDesirable().
11747 if (!ISD::isNON_EXTLoad(N0.getNode()) ||
11748 !ISD::isUNINDEXEDLoad(N0.getNode()) ||
11749 ((LegalOperations || VT.isFixedLengthVector() ||
11750 !cast<LoadSDNode>(N0)->isSimple()) &&
11751 !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType())))
11752 return {};
11753
11754 bool DoXform = true;
11755 SmallVector<SDNode *, 4> SetCCs;
11756 if (!N0.hasOneUse())
11757 DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
11758 if (VT.isVector())
11759 DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
11760 if (!DoXform)
11761 return {};
11762
11763 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11764 SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
11765 LN0->getBasePtr(), N0.getValueType(),
11766 LN0->getMemOperand());
11767 Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
11768 // If the load value is used only by N, replace it via CombineTo N.
11769 bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
11770 Combiner.CombineTo(N, ExtLoad);
11771 if (NoReplaceTrunc) {
11772 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
11773 Combiner.recursivelyDeleteUnusedNodes(LN0);
11774 } else {
11775 SDValue Trunc =
11776 DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
11777 Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
11778 }
11779 return SDValue(N, 0); // Return N so it doesn't get rechecked!
11780 }
11781
tryToFoldExtOfMaskedLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)11782 static SDValue tryToFoldExtOfMaskedLoad(SelectionDAG &DAG,
11783 const TargetLowering &TLI, EVT VT,
11784 SDNode *N, SDValue N0,
11785 ISD::LoadExtType ExtLoadType,
11786 ISD::NodeType ExtOpc) {
11787 if (!N0.hasOneUse())
11788 return SDValue();
11789
11790 MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0);
11791 if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
11792 return SDValue();
11793
11794 if (!TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0)))
11795 return SDValue();
11796
11797 if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
11798 return SDValue();
11799
11800 SDLoc dl(Ld);
11801 SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
11802 SDValue NewLoad = DAG.getMaskedLoad(
11803 VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
11804 PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
11805 ExtLoadType, Ld->isExpandingLoad());
11806 DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
11807 return NewLoad;
11808 }
11809
foldExtendedSignBitTest(SDNode * N,SelectionDAG & DAG,bool LegalOperations)11810 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
11811 bool LegalOperations) {
11812 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
11813 N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
11814
11815 SDValue SetCC = N->getOperand(0);
11816 if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
11817 !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
11818 return SDValue();
11819
11820 SDValue X = SetCC.getOperand(0);
11821 SDValue Ones = SetCC.getOperand(1);
11822 ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
11823 EVT VT = N->getValueType(0);
11824 EVT XVT = X.getValueType();
11825 // setge X, C is canonicalized to setgt, so we do not need to match that
11826 // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
11827 // not require the 'not' op.
11828 if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
11829 // Invert and smear/shift the sign bit:
11830 // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
11831 // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
11832 SDLoc DL(N);
11833 unsigned ShCt = VT.getSizeInBits() - 1;
11834 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11835 if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
11836 SDValue NotX = DAG.getNOT(DL, X, VT);
11837 SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
11838 auto ShiftOpcode =
11839 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
11840 return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
11841 }
11842 }
11843 return SDValue();
11844 }
11845
foldSextSetcc(SDNode * N)11846 SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
11847 SDValue N0 = N->getOperand(0);
11848 if (N0.getOpcode() != ISD::SETCC)
11849 return SDValue();
11850
11851 SDValue N00 = N0.getOperand(0);
11852 SDValue N01 = N0.getOperand(1);
11853 ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
11854 EVT VT = N->getValueType(0);
11855 EVT N00VT = N00.getValueType();
11856 SDLoc DL(N);
11857
11858 // Propagate fast-math-flags.
11859 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
11860
11861 // On some architectures (such as SSE/NEON/etc) the SETCC result type is
11862 // the same size as the compared operands. Try to optimize sext(setcc())
11863 // if this is the case.
11864 if (VT.isVector() && !LegalOperations &&
11865 TLI.getBooleanContents(N00VT) ==
11866 TargetLowering::ZeroOrNegativeOneBooleanContent) {
11867 EVT SVT = getSetCCResultType(N00VT);
11868
11869 // If we already have the desired type, don't change it.
11870 if (SVT != N0.getValueType()) {
11871 // We know that the # elements of the results is the same as the
11872 // # elements of the compare (and the # elements of the compare result
11873 // for that matter). Check to see that they are the same size. If so,
11874 // we know that the element size of the sext'd result matches the
11875 // element size of the compare operands.
11876 if (VT.getSizeInBits() == SVT.getSizeInBits())
11877 return DAG.getSetCC(DL, VT, N00, N01, CC);
11878
11879 // If the desired elements are smaller or larger than the source
11880 // elements, we can use a matching integer vector type and then
11881 // truncate/sign extend.
11882 EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
11883 if (SVT == MatchingVecType) {
11884 SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
11885 return DAG.getSExtOrTrunc(VsetCC, DL, VT);
11886 }
11887 }
11888
11889 // Try to eliminate the sext of a setcc by zexting the compare operands.
11890 if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(ISD::SETCC, VT) &&
11891 !TLI.isOperationLegalOrCustom(ISD::SETCC, SVT)) {
11892 bool IsSignedCmp = ISD::isSignedIntSetCC(CC);
11893 unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
11894 unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
11895
11896 // We have an unsupported narrow vector compare op that would be legal
11897 // if extended to the destination type. See if the compare operands
11898 // can be freely extended to the destination type.
11899 auto IsFreeToExtend = [&](SDValue V) {
11900 if (isConstantOrConstantVector(V, /*NoOpaques*/ true))
11901 return true;
11902 // Match a simple, non-extended load that can be converted to a
11903 // legal {z/s}ext-load.
11904 // TODO: Allow widening of an existing {z/s}ext-load?
11905 if (!(ISD::isNON_EXTLoad(V.getNode()) &&
11906 ISD::isUNINDEXEDLoad(V.getNode()) &&
11907 cast<LoadSDNode>(V)->isSimple() &&
11908 TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType())))
11909 return false;
11910
11911 // Non-chain users of this value must either be the setcc in this
11912 // sequence or extends that can be folded into the new {z/s}ext-load.
11913 for (SDNode::use_iterator UI = V->use_begin(), UE = V->use_end();
11914 UI != UE; ++UI) {
11915 // Skip uses of the chain and the setcc.
11916 SDNode *User = *UI;
11917 if (UI.getUse().getResNo() != 0 || User == N0.getNode())
11918 continue;
11919 // Extra users must have exactly the same cast we are about to create.
11920 // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
11921 // is enhanced similarly.
11922 if (User->getOpcode() != ExtOpcode || User->getValueType(0) != VT)
11923 return false;
11924 }
11925 return true;
11926 };
11927
11928 if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
11929 SDValue Ext0 = DAG.getNode(ExtOpcode, DL, VT, N00);
11930 SDValue Ext1 = DAG.getNode(ExtOpcode, DL, VT, N01);
11931 return DAG.getSetCC(DL, VT, Ext0, Ext1, CC);
11932 }
11933 }
11934 }
11935
11936 // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
11937 // Here, T can be 1 or -1, depending on the type of the setcc and
11938 // getBooleanContents().
11939 unsigned SetCCWidth = N0.getScalarValueSizeInBits();
11940
11941 // To determine the "true" side of the select, we need to know the high bit
11942 // of the value returned by the setcc if it evaluates to true.
11943 // If the type of the setcc is i1, then the true case of the select is just
11944 // sext(i1 1), that is, -1.
11945 // If the type of the setcc is larger (say, i8) then the value of the high
11946 // bit depends on getBooleanContents(), so ask TLI for a real "true" value
11947 // of the appropriate width.
11948 SDValue ExtTrueVal = (SetCCWidth == 1)
11949 ? DAG.getAllOnesConstant(DL, VT)
11950 : DAG.getBoolConstant(true, DL, VT, N00VT);
11951 SDValue Zero = DAG.getConstant(0, DL, VT);
11952 if (SDValue SCC = SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
11953 return SCC;
11954
11955 if (!VT.isVector() && !TLI.convertSelectOfConstantsToMath(VT)) {
11956 EVT SetCCVT = getSetCCResultType(N00VT);
11957 // Don't do this transform for i1 because there's a select transform
11958 // that would reverse it.
11959 // TODO: We should not do this transform at all without a target hook
11960 // because a sext is likely cheaper than a select?
11961 if (SetCCVT.getScalarSizeInBits() != 1 &&
11962 (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
11963 SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
11964 return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
11965 }
11966 }
11967
11968 return SDValue();
11969 }
11970
visitSIGN_EXTEND(SDNode * N)11971 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
11972 SDValue N0 = N->getOperand(0);
11973 EVT VT = N->getValueType(0);
11974 SDLoc DL(N);
11975
11976 // sext(undef) = 0 because the top bit will all be the same.
11977 if (N0.isUndef())
11978 return DAG.getConstant(0, DL, VT);
11979
11980 if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
11981 return Res;
11982
11983 // fold (sext (sext x)) -> (sext x)
11984 // fold (sext (aext x)) -> (sext x)
11985 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
11986 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
11987
11988 if (N0.getOpcode() == ISD::TRUNCATE) {
11989 // fold (sext (truncate (load x))) -> (sext (smaller load x))
11990 // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
11991 if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
11992 SDNode *oye = N0.getOperand(0).getNode();
11993 if (NarrowLoad.getNode() != N0.getNode()) {
11994 CombineTo(N0.getNode(), NarrowLoad);
11995 // CombineTo deleted the truncate, if needed, but not what's under it.
11996 AddToWorklist(oye);
11997 }
11998 return SDValue(N, 0); // Return N so it doesn't get rechecked!
11999 }
12000
12001 // See if the value being truncated is already sign extended. If so, just
12002 // eliminate the trunc/sext pair.
12003 SDValue Op = N0.getOperand(0);
12004 unsigned OpBits = Op.getScalarValueSizeInBits();
12005 unsigned MidBits = N0.getScalarValueSizeInBits();
12006 unsigned DestBits = VT.getScalarSizeInBits();
12007 unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
12008
12009 if (OpBits == DestBits) {
12010 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
12011 // bits, it is already ready.
12012 if (NumSignBits > DestBits-MidBits)
12013 return Op;
12014 } else if (OpBits < DestBits) {
12015 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
12016 // bits, just sext from i32.
12017 if (NumSignBits > OpBits-MidBits)
12018 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
12019 } else {
12020 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
12021 // bits, just truncate to i32.
12022 if (NumSignBits > OpBits-MidBits)
12023 return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
12024 }
12025
12026 // fold (sext (truncate x)) -> (sextinreg x).
12027 if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
12028 N0.getValueType())) {
12029 if (OpBits < DestBits)
12030 Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
12031 else if (OpBits > DestBits)
12032 Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
12033 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
12034 DAG.getValueType(N0.getValueType()));
12035 }
12036 }
12037
12038 // Try to simplify (sext (load x)).
12039 if (SDValue foldedExt =
12040 tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
12041 ISD::SEXTLOAD, ISD::SIGN_EXTEND))
12042 return foldedExt;
12043
12044 if (SDValue foldedExt =
12045 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::SEXTLOAD,
12046 ISD::SIGN_EXTEND))
12047 return foldedExt;
12048
12049 // fold (sext (load x)) to multiple smaller sextloads.
12050 // Only on illegal but splittable vectors.
12051 if (SDValue ExtLoad = CombineExtLoad(N))
12052 return ExtLoad;
12053
12054 // Try to simplify (sext (sextload x)).
12055 if (SDValue foldedExt = tryToFoldExtOfExtload(
12056 DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
12057 return foldedExt;
12058
12059 // fold (sext (and/or/xor (load x), cst)) ->
12060 // (and/or/xor (sextload x), (sext cst))
12061 if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
12062 N0.getOpcode() == ISD::XOR) &&
12063 isa<LoadSDNode>(N0.getOperand(0)) &&
12064 N0.getOperand(1).getOpcode() == ISD::Constant &&
12065 (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
12066 LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
12067 EVT MemVT = LN00->getMemoryVT();
12068 if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
12069 LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
12070 SmallVector<SDNode*, 4> SetCCs;
12071 bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
12072 ISD::SIGN_EXTEND, SetCCs, TLI);
12073 if (DoXform) {
12074 SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
12075 LN00->getChain(), LN00->getBasePtr(),
12076 LN00->getMemoryVT(),
12077 LN00->getMemOperand());
12078 APInt Mask = N0.getConstantOperandAPInt(1).sext(VT.getSizeInBits());
12079 SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
12080 ExtLoad, DAG.getConstant(Mask, DL, VT));
12081 ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
12082 bool NoReplaceTruncAnd = !N0.hasOneUse();
12083 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
12084 CombineTo(N, And);
12085 // If N0 has multiple uses, change other uses as well.
12086 if (NoReplaceTruncAnd) {
12087 SDValue TruncAnd =
12088 DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
12089 CombineTo(N0.getNode(), TruncAnd);
12090 }
12091 if (NoReplaceTrunc) {
12092 DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
12093 } else {
12094 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
12095 LN00->getValueType(0), ExtLoad);
12096 CombineTo(LN00, Trunc, ExtLoad.getValue(1));
12097 }
12098 return SDValue(N,0); // Return N so it doesn't get rechecked!
12099 }
12100 }
12101 }
12102
12103 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
12104 return V;
12105
12106 if (SDValue V = foldSextSetcc(N))
12107 return V;
12108
12109 // fold (sext x) -> (zext x) if the sign bit is known zero.
12110 if ((!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
12111 DAG.SignBitIsZero(N0))
12112 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0);
12113
12114 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
12115 return NewVSel;
12116
12117 // Eliminate this sign extend by doing a negation in the destination type:
12118 // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
12119 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
12120 isNullOrNullSplat(N0.getOperand(0)) &&
12121 N0.getOperand(1).getOpcode() == ISD::ZERO_EXTEND &&
12122 TLI.isOperationLegalOrCustom(ISD::SUB, VT)) {
12123 SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
12124 return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Zext);
12125 }
12126 // Eliminate this sign extend by doing a decrement in the destination type:
12127 // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
12128 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
12129 isAllOnesOrAllOnesSplat(N0.getOperand(1)) &&
12130 N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
12131 TLI.isOperationLegalOrCustom(ISD::ADD, VT)) {
12132 SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
12133 return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
12134 }
12135
12136 // fold sext (not i1 X) -> add (zext i1 X), -1
12137 // TODO: This could be extended to handle bool vectors.
12138 if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
12139 (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
12140 TLI.isOperationLegal(ISD::ADD, VT)))) {
12141 // If we can eliminate the 'not', the sext form should be better
12142 if (SDValue NewXor = visitXOR(N0.getNode())) {
12143 // Returning N0 is a form of in-visit replacement that may have
12144 // invalidated N0.
12145 if (NewXor.getNode() == N0.getNode()) {
12146 // Return SDValue here as the xor should have already been replaced in
12147 // this sext.
12148 return SDValue();
12149 }
12150
12151 // Return a new sext with the new xor.
12152 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor);
12153 }
12154
12155 SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
12156 return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
12157 }
12158
12159 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
12160 return Res;
12161
12162 return SDValue();
12163 }
12164
12165 // isTruncateOf - If N is a truncate of some other value, return true, record
12166 // the value being truncated in Op and which of Op's bits are zero/one in Known.
12167 // This function computes KnownBits to avoid a duplicated call to
12168 // computeKnownBits in the caller.
isTruncateOf(SelectionDAG & DAG,SDValue N,SDValue & Op,KnownBits & Known)12169 static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
12170 KnownBits &Known) {
12171 if (N->getOpcode() == ISD::TRUNCATE) {
12172 Op = N->getOperand(0);
12173 Known = DAG.computeKnownBits(Op);
12174 return true;
12175 }
12176
12177 if (N.getOpcode() != ISD::SETCC ||
12178 N.getValueType().getScalarType() != MVT::i1 ||
12179 cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
12180 return false;
12181
12182 SDValue Op0 = N->getOperand(0);
12183 SDValue Op1 = N->getOperand(1);
12184 assert(Op0.getValueType() == Op1.getValueType());
12185
12186 if (isNullOrNullSplat(Op0))
12187 Op = Op1;
12188 else if (isNullOrNullSplat(Op1))
12189 Op = Op0;
12190 else
12191 return false;
12192
12193 Known = DAG.computeKnownBits(Op);
12194
12195 return (Known.Zero | 1).isAllOnes();
12196 }
12197
12198 /// Given an extending node with a pop-count operand, if the target does not
12199 /// support a pop-count in the narrow source type but does support it in the
12200 /// destination type, widen the pop-count to the destination type.
widenCtPop(SDNode * Extend,SelectionDAG & DAG)12201 static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG) {
12202 assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
12203 Extend->getOpcode() == ISD::ANY_EXTEND) && "Expected extend op");
12204
12205 SDValue CtPop = Extend->getOperand(0);
12206 if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
12207 return SDValue();
12208
12209 EVT VT = Extend->getValueType(0);
12210 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12211 if (TLI.isOperationLegalOrCustom(ISD::CTPOP, CtPop.getValueType()) ||
12212 !TLI.isOperationLegalOrCustom(ISD::CTPOP, VT))
12213 return SDValue();
12214
12215 // zext (ctpop X) --> ctpop (zext X)
12216 SDLoc DL(Extend);
12217 SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
12218 return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
12219 }
12220
visitZERO_EXTEND(SDNode * N)12221 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
12222 SDValue N0 = N->getOperand(0);
12223 EVT VT = N->getValueType(0);
12224
12225 // zext(undef) = 0
12226 if (N0.isUndef())
12227 return DAG.getConstant(0, SDLoc(N), VT);
12228
12229 if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
12230 return Res;
12231
12232 // fold (zext (zext x)) -> (zext x)
12233 // fold (zext (aext x)) -> (zext x)
12234 if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
12235 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT,
12236 N0.getOperand(0));
12237
12238 // fold (zext (truncate x)) -> (zext x) or
12239 // (zext (truncate x)) -> (truncate x)
12240 // This is valid when the truncated bits of x are already zero.
12241 SDValue Op;
12242 KnownBits Known;
12243 if (isTruncateOf(DAG, N0, Op, Known)) {
12244 APInt TruncatedBits =
12245 (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
12246 APInt(Op.getScalarValueSizeInBits(), 0) :
12247 APInt::getBitsSet(Op.getScalarValueSizeInBits(),
12248 N0.getScalarValueSizeInBits(),
12249 std::min(Op.getScalarValueSizeInBits(),
12250 VT.getScalarSizeInBits()));
12251 if (TruncatedBits.isSubsetOf(Known.Zero))
12252 return DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
12253 }
12254
12255 // fold (zext (truncate x)) -> (and x, mask)
12256 if (N0.getOpcode() == ISD::TRUNCATE) {
12257 // fold (zext (truncate (load x))) -> (zext (smaller load x))
12258 // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
12259 if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
12260 SDNode *oye = N0.getOperand(0).getNode();
12261 if (NarrowLoad.getNode() != N0.getNode()) {
12262 CombineTo(N0.getNode(), NarrowLoad);
12263 // CombineTo deleted the truncate, if needed, but not what's under it.
12264 AddToWorklist(oye);
12265 }
12266 return SDValue(N, 0); // Return N so it doesn't get rechecked!
12267 }
12268
12269 EVT SrcVT = N0.getOperand(0).getValueType();
12270 EVT MinVT = N0.getValueType();
12271
12272 // Try to mask before the extension to avoid having to generate a larger mask,
12273 // possibly over several sub-vectors.
12274 if (SrcVT.bitsLT(VT) && VT.isVector()) {
12275 if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
12276 TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
12277 SDValue Op = N0.getOperand(0);
12278 Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
12279 AddToWorklist(Op.getNode());
12280 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
12281 // Transfer the debug info; the new node is equivalent to N0.
12282 DAG.transferDbgValues(N0, ZExtOrTrunc);
12283 return ZExtOrTrunc;
12284 }
12285 }
12286
12287 if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
12288 SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
12289 AddToWorklist(Op.getNode());
12290 SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
12291 // We may safely transfer the debug info describing the truncate node over
12292 // to the equivalent and operation.
12293 DAG.transferDbgValues(N0, And);
12294 return And;
12295 }
12296 }
12297
12298 // Fold (zext (and (trunc x), cst)) -> (and x, cst),
12299 // if either of the casts is not free.
12300 if (N0.getOpcode() == ISD::AND &&
12301 N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
12302 N0.getOperand(1).getOpcode() == ISD::Constant &&
12303 (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
12304 N0.getValueType()) ||
12305 !TLI.isZExtFree(N0.getValueType(), VT))) {
12306 SDValue X = N0.getOperand(0).getOperand(0);
12307 X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
12308 APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
12309 SDLoc DL(N);
12310 return DAG.getNode(ISD::AND, DL, VT,
12311 X, DAG.getConstant(Mask, DL, VT));
12312 }
12313
12314 // Try to simplify (zext (load x)).
12315 if (SDValue foldedExt =
12316 tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
12317 ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
12318 return foldedExt;
12319
12320 if (SDValue foldedExt =
12321 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::ZEXTLOAD,
12322 ISD::ZERO_EXTEND))
12323 return foldedExt;
12324
12325 // fold (zext (load x)) to multiple smaller zextloads.
12326 // Only on illegal but splittable vectors.
12327 if (SDValue ExtLoad = CombineExtLoad(N))
12328 return ExtLoad;
12329
12330 // fold (zext (and/or/xor (load x), cst)) ->
12331 // (and/or/xor (zextload x), (zext cst))
12332 // Unless (and (load x) cst) will match as a zextload already and has
12333 // additional users.
12334 if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
12335 N0.getOpcode() == ISD::XOR) &&
12336 isa<LoadSDNode>(N0.getOperand(0)) &&
12337 N0.getOperand(1).getOpcode() == ISD::Constant &&
12338 (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
12339 LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
12340 EVT MemVT = LN00->getMemoryVT();
12341 if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
12342 LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
12343 bool DoXform = true;
12344 SmallVector<SDNode*, 4> SetCCs;
12345 if (!N0.hasOneUse()) {
12346 if (N0.getOpcode() == ISD::AND) {
12347 auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
12348 EVT LoadResultTy = AndC->getValueType(0);
12349 EVT ExtVT;
12350 if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
12351 DoXform = false;
12352 }
12353 }
12354 if (DoXform)
12355 DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
12356 ISD::ZERO_EXTEND, SetCCs, TLI);
12357 if (DoXform) {
12358 SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
12359 LN00->getChain(), LN00->getBasePtr(),
12360 LN00->getMemoryVT(),
12361 LN00->getMemOperand());
12362 APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
12363 SDLoc DL(N);
12364 SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
12365 ExtLoad, DAG.getConstant(Mask, DL, VT));
12366 ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
12367 bool NoReplaceTruncAnd = !N0.hasOneUse();
12368 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
12369 CombineTo(N, And);
12370 // If N0 has multiple uses, change other uses as well.
12371 if (NoReplaceTruncAnd) {
12372 SDValue TruncAnd =
12373 DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
12374 CombineTo(N0.getNode(), TruncAnd);
12375 }
12376 if (NoReplaceTrunc) {
12377 DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
12378 } else {
12379 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
12380 LN00->getValueType(0), ExtLoad);
12381 CombineTo(LN00, Trunc, ExtLoad.getValue(1));
12382 }
12383 return SDValue(N,0); // Return N so it doesn't get rechecked!
12384 }
12385 }
12386 }
12387
12388 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
12389 // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
12390 if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
12391 return ZExtLoad;
12392
12393 // Try to simplify (zext (zextload x)).
12394 if (SDValue foldedExt = tryToFoldExtOfExtload(
12395 DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
12396 return foldedExt;
12397
12398 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
12399 return V;
12400
12401 if (N0.getOpcode() == ISD::SETCC) {
12402 // Propagate fast-math-flags.
12403 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
12404
12405 // Only do this before legalize for now.
12406 if (!LegalOperations && VT.isVector() &&
12407 N0.getValueType().getVectorElementType() == MVT::i1) {
12408 EVT N00VT = N0.getOperand(0).getValueType();
12409 if (getSetCCResultType(N00VT) == N0.getValueType())
12410 return SDValue();
12411
12412 // We know that the # elements of the results is the same as the #
12413 // elements of the compare (and the # elements of the compare result for
12414 // that matter). Check to see that they are the same size. If so, we know
12415 // that the element size of the sext'd result matches the element size of
12416 // the compare operands.
12417 SDLoc DL(N);
12418 if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
12419 // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
12420 SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
12421 N0.getOperand(1), N0.getOperand(2));
12422 return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType());
12423 }
12424
12425 // If the desired elements are smaller or larger than the source
12426 // elements we can use a matching integer vector type and then
12427 // truncate/any extend followed by zext_in_reg.
12428 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
12429 SDValue VsetCC =
12430 DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
12431 N0.getOperand(1), N0.getOperand(2));
12432 return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL,
12433 N0.getValueType());
12434 }
12435
12436 // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
12437 SDLoc DL(N);
12438 EVT N0VT = N0.getValueType();
12439 EVT N00VT = N0.getOperand(0).getValueType();
12440 if (SDValue SCC = SimplifySelectCC(
12441 DL, N0.getOperand(0), N0.getOperand(1),
12442 DAG.getBoolConstant(true, DL, N0VT, N00VT),
12443 DAG.getBoolConstant(false, DL, N0VT, N00VT),
12444 cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
12445 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, SCC);
12446 }
12447
12448 // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
12449 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
12450 isa<ConstantSDNode>(N0.getOperand(1)) &&
12451 N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
12452 N0.hasOneUse()) {
12453 SDValue ShAmt = N0.getOperand(1);
12454 if (N0.getOpcode() == ISD::SHL) {
12455 SDValue InnerZExt = N0.getOperand(0);
12456 // If the original shl may be shifting out bits, do not perform this
12457 // transformation.
12458 unsigned KnownZeroBits = InnerZExt.getValueSizeInBits() -
12459 InnerZExt.getOperand(0).getValueSizeInBits();
12460 if (cast<ConstantSDNode>(ShAmt)->getAPIntValue().ugt(KnownZeroBits))
12461 return SDValue();
12462 }
12463
12464 SDLoc DL(N);
12465
12466 // Ensure that the shift amount is wide enough for the shifted value.
12467 if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
12468 ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
12469
12470 return DAG.getNode(N0.getOpcode(), DL, VT,
12471 DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)),
12472 ShAmt);
12473 }
12474
12475 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
12476 return NewVSel;
12477
12478 if (SDValue NewCtPop = widenCtPop(N, DAG))
12479 return NewCtPop;
12480
12481 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
12482 return Res;
12483
12484 return SDValue();
12485 }
12486
visitANY_EXTEND(SDNode * N)12487 SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
12488 SDValue N0 = N->getOperand(0);
12489 EVT VT = N->getValueType(0);
12490
12491 // aext(undef) = undef
12492 if (N0.isUndef())
12493 return DAG.getUNDEF(VT);
12494
12495 if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
12496 return Res;
12497
12498 // fold (aext (aext x)) -> (aext x)
12499 // fold (aext (zext x)) -> (zext x)
12500 // fold (aext (sext x)) -> (sext x)
12501 if (N0.getOpcode() == ISD::ANY_EXTEND ||
12502 N0.getOpcode() == ISD::ZERO_EXTEND ||
12503 N0.getOpcode() == ISD::SIGN_EXTEND)
12504 return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
12505
12506 // fold (aext (truncate (load x))) -> (aext (smaller load x))
12507 // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
12508 if (N0.getOpcode() == ISD::TRUNCATE) {
12509 if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
12510 SDNode *oye = N0.getOperand(0).getNode();
12511 if (NarrowLoad.getNode() != N0.getNode()) {
12512 CombineTo(N0.getNode(), NarrowLoad);
12513 // CombineTo deleted the truncate, if needed, but not what's under it.
12514 AddToWorklist(oye);
12515 }
12516 return SDValue(N, 0); // Return N so it doesn't get rechecked!
12517 }
12518 }
12519
12520 // fold (aext (truncate x))
12521 if (N0.getOpcode() == ISD::TRUNCATE)
12522 return DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
12523
12524 // Fold (aext (and (trunc x), cst)) -> (and x, cst)
12525 // if the trunc is not free.
12526 if (N0.getOpcode() == ISD::AND &&
12527 N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
12528 N0.getOperand(1).getOpcode() == ISD::Constant &&
12529 !TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
12530 N0.getValueType())) {
12531 SDLoc DL(N);
12532 SDValue X = DAG.getAnyExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
12533 SDValue Y = DAG.getNode(ISD::ANY_EXTEND, DL, VT, N0.getOperand(1));
12534 assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
12535 return DAG.getNode(ISD::AND, DL, VT, X, Y);
12536 }
12537
12538 // fold (aext (load x)) -> (aext (truncate (extload x)))
12539 // None of the supported targets knows how to perform load and any_ext
12540 // on vectors in one instruction, so attempt to fold to zext instead.
12541 if (VT.isVector()) {
12542 // Try to simplify (zext (load x)).
12543 if (SDValue foldedExt =
12544 tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
12545 ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
12546 return foldedExt;
12547 } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
12548 ISD::isUNINDEXEDLoad(N0.getNode()) &&
12549 TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
12550 bool DoXform = true;
12551 SmallVector<SDNode *, 4> SetCCs;
12552 if (!N0.hasOneUse())
12553 DoXform =
12554 ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, TLI);
12555 if (DoXform) {
12556 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12557 SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
12558 LN0->getChain(), LN0->getBasePtr(),
12559 N0.getValueType(), LN0->getMemOperand());
12560 ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
12561 // If the load value is used only by N, replace it via CombineTo N.
12562 bool NoReplaceTrunc = N0.hasOneUse();
12563 CombineTo(N, ExtLoad);
12564 if (NoReplaceTrunc) {
12565 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
12566 recursivelyDeleteUnusedNodes(LN0);
12567 } else {
12568 SDValue Trunc =
12569 DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
12570 CombineTo(LN0, Trunc, ExtLoad.getValue(1));
12571 }
12572 return SDValue(N, 0); // Return N so it doesn't get rechecked!
12573 }
12574 }
12575
12576 // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
12577 // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
12578 // fold (aext ( extload x)) -> (aext (truncate (extload x)))
12579 if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
12580 ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
12581 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12582 ISD::LoadExtType ExtType = LN0->getExtensionType();
12583 EVT MemVT = LN0->getMemoryVT();
12584 if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
12585 SDValue ExtLoad = DAG.getExtLoad(ExtType, SDLoc(N),
12586 VT, LN0->getChain(), LN0->getBasePtr(),
12587 MemVT, LN0->getMemOperand());
12588 CombineTo(N, ExtLoad);
12589 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
12590 recursivelyDeleteUnusedNodes(LN0);
12591 return SDValue(N, 0); // Return N so it doesn't get rechecked!
12592 }
12593 }
12594
12595 if (N0.getOpcode() == ISD::SETCC) {
12596 // Propagate fast-math-flags.
12597 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
12598
12599 // For vectors:
12600 // aext(setcc) -> vsetcc
12601 // aext(setcc) -> truncate(vsetcc)
12602 // aext(setcc) -> aext(vsetcc)
12603 // Only do this before legalize for now.
12604 if (VT.isVector() && !LegalOperations) {
12605 EVT N00VT = N0.getOperand(0).getValueType();
12606 if (getSetCCResultType(N00VT) == N0.getValueType())
12607 return SDValue();
12608
12609 // We know that the # elements of the results is the same as the
12610 // # elements of the compare (and the # elements of the compare result
12611 // for that matter). Check to see that they are the same size. If so,
12612 // we know that the element size of the sext'd result matches the
12613 // element size of the compare operands.
12614 if (VT.getSizeInBits() == N00VT.getSizeInBits())
12615 return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0),
12616 N0.getOperand(1),
12617 cast<CondCodeSDNode>(N0.getOperand(2))->get());
12618
12619 // If the desired elements are smaller or larger than the source
12620 // elements we can use a matching integer vector type and then
12621 // truncate/any extend
12622 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
12623 SDValue VsetCC =
12624 DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0),
12625 N0.getOperand(1),
12626 cast<CondCodeSDNode>(N0.getOperand(2))->get());
12627 return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT);
12628 }
12629
12630 // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
12631 SDLoc DL(N);
12632 if (SDValue SCC = SimplifySelectCC(
12633 DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
12634 DAG.getConstant(0, DL, VT),
12635 cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
12636 return SCC;
12637 }
12638
12639 if (SDValue NewCtPop = widenCtPop(N, DAG))
12640 return NewCtPop;
12641
12642 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
12643 return Res;
12644
12645 return SDValue();
12646 }
12647
visitAssertExt(SDNode * N)12648 SDValue DAGCombiner::visitAssertExt(SDNode *N) {
12649 unsigned Opcode = N->getOpcode();
12650 SDValue N0 = N->getOperand(0);
12651 SDValue N1 = N->getOperand(1);
12652 EVT AssertVT = cast<VTSDNode>(N1)->getVT();
12653
12654 // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
12655 if (N0.getOpcode() == Opcode &&
12656 AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
12657 return N0;
12658
12659 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
12660 N0.getOperand(0).getOpcode() == Opcode) {
12661 // We have an assert, truncate, assert sandwich. Make one stronger assert
12662 // by asserting on the smallest asserted type to the larger source type.
12663 // This eliminates the later assert:
12664 // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
12665 // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
12666 SDLoc DL(N);
12667 SDValue BigA = N0.getOperand(0);
12668 EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
12669 EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
12670 SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
12671 SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
12672 BigA.getOperand(0), MinAssertVTVal);
12673 return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
12674 }
12675
12676 // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
12677 // than X. Just move the AssertZext in front of the truncate and drop the
12678 // AssertSExt.
12679 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
12680 N0.getOperand(0).getOpcode() == ISD::AssertSext &&
12681 Opcode == ISD::AssertZext) {
12682 SDValue BigA = N0.getOperand(0);
12683 EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
12684 if (AssertVT.bitsLT(BigA_AssertVT)) {
12685 SDLoc DL(N);
12686 SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
12687 BigA.getOperand(0), N1);
12688 return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
12689 }
12690 }
12691
12692 return SDValue();
12693 }
12694
visitAssertAlign(SDNode * N)12695 SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
12696 SDLoc DL(N);
12697
12698 Align AL = cast<AssertAlignSDNode>(N)->getAlign();
12699 SDValue N0 = N->getOperand(0);
12700
12701 // Fold (assertalign (assertalign x, AL0), AL1) ->
12702 // (assertalign x, max(AL0, AL1))
12703 if (auto *AAN = dyn_cast<AssertAlignSDNode>(N0))
12704 return DAG.getAssertAlign(DL, N0.getOperand(0),
12705 std::max(AL, AAN->getAlign()));
12706
12707 // In rare cases, there are trivial arithmetic ops in source operands. Sink
12708 // this assert down to source operands so that those arithmetic ops could be
12709 // exposed to the DAG combining.
12710 switch (N0.getOpcode()) {
12711 default:
12712 break;
12713 case ISD::ADD:
12714 case ISD::SUB: {
12715 unsigned AlignShift = Log2(AL);
12716 SDValue LHS = N0.getOperand(0);
12717 SDValue RHS = N0.getOperand(1);
12718 unsigned LHSAlignShift = DAG.computeKnownBits(LHS).countMinTrailingZeros();
12719 unsigned RHSAlignShift = DAG.computeKnownBits(RHS).countMinTrailingZeros();
12720 if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
12721 if (LHSAlignShift < AlignShift)
12722 LHS = DAG.getAssertAlign(DL, LHS, AL);
12723 if (RHSAlignShift < AlignShift)
12724 RHS = DAG.getAssertAlign(DL, RHS, AL);
12725 return DAG.getNode(N0.getOpcode(), DL, N0.getValueType(), LHS, RHS);
12726 }
12727 break;
12728 }
12729 }
12730
12731 return SDValue();
12732 }
12733
12734 /// If the result of a load is shifted/masked/truncated to an effectively
12735 /// narrower type, try to transform the load to a narrower type and/or
12736 /// use an extending load.
reduceLoadWidth(SDNode * N)12737 SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
12738 unsigned Opc = N->getOpcode();
12739
12740 ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
12741 SDValue N0 = N->getOperand(0);
12742 EVT VT = N->getValueType(0);
12743 EVT ExtVT = VT;
12744
12745 // This transformation isn't valid for vector loads.
12746 if (VT.isVector())
12747 return SDValue();
12748
12749 // The ShAmt variable is used to indicate that we've consumed a right
12750 // shift. I.e. we want to narrow the width of the load by skipping to load the
12751 // ShAmt least significant bits.
12752 unsigned ShAmt = 0;
12753 // A special case is when the least significant bits from the load are masked
12754 // away, but using an AND rather than a right shift. HasShiftedOffset is used
12755 // to indicate that the narrowed load should be left-shifted ShAmt bits to get
12756 // the result.
12757 bool HasShiftedOffset = false;
12758 // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
12759 // extended to VT.
12760 if (Opc == ISD::SIGN_EXTEND_INREG) {
12761 ExtType = ISD::SEXTLOAD;
12762 ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
12763 } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
12764 // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
12765 // value, or it may be shifting a higher subword, half or byte into the
12766 // lowest bits.
12767
12768 // Only handle shift with constant shift amount, and the shiftee must be a
12769 // load.
12770 auto *LN = dyn_cast<LoadSDNode>(N0);
12771 auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
12772 if (!N1C || !LN)
12773 return SDValue();
12774 // If the shift amount is larger than the memory type then we're not
12775 // accessing any of the loaded bytes.
12776 ShAmt = N1C->getZExtValue();
12777 uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
12778 if (MemoryWidth <= ShAmt)
12779 return SDValue();
12780 // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
12781 ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
12782 ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
12783 // If original load is a SEXTLOAD then we can't simply replace it by a
12784 // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
12785 // followed by a ZEXT, but that is not handled at the moment). Similarly if
12786 // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
12787 if ((LN->getExtensionType() == ISD::SEXTLOAD ||
12788 LN->getExtensionType() == ISD::ZEXTLOAD) &&
12789 LN->getExtensionType() != ExtType)
12790 return SDValue();
12791 } else if (Opc == ISD::AND) {
12792 // An AND with a constant mask is the same as a truncate + zero-extend.
12793 auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
12794 if (!AndC)
12795 return SDValue();
12796
12797 const APInt &Mask = AndC->getAPIntValue();
12798 unsigned ActiveBits = 0;
12799 if (Mask.isMask()) {
12800 ActiveBits = Mask.countTrailingOnes();
12801 } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) {
12802 HasShiftedOffset = true;
12803 } else {
12804 return SDValue();
12805 }
12806
12807 ExtType = ISD::ZEXTLOAD;
12808 ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
12809 }
12810
12811 // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
12812 // a right shift. Here we redo some of those checks, to possibly adjust the
12813 // ExtVT even further based on "a masking AND". We could also end up here for
12814 // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
12815 // need to be done here as well.
12816 if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
12817 SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
12818 // Bail out when the SRL has more than one use. This is done for historical
12819 // (undocumented) reasons. Maybe intent was to guard the AND-masking below
12820 // check below? And maybe it could be non-profitable to do the transform in
12821 // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
12822 // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
12823 if (!SRL.hasOneUse())
12824 return SDValue();
12825
12826 // Only handle shift with constant shift amount, and the shiftee must be a
12827 // load.
12828 auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
12829 auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
12830 if (!SRL1C || !LN)
12831 return SDValue();
12832
12833 // If the shift amount is larger than the input type then we're not
12834 // accessing any of the loaded bytes. If the load was a zextload/extload
12835 // then the result of the shift+trunc is zero/undef (handled elsewhere).
12836 ShAmt = SRL1C->getZExtValue();
12837 uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
12838 if (ShAmt >= MemoryWidth)
12839 return SDValue();
12840
12841 // Because a SRL must be assumed to *need* to zero-extend the high bits
12842 // (as opposed to anyext the high bits), we can't combine the zextload
12843 // lowering of SRL and an sextload.
12844 if (LN->getExtensionType() == ISD::SEXTLOAD)
12845 return SDValue();
12846
12847 // Avoid reading outside the memory accessed by the original load (could
12848 // happened if we only adjust the load base pointer by ShAmt). Instead we
12849 // try to narrow the load even further. The typical scenario here is:
12850 // (i64 (truncate (i96 (srl (load x), 64)))) ->
12851 // (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
12852 if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
12853 // Don't replace sextload by zextload.
12854 if (ExtType == ISD::SEXTLOAD)
12855 return SDValue();
12856 // Narrow the load.
12857 ExtType = ISD::ZEXTLOAD;
12858 ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
12859 }
12860
12861 // If the SRL is only used by a masking AND, we may be able to adjust
12862 // the ExtVT to make the AND redundant.
12863 SDNode *Mask = *(SRL->use_begin());
12864 if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
12865 isa<ConstantSDNode>(Mask->getOperand(1))) {
12866 const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
12867 if (ShiftMask.isMask()) {
12868 EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
12869 ShiftMask.countTrailingOnes());
12870 // If the mask is smaller, recompute the type.
12871 if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
12872 TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
12873 ExtVT = MaskedVT;
12874 }
12875 }
12876
12877 N0 = SRL.getOperand(0);
12878 }
12879
12880 // If the load is shifted left (and the result isn't shifted back right), we
12881 // can fold a truncate through the shift. The typical scenario is that N
12882 // points at a TRUNCATE here so the attempted fold is:
12883 // (truncate (shl (load x), c))) -> (shl (narrow load x), c)
12884 // ShLeftAmt will indicate how much a narrowed load should be shifted left.
12885 unsigned ShLeftAmt = 0;
12886 if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
12887 ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
12888 if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
12889 ShLeftAmt = N01->getZExtValue();
12890 N0 = N0.getOperand(0);
12891 }
12892 }
12893
12894 // If we haven't found a load, we can't narrow it.
12895 if (!isa<LoadSDNode>(N0))
12896 return SDValue();
12897
12898 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12899 // Reducing the width of a volatile load is illegal. For atomics, we may be
12900 // able to reduce the width provided we never widen again. (see D66309)
12901 if (!LN0->isSimple() ||
12902 !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
12903 return SDValue();
12904
12905 auto AdjustBigEndianShift = [&](unsigned ShAmt) {
12906 unsigned LVTStoreBits =
12907 LN0->getMemoryVT().getStoreSizeInBits().getFixedSize();
12908 unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedSize();
12909 return LVTStoreBits - EVTStoreBits - ShAmt;
12910 };
12911
12912 // We need to adjust the pointer to the load by ShAmt bits in order to load
12913 // the correct bytes.
12914 unsigned PtrAdjustmentInBits =
12915 DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
12916
12917 uint64_t PtrOff = PtrAdjustmentInBits / 8;
12918 Align NewAlign = commonAlignment(LN0->getAlign(), PtrOff);
12919 SDLoc DL(LN0);
12920 // The original load itself didn't wrap, so an offset within it doesn't.
12921 SDNodeFlags Flags;
12922 Flags.setNoUnsignedWrap(true);
12923 SDValue NewPtr = DAG.getMemBasePlusOffset(LN0->getBasePtr(),
12924 TypeSize::Fixed(PtrOff), DL, Flags);
12925 AddToWorklist(NewPtr.getNode());
12926
12927 SDValue Load;
12928 if (ExtType == ISD::NON_EXTLOAD)
12929 Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
12930 LN0->getPointerInfo().getWithOffset(PtrOff), NewAlign,
12931 LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
12932 else
12933 Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
12934 LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
12935 NewAlign, LN0->getMemOperand()->getFlags(),
12936 LN0->getAAInfo());
12937
12938 // Replace the old load's chain with the new load's chain.
12939 WorklistRemover DeadNodes(*this);
12940 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
12941
12942 // Shift the result left, if we've swallowed a left shift.
12943 SDValue Result = Load;
12944 if (ShLeftAmt != 0) {
12945 EVT ShImmTy = getShiftAmountTy(Result.getValueType());
12946 if (!isUIntN(ShImmTy.getScalarSizeInBits(), ShLeftAmt))
12947 ShImmTy = VT;
12948 // If the shift amount is as large as the result size (but, presumably,
12949 // no larger than the source) then the useful bits of the result are
12950 // zero; we can't simply return the shortened shift, because the result
12951 // of that operation is undefined.
12952 if (ShLeftAmt >= VT.getScalarSizeInBits())
12953 Result = DAG.getConstant(0, DL, VT);
12954 else
12955 Result = DAG.getNode(ISD::SHL, DL, VT,
12956 Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
12957 }
12958
12959 if (HasShiftedOffset) {
12960 // We're using a shifted mask, so the load now has an offset. This means
12961 // that data has been loaded into the lower bytes than it would have been
12962 // before, so we need to shl the loaded data into the correct position in the
12963 // register.
12964 SDValue ShiftC = DAG.getConstant(ShAmt, DL, VT);
12965 Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
12966 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
12967 }
12968
12969 // Return the new loaded value.
12970 return Result;
12971 }
12972
visitSIGN_EXTEND_INREG(SDNode * N)12973 SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
12974 SDValue N0 = N->getOperand(0);
12975 SDValue N1 = N->getOperand(1);
12976 EVT VT = N->getValueType(0);
12977 EVT ExtVT = cast<VTSDNode>(N1)->getVT();
12978 unsigned VTBits = VT.getScalarSizeInBits();
12979 unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
12980
12981 // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
12982 if (N0.isUndef())
12983 return DAG.getConstant(0, SDLoc(N), VT);
12984
12985 // fold (sext_in_reg c1) -> c1
12986 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
12987 return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1);
12988
12989 // If the input is already sign extended, just drop the extension.
12990 if (ExtVTBits >= DAG.ComputeMaxSignificantBits(N0))
12991 return N0;
12992
12993 // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
12994 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
12995 ExtVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
12996 return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0.getOperand(0),
12997 N1);
12998
12999 // fold (sext_in_reg (sext x)) -> (sext x)
13000 // fold (sext_in_reg (aext x)) -> (sext x)
13001 // if x is small enough or if we know that x has more than 1 sign bit and the
13002 // sign_extend_inreg is extending from one of them.
13003 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
13004 SDValue N00 = N0.getOperand(0);
13005 unsigned N00Bits = N00.getScalarValueSizeInBits();
13006 if ((N00Bits <= ExtVTBits ||
13007 DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits) &&
13008 (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
13009 return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
13010 }
13011
13012 // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
13013 // if x is small enough or if we know that x has more than 1 sign bit and the
13014 // sign_extend_inreg is extending from one of them.
13015 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13016 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ||
13017 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) {
13018 SDValue N00 = N0.getOperand(0);
13019 unsigned N00Bits = N00.getScalarValueSizeInBits();
13020 unsigned DstElts = N0.getValueType().getVectorMinNumElements();
13021 unsigned SrcElts = N00.getValueType().getVectorMinNumElements();
13022 bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
13023 APInt DemandedSrcElts = APInt::getLowBitsSet(SrcElts, DstElts);
13024 if ((N00Bits == ExtVTBits ||
13025 (!IsZext && (N00Bits < ExtVTBits ||
13026 DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits))) &&
13027 (!LegalOperations ||
13028 TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
13029 return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT, N00);
13030 }
13031
13032 // fold (sext_in_reg (zext x)) -> (sext x)
13033 // iff we are extending the source sign bit.
13034 if (N0.getOpcode() == ISD::ZERO_EXTEND) {
13035 SDValue N00 = N0.getOperand(0);
13036 if (N00.getScalarValueSizeInBits() == ExtVTBits &&
13037 (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
13038 return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1);
13039 }
13040
13041 // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
13042 if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, ExtVTBits - 1)))
13043 return DAG.getZeroExtendInReg(N0, SDLoc(N), ExtVT);
13044
13045 // fold operands of sext_in_reg based on knowledge that the top bits are not
13046 // demanded.
13047 if (SimplifyDemandedBits(SDValue(N, 0)))
13048 return SDValue(N, 0);
13049
13050 // fold (sext_in_reg (load x)) -> (smaller sextload x)
13051 // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
13052 if (SDValue NarrowLoad = reduceLoadWidth(N))
13053 return NarrowLoad;
13054
13055 // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
13056 // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
13057 // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
13058 if (N0.getOpcode() == ISD::SRL) {
13059 if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
13060 if (ShAmt->getAPIntValue().ule(VTBits - ExtVTBits)) {
13061 // We can turn this into an SRA iff the input to the SRL is already sign
13062 // extended enough.
13063 unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
13064 if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
13065 return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0),
13066 N0.getOperand(1));
13067 }
13068 }
13069
13070 // fold (sext_inreg (extload x)) -> (sextload x)
13071 // If sextload is not supported by target, we can only do the combine when
13072 // load has one use. Doing otherwise can block folding the extload with other
13073 // extends that the target does support.
13074 if (ISD::isEXTLoad(N0.getNode()) &&
13075 ISD::isUNINDEXEDLoad(N0.getNode()) &&
13076 ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
13077 ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
13078 N0.hasOneUse()) ||
13079 TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
13080 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13081 SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
13082 LN0->getChain(),
13083 LN0->getBasePtr(), ExtVT,
13084 LN0->getMemOperand());
13085 CombineTo(N, ExtLoad);
13086 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
13087 AddToWorklist(ExtLoad.getNode());
13088 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13089 }
13090
13091 // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
13092 if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
13093 N0.hasOneUse() &&
13094 ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
13095 ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
13096 TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
13097 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13098 SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
13099 LN0->getChain(),
13100 LN0->getBasePtr(), ExtVT,
13101 LN0->getMemOperand());
13102 CombineTo(N, ExtLoad);
13103 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
13104 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13105 }
13106
13107 // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
13108 // ignore it if the masked load is already sign extended
13109 if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0)) {
13110 if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
13111 Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
13112 TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) {
13113 SDValue ExtMaskedLoad = DAG.getMaskedLoad(
13114 VT, SDLoc(N), Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(),
13115 Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(),
13116 Ld->getAddressingMode(), ISD::SEXTLOAD, Ld->isExpandingLoad());
13117 CombineTo(N, ExtMaskedLoad);
13118 CombineTo(N0.getNode(), ExtMaskedLoad, ExtMaskedLoad.getValue(1));
13119 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13120 }
13121 }
13122
13123 // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
13124 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
13125 if (SDValue(GN0, 0).hasOneUse() &&
13126 ExtVT == GN0->getMemoryVT() &&
13127 TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
13128 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
13129 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
13130
13131 SDValue ExtLoad = DAG.getMaskedGather(
13132 DAG.getVTList(VT, MVT::Other), ExtVT, SDLoc(N), Ops,
13133 GN0->getMemOperand(), GN0->getIndexType(), ISD::SEXTLOAD);
13134
13135 CombineTo(N, ExtLoad);
13136 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
13137 AddToWorklist(ExtLoad.getNode());
13138 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13139 }
13140 }
13141
13142 // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
13143 if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
13144 if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
13145 N0.getOperand(1), false))
13146 return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, BSwap, N1);
13147 }
13148
13149 return SDValue();
13150 }
13151
visitEXTEND_VECTOR_INREG(SDNode * N)13152 SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
13153 SDValue N0 = N->getOperand(0);
13154 EVT VT = N->getValueType(0);
13155
13156 // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
13157 if (N0.isUndef())
13158 return DAG.getConstant(0, SDLoc(N), VT);
13159
13160 if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
13161 return Res;
13162
13163 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
13164 return SDValue(N, 0);
13165
13166 return SDValue();
13167 }
13168
visitTRUNCATE(SDNode * N)13169 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
13170 SDValue N0 = N->getOperand(0);
13171 EVT VT = N->getValueType(0);
13172 EVT SrcVT = N0.getValueType();
13173 bool isLE = DAG.getDataLayout().isLittleEndian();
13174
13175 // noop truncate
13176 if (SrcVT == VT)
13177 return N0;
13178
13179 // fold (truncate (truncate x)) -> (truncate x)
13180 if (N0.getOpcode() == ISD::TRUNCATE)
13181 return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
13182
13183 // fold (truncate c1) -> c1
13184 if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
13185 SDValue C = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0);
13186 if (C.getNode() != N)
13187 return C;
13188 }
13189
13190 // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
13191 if (N0.getOpcode() == ISD::ZERO_EXTEND ||
13192 N0.getOpcode() == ISD::SIGN_EXTEND ||
13193 N0.getOpcode() == ISD::ANY_EXTEND) {
13194 // if the source is smaller than the dest, we still need an extend.
13195 if (N0.getOperand(0).getValueType().bitsLT(VT))
13196 return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
13197 // if the source is larger than the dest, than we just need the truncate.
13198 if (N0.getOperand(0).getValueType().bitsGT(VT))
13199 return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
13200 // if the source and dest are the same type, we can drop both the extend
13201 // and the truncate.
13202 return N0.getOperand(0);
13203 }
13204
13205 // Try to narrow a truncate-of-sext_in_reg to the destination type:
13206 // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
13207 if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
13208 N0.hasOneUse()) {
13209 SDValue X = N0.getOperand(0);
13210 SDValue ExtVal = N0.getOperand(1);
13211 EVT ExtVT = cast<VTSDNode>(ExtVal)->getVT();
13212 if (ExtVT.bitsLT(VT)) {
13213 SDValue TrX = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, X);
13214 return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, TrX, ExtVal);
13215 }
13216 }
13217
13218 // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
13219 if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
13220 return SDValue();
13221
13222 // Fold extract-and-trunc into a narrow extract. For example:
13223 // i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
13224 // i32 y = TRUNCATE(i64 x)
13225 // -- becomes --
13226 // v16i8 b = BITCAST (v2i64 val)
13227 // i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
13228 //
13229 // Note: We only run this optimization after type legalization (which often
13230 // creates this pattern) and before operation legalization after which
13231 // we need to be more careful about the vector instructions that we generate.
13232 if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
13233 LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
13234 EVT VecTy = N0.getOperand(0).getValueType();
13235 EVT ExTy = N0.getValueType();
13236 EVT TrTy = N->getValueType(0);
13237
13238 auto EltCnt = VecTy.getVectorElementCount();
13239 unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
13240 auto NewEltCnt = EltCnt * SizeRatio;
13241
13242 EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
13243 assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
13244
13245 SDValue EltNo = N0->getOperand(1);
13246 if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
13247 int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
13248 int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
13249
13250 SDLoc DL(N);
13251 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
13252 DAG.getBitcast(NVT, N0.getOperand(0)),
13253 DAG.getVectorIdxConstant(Index, DL));
13254 }
13255 }
13256
13257 // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
13258 if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
13259 if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
13260 TLI.isTruncateFree(SrcVT, VT)) {
13261 SDLoc SL(N0);
13262 SDValue Cond = N0.getOperand(0);
13263 SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
13264 SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
13265 return DAG.getNode(ISD::SELECT, SDLoc(N), VT, Cond, TruncOp0, TruncOp1);
13266 }
13267 }
13268
13269 // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
13270 if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
13271 (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
13272 TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
13273 SDValue Amt = N0.getOperand(1);
13274 KnownBits Known = DAG.computeKnownBits(Amt);
13275 unsigned Size = VT.getScalarSizeInBits();
13276 if (Known.countMaxActiveBits() <= Log2_32(Size)) {
13277 SDLoc SL(N);
13278 EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
13279
13280 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
13281 if (AmtVT != Amt.getValueType()) {
13282 Amt = DAG.getZExtOrTrunc(Amt, SL, AmtVT);
13283 AddToWorklist(Amt.getNode());
13284 }
13285 return DAG.getNode(ISD::SHL, SL, VT, Trunc, Amt);
13286 }
13287 }
13288
13289 if (SDValue V = foldSubToUSubSat(VT, N0.getNode()))
13290 return V;
13291
13292 // Attempt to pre-truncate BUILD_VECTOR sources.
13293 if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
13294 TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()) &&
13295 // Avoid creating illegal types if running after type legalizer.
13296 (!LegalTypes || TLI.isTypeLegal(VT.getScalarType()))) {
13297 SDLoc DL(N);
13298 EVT SVT = VT.getScalarType();
13299 SmallVector<SDValue, 8> TruncOps;
13300 for (const SDValue &Op : N0->op_values()) {
13301 SDValue TruncOp = DAG.getNode(ISD::TRUNCATE, DL, SVT, Op);
13302 TruncOps.push_back(TruncOp);
13303 }
13304 return DAG.getBuildVector(VT, DL, TruncOps);
13305 }
13306
13307 // Fold a series of buildvector, bitcast, and truncate if possible.
13308 // For example fold
13309 // (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
13310 // (2xi32 (buildvector x, y)).
13311 if (Level == AfterLegalizeVectorOps && VT.isVector() &&
13312 N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
13313 N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
13314 N0.getOperand(0).hasOneUse()) {
13315 SDValue BuildVect = N0.getOperand(0);
13316 EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
13317 EVT TruncVecEltTy = VT.getVectorElementType();
13318
13319 // Check that the element types match.
13320 if (BuildVectEltTy == TruncVecEltTy) {
13321 // Now we only need to compute the offset of the truncated elements.
13322 unsigned BuildVecNumElts = BuildVect.getNumOperands();
13323 unsigned TruncVecNumElts = VT.getVectorNumElements();
13324 unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
13325
13326 assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
13327 "Invalid number of elements");
13328
13329 SmallVector<SDValue, 8> Opnds;
13330 for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
13331 Opnds.push_back(BuildVect.getOperand(i));
13332
13333 return DAG.getBuildVector(VT, SDLoc(N), Opnds);
13334 }
13335 }
13336
13337 // fold (truncate (load x)) -> (smaller load x)
13338 // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
13339 if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
13340 if (SDValue Reduced = reduceLoadWidth(N))
13341 return Reduced;
13342
13343 // Handle the case where the load remains an extending load even
13344 // after truncation.
13345 if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
13346 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13347 if (LN0->isSimple() && LN0->getMemoryVT().bitsLT(VT)) {
13348 SDValue NewLoad = DAG.getExtLoad(LN0->getExtensionType(), SDLoc(LN0),
13349 VT, LN0->getChain(), LN0->getBasePtr(),
13350 LN0->getMemoryVT(),
13351 LN0->getMemOperand());
13352 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
13353 return NewLoad;
13354 }
13355 }
13356 }
13357
13358 // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
13359 // where ... are all 'undef'.
13360 if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
13361 SmallVector<EVT, 8> VTs;
13362 SDValue V;
13363 unsigned Idx = 0;
13364 unsigned NumDefs = 0;
13365
13366 for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
13367 SDValue X = N0.getOperand(i);
13368 if (!X.isUndef()) {
13369 V = X;
13370 Idx = i;
13371 NumDefs++;
13372 }
13373 // Stop if more than one members are non-undef.
13374 if (NumDefs > 1)
13375 break;
13376
13377 VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
13378 VT.getVectorElementType(),
13379 X.getValueType().getVectorElementCount()));
13380 }
13381
13382 if (NumDefs == 0)
13383 return DAG.getUNDEF(VT);
13384
13385 if (NumDefs == 1) {
13386 assert(V.getNode() && "The single defined operand is empty!");
13387 SmallVector<SDValue, 8> Opnds;
13388 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
13389 if (i != Idx) {
13390 Opnds.push_back(DAG.getUNDEF(VTs[i]));
13391 continue;
13392 }
13393 SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
13394 AddToWorklist(NV.getNode());
13395 Opnds.push_back(NV);
13396 }
13397 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Opnds);
13398 }
13399 }
13400
13401 // Fold truncate of a bitcast of a vector to an extract of the low vector
13402 // element.
13403 //
13404 // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
13405 if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
13406 SDValue VecSrc = N0.getOperand(0);
13407 EVT VecSrcVT = VecSrc.getValueType();
13408 if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
13409 (!LegalOperations ||
13410 TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
13411 SDLoc SL(N);
13412
13413 unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
13414 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, VT, VecSrc,
13415 DAG.getVectorIdxConstant(Idx, SL));
13416 }
13417 }
13418
13419 // Simplify the operands using demanded-bits information.
13420 if (SimplifyDemandedBits(SDValue(N, 0)))
13421 return SDValue(N, 0);
13422
13423 // See if we can simplify the input to this truncate through knowledge that
13424 // only the low bits are being used.
13425 // For example "trunc (or (shl x, 8), y)" // -> trunc y
13426 // Currently we only perform this optimization on scalars because vectors
13427 // may have different active low bits.
13428 if (!VT.isVector()) {
13429 APInt Mask =
13430 APInt::getLowBitsSet(N0.getValueSizeInBits(), VT.getSizeInBits());
13431 if (SDValue Shorter = DAG.GetDemandedBits(N0, Mask))
13432 return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Shorter);
13433 }
13434
13435 // fold (truncate (extract_subvector(ext x))) ->
13436 // (extract_subvector x)
13437 // TODO: This can be generalized to cover cases where the truncate and extract
13438 // do not fully cancel each other out.
13439 if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
13440 SDValue N00 = N0.getOperand(0);
13441 if (N00.getOpcode() == ISD::SIGN_EXTEND ||
13442 N00.getOpcode() == ISD::ZERO_EXTEND ||
13443 N00.getOpcode() == ISD::ANY_EXTEND) {
13444 if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
13445 VT.getVectorElementType())
13446 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
13447 N00.getOperand(0), N0.getOperand(1));
13448 }
13449 }
13450
13451 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
13452 return NewVSel;
13453
13454 // Narrow a suitable binary operation with a non-opaque constant operand by
13455 // moving it ahead of the truncate. This is limited to pre-legalization
13456 // because targets may prefer a wider type during later combines and invert
13457 // this transform.
13458 switch (N0.getOpcode()) {
13459 case ISD::ADD:
13460 case ISD::SUB:
13461 case ISD::MUL:
13462 case ISD::AND:
13463 case ISD::OR:
13464 case ISD::XOR:
13465 if (!LegalOperations && N0.hasOneUse() &&
13466 (isConstantOrConstantVector(N0.getOperand(0), true) ||
13467 isConstantOrConstantVector(N0.getOperand(1), true))) {
13468 // TODO: We already restricted this to pre-legalization, but for vectors
13469 // we are extra cautious to not create an unsupported operation.
13470 // Target-specific changes are likely needed to avoid regressions here.
13471 if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
13472 SDLoc DL(N);
13473 SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
13474 SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
13475 return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
13476 }
13477 }
13478 break;
13479 case ISD::ADDE:
13480 case ISD::ADDCARRY:
13481 // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
13482 // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry)
13483 // When the adde's carry is not used.
13484 // We only do for addcarry before legalize operation
13485 if (((!LegalOperations && N0.getOpcode() == ISD::ADDCARRY) ||
13486 TLI.isOperationLegal(N0.getOpcode(), VT)) &&
13487 N0.hasOneUse() && !N0->hasAnyUseOfValue(1)) {
13488 SDLoc DL(N);
13489 SDValue X = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
13490 SDValue Y = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
13491 SDVTList VTs = DAG.getVTList(VT, N0->getValueType(1));
13492 return DAG.getNode(N0.getOpcode(), DL, VTs, X, Y, N0.getOperand(2));
13493 }
13494 break;
13495 case ISD::USUBSAT:
13496 // Truncate the USUBSAT only if LHS is a known zero-extension, its not
13497 // enough to know that the upper bits are zero we must ensure that we don't
13498 // introduce an extra truncate.
13499 if (!LegalOperations && N0.hasOneUse() &&
13500 N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
13501 N0.getOperand(0).getOperand(0).getScalarValueSizeInBits() <=
13502 VT.getScalarSizeInBits() &&
13503 hasOperation(N0.getOpcode(), VT)) {
13504 return getTruncatedUSUBSAT(VT, SrcVT, N0.getOperand(0), N0.getOperand(1),
13505 DAG, SDLoc(N));
13506 }
13507 break;
13508 }
13509
13510 return SDValue();
13511 }
13512
getBuildPairElt(SDNode * N,unsigned i)13513 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
13514 SDValue Elt = N->getOperand(i);
13515 if (Elt.getOpcode() != ISD::MERGE_VALUES)
13516 return Elt.getNode();
13517 return Elt.getOperand(Elt.getResNo()).getNode();
13518 }
13519
13520 /// build_pair (load, load) -> load
13521 /// if load locations are consecutive.
CombineConsecutiveLoads(SDNode * N,EVT VT)13522 SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
13523 assert(N->getOpcode() == ISD::BUILD_PAIR);
13524
13525 auto *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
13526 auto *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
13527
13528 // A BUILD_PAIR is always having the least significant part in elt 0 and the
13529 // most significant part in elt 1. So when combining into one large load, we
13530 // need to consider the endianness.
13531 if (DAG.getDataLayout().isBigEndian())
13532 std::swap(LD1, LD2);
13533
13534 if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !ISD::isNON_EXTLoad(LD2) ||
13535 !LD1->hasOneUse() || !LD2->hasOneUse() ||
13536 LD1->getAddressSpace() != LD2->getAddressSpace())
13537 return SDValue();
13538
13539 bool LD1Fast = false;
13540 EVT LD1VT = LD1->getValueType(0);
13541 unsigned LD1Bytes = LD1VT.getStoreSize();
13542 if ((!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) &&
13543 DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1) &&
13544 TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
13545 *LD1->getMemOperand(), &LD1Fast) && LD1Fast)
13546 return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
13547 LD1->getPointerInfo(), LD1->getAlign());
13548
13549 return SDValue();
13550 }
13551
getPPCf128HiElementSelector(const SelectionDAG & DAG)13552 static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
13553 // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
13554 // and Lo parts; on big-endian machines it doesn't.
13555 return DAG.getDataLayout().isBigEndian() ? 1 : 0;
13556 }
13557
foldBitcastedFPLogic(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)13558 static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
13559 const TargetLowering &TLI) {
13560 // If this is not a bitcast to an FP type or if the target doesn't have
13561 // IEEE754-compliant FP logic, we're done.
13562 EVT VT = N->getValueType(0);
13563 if (!VT.isFloatingPoint() || !TLI.hasBitPreservingFPLogic(VT))
13564 return SDValue();
13565
13566 // TODO: Handle cases where the integer constant is a different scalar
13567 // bitwidth to the FP.
13568 SDValue N0 = N->getOperand(0);
13569 EVT SourceVT = N0.getValueType();
13570 if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
13571 return SDValue();
13572
13573 unsigned FPOpcode;
13574 APInt SignMask;
13575 switch (N0.getOpcode()) {
13576 case ISD::AND:
13577 FPOpcode = ISD::FABS;
13578 SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
13579 break;
13580 case ISD::XOR:
13581 FPOpcode = ISD::FNEG;
13582 SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
13583 break;
13584 case ISD::OR:
13585 FPOpcode = ISD::FABS;
13586 SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
13587 break;
13588 default:
13589 return SDValue();
13590 }
13591
13592 // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
13593 // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
13594 // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
13595 // fneg (fabs X)
13596 SDValue LogicOp0 = N0.getOperand(0);
13597 ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
13598 if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
13599 LogicOp0.getOpcode() == ISD::BITCAST &&
13600 LogicOp0.getOperand(0).getValueType() == VT) {
13601 SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0.getOperand(0));
13602 NumFPLogicOpsConv++;
13603 if (N0.getOpcode() == ISD::OR)
13604 return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
13605 return FPOp;
13606 }
13607
13608 return SDValue();
13609 }
13610
visitBITCAST(SDNode * N)13611 SDValue DAGCombiner::visitBITCAST(SDNode *N) {
13612 SDValue N0 = N->getOperand(0);
13613 EVT VT = N->getValueType(0);
13614
13615 if (N0.isUndef())
13616 return DAG.getUNDEF(VT);
13617
13618 // If the input is a BUILD_VECTOR with all constant elements, fold this now.
13619 // Only do this before legalize types, unless both types are integer and the
13620 // scalar type is legal. Only do this before legalize ops, since the target
13621 // maybe depending on the bitcast.
13622 // First check to see if this is all constant.
13623 // TODO: Support FP bitcasts after legalize types.
13624 if (VT.isVector() &&
13625 (!LegalTypes ||
13626 (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
13627 TLI.isTypeLegal(VT.getVectorElementType()))) &&
13628 N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
13629 cast<BuildVectorSDNode>(N0)->isConstant())
13630 return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
13631 VT.getVectorElementType());
13632
13633 // If the input is a constant, let getNode fold it.
13634 if (isIntOrFPConstant(N0)) {
13635 // If we can't allow illegal operations, we need to check that this is just
13636 // a fp -> int or int -> conversion and that the resulting operation will
13637 // be legal.
13638 if (!LegalOperations ||
13639 (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
13640 TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
13641 (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
13642 TLI.isOperationLegal(ISD::Constant, VT))) {
13643 SDValue C = DAG.getBitcast(VT, N0);
13644 if (C.getNode() != N)
13645 return C;
13646 }
13647 }
13648
13649 // (conv (conv x, t1), t2) -> (conv x, t2)
13650 if (N0.getOpcode() == ISD::BITCAST)
13651 return DAG.getBitcast(VT, N0.getOperand(0));
13652
13653 // fold (conv (load x)) -> (load (conv*)x)
13654 // If the resultant load doesn't need a higher alignment than the original!
13655 if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
13656 // Do not remove the cast if the types differ in endian layout.
13657 TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
13658 TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
13659 // If the load is volatile, we only want to change the load type if the
13660 // resulting load is legal. Otherwise we might increase the number of
13661 // memory accesses. We don't care if the original type was legal or not
13662 // as we assume software couldn't rely on the number of accesses of an
13663 // illegal type.
13664 ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
13665 TLI.isOperationLegal(ISD::LOAD, VT))) {
13666 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13667
13668 if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
13669 *LN0->getMemOperand())) {
13670 SDValue Load =
13671 DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
13672 LN0->getPointerInfo(), LN0->getAlign(),
13673 LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
13674 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
13675 return Load;
13676 }
13677 }
13678
13679 if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
13680 return V;
13681
13682 // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
13683 // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
13684 //
13685 // For ppc_fp128:
13686 // fold (bitcast (fneg x)) ->
13687 // flipbit = signbit
13688 // (xor (bitcast x) (build_pair flipbit, flipbit))
13689 //
13690 // fold (bitcast (fabs x)) ->
13691 // flipbit = (and (extract_element (bitcast x), 0), signbit)
13692 // (xor (bitcast x) (build_pair flipbit, flipbit))
13693 // This often reduces constant pool loads.
13694 if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
13695 (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
13696 N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
13697 !N0.getValueType().isVector()) {
13698 SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
13699 AddToWorklist(NewConv.getNode());
13700
13701 SDLoc DL(N);
13702 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
13703 assert(VT.getSizeInBits() == 128);
13704 SDValue SignBit = DAG.getConstant(
13705 APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
13706 SDValue FlipBit;
13707 if (N0.getOpcode() == ISD::FNEG) {
13708 FlipBit = SignBit;
13709 AddToWorklist(FlipBit.getNode());
13710 } else {
13711 assert(N0.getOpcode() == ISD::FABS);
13712 SDValue Hi =
13713 DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
13714 DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
13715 SDLoc(NewConv)));
13716 AddToWorklist(Hi.getNode());
13717 FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
13718 AddToWorklist(FlipBit.getNode());
13719 }
13720 SDValue FlipBits =
13721 DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
13722 AddToWorklist(FlipBits.getNode());
13723 return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
13724 }
13725 APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
13726 if (N0.getOpcode() == ISD::FNEG)
13727 return DAG.getNode(ISD::XOR, DL, VT,
13728 NewConv, DAG.getConstant(SignBit, DL, VT));
13729 assert(N0.getOpcode() == ISD::FABS);
13730 return DAG.getNode(ISD::AND, DL, VT,
13731 NewConv, DAG.getConstant(~SignBit, DL, VT));
13732 }
13733
13734 // fold (bitconvert (fcopysign cst, x)) ->
13735 // (or (and (bitconvert x), sign), (and cst, (not sign)))
13736 // Note that we don't handle (copysign x, cst) because this can always be
13737 // folded to an fneg or fabs.
13738 //
13739 // For ppc_fp128:
13740 // fold (bitcast (fcopysign cst, x)) ->
13741 // flipbit = (and (extract_element
13742 // (xor (bitcast cst), (bitcast x)), 0),
13743 // signbit)
13744 // (xor (bitcast cst) (build_pair flipbit, flipbit))
13745 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
13746 isa<ConstantFPSDNode>(N0.getOperand(0)) && VT.isInteger() &&
13747 !VT.isVector()) {
13748 unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
13749 EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
13750 if (isTypeLegal(IntXVT)) {
13751 SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
13752 AddToWorklist(X.getNode());
13753
13754 // If X has a different width than the result/lhs, sext it or truncate it.
13755 unsigned VTWidth = VT.getSizeInBits();
13756 if (OrigXWidth < VTWidth) {
13757 X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
13758 AddToWorklist(X.getNode());
13759 } else if (OrigXWidth > VTWidth) {
13760 // To get the sign bit in the right place, we have to shift it right
13761 // before truncating.
13762 SDLoc DL(X);
13763 X = DAG.getNode(ISD::SRL, DL,
13764 X.getValueType(), X,
13765 DAG.getConstant(OrigXWidth-VTWidth, DL,
13766 X.getValueType()));
13767 AddToWorklist(X.getNode());
13768 X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
13769 AddToWorklist(X.getNode());
13770 }
13771
13772 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
13773 APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
13774 SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
13775 AddToWorklist(Cst.getNode());
13776 SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
13777 AddToWorklist(X.getNode());
13778 SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
13779 AddToWorklist(XorResult.getNode());
13780 SDValue XorResult64 = DAG.getNode(
13781 ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
13782 DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
13783 SDLoc(XorResult)));
13784 AddToWorklist(XorResult64.getNode());
13785 SDValue FlipBit =
13786 DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
13787 DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
13788 AddToWorklist(FlipBit.getNode());
13789 SDValue FlipBits =
13790 DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
13791 AddToWorklist(FlipBits.getNode());
13792 return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
13793 }
13794 APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
13795 X = DAG.getNode(ISD::AND, SDLoc(X), VT,
13796 X, DAG.getConstant(SignBit, SDLoc(X), VT));
13797 AddToWorklist(X.getNode());
13798
13799 SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
13800 Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
13801 Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
13802 AddToWorklist(Cst.getNode());
13803
13804 return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
13805 }
13806 }
13807
13808 // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
13809 if (N0.getOpcode() == ISD::BUILD_PAIR)
13810 if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
13811 return CombineLD;
13812
13813 // Remove double bitcasts from shuffles - this is often a legacy of
13814 // XformToShuffleWithZero being used to combine bitmaskings (of
13815 // float vectors bitcast to integer vectors) into shuffles.
13816 // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
13817 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
13818 N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
13819 VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
13820 !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
13821 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
13822
13823 // If operands are a bitcast, peek through if it casts the original VT.
13824 // If operands are a constant, just bitcast back to original VT.
13825 auto PeekThroughBitcast = [&](SDValue Op) {
13826 if (Op.getOpcode() == ISD::BITCAST &&
13827 Op.getOperand(0).getValueType() == VT)
13828 return SDValue(Op.getOperand(0));
13829 if (Op.isUndef() || isAnyConstantBuildVector(Op))
13830 return DAG.getBitcast(VT, Op);
13831 return SDValue();
13832 };
13833
13834 // FIXME: If either input vector is bitcast, try to convert the shuffle to
13835 // the result type of this bitcast. This would eliminate at least one
13836 // bitcast. See the transform in InstCombine.
13837 SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
13838 SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
13839 if (!(SV0 && SV1))
13840 return SDValue();
13841
13842 int MaskScale =
13843 VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
13844 SmallVector<int, 8> NewMask;
13845 for (int M : SVN->getMask())
13846 for (int i = 0; i != MaskScale; ++i)
13847 NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
13848
13849 SDValue LegalShuffle =
13850 TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
13851 if (LegalShuffle)
13852 return LegalShuffle;
13853 }
13854
13855 return SDValue();
13856 }
13857
visitBUILD_PAIR(SDNode * N)13858 SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
13859 EVT VT = N->getValueType(0);
13860 return CombineConsecutiveLoads(N, VT);
13861 }
13862
visitFREEZE(SDNode * N)13863 SDValue DAGCombiner::visitFREEZE(SDNode *N) {
13864 SDValue N0 = N->getOperand(0);
13865
13866 if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, /*PoisonOnly*/ false))
13867 return N0;
13868
13869 // Fold freeze(bitcast(x)) -> bitcast(freeze(x)).
13870 // TODO: Replace with pushFreezeToPreventPoisonFromPropagating fold.
13871 if (N0.getOpcode() == ISD::BITCAST)
13872 return DAG.getBitcast(N->getValueType(0),
13873 DAG.getNode(ISD::FREEZE, SDLoc(N0),
13874 N0.getOperand(0).getValueType(),
13875 N0.getOperand(0)));
13876
13877 return SDValue();
13878 }
13879
13880 /// We know that BV is a build_vector node with Constant, ConstantFP or Undef
13881 /// operands. DstEltVT indicates the destination element value type.
13882 SDValue DAGCombiner::
ConstantFoldBITCASTofBUILD_VECTOR(SDNode * BV,EVT DstEltVT)13883 ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
13884 EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
13885
13886 // If this is already the right type, we're done.
13887 if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
13888
13889 unsigned SrcBitSize = SrcEltVT.getSizeInBits();
13890 unsigned DstBitSize = DstEltVT.getSizeInBits();
13891
13892 // If this is a conversion of N elements of one type to N elements of another
13893 // type, convert each element. This handles FP<->INT cases.
13894 if (SrcBitSize == DstBitSize) {
13895 SmallVector<SDValue, 8> Ops;
13896 for (SDValue Op : BV->op_values()) {
13897 // If the vector element type is not legal, the BUILD_VECTOR operands
13898 // are promoted and implicitly truncated. Make that explicit here.
13899 if (Op.getValueType() != SrcEltVT)
13900 Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
13901 Ops.push_back(DAG.getBitcast(DstEltVT, Op));
13902 AddToWorklist(Ops.back().getNode());
13903 }
13904 EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
13905 BV->getValueType(0).getVectorNumElements());
13906 return DAG.getBuildVector(VT, SDLoc(BV), Ops);
13907 }
13908
13909 // Otherwise, we're growing or shrinking the elements. To avoid having to
13910 // handle annoying details of growing/shrinking FP values, we convert them to
13911 // int first.
13912 if (SrcEltVT.isFloatingPoint()) {
13913 // Convert the input float vector to a int vector where the elements are the
13914 // same sizes.
13915 EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits());
13916 BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode();
13917 SrcEltVT = IntVT;
13918 }
13919
13920 // Now we know the input is an integer vector. If the output is a FP type,
13921 // convert to integer first, then to FP of the right size.
13922 if (DstEltVT.isFloatingPoint()) {
13923 EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits());
13924 SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode();
13925
13926 // Next, convert to FP elements of the same size.
13927 return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
13928 }
13929
13930 // Okay, we know the src/dst types are both integers of differing types.
13931 assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
13932
13933 // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
13934 // BuildVectorSDNode?
13935 auto *BVN = cast<BuildVectorSDNode>(BV);
13936
13937 // Extract the constant raw bit data.
13938 BitVector UndefElements;
13939 SmallVector<APInt> RawBits;
13940 bool IsLE = DAG.getDataLayout().isLittleEndian();
13941 if (!BVN->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements))
13942 return SDValue();
13943
13944 SDLoc DL(BV);
13945 SmallVector<SDValue, 8> Ops;
13946 for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
13947 if (UndefElements[I])
13948 Ops.push_back(DAG.getUNDEF(DstEltVT));
13949 else
13950 Ops.push_back(DAG.getConstant(RawBits[I], DL, DstEltVT));
13951 }
13952
13953 EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
13954 return DAG.getBuildVector(VT, DL, Ops);
13955 }
13956
13957 // Returns true if floating point contraction is allowed on the FMUL-SDValue
13958 // `N`
isContractableFMUL(const TargetOptions & Options,SDValue N)13959 static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
13960 assert(N.getOpcode() == ISD::FMUL);
13961
13962 return Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath ||
13963 N->getFlags().hasAllowContract();
13964 }
13965
13966 // Returns true if `N` can assume no infinities involved in its computation.
hasNoInfs(const TargetOptions & Options,SDValue N)13967 static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
13968 return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
13969 }
13970
13971 /// Try to perform FMA combining on a given FADD node.
visitFADDForFMACombine(SDNode * N)13972 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
13973 SDValue N0 = N->getOperand(0);
13974 SDValue N1 = N->getOperand(1);
13975 EVT VT = N->getValueType(0);
13976 SDLoc SL(N);
13977
13978 const TargetOptions &Options = DAG.getTarget().Options;
13979
13980 // Floating-point multiply-add with intermediate rounding.
13981 bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
13982
13983 // Floating-point multiply-add without intermediate rounding.
13984 bool HasFMA =
13985 TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
13986 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
13987
13988 // No valid opcode, do not combine.
13989 if (!HasFMAD && !HasFMA)
13990 return SDValue();
13991
13992 bool CanReassociate =
13993 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
13994 bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
13995 Options.UnsafeFPMath || HasFMAD);
13996 // If the addition is not contractable, do not combine.
13997 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
13998 return SDValue();
13999
14000 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
14001 return SDValue();
14002
14003 // Always prefer FMAD to FMA for precision.
14004 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
14005 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
14006
14007 auto isFusedOp = [&](SDValue N) {
14008 unsigned Opcode = N.getOpcode();
14009 return Opcode == ISD::FMA || Opcode == ISD::FMAD;
14010 };
14011
14012 // Is the node an FMUL and contractable either due to global flags or
14013 // SDNodeFlags.
14014 auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
14015 if (N.getOpcode() != ISD::FMUL)
14016 return false;
14017 return AllowFusionGlobally || N->getFlags().hasAllowContract();
14018 };
14019 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
14020 // prefer to fold the multiply with fewer uses.
14021 if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
14022 if (N0->use_size() > N1->use_size())
14023 std::swap(N0, N1);
14024 }
14025
14026 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
14027 if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
14028 return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
14029 N0.getOperand(1), N1);
14030 }
14031
14032 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
14033 // Note: Commutes FADD operands.
14034 if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
14035 return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
14036 N1.getOperand(1), N0);
14037 }
14038
14039 // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
14040 // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
14041 // This requires reassociation because it changes the order of operations.
14042 SDValue FMA, E;
14043 if (CanReassociate && isFusedOp(N0) &&
14044 N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() &&
14045 N0.getOperand(2).hasOneUse()) {
14046 FMA = N0;
14047 E = N1;
14048 } else if (CanReassociate && isFusedOp(N1) &&
14049 N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() &&
14050 N1.getOperand(2).hasOneUse()) {
14051 FMA = N1;
14052 E = N0;
14053 }
14054 if (FMA && E) {
14055 SDValue A = FMA.getOperand(0);
14056 SDValue B = FMA.getOperand(1);
14057 SDValue C = FMA.getOperand(2).getOperand(0);
14058 SDValue D = FMA.getOperand(2).getOperand(1);
14059 SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
14060 return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE);
14061 }
14062
14063 // Look through FP_EXTEND nodes to do more combining.
14064
14065 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
14066 if (N0.getOpcode() == ISD::FP_EXTEND) {
14067 SDValue N00 = N0.getOperand(0);
14068 if (isContractableFMUL(N00) &&
14069 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14070 N00.getValueType())) {
14071 return DAG.getNode(PreferredFusedOpcode, SL, VT,
14072 DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
14073 DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
14074 N1);
14075 }
14076 }
14077
14078 // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
14079 // Note: Commutes FADD operands.
14080 if (N1.getOpcode() == ISD::FP_EXTEND) {
14081 SDValue N10 = N1.getOperand(0);
14082 if (isContractableFMUL(N10) &&
14083 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14084 N10.getValueType())) {
14085 return DAG.getNode(PreferredFusedOpcode, SL, VT,
14086 DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
14087 DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)),
14088 N0);
14089 }
14090 }
14091
14092 // More folding opportunities when target permits.
14093 if (Aggressive) {
14094 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
14095 // -> (fma x, y, (fma (fpext u), (fpext v), z))
14096 auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
14097 SDValue Z) {
14098 return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y,
14099 DAG.getNode(PreferredFusedOpcode, SL, VT,
14100 DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
14101 DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
14102 Z));
14103 };
14104 if (isFusedOp(N0)) {
14105 SDValue N02 = N0.getOperand(2);
14106 if (N02.getOpcode() == ISD::FP_EXTEND) {
14107 SDValue N020 = N02.getOperand(0);
14108 if (isContractableFMUL(N020) &&
14109 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14110 N020.getValueType())) {
14111 return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
14112 N020.getOperand(0), N020.getOperand(1),
14113 N1);
14114 }
14115 }
14116 }
14117
14118 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
14119 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
14120 // FIXME: This turns two single-precision and one double-precision
14121 // operation into two double-precision operations, which might not be
14122 // interesting for all targets, especially GPUs.
14123 auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
14124 SDValue Z) {
14125 return DAG.getNode(
14126 PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, X),
14127 DAG.getNode(ISD::FP_EXTEND, SL, VT, Y),
14128 DAG.getNode(PreferredFusedOpcode, SL, VT,
14129 DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
14130 DAG.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
14131 };
14132 if (N0.getOpcode() == ISD::FP_EXTEND) {
14133 SDValue N00 = N0.getOperand(0);
14134 if (isFusedOp(N00)) {
14135 SDValue N002 = N00.getOperand(2);
14136 if (isContractableFMUL(N002) &&
14137 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14138 N00.getValueType())) {
14139 return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
14140 N002.getOperand(0), N002.getOperand(1),
14141 N1);
14142 }
14143 }
14144 }
14145
14146 // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
14147 // -> (fma y, z, (fma (fpext u), (fpext v), x))
14148 if (isFusedOp(N1)) {
14149 SDValue N12 = N1.getOperand(2);
14150 if (N12.getOpcode() == ISD::FP_EXTEND) {
14151 SDValue N120 = N12.getOperand(0);
14152 if (isContractableFMUL(N120) &&
14153 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14154 N120.getValueType())) {
14155 return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
14156 N120.getOperand(0), N120.getOperand(1),
14157 N0);
14158 }
14159 }
14160 }
14161
14162 // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
14163 // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
14164 // FIXME: This turns two single-precision and one double-precision
14165 // operation into two double-precision operations, which might not be
14166 // interesting for all targets, especially GPUs.
14167 if (N1.getOpcode() == ISD::FP_EXTEND) {
14168 SDValue N10 = N1.getOperand(0);
14169 if (isFusedOp(N10)) {
14170 SDValue N102 = N10.getOperand(2);
14171 if (isContractableFMUL(N102) &&
14172 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14173 N10.getValueType())) {
14174 return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
14175 N102.getOperand(0), N102.getOperand(1),
14176 N0);
14177 }
14178 }
14179 }
14180 }
14181
14182 return SDValue();
14183 }
14184
14185 /// Try to perform FMA combining on a given FSUB node.
visitFSUBForFMACombine(SDNode * N)14186 SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
14187 SDValue N0 = N->getOperand(0);
14188 SDValue N1 = N->getOperand(1);
14189 EVT VT = N->getValueType(0);
14190 SDLoc SL(N);
14191
14192 const TargetOptions &Options = DAG.getTarget().Options;
14193 // Floating-point multiply-add with intermediate rounding.
14194 bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
14195
14196 // Floating-point multiply-add without intermediate rounding.
14197 bool HasFMA =
14198 TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
14199 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
14200
14201 // No valid opcode, do not combine.
14202 if (!HasFMAD && !HasFMA)
14203 return SDValue();
14204
14205 const SDNodeFlags Flags = N->getFlags();
14206 bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
14207 Options.UnsafeFPMath || HasFMAD);
14208
14209 // If the subtraction is not contractable, do not combine.
14210 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
14211 return SDValue();
14212
14213 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
14214 return SDValue();
14215
14216 // Always prefer FMAD to FMA for precision.
14217 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
14218 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
14219 bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
14220
14221 // Is the node an FMUL and contractable either due to global flags or
14222 // SDNodeFlags.
14223 auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
14224 if (N.getOpcode() != ISD::FMUL)
14225 return false;
14226 return AllowFusionGlobally || N->getFlags().hasAllowContract();
14227 };
14228
14229 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
14230 auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
14231 if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
14232 return DAG.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0),
14233 XY.getOperand(1), DAG.getNode(ISD::FNEG, SL, VT, Z));
14234 }
14235 return SDValue();
14236 };
14237
14238 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
14239 // Note: Commutes FSUB operands.
14240 auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
14241 if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
14242 return DAG.getNode(PreferredFusedOpcode, SL, VT,
14243 DAG.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)),
14244 YZ.getOperand(1), X);
14245 }
14246 return SDValue();
14247 };
14248
14249 // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
14250 // prefer to fold the multiply with fewer uses.
14251 if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
14252 (N0->use_size() > N1->use_size())) {
14253 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
14254 if (SDValue V = tryToFoldXSubYZ(N0, N1))
14255 return V;
14256 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
14257 if (SDValue V = tryToFoldXYSubZ(N0, N1))
14258 return V;
14259 } else {
14260 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
14261 if (SDValue V = tryToFoldXYSubZ(N0, N1))
14262 return V;
14263 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
14264 if (SDValue V = tryToFoldXSubYZ(N0, N1))
14265 return V;
14266 }
14267
14268 // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
14269 if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) &&
14270 (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
14271 SDValue N00 = N0.getOperand(0).getOperand(0);
14272 SDValue N01 = N0.getOperand(0).getOperand(1);
14273 return DAG.getNode(PreferredFusedOpcode, SL, VT,
14274 DAG.getNode(ISD::FNEG, SL, VT, N00), N01,
14275 DAG.getNode(ISD::FNEG, SL, VT, N1));
14276 }
14277
14278 // Look through FP_EXTEND nodes to do more combining.
14279
14280 // fold (fsub (fpext (fmul x, y)), z)
14281 // -> (fma (fpext x), (fpext y), (fneg z))
14282 if (N0.getOpcode() == ISD::FP_EXTEND) {
14283 SDValue N00 = N0.getOperand(0);
14284 if (isContractableFMUL(N00) &&
14285 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14286 N00.getValueType())) {
14287 return DAG.getNode(PreferredFusedOpcode, SL, VT,
14288 DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
14289 DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
14290 DAG.getNode(ISD::FNEG, SL, VT, N1));
14291 }
14292 }
14293
14294 // fold (fsub x, (fpext (fmul y, z)))
14295 // -> (fma (fneg (fpext y)), (fpext z), x)
14296 // Note: Commutes FSUB operands.
14297 if (N1.getOpcode() == ISD::FP_EXTEND) {
14298 SDValue N10 = N1.getOperand(0);
14299 if (isContractableFMUL(N10) &&
14300 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14301 N10.getValueType())) {
14302 return DAG.getNode(
14303 PreferredFusedOpcode, SL, VT,
14304 DAG.getNode(ISD::FNEG, SL, VT,
14305 DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))),
14306 DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
14307 }
14308 }
14309
14310 // fold (fsub (fpext (fneg (fmul, x, y))), z)
14311 // -> (fneg (fma (fpext x), (fpext y), z))
14312 // Note: This could be removed with appropriate canonicalization of the
14313 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
14314 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
14315 // from implementing the canonicalization in visitFSUB.
14316 if (N0.getOpcode() == ISD::FP_EXTEND) {
14317 SDValue N00 = N0.getOperand(0);
14318 if (N00.getOpcode() == ISD::FNEG) {
14319 SDValue N000 = N00.getOperand(0);
14320 if (isContractableFMUL(N000) &&
14321 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14322 N00.getValueType())) {
14323 return DAG.getNode(
14324 ISD::FNEG, SL, VT,
14325 DAG.getNode(PreferredFusedOpcode, SL, VT,
14326 DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
14327 DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
14328 N1));
14329 }
14330 }
14331 }
14332
14333 // fold (fsub (fneg (fpext (fmul, x, y))), z)
14334 // -> (fneg (fma (fpext x)), (fpext y), z)
14335 // Note: This could be removed with appropriate canonicalization of the
14336 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
14337 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
14338 // from implementing the canonicalization in visitFSUB.
14339 if (N0.getOpcode() == ISD::FNEG) {
14340 SDValue N00 = N0.getOperand(0);
14341 if (N00.getOpcode() == ISD::FP_EXTEND) {
14342 SDValue N000 = N00.getOperand(0);
14343 if (isContractableFMUL(N000) &&
14344 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14345 N000.getValueType())) {
14346 return DAG.getNode(
14347 ISD::FNEG, SL, VT,
14348 DAG.getNode(PreferredFusedOpcode, SL, VT,
14349 DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
14350 DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
14351 N1));
14352 }
14353 }
14354 }
14355
14356 auto isReassociable = [Options](SDNode *N) {
14357 return Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
14358 };
14359
14360 auto isContractableAndReassociableFMUL = [isContractableFMUL,
14361 isReassociable](SDValue N) {
14362 return isContractableFMUL(N) && isReassociable(N.getNode());
14363 };
14364
14365 auto isFusedOp = [&](SDValue N) {
14366 unsigned Opcode = N.getOpcode();
14367 return Opcode == ISD::FMA || Opcode == ISD::FMAD;
14368 };
14369
14370 // More folding opportunities when target permits.
14371 if (Aggressive && isReassociable(N)) {
14372 bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract();
14373 // fold (fsub (fma x, y, (fmul u, v)), z)
14374 // -> (fma x, y (fma u, v, (fneg z)))
14375 if (CanFuse && isFusedOp(N0) &&
14376 isContractableAndReassociableFMUL(N0.getOperand(2)) &&
14377 N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
14378 return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
14379 N0.getOperand(1),
14380 DAG.getNode(PreferredFusedOpcode, SL, VT,
14381 N0.getOperand(2).getOperand(0),
14382 N0.getOperand(2).getOperand(1),
14383 DAG.getNode(ISD::FNEG, SL, VT, N1)));
14384 }
14385
14386 // fold (fsub x, (fma y, z, (fmul u, v)))
14387 // -> (fma (fneg y), z, (fma (fneg u), v, x))
14388 if (CanFuse && isFusedOp(N1) &&
14389 isContractableAndReassociableFMUL(N1.getOperand(2)) &&
14390 N1->hasOneUse() && NoSignedZero) {
14391 SDValue N20 = N1.getOperand(2).getOperand(0);
14392 SDValue N21 = N1.getOperand(2).getOperand(1);
14393 return DAG.getNode(
14394 PreferredFusedOpcode, SL, VT,
14395 DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1),
14396 DAG.getNode(PreferredFusedOpcode, SL, VT,
14397 DAG.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
14398 }
14399
14400 // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
14401 // -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
14402 if (isFusedOp(N0) && N0->hasOneUse()) {
14403 SDValue N02 = N0.getOperand(2);
14404 if (N02.getOpcode() == ISD::FP_EXTEND) {
14405 SDValue N020 = N02.getOperand(0);
14406 if (isContractableAndReassociableFMUL(N020) &&
14407 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14408 N020.getValueType())) {
14409 return DAG.getNode(
14410 PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
14411 DAG.getNode(
14412 PreferredFusedOpcode, SL, VT,
14413 DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)),
14414 DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)),
14415 DAG.getNode(ISD::FNEG, SL, VT, N1)));
14416 }
14417 }
14418 }
14419
14420 // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
14421 // -> (fma (fpext x), (fpext y),
14422 // (fma (fpext u), (fpext v), (fneg z)))
14423 // FIXME: This turns two single-precision and one double-precision
14424 // operation into two double-precision operations, which might not be
14425 // interesting for all targets, especially GPUs.
14426 if (N0.getOpcode() == ISD::FP_EXTEND) {
14427 SDValue N00 = N0.getOperand(0);
14428 if (isFusedOp(N00)) {
14429 SDValue N002 = N00.getOperand(2);
14430 if (isContractableAndReassociableFMUL(N002) &&
14431 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14432 N00.getValueType())) {
14433 return DAG.getNode(
14434 PreferredFusedOpcode, SL, VT,
14435 DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
14436 DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
14437 DAG.getNode(
14438 PreferredFusedOpcode, SL, VT,
14439 DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)),
14440 DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)),
14441 DAG.getNode(ISD::FNEG, SL, VT, N1)));
14442 }
14443 }
14444 }
14445
14446 // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
14447 // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
14448 if (isFusedOp(N1) && N1.getOperand(2).getOpcode() == ISD::FP_EXTEND &&
14449 N1->hasOneUse()) {
14450 SDValue N120 = N1.getOperand(2).getOperand(0);
14451 if (isContractableAndReassociableFMUL(N120) &&
14452 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14453 N120.getValueType())) {
14454 SDValue N1200 = N120.getOperand(0);
14455 SDValue N1201 = N120.getOperand(1);
14456 return DAG.getNode(
14457 PreferredFusedOpcode, SL, VT,
14458 DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1),
14459 DAG.getNode(PreferredFusedOpcode, SL, VT,
14460 DAG.getNode(ISD::FNEG, SL, VT,
14461 DAG.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
14462 DAG.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
14463 }
14464 }
14465
14466 // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
14467 // -> (fma (fneg (fpext y)), (fpext z),
14468 // (fma (fneg (fpext u)), (fpext v), x))
14469 // FIXME: This turns two single-precision and one double-precision
14470 // operation into two double-precision operations, which might not be
14471 // interesting for all targets, especially GPUs.
14472 if (N1.getOpcode() == ISD::FP_EXTEND && isFusedOp(N1.getOperand(0))) {
14473 SDValue CvtSrc = N1.getOperand(0);
14474 SDValue N100 = CvtSrc.getOperand(0);
14475 SDValue N101 = CvtSrc.getOperand(1);
14476 SDValue N102 = CvtSrc.getOperand(2);
14477 if (isContractableAndReassociableFMUL(N102) &&
14478 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14479 CvtSrc.getValueType())) {
14480 SDValue N1020 = N102.getOperand(0);
14481 SDValue N1021 = N102.getOperand(1);
14482 return DAG.getNode(
14483 PreferredFusedOpcode, SL, VT,
14484 DAG.getNode(ISD::FNEG, SL, VT,
14485 DAG.getNode(ISD::FP_EXTEND, SL, VT, N100)),
14486 DAG.getNode(ISD::FP_EXTEND, SL, VT, N101),
14487 DAG.getNode(PreferredFusedOpcode, SL, VT,
14488 DAG.getNode(ISD::FNEG, SL, VT,
14489 DAG.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
14490 DAG.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
14491 }
14492 }
14493 }
14494
14495 return SDValue();
14496 }
14497
14498 /// Try to perform FMA combining on a given FMUL node based on the distributive
14499 /// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
14500 /// subtraction instead of addition).
visitFMULForFMADistributiveCombine(SDNode * N)14501 SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
14502 SDValue N0 = N->getOperand(0);
14503 SDValue N1 = N->getOperand(1);
14504 EVT VT = N->getValueType(0);
14505 SDLoc SL(N);
14506
14507 assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
14508
14509 const TargetOptions &Options = DAG.getTarget().Options;
14510
14511 // The transforms below are incorrect when x == 0 and y == inf, because the
14512 // intermediate multiplication produces a nan.
14513 SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
14514 if (!hasNoInfs(Options, FAdd))
14515 return SDValue();
14516
14517 // Floating-point multiply-add without intermediate rounding.
14518 bool HasFMA =
14519 isContractableFMUL(Options, SDValue(N, 0)) &&
14520 TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
14521 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
14522
14523 // Floating-point multiply-add with intermediate rounding. This can result
14524 // in a less precise result due to the changed rounding order.
14525 bool HasFMAD = Options.UnsafeFPMath &&
14526 (LegalOperations && TLI.isFMADLegal(DAG, N));
14527
14528 // No valid opcode, do not combine.
14529 if (!HasFMAD && !HasFMA)
14530 return SDValue();
14531
14532 // Always prefer FMAD to FMA for precision.
14533 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
14534 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
14535
14536 // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
14537 // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
14538 auto FuseFADD = [&](SDValue X, SDValue Y) {
14539 if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
14540 if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
14541 if (C->isExactlyValue(+1.0))
14542 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
14543 Y);
14544 if (C->isExactlyValue(-1.0))
14545 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
14546 DAG.getNode(ISD::FNEG, SL, VT, Y));
14547 }
14548 }
14549 return SDValue();
14550 };
14551
14552 if (SDValue FMA = FuseFADD(N0, N1))
14553 return FMA;
14554 if (SDValue FMA = FuseFADD(N1, N0))
14555 return FMA;
14556
14557 // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
14558 // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
14559 // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
14560 // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
14561 auto FuseFSUB = [&](SDValue X, SDValue Y) {
14562 if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
14563 if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
14564 if (C0->isExactlyValue(+1.0))
14565 return DAG.getNode(PreferredFusedOpcode, SL, VT,
14566 DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
14567 Y);
14568 if (C0->isExactlyValue(-1.0))
14569 return DAG.getNode(PreferredFusedOpcode, SL, VT,
14570 DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
14571 DAG.getNode(ISD::FNEG, SL, VT, Y));
14572 }
14573 if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
14574 if (C1->isExactlyValue(+1.0))
14575 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
14576 DAG.getNode(ISD::FNEG, SL, VT, Y));
14577 if (C1->isExactlyValue(-1.0))
14578 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
14579 Y);
14580 }
14581 }
14582 return SDValue();
14583 };
14584
14585 if (SDValue FMA = FuseFSUB(N0, N1))
14586 return FMA;
14587 if (SDValue FMA = FuseFSUB(N1, N0))
14588 return FMA;
14589
14590 return SDValue();
14591 }
14592
visitFADD(SDNode * N)14593 SDValue DAGCombiner::visitFADD(SDNode *N) {
14594 SDValue N0 = N->getOperand(0);
14595 SDValue N1 = N->getOperand(1);
14596 bool N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
14597 bool N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
14598 EVT VT = N->getValueType(0);
14599 SDLoc DL(N);
14600 const TargetOptions &Options = DAG.getTarget().Options;
14601 SDNodeFlags Flags = N->getFlags();
14602 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14603
14604 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
14605 return R;
14606
14607 // fold (fadd c1, c2) -> c1 + c2
14608 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FADD, DL, VT, {N0, N1}))
14609 return C;
14610
14611 // canonicalize constant to RHS
14612 if (N0CFP && !N1CFP)
14613 return DAG.getNode(ISD::FADD, DL, VT, N1, N0);
14614
14615 // fold vector ops
14616 if (VT.isVector())
14617 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
14618 return FoldedVOp;
14619
14620 // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
14621 ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
14622 if (N1C && N1C->isZero())
14623 if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
14624 return N0;
14625
14626 if (SDValue NewSel = foldBinOpIntoSelect(N))
14627 return NewSel;
14628
14629 // fold (fadd A, (fneg B)) -> (fsub A, B)
14630 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
14631 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
14632 N1, DAG, LegalOperations, ForCodeSize))
14633 return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1);
14634
14635 // fold (fadd (fneg A), B) -> (fsub B, A)
14636 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
14637 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
14638 N0, DAG, LegalOperations, ForCodeSize))
14639 return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0);
14640
14641 auto isFMulNegTwo = [](SDValue FMul) {
14642 if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
14643 return false;
14644 auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
14645 return C && C->isExactlyValue(-2.0);
14646 };
14647
14648 // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
14649 if (isFMulNegTwo(N0)) {
14650 SDValue B = N0.getOperand(0);
14651 SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
14652 return DAG.getNode(ISD::FSUB, DL, VT, N1, Add);
14653 }
14654 // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
14655 if (isFMulNegTwo(N1)) {
14656 SDValue B = N1.getOperand(0);
14657 SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
14658 return DAG.getNode(ISD::FSUB, DL, VT, N0, Add);
14659 }
14660
14661 // No FP constant should be created after legalization as Instruction
14662 // Selection pass has a hard time dealing with FP constants.
14663 bool AllowNewConst = (Level < AfterLegalizeDAG);
14664
14665 // If nnan is enabled, fold lots of things.
14666 if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
14667 // If allowed, fold (fadd (fneg x), x) -> 0.0
14668 if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
14669 return DAG.getConstantFP(0.0, DL, VT);
14670
14671 // If allowed, fold (fadd x, (fneg x)) -> 0.0
14672 if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
14673 return DAG.getConstantFP(0.0, DL, VT);
14674 }
14675
14676 // If 'unsafe math' or reassoc and nsz, fold lots of things.
14677 // TODO: break out portions of the transformations below for which Unsafe is
14678 // considered and which do not require both nsz and reassoc
14679 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
14680 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
14681 AllowNewConst) {
14682 // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
14683 if (N1CFP && N0.getOpcode() == ISD::FADD &&
14684 DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
14685 SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1);
14686 return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC);
14687 }
14688
14689 // We can fold chains of FADD's of the same value into multiplications.
14690 // This transform is not safe in general because we are reducing the number
14691 // of rounding steps.
14692 if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
14693 if (N0.getOpcode() == ISD::FMUL) {
14694 bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
14695 bool CFP01 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
14696
14697 // (fadd (fmul x, c), x) -> (fmul x, c+1)
14698 if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
14699 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
14700 DAG.getConstantFP(1.0, DL, VT));
14701 return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP);
14702 }
14703
14704 // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
14705 if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
14706 N1.getOperand(0) == N1.getOperand(1) &&
14707 N0.getOperand(0) == N1.getOperand(0)) {
14708 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
14709 DAG.getConstantFP(2.0, DL, VT));
14710 return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP);
14711 }
14712 }
14713
14714 if (N1.getOpcode() == ISD::FMUL) {
14715 bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
14716 bool CFP11 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
14717
14718 // (fadd x, (fmul x, c)) -> (fmul x, c+1)
14719 if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
14720 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
14721 DAG.getConstantFP(1.0, DL, VT));
14722 return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP);
14723 }
14724
14725 // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
14726 if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
14727 N0.getOperand(0) == N0.getOperand(1) &&
14728 N1.getOperand(0) == N0.getOperand(0)) {
14729 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
14730 DAG.getConstantFP(2.0, DL, VT));
14731 return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP);
14732 }
14733 }
14734
14735 if (N0.getOpcode() == ISD::FADD) {
14736 bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
14737 // (fadd (fadd x, x), x) -> (fmul x, 3.0)
14738 if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
14739 (N0.getOperand(0) == N1)) {
14740 return DAG.getNode(ISD::FMUL, DL, VT, N1,
14741 DAG.getConstantFP(3.0, DL, VT));
14742 }
14743 }
14744
14745 if (N1.getOpcode() == ISD::FADD) {
14746 bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
14747 // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
14748 if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
14749 N1.getOperand(0) == N0) {
14750 return DAG.getNode(ISD::FMUL, DL, VT, N0,
14751 DAG.getConstantFP(3.0, DL, VT));
14752 }
14753 }
14754
14755 // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
14756 if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
14757 N0.getOperand(0) == N0.getOperand(1) &&
14758 N1.getOperand(0) == N1.getOperand(1) &&
14759 N0.getOperand(0) == N1.getOperand(0)) {
14760 return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
14761 DAG.getConstantFP(4.0, DL, VT));
14762 }
14763 }
14764 } // enable-unsafe-fp-math
14765
14766 // FADD -> FMA combines:
14767 if (SDValue Fused = visitFADDForFMACombine(N)) {
14768 AddToWorklist(Fused.getNode());
14769 return Fused;
14770 }
14771 return SDValue();
14772 }
14773
visitSTRICT_FADD(SDNode * N)14774 SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
14775 SDValue Chain = N->getOperand(0);
14776 SDValue N0 = N->getOperand(1);
14777 SDValue N1 = N->getOperand(2);
14778 EVT VT = N->getValueType(0);
14779 EVT ChainVT = N->getValueType(1);
14780 SDLoc DL(N);
14781 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14782
14783 // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
14784 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
14785 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
14786 N1, DAG, LegalOperations, ForCodeSize)) {
14787 return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
14788 {Chain, N0, NegN1});
14789 }
14790
14791 // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
14792 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
14793 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
14794 N0, DAG, LegalOperations, ForCodeSize)) {
14795 return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
14796 {Chain, N1, NegN0});
14797 }
14798 return SDValue();
14799 }
14800
visitFSUB(SDNode * N)14801 SDValue DAGCombiner::visitFSUB(SDNode *N) {
14802 SDValue N0 = N->getOperand(0);
14803 SDValue N1 = N->getOperand(1);
14804 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
14805 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
14806 EVT VT = N->getValueType(0);
14807 SDLoc DL(N);
14808 const TargetOptions &Options = DAG.getTarget().Options;
14809 const SDNodeFlags Flags = N->getFlags();
14810 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14811
14812 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
14813 return R;
14814
14815 // fold (fsub c1, c2) -> c1-c2
14816 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FSUB, DL, VT, {N0, N1}))
14817 return C;
14818
14819 // fold vector ops
14820 if (VT.isVector())
14821 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
14822 return FoldedVOp;
14823
14824 if (SDValue NewSel = foldBinOpIntoSelect(N))
14825 return NewSel;
14826
14827 // (fsub A, 0) -> A
14828 if (N1CFP && N1CFP->isZero()) {
14829 if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
14830 Flags.hasNoSignedZeros()) {
14831 return N0;
14832 }
14833 }
14834
14835 if (N0 == N1) {
14836 // (fsub x, x) -> 0.0
14837 if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
14838 return DAG.getConstantFP(0.0f, DL, VT);
14839 }
14840
14841 // (fsub -0.0, N1) -> -N1
14842 if (N0CFP && N0CFP->isZero()) {
14843 if (N0CFP->isNegative() ||
14844 (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
14845 // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
14846 // flushed to zero, unless all users treat denorms as zero (DAZ).
14847 // FIXME: This transform will change the sign of a NaN and the behavior
14848 // of a signaling NaN. It is only valid when a NoNaN flag is present.
14849 DenormalMode DenormMode = DAG.getDenormalMode(VT);
14850 if (DenormMode == DenormalMode::getIEEE()) {
14851 if (SDValue NegN1 =
14852 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
14853 return NegN1;
14854 if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
14855 return DAG.getNode(ISD::FNEG, DL, VT, N1);
14856 }
14857 }
14858 }
14859
14860 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
14861 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
14862 N1.getOpcode() == ISD::FADD) {
14863 // X - (X + Y) -> -Y
14864 if (N0 == N1->getOperand(0))
14865 return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1));
14866 // X - (Y + X) -> -Y
14867 if (N0 == N1->getOperand(1))
14868 return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0));
14869 }
14870
14871 // fold (fsub A, (fneg B)) -> (fadd A, B)
14872 if (SDValue NegN1 =
14873 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
14874 return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1);
14875
14876 // FSUB -> FMA combines:
14877 if (SDValue Fused = visitFSUBForFMACombine(N)) {
14878 AddToWorklist(Fused.getNode());
14879 return Fused;
14880 }
14881
14882 return SDValue();
14883 }
14884
visitFMUL(SDNode * N)14885 SDValue DAGCombiner::visitFMUL(SDNode *N) {
14886 SDValue N0 = N->getOperand(0);
14887 SDValue N1 = N->getOperand(1);
14888 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
14889 EVT VT = N->getValueType(0);
14890 SDLoc DL(N);
14891 const TargetOptions &Options = DAG.getTarget().Options;
14892 const SDNodeFlags Flags = N->getFlags();
14893 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
14894
14895 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
14896 return R;
14897
14898 // fold (fmul c1, c2) -> c1*c2
14899 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FMUL, DL, VT, {N0, N1}))
14900 return C;
14901
14902 // canonicalize constant to RHS
14903 if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
14904 !DAG.isConstantFPBuildVectorOrConstantFP(N1))
14905 return DAG.getNode(ISD::FMUL, DL, VT, N1, N0);
14906
14907 // fold vector ops
14908 if (VT.isVector())
14909 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
14910 return FoldedVOp;
14911
14912 if (SDValue NewSel = foldBinOpIntoSelect(N))
14913 return NewSel;
14914
14915 if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
14916 // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
14917 if (DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
14918 N0.getOpcode() == ISD::FMUL) {
14919 SDValue N00 = N0.getOperand(0);
14920 SDValue N01 = N0.getOperand(1);
14921 // Avoid an infinite loop by making sure that N00 is not a constant
14922 // (the inner multiply has not been constant folded yet).
14923 if (DAG.isConstantFPBuildVectorOrConstantFP(N01) &&
14924 !DAG.isConstantFPBuildVectorOrConstantFP(N00)) {
14925 SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1);
14926 return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts);
14927 }
14928 }
14929
14930 // Match a special-case: we convert X * 2.0 into fadd.
14931 // fmul (fadd X, X), C -> fmul X, 2.0 * C
14932 if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
14933 N0.getOperand(0) == N0.getOperand(1)) {
14934 const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
14935 SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
14936 return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
14937 }
14938 }
14939
14940 // fold (fmul X, 2.0) -> (fadd X, X)
14941 if (N1CFP && N1CFP->isExactlyValue(+2.0))
14942 return DAG.getNode(ISD::FADD, DL, VT, N0, N0);
14943
14944 // fold (fmul X, -1.0) -> (fsub -0.0, X)
14945 if (N1CFP && N1CFP->isExactlyValue(-1.0)) {
14946 if (!LegalOperations || TLI.isOperationLegal(ISD::FSUB, VT)) {
14947 return DAG.getNode(ISD::FSUB, DL, VT,
14948 DAG.getConstantFP(-0.0, DL, VT), N0, Flags);
14949 }
14950 }
14951
14952 // -N0 * -N1 --> N0 * N1
14953 TargetLowering::NegatibleCost CostN0 =
14954 TargetLowering::NegatibleCost::Expensive;
14955 TargetLowering::NegatibleCost CostN1 =
14956 TargetLowering::NegatibleCost::Expensive;
14957 SDValue NegN0 =
14958 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
14959 SDValue NegN1 =
14960 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
14961 if (NegN0 && NegN1 &&
14962 (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
14963 CostN1 == TargetLowering::NegatibleCost::Cheaper))
14964 return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1);
14965
14966 // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
14967 // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
14968 if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
14969 (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
14970 TLI.isOperationLegal(ISD::FABS, VT)) {
14971 SDValue Select = N0, X = N1;
14972 if (Select.getOpcode() != ISD::SELECT)
14973 std::swap(Select, X);
14974
14975 SDValue Cond = Select.getOperand(0);
14976 auto TrueOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
14977 auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
14978
14979 if (TrueOpnd && FalseOpnd &&
14980 Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
14981 isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
14982 cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
14983 ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
14984 switch (CC) {
14985 default: break;
14986 case ISD::SETOLT:
14987 case ISD::SETULT:
14988 case ISD::SETOLE:
14989 case ISD::SETULE:
14990 case ISD::SETLT:
14991 case ISD::SETLE:
14992 std::swap(TrueOpnd, FalseOpnd);
14993 LLVM_FALLTHROUGH;
14994 case ISD::SETOGT:
14995 case ISD::SETUGT:
14996 case ISD::SETOGE:
14997 case ISD::SETUGE:
14998 case ISD::SETGT:
14999 case ISD::SETGE:
15000 if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
15001 TLI.isOperationLegal(ISD::FNEG, VT))
15002 return DAG.getNode(ISD::FNEG, DL, VT,
15003 DAG.getNode(ISD::FABS, DL, VT, X));
15004 if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
15005 return DAG.getNode(ISD::FABS, DL, VT, X);
15006
15007 break;
15008 }
15009 }
15010 }
15011
15012 // FMUL -> FMA combines:
15013 if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
15014 AddToWorklist(Fused.getNode());
15015 return Fused;
15016 }
15017
15018 return SDValue();
15019 }
15020
visitFMA(SDNode * N)15021 SDValue DAGCombiner::visitFMA(SDNode *N) {
15022 SDValue N0 = N->getOperand(0);
15023 SDValue N1 = N->getOperand(1);
15024 SDValue N2 = N->getOperand(2);
15025 ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
15026 ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
15027 EVT VT = N->getValueType(0);
15028 SDLoc DL(N);
15029 const TargetOptions &Options = DAG.getTarget().Options;
15030 // FMA nodes have flags that propagate to the created nodes.
15031 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15032
15033 bool CanReassociate =
15034 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
15035
15036 // Constant fold FMA.
15037 if (isa<ConstantFPSDNode>(N0) &&
15038 isa<ConstantFPSDNode>(N1) &&
15039 isa<ConstantFPSDNode>(N2)) {
15040 return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2);
15041 }
15042
15043 // (-N0 * -N1) + N2 --> (N0 * N1) + N2
15044 TargetLowering::NegatibleCost CostN0 =
15045 TargetLowering::NegatibleCost::Expensive;
15046 TargetLowering::NegatibleCost CostN1 =
15047 TargetLowering::NegatibleCost::Expensive;
15048 SDValue NegN0 =
15049 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
15050 SDValue NegN1 =
15051 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
15052 if (NegN0 && NegN1 &&
15053 (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
15054 CostN1 == TargetLowering::NegatibleCost::Cheaper))
15055 return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
15056
15057 // FIXME: use fast math flags instead of Options.UnsafeFPMath
15058 if (Options.UnsafeFPMath) {
15059 if (N0CFP && N0CFP->isZero())
15060 return N2;
15061 if (N1CFP && N1CFP->isZero())
15062 return N2;
15063 }
15064
15065 if (N0CFP && N0CFP->isExactlyValue(1.0))
15066 return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
15067 if (N1CFP && N1CFP->isExactlyValue(1.0))
15068 return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
15069
15070 // Canonicalize (fma c, x, y) -> (fma x, c, y)
15071 if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
15072 !DAG.isConstantFPBuildVectorOrConstantFP(N1))
15073 return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
15074
15075 if (CanReassociate) {
15076 // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
15077 if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) &&
15078 DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
15079 DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
15080 return DAG.getNode(ISD::FMUL, DL, VT, N0,
15081 DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
15082 }
15083
15084 // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
15085 if (N0.getOpcode() == ISD::FMUL &&
15086 DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
15087 DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
15088 return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
15089 DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)),
15090 N2);
15091 }
15092 }
15093
15094 // (fma x, -1, y) -> (fadd (fneg x), y)
15095 if (N1CFP) {
15096 if (N1CFP->isExactlyValue(1.0))
15097 return DAG.getNode(ISD::FADD, DL, VT, N0, N2);
15098
15099 if (N1CFP->isExactlyValue(-1.0) &&
15100 (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
15101 SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0);
15102 AddToWorklist(RHSNeg.getNode());
15103 return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
15104 }
15105
15106 // fma (fneg x), K, y -> fma x -K, y
15107 if (N0.getOpcode() == ISD::FNEG &&
15108 (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
15109 (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT,
15110 ForCodeSize)))) {
15111 return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
15112 DAG.getNode(ISD::FNEG, DL, VT, N1), N2);
15113 }
15114 }
15115
15116 if (CanReassociate) {
15117 // (fma x, c, x) -> (fmul x, (c+1))
15118 if (N1CFP && N0 == N2) {
15119 return DAG.getNode(
15120 ISD::FMUL, DL, VT, N0,
15121 DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(1.0, DL, VT)));
15122 }
15123
15124 // (fma x, c, (fneg x)) -> (fmul x, (c-1))
15125 if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
15126 return DAG.getNode(
15127 ISD::FMUL, DL, VT, N0,
15128 DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(-1.0, DL, VT)));
15129 }
15130 }
15131
15132 // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
15133 // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
15134 if (!TLI.isFNegFree(VT))
15135 if (SDValue Neg = TLI.getCheaperNegatedExpression(
15136 SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
15137 return DAG.getNode(ISD::FNEG, DL, VT, Neg);
15138 return SDValue();
15139 }
15140
15141 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
15142 // reciprocal.
15143 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
15144 // Notice that this is not always beneficial. One reason is different targets
15145 // may have different costs for FDIV and FMUL, so sometimes the cost of two
15146 // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
15147 // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
combineRepeatedFPDivisors(SDNode * N)15148 SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
15149 // TODO: Limit this transform based on optsize/minsize - it always creates at
15150 // least 1 extra instruction. But the perf win may be substantial enough
15151 // that only minsize should restrict this.
15152 bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
15153 const SDNodeFlags Flags = N->getFlags();
15154 if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
15155 return SDValue();
15156
15157 // Skip if current node is a reciprocal/fneg-reciprocal.
15158 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
15159 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
15160 if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
15161 return SDValue();
15162
15163 // Exit early if the target does not want this transform or if there can't
15164 // possibly be enough uses of the divisor to make the transform worthwhile.
15165 unsigned MinUses = TLI.combineRepeatedFPDivisors();
15166
15167 // For splat vectors, scale the number of uses by the splat factor. If we can
15168 // convert the division into a scalar op, that will likely be much faster.
15169 unsigned NumElts = 1;
15170 EVT VT = N->getValueType(0);
15171 if (VT.isVector() && DAG.isSplatValue(N1))
15172 NumElts = VT.getVectorMinNumElements();
15173
15174 if (!MinUses || (N1->use_size() * NumElts) < MinUses)
15175 return SDValue();
15176
15177 // Find all FDIV users of the same divisor.
15178 // Use a set because duplicates may be present in the user list.
15179 SetVector<SDNode *> Users;
15180 for (auto *U : N1->uses()) {
15181 if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
15182 // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
15183 if (U->getOperand(1).getOpcode() == ISD::FSQRT &&
15184 U->getOperand(0) == U->getOperand(1).getOperand(0) &&
15185 U->getFlags().hasAllowReassociation() &&
15186 U->getFlags().hasNoSignedZeros())
15187 continue;
15188
15189 // This division is eligible for optimization only if global unsafe math
15190 // is enabled or if this division allows reciprocal formation.
15191 if (UnsafeMath || U->getFlags().hasAllowReciprocal())
15192 Users.insert(U);
15193 }
15194 }
15195
15196 // Now that we have the actual number of divisor uses, make sure it meets
15197 // the minimum threshold specified by the target.
15198 if ((Users.size() * NumElts) < MinUses)
15199 return SDValue();
15200
15201 SDLoc DL(N);
15202 SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
15203 SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
15204
15205 // Dividend / Divisor -> Dividend * Reciprocal
15206 for (auto *U : Users) {
15207 SDValue Dividend = U->getOperand(0);
15208 if (Dividend != FPOne) {
15209 SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
15210 Reciprocal, Flags);
15211 CombineTo(U, NewNode);
15212 } else if (U != Reciprocal.getNode()) {
15213 // In the absence of fast-math-flags, this user node is always the
15214 // same node as Reciprocal, but with FMF they may be different nodes.
15215 CombineTo(U, Reciprocal);
15216 }
15217 }
15218 return SDValue(N, 0); // N was replaced.
15219 }
15220
visitFDIV(SDNode * N)15221 SDValue DAGCombiner::visitFDIV(SDNode *N) {
15222 SDValue N0 = N->getOperand(0);
15223 SDValue N1 = N->getOperand(1);
15224 EVT VT = N->getValueType(0);
15225 SDLoc DL(N);
15226 const TargetOptions &Options = DAG.getTarget().Options;
15227 SDNodeFlags Flags = N->getFlags();
15228 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15229
15230 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
15231 return R;
15232
15233 // fold (fdiv c1, c2) -> c1/c2
15234 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FDIV, DL, VT, {N0, N1}))
15235 return C;
15236
15237 // fold vector ops
15238 if (VT.isVector())
15239 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
15240 return FoldedVOp;
15241
15242 if (SDValue NewSel = foldBinOpIntoSelect(N))
15243 return NewSel;
15244
15245 if (SDValue V = combineRepeatedFPDivisors(N))
15246 return V;
15247
15248 if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
15249 // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
15250 if (auto *N1CFP = dyn_cast<ConstantFPSDNode>(N1)) {
15251 // Compute the reciprocal 1.0 / c2.
15252 const APFloat &N1APF = N1CFP->getValueAPF();
15253 APFloat Recip(N1APF.getSemantics(), 1); // 1.0
15254 APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
15255 // Only do the transform if the reciprocal is a legal fp immediate that
15256 // isn't too nasty (eg NaN, denormal, ...).
15257 if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
15258 (!LegalOperations ||
15259 // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
15260 // backend)... we should handle this gracefully after Legalize.
15261 // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
15262 TLI.isOperationLegal(ISD::ConstantFP, VT) ||
15263 TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
15264 return DAG.getNode(ISD::FMUL, DL, VT, N0,
15265 DAG.getConstantFP(Recip, DL, VT));
15266 }
15267
15268 // If this FDIV is part of a reciprocal square root, it may be folded
15269 // into a target-specific square root estimate instruction.
15270 if (N1.getOpcode() == ISD::FSQRT) {
15271 if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags))
15272 return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
15273 } else if (N1.getOpcode() == ISD::FP_EXTEND &&
15274 N1.getOperand(0).getOpcode() == ISD::FSQRT) {
15275 if (SDValue RV =
15276 buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
15277 RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
15278 AddToWorklist(RV.getNode());
15279 return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
15280 }
15281 } else if (N1.getOpcode() == ISD::FP_ROUND &&
15282 N1.getOperand(0).getOpcode() == ISD::FSQRT) {
15283 if (SDValue RV =
15284 buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
15285 RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
15286 AddToWorklist(RV.getNode());
15287 return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
15288 }
15289 } else if (N1.getOpcode() == ISD::FMUL) {
15290 // Look through an FMUL. Even though this won't remove the FDIV directly,
15291 // it's still worthwhile to get rid of the FSQRT if possible.
15292 SDValue Sqrt, Y;
15293 if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
15294 Sqrt = N1.getOperand(0);
15295 Y = N1.getOperand(1);
15296 } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
15297 Sqrt = N1.getOperand(1);
15298 Y = N1.getOperand(0);
15299 }
15300 if (Sqrt.getNode()) {
15301 // If the other multiply operand is known positive, pull it into the
15302 // sqrt. That will eliminate the division if we convert to an estimate.
15303 if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
15304 N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
15305 SDValue A;
15306 if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
15307 A = Y.getOperand(0);
15308 else if (Y == Sqrt.getOperand(0))
15309 A = Y;
15310 if (A) {
15311 // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
15312 // X / (A * sqrt(A)) --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
15313 SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A);
15314 SDValue AAZ =
15315 DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0));
15316 if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
15317 return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt);
15318
15319 // Estimate creation failed. Clean up speculatively created nodes.
15320 recursivelyDeleteUnusedNodes(AAZ.getNode());
15321 }
15322 }
15323
15324 // We found a FSQRT, so try to make this fold:
15325 // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
15326 if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {
15327 SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
15328 AddToWorklist(Div.getNode());
15329 return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
15330 }
15331 }
15332 }
15333
15334 // Fold into a reciprocal estimate and multiply instead of a real divide.
15335 if (Options.NoInfsFPMath || Flags.hasNoInfs())
15336 if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
15337 return RV;
15338 }
15339
15340 // Fold X/Sqrt(X) -> Sqrt(X)
15341 if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
15342 (Options.UnsafeFPMath || Flags.hasAllowReassociation()))
15343 if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(0))
15344 return N1;
15345
15346 // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
15347 TargetLowering::NegatibleCost CostN0 =
15348 TargetLowering::NegatibleCost::Expensive;
15349 TargetLowering::NegatibleCost CostN1 =
15350 TargetLowering::NegatibleCost::Expensive;
15351 SDValue NegN0 =
15352 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
15353 SDValue NegN1 =
15354 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
15355 if (NegN0 && NegN1 &&
15356 (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
15357 CostN1 == TargetLowering::NegatibleCost::Cheaper))
15358 return DAG.getNode(ISD::FDIV, SDLoc(N), VT, NegN0, NegN1);
15359
15360 return SDValue();
15361 }
15362
visitFREM(SDNode * N)15363 SDValue DAGCombiner::visitFREM(SDNode *N) {
15364 SDValue N0 = N->getOperand(0);
15365 SDValue N1 = N->getOperand(1);
15366 EVT VT = N->getValueType(0);
15367 SDNodeFlags Flags = N->getFlags();
15368 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15369
15370 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
15371 return R;
15372
15373 // fold (frem c1, c2) -> fmod(c1,c2)
15374 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FREM, SDLoc(N), VT, {N0, N1}))
15375 return C;
15376
15377 if (SDValue NewSel = foldBinOpIntoSelect(N))
15378 return NewSel;
15379
15380 return SDValue();
15381 }
15382
visitFSQRT(SDNode * N)15383 SDValue DAGCombiner::visitFSQRT(SDNode *N) {
15384 SDNodeFlags Flags = N->getFlags();
15385 const TargetOptions &Options = DAG.getTarget().Options;
15386
15387 // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
15388 // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
15389 if (!Flags.hasApproximateFuncs() ||
15390 (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
15391 return SDValue();
15392
15393 SDValue N0 = N->getOperand(0);
15394 if (TLI.isFsqrtCheap(N0, DAG))
15395 return SDValue();
15396
15397 // FSQRT nodes have flags that propagate to the created nodes.
15398 // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
15399 // transform the fdiv, we may produce a sub-optimal estimate sequence
15400 // because the reciprocal calculation may not have to filter out a
15401 // 0.0 input.
15402 return buildSqrtEstimate(N0, Flags);
15403 }
15404
15405 /// copysign(x, fp_extend(y)) -> copysign(x, y)
15406 /// copysign(x, fp_round(y)) -> copysign(x, y)
CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode * N)15407 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
15408 SDValue N1 = N->getOperand(1);
15409 if ((N1.getOpcode() == ISD::FP_EXTEND ||
15410 N1.getOpcode() == ISD::FP_ROUND)) {
15411 EVT N1VT = N1->getValueType(0);
15412 EVT N1Op0VT = N1->getOperand(0).getValueType();
15413
15414 // Always fold no-op FP casts.
15415 if (N1VT == N1Op0VT)
15416 return true;
15417
15418 // Do not optimize out type conversion of f128 type yet.
15419 // For some targets like x86_64, configuration is changed to keep one f128
15420 // value in one SSE register, but instruction selection cannot handle
15421 // FCOPYSIGN on SSE registers yet.
15422 if (N1Op0VT == MVT::f128)
15423 return false;
15424
15425 // Avoid mismatched vector operand types, for better instruction selection.
15426 if (N1Op0VT.isVector())
15427 return false;
15428
15429 return true;
15430 }
15431 return false;
15432 }
15433
visitFCOPYSIGN(SDNode * N)15434 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
15435 SDValue N0 = N->getOperand(0);
15436 SDValue N1 = N->getOperand(1);
15437 EVT VT = N->getValueType(0);
15438
15439 // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
15440 if (SDValue C =
15441 DAG.FoldConstantArithmetic(ISD::FCOPYSIGN, SDLoc(N), VT, {N0, N1}))
15442 return C;
15443
15444 if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) {
15445 const APFloat &V = N1C->getValueAPF();
15446 // copysign(x, c1) -> fabs(x) iff ispos(c1)
15447 // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
15448 if (!V.isNegative()) {
15449 if (!LegalOperations || TLI.isOperationLegal(ISD::FABS, VT))
15450 return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
15451 } else {
15452 if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
15453 return DAG.getNode(ISD::FNEG, SDLoc(N), VT,
15454 DAG.getNode(ISD::FABS, SDLoc(N0), VT, N0));
15455 }
15456 }
15457
15458 // copysign(fabs(x), y) -> copysign(x, y)
15459 // copysign(fneg(x), y) -> copysign(x, y)
15460 // copysign(copysign(x,z), y) -> copysign(x, y)
15461 if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
15462 N0.getOpcode() == ISD::FCOPYSIGN)
15463 return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0.getOperand(0), N1);
15464
15465 // copysign(x, abs(y)) -> abs(x)
15466 if (N1.getOpcode() == ISD::FABS)
15467 return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
15468
15469 // copysign(x, copysign(y,z)) -> copysign(x, z)
15470 if (N1.getOpcode() == ISD::FCOPYSIGN)
15471 return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(1));
15472
15473 // copysign(x, fp_extend(y)) -> copysign(x, y)
15474 // copysign(x, fp_round(y)) -> copysign(x, y)
15475 if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
15476 return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(0));
15477
15478 return SDValue();
15479 }
15480
visitFPOW(SDNode * N)15481 SDValue DAGCombiner::visitFPOW(SDNode *N) {
15482 ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
15483 if (!ExponentC)
15484 return SDValue();
15485 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15486
15487 // Try to convert x ** (1/3) into cube root.
15488 // TODO: Handle the various flavors of long double.
15489 // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
15490 // Some range near 1/3 should be fine.
15491 EVT VT = N->getValueType(0);
15492 if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
15493 (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
15494 // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
15495 // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
15496 // pow(-val, 1/3) = nan; cbrt(-val) = -num.
15497 // For regular numbers, rounding may cause the results to differ.
15498 // Therefore, we require { nsz ninf nnan afn } for this transform.
15499 // TODO: We could select out the special cases if we don't have nsz/ninf.
15500 SDNodeFlags Flags = N->getFlags();
15501 if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
15502 !Flags.hasApproximateFuncs())
15503 return SDValue();
15504
15505 // Do not create a cbrt() libcall if the target does not have it, and do not
15506 // turn a pow that has lowering support into a cbrt() libcall.
15507 if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
15508 (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
15509 DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
15510 return SDValue();
15511
15512 return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0));
15513 }
15514
15515 // Try to convert x ** (1/4) and x ** (3/4) into square roots.
15516 // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
15517 // TODO: This could be extended (using a target hook) to handle smaller
15518 // power-of-2 fractional exponents.
15519 bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
15520 bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
15521 if (ExponentIs025 || ExponentIs075) {
15522 // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
15523 // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN.
15524 // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
15525 // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) = NaN.
15526 // For regular numbers, rounding may cause the results to differ.
15527 // Therefore, we require { nsz ninf afn } for this transform.
15528 // TODO: We could select out the special cases if we don't have nsz/ninf.
15529 SDNodeFlags Flags = N->getFlags();
15530
15531 // We only need no signed zeros for the 0.25 case.
15532 if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
15533 !Flags.hasApproximateFuncs())
15534 return SDValue();
15535
15536 // Don't double the number of libcalls. We are trying to inline fast code.
15537 if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
15538 return SDValue();
15539
15540 // Assume that libcalls are the smallest code.
15541 // TODO: This restriction should probably be lifted for vectors.
15542 if (ForCodeSize)
15543 return SDValue();
15544
15545 // pow(X, 0.25) --> sqrt(sqrt(X))
15546 SDLoc DL(N);
15547 SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0));
15548 SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt);
15549 if (ExponentIs025)
15550 return SqrtSqrt;
15551 // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
15552 return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt);
15553 }
15554
15555 return SDValue();
15556 }
15557
foldFPToIntToFP(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)15558 static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
15559 const TargetLowering &TLI) {
15560 // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
15561 // replacing casts with a libcall. We also must be allowed to ignore -0.0
15562 // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
15563 // conversions would return +0.0.
15564 // FIXME: We should be able to use node-level FMF here.
15565 // TODO: If strict math, should we use FABS (+ range check for signed cast)?
15566 EVT VT = N->getValueType(0);
15567 if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
15568 !DAG.getTarget().Options.NoSignedZerosFPMath)
15569 return SDValue();
15570
15571 // fptosi/fptoui round towards zero, so converting from FP to integer and
15572 // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
15573 SDValue N0 = N->getOperand(0);
15574 if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
15575 N0.getOperand(0).getValueType() == VT)
15576 return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
15577
15578 if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
15579 N0.getOperand(0).getValueType() == VT)
15580 return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
15581
15582 return SDValue();
15583 }
15584
visitSINT_TO_FP(SDNode * N)15585 SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
15586 SDValue N0 = N->getOperand(0);
15587 EVT VT = N->getValueType(0);
15588 EVT OpVT = N0.getValueType();
15589
15590 // [us]itofp(undef) = 0, because the result value is bounded.
15591 if (N0.isUndef())
15592 return DAG.getConstantFP(0.0, SDLoc(N), VT);
15593
15594 // fold (sint_to_fp c1) -> c1fp
15595 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
15596 // ...but only if the target supports immediate floating-point values
15597 (!LegalOperations ||
15598 TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
15599 return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
15600
15601 // If the input is a legal type, and SINT_TO_FP is not legal on this target,
15602 // but UINT_TO_FP is legal on this target, try to convert.
15603 if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
15604 hasOperation(ISD::UINT_TO_FP, OpVT)) {
15605 // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
15606 if (DAG.SignBitIsZero(N0))
15607 return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
15608 }
15609
15610 // The next optimizations are desirable only if SELECT_CC can be lowered.
15611 // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
15612 if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
15613 !VT.isVector() &&
15614 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
15615 SDLoc DL(N);
15616 return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(-1.0, DL, VT),
15617 DAG.getConstantFP(0.0, DL, VT));
15618 }
15619
15620 // fold (sint_to_fp (zext (setcc x, y, cc))) ->
15621 // (select (setcc x, y, cc), 1.0, 0.0)
15622 if (N0.getOpcode() == ISD::ZERO_EXTEND &&
15623 N0.getOperand(0).getOpcode() == ISD::SETCC && !VT.isVector() &&
15624 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
15625 SDLoc DL(N);
15626 return DAG.getSelect(DL, VT, N0.getOperand(0),
15627 DAG.getConstantFP(1.0, DL, VT),
15628 DAG.getConstantFP(0.0, DL, VT));
15629 }
15630
15631 if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
15632 return FTrunc;
15633
15634 return SDValue();
15635 }
15636
visitUINT_TO_FP(SDNode * N)15637 SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
15638 SDValue N0 = N->getOperand(0);
15639 EVT VT = N->getValueType(0);
15640 EVT OpVT = N0.getValueType();
15641
15642 // [us]itofp(undef) = 0, because the result value is bounded.
15643 if (N0.isUndef())
15644 return DAG.getConstantFP(0.0, SDLoc(N), VT);
15645
15646 // fold (uint_to_fp c1) -> c1fp
15647 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
15648 // ...but only if the target supports immediate floating-point values
15649 (!LegalOperations ||
15650 TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
15651 return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
15652
15653 // If the input is a legal type, and UINT_TO_FP is not legal on this target,
15654 // but SINT_TO_FP is legal on this target, try to convert.
15655 if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
15656 hasOperation(ISD::SINT_TO_FP, OpVT)) {
15657 // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
15658 if (DAG.SignBitIsZero(N0))
15659 return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
15660 }
15661
15662 // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
15663 if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
15664 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
15665 SDLoc DL(N);
15666 return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(1.0, DL, VT),
15667 DAG.getConstantFP(0.0, DL, VT));
15668 }
15669
15670 if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
15671 return FTrunc;
15672
15673 return SDValue();
15674 }
15675
15676 // Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
FoldIntToFPToInt(SDNode * N,SelectionDAG & DAG)15677 static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
15678 SDValue N0 = N->getOperand(0);
15679 EVT VT = N->getValueType(0);
15680
15681 if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
15682 return SDValue();
15683
15684 SDValue Src = N0.getOperand(0);
15685 EVT SrcVT = Src.getValueType();
15686 bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
15687 bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
15688
15689 // We can safely assume the conversion won't overflow the output range,
15690 // because (for example) (uint8_t)18293.f is undefined behavior.
15691
15692 // Since we can assume the conversion won't overflow, our decision as to
15693 // whether the input will fit in the float should depend on the minimum
15694 // of the input range and output range.
15695
15696 // This means this is also safe for a signed input and unsigned output, since
15697 // a negative input would lead to undefined behavior.
15698 unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
15699 unsigned OutputSize = (int)VT.getScalarSizeInBits();
15700 unsigned ActualSize = std::min(InputSize, OutputSize);
15701 const fltSemantics &sem = DAG.EVTToAPFloatSemantics(N0.getValueType());
15702
15703 // We can only fold away the float conversion if the input range can be
15704 // represented exactly in the float range.
15705 if (APFloat::semanticsPrecision(sem) >= ActualSize) {
15706 if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
15707 unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
15708 : ISD::ZERO_EXTEND;
15709 return DAG.getNode(ExtOp, SDLoc(N), VT, Src);
15710 }
15711 if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
15712 return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Src);
15713 return DAG.getBitcast(VT, Src);
15714 }
15715 return SDValue();
15716 }
15717
visitFP_TO_SINT(SDNode * N)15718 SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
15719 SDValue N0 = N->getOperand(0);
15720 EVT VT = N->getValueType(0);
15721
15722 // fold (fp_to_sint undef) -> undef
15723 if (N0.isUndef())
15724 return DAG.getUNDEF(VT);
15725
15726 // fold (fp_to_sint c1fp) -> c1
15727 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15728 return DAG.getNode(ISD::FP_TO_SINT, SDLoc(N), VT, N0);
15729
15730 return FoldIntToFPToInt(N, DAG);
15731 }
15732
visitFP_TO_UINT(SDNode * N)15733 SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
15734 SDValue N0 = N->getOperand(0);
15735 EVT VT = N->getValueType(0);
15736
15737 // fold (fp_to_uint undef) -> undef
15738 if (N0.isUndef())
15739 return DAG.getUNDEF(VT);
15740
15741 // fold (fp_to_uint c1fp) -> c1
15742 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15743 return DAG.getNode(ISD::FP_TO_UINT, SDLoc(N), VT, N0);
15744
15745 return FoldIntToFPToInt(N, DAG);
15746 }
15747
visitFP_ROUND(SDNode * N)15748 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
15749 SDValue N0 = N->getOperand(0);
15750 SDValue N1 = N->getOperand(1);
15751 ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
15752 EVT VT = N->getValueType(0);
15753
15754 // fold (fp_round c1fp) -> c1fp
15755 if (N0CFP)
15756 return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT, N0, N1);
15757
15758 // fold (fp_round (fp_extend x)) -> x
15759 if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
15760 return N0.getOperand(0);
15761
15762 // fold (fp_round (fp_round x)) -> (fp_round x)
15763 if (N0.getOpcode() == ISD::FP_ROUND) {
15764 const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
15765 const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
15766
15767 // Skip this folding if it results in an fp_round from f80 to f16.
15768 //
15769 // f80 to f16 always generates an expensive (and as yet, unimplemented)
15770 // libcall to __truncxfhf2 instead of selecting native f16 conversion
15771 // instructions from f32 or f64. Moreover, the first (value-preserving)
15772 // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
15773 // x86.
15774 if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
15775 return SDValue();
15776
15777 // If the first fp_round isn't a value preserving truncation, it might
15778 // introduce a tie in the second fp_round, that wouldn't occur in the
15779 // single-step fp_round we want to fold to.
15780 // In other words, double rounding isn't the same as rounding.
15781 // Also, this is a value preserving truncation iff both fp_round's are.
15782 if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
15783 SDLoc DL(N);
15784 return DAG.getNode(ISD::FP_ROUND, DL, VT, N0.getOperand(0),
15785 DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL));
15786 }
15787 }
15788
15789 // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
15790 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse()) {
15791 SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
15792 N0.getOperand(0), N1);
15793 AddToWorklist(Tmp.getNode());
15794 return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
15795 Tmp, N0.getOperand(1));
15796 }
15797
15798 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
15799 return NewVSel;
15800
15801 return SDValue();
15802 }
15803
visitFP_EXTEND(SDNode * N)15804 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
15805 SDValue N0 = N->getOperand(0);
15806 EVT VT = N->getValueType(0);
15807
15808 // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
15809 if (N->hasOneUse() &&
15810 N->use_begin()->getOpcode() == ISD::FP_ROUND)
15811 return SDValue();
15812
15813 // fold (fp_extend c1fp) -> c1fp
15814 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15815 return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, N0);
15816
15817 // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
15818 if (N0.getOpcode() == ISD::FP16_TO_FP &&
15819 TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
15820 return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), VT, N0.getOperand(0));
15821
15822 // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
15823 // value of X.
15824 if (N0.getOpcode() == ISD::FP_ROUND
15825 && N0.getConstantOperandVal(1) == 1) {
15826 SDValue In = N0.getOperand(0);
15827 if (In.getValueType() == VT) return In;
15828 if (VT.bitsLT(In.getValueType()))
15829 return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT,
15830 In, N0.getOperand(1));
15831 return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, In);
15832 }
15833
15834 // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
15835 if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
15836 TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) {
15837 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
15838 SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
15839 LN0->getChain(),
15840 LN0->getBasePtr(), N0.getValueType(),
15841 LN0->getMemOperand());
15842 CombineTo(N, ExtLoad);
15843 CombineTo(N0.getNode(),
15844 DAG.getNode(ISD::FP_ROUND, SDLoc(N0),
15845 N0.getValueType(), ExtLoad,
15846 DAG.getIntPtrConstant(1, SDLoc(N0))),
15847 ExtLoad.getValue(1));
15848 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15849 }
15850
15851 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
15852 return NewVSel;
15853
15854 return SDValue();
15855 }
15856
visitFCEIL(SDNode * N)15857 SDValue DAGCombiner::visitFCEIL(SDNode *N) {
15858 SDValue N0 = N->getOperand(0);
15859 EVT VT = N->getValueType(0);
15860
15861 // fold (fceil c1) -> fceil(c1)
15862 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15863 return DAG.getNode(ISD::FCEIL, SDLoc(N), VT, N0);
15864
15865 return SDValue();
15866 }
15867
visitFTRUNC(SDNode * N)15868 SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
15869 SDValue N0 = N->getOperand(0);
15870 EVT VT = N->getValueType(0);
15871
15872 // fold (ftrunc c1) -> ftrunc(c1)
15873 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15874 return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0);
15875
15876 // fold ftrunc (known rounded int x) -> x
15877 // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
15878 // likely to be generated to extract integer from a rounded floating value.
15879 switch (N0.getOpcode()) {
15880 default: break;
15881 case ISD::FRINT:
15882 case ISD::FTRUNC:
15883 case ISD::FNEARBYINT:
15884 case ISD::FFLOOR:
15885 case ISD::FCEIL:
15886 return N0;
15887 }
15888
15889 return SDValue();
15890 }
15891
visitFFLOOR(SDNode * N)15892 SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
15893 SDValue N0 = N->getOperand(0);
15894 EVT VT = N->getValueType(0);
15895
15896 // fold (ffloor c1) -> ffloor(c1)
15897 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15898 return DAG.getNode(ISD::FFLOOR, SDLoc(N), VT, N0);
15899
15900 return SDValue();
15901 }
15902
visitFNEG(SDNode * N)15903 SDValue DAGCombiner::visitFNEG(SDNode *N) {
15904 SDValue N0 = N->getOperand(0);
15905 EVT VT = N->getValueType(0);
15906 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15907
15908 // Constant fold FNEG.
15909 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15910 return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0);
15911
15912 if (SDValue NegN0 =
15913 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize))
15914 return NegN0;
15915
15916 // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
15917 // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
15918 // know it was called from a context with a nsz flag if the input fsub does
15919 // not.
15920 if (N0.getOpcode() == ISD::FSUB &&
15921 (DAG.getTarget().Options.NoSignedZerosFPMath ||
15922 N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
15923 return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
15924 N0.getOperand(0));
15925 }
15926
15927 if (SDValue Cast = foldSignChangeInBitcast(N))
15928 return Cast;
15929
15930 return SDValue();
15931 }
15932
visitFMinMax(SDNode * N)15933 SDValue DAGCombiner::visitFMinMax(SDNode *N) {
15934 SDValue N0 = N->getOperand(0);
15935 SDValue N1 = N->getOperand(1);
15936 EVT VT = N->getValueType(0);
15937 const SDNodeFlags Flags = N->getFlags();
15938 unsigned Opc = N->getOpcode();
15939 bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
15940 bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
15941 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15942
15943 // Constant fold.
15944 if (SDValue C = DAG.FoldConstantArithmetic(Opc, SDLoc(N), VT, {N0, N1}))
15945 return C;
15946
15947 // Canonicalize to constant on RHS.
15948 if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
15949 !DAG.isConstantFPBuildVectorOrConstantFP(N1))
15950 return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
15951
15952 if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1)) {
15953 const APFloat &AF = N1CFP->getValueAPF();
15954
15955 // minnum(X, nan) -> X
15956 // maxnum(X, nan) -> X
15957 // minimum(X, nan) -> nan
15958 // maximum(X, nan) -> nan
15959 if (AF.isNaN())
15960 return PropagatesNaN ? N->getOperand(1) : N->getOperand(0);
15961
15962 // In the following folds, inf can be replaced with the largest finite
15963 // float, if the ninf flag is set.
15964 if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
15965 // minnum(X, -inf) -> -inf
15966 // maxnum(X, +inf) -> +inf
15967 // minimum(X, -inf) -> -inf if nnan
15968 // maximum(X, +inf) -> +inf if nnan
15969 if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
15970 return N->getOperand(1);
15971
15972 // minnum(X, +inf) -> X if nnan
15973 // maxnum(X, -inf) -> X if nnan
15974 // minimum(X, +inf) -> X
15975 // maximum(X, -inf) -> X
15976 if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
15977 return N->getOperand(0);
15978 }
15979 }
15980
15981 return SDValue();
15982 }
15983
visitFABS(SDNode * N)15984 SDValue DAGCombiner::visitFABS(SDNode *N) {
15985 SDValue N0 = N->getOperand(0);
15986 EVT VT = N->getValueType(0);
15987
15988 // fold (fabs c1) -> fabs(c1)
15989 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
15990 return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
15991
15992 // fold (fabs (fabs x)) -> (fabs x)
15993 if (N0.getOpcode() == ISD::FABS)
15994 return N->getOperand(0);
15995
15996 // fold (fabs (fneg x)) -> (fabs x)
15997 // fold (fabs (fcopysign x, y)) -> (fabs x)
15998 if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
15999 return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0));
16000
16001 if (SDValue Cast = foldSignChangeInBitcast(N))
16002 return Cast;
16003
16004 return SDValue();
16005 }
16006
visitBRCOND(SDNode * N)16007 SDValue DAGCombiner::visitBRCOND(SDNode *N) {
16008 SDValue Chain = N->getOperand(0);
16009 SDValue N1 = N->getOperand(1);
16010 SDValue N2 = N->getOperand(2);
16011
16012 // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
16013 // nondeterministic jumps).
16014 if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
16015 return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
16016 N1->getOperand(0), N2);
16017 }
16018
16019 // If N is a constant we could fold this into a fallthrough or unconditional
16020 // branch. However that doesn't happen very often in normal code, because
16021 // Instcombine/SimplifyCFG should have handled the available opportunities.
16022 // If we did this folding here, it would be necessary to update the
16023 // MachineBasicBlock CFG, which is awkward.
16024
16025 // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
16026 // on the target.
16027 if (N1.getOpcode() == ISD::SETCC &&
16028 TLI.isOperationLegalOrCustom(ISD::BR_CC,
16029 N1.getOperand(0).getValueType())) {
16030 return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
16031 Chain, N1.getOperand(2),
16032 N1.getOperand(0), N1.getOperand(1), N2);
16033 }
16034
16035 if (N1.hasOneUse()) {
16036 // rebuildSetCC calls visitXor which may change the Chain when there is a
16037 // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
16038 HandleSDNode ChainHandle(Chain);
16039 if (SDValue NewN1 = rebuildSetCC(N1))
16040 return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
16041 ChainHandle.getValue(), NewN1, N2);
16042 }
16043
16044 return SDValue();
16045 }
16046
rebuildSetCC(SDValue N)16047 SDValue DAGCombiner::rebuildSetCC(SDValue N) {
16048 if (N.getOpcode() == ISD::SRL ||
16049 (N.getOpcode() == ISD::TRUNCATE &&
16050 (N.getOperand(0).hasOneUse() &&
16051 N.getOperand(0).getOpcode() == ISD::SRL))) {
16052 // Look pass the truncate.
16053 if (N.getOpcode() == ISD::TRUNCATE)
16054 N = N.getOperand(0);
16055
16056 // Match this pattern so that we can generate simpler code:
16057 //
16058 // %a = ...
16059 // %b = and i32 %a, 2
16060 // %c = srl i32 %b, 1
16061 // brcond i32 %c ...
16062 //
16063 // into
16064 //
16065 // %a = ...
16066 // %b = and i32 %a, 2
16067 // %c = setcc eq %b, 0
16068 // brcond %c ...
16069 //
16070 // This applies only when the AND constant value has one bit set and the
16071 // SRL constant is equal to the log2 of the AND constant. The back-end is
16072 // smart enough to convert the result into a TEST/JMP sequence.
16073 SDValue Op0 = N.getOperand(0);
16074 SDValue Op1 = N.getOperand(1);
16075
16076 if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
16077 SDValue AndOp1 = Op0.getOperand(1);
16078
16079 if (AndOp1.getOpcode() == ISD::Constant) {
16080 const APInt &AndConst = cast<ConstantSDNode>(AndOp1)->getAPIntValue();
16081
16082 if (AndConst.isPowerOf2() &&
16083 cast<ConstantSDNode>(Op1)->getAPIntValue() == AndConst.logBase2()) {
16084 SDLoc DL(N);
16085 return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
16086 Op0, DAG.getConstant(0, DL, Op0.getValueType()),
16087 ISD::SETNE);
16088 }
16089 }
16090 }
16091 }
16092
16093 // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
16094 // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
16095 if (N.getOpcode() == ISD::XOR) {
16096 // Because we may call this on a speculatively constructed
16097 // SimplifiedSetCC Node, we need to simplify this node first.
16098 // Ideally this should be folded into SimplifySetCC and not
16099 // here. For now, grab a handle to N so we don't lose it from
16100 // replacements interal to the visit.
16101 HandleSDNode XORHandle(N);
16102 while (N.getOpcode() == ISD::XOR) {
16103 SDValue Tmp = visitXOR(N.getNode());
16104 // No simplification done.
16105 if (!Tmp.getNode())
16106 break;
16107 // Returning N is form in-visit replacement that may invalidated
16108 // N. Grab value from Handle.
16109 if (Tmp.getNode() == N.getNode())
16110 N = XORHandle.getValue();
16111 else // Node simplified. Try simplifying again.
16112 N = Tmp;
16113 }
16114
16115 if (N.getOpcode() != ISD::XOR)
16116 return N;
16117
16118 SDValue Op0 = N->getOperand(0);
16119 SDValue Op1 = N->getOperand(1);
16120
16121 if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
16122 bool Equal = false;
16123 // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
16124 if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
16125 Op0.getValueType() == MVT::i1) {
16126 N = Op0;
16127 Op0 = N->getOperand(0);
16128 Op1 = N->getOperand(1);
16129 Equal = true;
16130 }
16131
16132 EVT SetCCVT = N.getValueType();
16133 if (LegalTypes)
16134 SetCCVT = getSetCCResultType(SetCCVT);
16135 // Replace the uses of XOR with SETCC
16136 return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1,
16137 Equal ? ISD::SETEQ : ISD::SETNE);
16138 }
16139 }
16140
16141 return SDValue();
16142 }
16143
16144 // Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
16145 //
visitBR_CC(SDNode * N)16146 SDValue DAGCombiner::visitBR_CC(SDNode *N) {
16147 CondCodeSDNode *CC = cast<CondCodeSDNode>(N->getOperand(1));
16148 SDValue CondLHS = N->getOperand(2), CondRHS = N->getOperand(3);
16149
16150 // If N is a constant we could fold this into a fallthrough or unconditional
16151 // branch. However that doesn't happen very often in normal code, because
16152 // Instcombine/SimplifyCFG should have handled the available opportunities.
16153 // If we did this folding here, it would be necessary to update the
16154 // MachineBasicBlock CFG, which is awkward.
16155
16156 // Use SimplifySetCC to simplify SETCC's.
16157 SDValue Simp = SimplifySetCC(getSetCCResultType(CondLHS.getValueType()),
16158 CondLHS, CondRHS, CC->get(), SDLoc(N),
16159 false);
16160 if (Simp.getNode()) AddToWorklist(Simp.getNode());
16161
16162 // fold to a simpler setcc
16163 if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
16164 return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
16165 N->getOperand(0), Simp.getOperand(2),
16166 Simp.getOperand(0), Simp.getOperand(1),
16167 N->getOperand(4));
16168
16169 return SDValue();
16170 }
16171
getCombineLoadStoreParts(SDNode * N,unsigned Inc,unsigned Dec,bool & IsLoad,bool & IsMasked,SDValue & Ptr,const TargetLowering & TLI)16172 static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
16173 bool &IsLoad, bool &IsMasked, SDValue &Ptr,
16174 const TargetLowering &TLI) {
16175 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
16176 if (LD->isIndexed())
16177 return false;
16178 EVT VT = LD->getMemoryVT();
16179 if (!TLI.isIndexedLoadLegal(Inc, VT) && !TLI.isIndexedLoadLegal(Dec, VT))
16180 return false;
16181 Ptr = LD->getBasePtr();
16182 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
16183 if (ST->isIndexed())
16184 return false;
16185 EVT VT = ST->getMemoryVT();
16186 if (!TLI.isIndexedStoreLegal(Inc, VT) && !TLI.isIndexedStoreLegal(Dec, VT))
16187 return false;
16188 Ptr = ST->getBasePtr();
16189 IsLoad = false;
16190 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) {
16191 if (LD->isIndexed())
16192 return false;
16193 EVT VT = LD->getMemoryVT();
16194 if (!TLI.isIndexedMaskedLoadLegal(Inc, VT) &&
16195 !TLI.isIndexedMaskedLoadLegal(Dec, VT))
16196 return false;
16197 Ptr = LD->getBasePtr();
16198 IsMasked = true;
16199 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) {
16200 if (ST->isIndexed())
16201 return false;
16202 EVT VT = ST->getMemoryVT();
16203 if (!TLI.isIndexedMaskedStoreLegal(Inc, VT) &&
16204 !TLI.isIndexedMaskedStoreLegal(Dec, VT))
16205 return false;
16206 Ptr = ST->getBasePtr();
16207 IsLoad = false;
16208 IsMasked = true;
16209 } else {
16210 return false;
16211 }
16212 return true;
16213 }
16214
16215 /// Try turning a load/store into a pre-indexed load/store when the base
16216 /// pointer is an add or subtract and it has other uses besides the load/store.
16217 /// After the transformation, the new indexed load/store has effectively folded
16218 /// the add/subtract in and all of its other uses are redirected to the
16219 /// new load/store.
CombineToPreIndexedLoadStore(SDNode * N)16220 bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
16221 if (Level < AfterLegalizeDAG)
16222 return false;
16223
16224 bool IsLoad = true;
16225 bool IsMasked = false;
16226 SDValue Ptr;
16227 if (!getCombineLoadStoreParts(N, ISD::PRE_INC, ISD::PRE_DEC, IsLoad, IsMasked,
16228 Ptr, TLI))
16229 return false;
16230
16231 // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
16232 // out. There is no reason to make this a preinc/predec.
16233 if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
16234 Ptr->hasOneUse())
16235 return false;
16236
16237 // Ask the target to do addressing mode selection.
16238 SDValue BasePtr;
16239 SDValue Offset;
16240 ISD::MemIndexedMode AM = ISD::UNINDEXED;
16241 if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
16242 return false;
16243
16244 // Backends without true r+i pre-indexed forms may need to pass a
16245 // constant base with a variable offset so that constant coercion
16246 // will work with the patterns in canonical form.
16247 bool Swapped = false;
16248 if (isa<ConstantSDNode>(BasePtr)) {
16249 std::swap(BasePtr, Offset);
16250 Swapped = true;
16251 }
16252
16253 // Don't create a indexed load / store with zero offset.
16254 if (isNullConstant(Offset))
16255 return false;
16256
16257 // Try turning it into a pre-indexed load / store except when:
16258 // 1) The new base ptr is a frame index.
16259 // 2) If N is a store and the new base ptr is either the same as or is a
16260 // predecessor of the value being stored.
16261 // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
16262 // that would create a cycle.
16263 // 4) All uses are load / store ops that use it as old base ptr.
16264
16265 // Check #1. Preinc'ing a frame index would require copying the stack pointer
16266 // (plus the implicit offset) to a register to preinc anyway.
16267 if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
16268 return false;
16269
16270 // Check #2.
16271 if (!IsLoad) {
16272 SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(N)->getValue()
16273 : cast<StoreSDNode>(N)->getValue();
16274
16275 // Would require a copy.
16276 if (Val == BasePtr)
16277 return false;
16278
16279 // Would create a cycle.
16280 if (Val == Ptr || Ptr->isPredecessorOf(Val.getNode()))
16281 return false;
16282 }
16283
16284 // Caches for hasPredecessorHelper.
16285 SmallPtrSet<const SDNode *, 32> Visited;
16286 SmallVector<const SDNode *, 16> Worklist;
16287 Worklist.push_back(N);
16288
16289 // If the offset is a constant, there may be other adds of constants that
16290 // can be folded with this one. We should do this to avoid having to keep
16291 // a copy of the original base pointer.
16292 SmallVector<SDNode *, 16> OtherUses;
16293 if (isa<ConstantSDNode>(Offset))
16294 for (SDNode::use_iterator UI = BasePtr->use_begin(),
16295 UE = BasePtr->use_end();
16296 UI != UE; ++UI) {
16297 SDUse &Use = UI.getUse();
16298 // Skip the use that is Ptr and uses of other results from BasePtr's
16299 // node (important for nodes that return multiple results).
16300 if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
16301 continue;
16302
16303 if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist))
16304 continue;
16305
16306 if (Use.getUser()->getOpcode() != ISD::ADD &&
16307 Use.getUser()->getOpcode() != ISD::SUB) {
16308 OtherUses.clear();
16309 break;
16310 }
16311
16312 SDValue Op1 = Use.getUser()->getOperand((UI.getOperandNo() + 1) & 1);
16313 if (!isa<ConstantSDNode>(Op1)) {
16314 OtherUses.clear();
16315 break;
16316 }
16317
16318 // FIXME: In some cases, we can be smarter about this.
16319 if (Op1.getValueType() != Offset.getValueType()) {
16320 OtherUses.clear();
16321 break;
16322 }
16323
16324 OtherUses.push_back(Use.getUser());
16325 }
16326
16327 if (Swapped)
16328 std::swap(BasePtr, Offset);
16329
16330 // Now check for #3 and #4.
16331 bool RealUse = false;
16332
16333 for (SDNode *Use : Ptr->uses()) {
16334 if (Use == N)
16335 continue;
16336 if (SDNode::hasPredecessorHelper(Use, Visited, Worklist))
16337 return false;
16338
16339 // If Ptr may be folded in addressing mode of other use, then it's
16340 // not profitable to do this transformation.
16341 if (!canFoldInAddressingMode(Ptr.getNode(), Use, DAG, TLI))
16342 RealUse = true;
16343 }
16344
16345 if (!RealUse)
16346 return false;
16347
16348 SDValue Result;
16349 if (!IsMasked) {
16350 if (IsLoad)
16351 Result = DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
16352 else
16353 Result =
16354 DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
16355 } else {
16356 if (IsLoad)
16357 Result = DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
16358 Offset, AM);
16359 else
16360 Result = DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr,
16361 Offset, AM);
16362 }
16363 ++PreIndexedNodes;
16364 ++NodesCombined;
16365 LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
16366 Result.dump(&DAG); dbgs() << '\n');
16367 WorklistRemover DeadNodes(*this);
16368 if (IsLoad) {
16369 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
16370 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
16371 } else {
16372 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
16373 }
16374
16375 // Finally, since the node is now dead, remove it from the graph.
16376 deleteAndRecombine(N);
16377
16378 if (Swapped)
16379 std::swap(BasePtr, Offset);
16380
16381 // Replace other uses of BasePtr that can be updated to use Ptr
16382 for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
16383 unsigned OffsetIdx = 1;
16384 if (OtherUses[i]->getOperand(OffsetIdx).getNode() == BasePtr.getNode())
16385 OffsetIdx = 0;
16386 assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
16387 BasePtr.getNode() && "Expected BasePtr operand");
16388
16389 // We need to replace ptr0 in the following expression:
16390 // x0 * offset0 + y0 * ptr0 = t0
16391 // knowing that
16392 // x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
16393 //
16394 // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
16395 // indexed load/store and the expression that needs to be re-written.
16396 //
16397 // Therefore, we have:
16398 // t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
16399
16400 auto *CN = cast<ConstantSDNode>(OtherUses[i]->getOperand(OffsetIdx));
16401 const APInt &Offset0 = CN->getAPIntValue();
16402 const APInt &Offset1 = cast<ConstantSDNode>(Offset)->getAPIntValue();
16403 int X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
16404 int Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
16405 int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
16406 int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
16407
16408 unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
16409
16410 APInt CNV = Offset0;
16411 if (X0 < 0) CNV = -CNV;
16412 if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
16413 else CNV = CNV - Offset1;
16414
16415 SDLoc DL(OtherUses[i]);
16416
16417 // We can now generate the new expression.
16418 SDValue NewOp1 = DAG.getConstant(CNV, DL, CN->getValueType(0));
16419 SDValue NewOp2 = Result.getValue(IsLoad ? 1 : 0);
16420
16421 SDValue NewUse = DAG.getNode(Opcode,
16422 DL,
16423 OtherUses[i]->getValueType(0), NewOp1, NewOp2);
16424 DAG.ReplaceAllUsesOfValueWith(SDValue(OtherUses[i], 0), NewUse);
16425 deleteAndRecombine(OtherUses[i]);
16426 }
16427
16428 // Replace the uses of Ptr with uses of the updated base value.
16429 DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(IsLoad ? 1 : 0));
16430 deleteAndRecombine(Ptr.getNode());
16431 AddToWorklist(Result.getNode());
16432
16433 return true;
16434 }
16435
shouldCombineToPostInc(SDNode * N,SDValue Ptr,SDNode * PtrUse,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)16436 static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
16437 SDValue &BasePtr, SDValue &Offset,
16438 ISD::MemIndexedMode &AM,
16439 SelectionDAG &DAG,
16440 const TargetLowering &TLI) {
16441 if (PtrUse == N ||
16442 (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
16443 return false;
16444
16445 if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
16446 return false;
16447
16448 // Don't create a indexed load / store with zero offset.
16449 if (isNullConstant(Offset))
16450 return false;
16451
16452 if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
16453 return false;
16454
16455 SmallPtrSet<const SDNode *, 32> Visited;
16456 for (SDNode *Use : BasePtr->uses()) {
16457 if (Use == Ptr.getNode())
16458 continue;
16459
16460 // No if there's a later user which could perform the index instead.
16461 if (isa<MemSDNode>(Use)) {
16462 bool IsLoad = true;
16463 bool IsMasked = false;
16464 SDValue OtherPtr;
16465 if (getCombineLoadStoreParts(Use, ISD::POST_INC, ISD::POST_DEC, IsLoad,
16466 IsMasked, OtherPtr, TLI)) {
16467 SmallVector<const SDNode *, 2> Worklist;
16468 Worklist.push_back(Use);
16469 if (SDNode::hasPredecessorHelper(N, Visited, Worklist))
16470 return false;
16471 }
16472 }
16473
16474 // If all the uses are load / store addresses, then don't do the
16475 // transformation.
16476 if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
16477 for (SDNode *UseUse : Use->uses())
16478 if (canFoldInAddressingMode(Use, UseUse, DAG, TLI))
16479 return false;
16480 }
16481 }
16482 return true;
16483 }
16484
getPostIndexedLoadStoreOp(SDNode * N,bool & IsLoad,bool & IsMasked,SDValue & Ptr,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)16485 static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
16486 bool &IsMasked, SDValue &Ptr,
16487 SDValue &BasePtr, SDValue &Offset,
16488 ISD::MemIndexedMode &AM,
16489 SelectionDAG &DAG,
16490 const TargetLowering &TLI) {
16491 if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad,
16492 IsMasked, Ptr, TLI) ||
16493 Ptr->hasOneUse())
16494 return nullptr;
16495
16496 // Try turning it into a post-indexed load / store except when
16497 // 1) All uses are load / store ops that use it as base ptr (and
16498 // it may be folded as addressing mmode).
16499 // 2) Op must be independent of N, i.e. Op is neither a predecessor
16500 // nor a successor of N. Otherwise, if Op is folded that would
16501 // create a cycle.
16502 for (SDNode *Op : Ptr->uses()) {
16503 // Check for #1.
16504 if (!shouldCombineToPostInc(N, Ptr, Op, BasePtr, Offset, AM, DAG, TLI))
16505 continue;
16506
16507 // Check for #2.
16508 SmallPtrSet<const SDNode *, 32> Visited;
16509 SmallVector<const SDNode *, 8> Worklist;
16510 // Ptr is predecessor to both N and Op.
16511 Visited.insert(Ptr.getNode());
16512 Worklist.push_back(N);
16513 Worklist.push_back(Op);
16514 if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) &&
16515 !SDNode::hasPredecessorHelper(Op, Visited, Worklist))
16516 return Op;
16517 }
16518 return nullptr;
16519 }
16520
16521 /// Try to combine a load/store with a add/sub of the base pointer node into a
16522 /// post-indexed load/store. The transformation folded the add/subtract into the
16523 /// new indexed load/store effectively and all of its uses are redirected to the
16524 /// new load/store.
CombineToPostIndexedLoadStore(SDNode * N)16525 bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
16526 if (Level < AfterLegalizeDAG)
16527 return false;
16528
16529 bool IsLoad = true;
16530 bool IsMasked = false;
16531 SDValue Ptr;
16532 SDValue BasePtr;
16533 SDValue Offset;
16534 ISD::MemIndexedMode AM = ISD::UNINDEXED;
16535 SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
16536 Offset, AM, DAG, TLI);
16537 if (!Op)
16538 return false;
16539
16540 SDValue Result;
16541 if (!IsMasked)
16542 Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
16543 Offset, AM)
16544 : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N),
16545 BasePtr, Offset, AM);
16546 else
16547 Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N),
16548 BasePtr, Offset, AM)
16549 : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N),
16550 BasePtr, Offset, AM);
16551 ++PostIndexedNodes;
16552 ++NodesCombined;
16553 LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
16554 Result.dump(&DAG); dbgs() << '\n');
16555 WorklistRemover DeadNodes(*this);
16556 if (IsLoad) {
16557 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
16558 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
16559 } else {
16560 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
16561 }
16562
16563 // Finally, since the node is now dead, remove it from the graph.
16564 deleteAndRecombine(N);
16565
16566 // Replace the uses of Use with uses of the updated base value.
16567 DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0),
16568 Result.getValue(IsLoad ? 1 : 0));
16569 deleteAndRecombine(Op);
16570 return true;
16571 }
16572
16573 /// Return the base-pointer arithmetic from an indexed \p LD.
SplitIndexingFromLoad(LoadSDNode * LD)16574 SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
16575 ISD::MemIndexedMode AM = LD->getAddressingMode();
16576 assert(AM != ISD::UNINDEXED);
16577 SDValue BP = LD->getOperand(1);
16578 SDValue Inc = LD->getOperand(2);
16579
16580 // Some backends use TargetConstants for load offsets, but don't expect
16581 // TargetConstants in general ADD nodes. We can convert these constants into
16582 // regular Constants (if the constant is not opaque).
16583 assert((Inc.getOpcode() != ISD::TargetConstant ||
16584 !cast<ConstantSDNode>(Inc)->isOpaque()) &&
16585 "Cannot split out indexing using opaque target constants");
16586 if (Inc.getOpcode() == ISD::TargetConstant) {
16587 ConstantSDNode *ConstInc = cast<ConstantSDNode>(Inc);
16588 Inc = DAG.getConstant(*ConstInc->getConstantIntValue(), SDLoc(Inc),
16589 ConstInc->getValueType(0));
16590 }
16591
16592 unsigned Opc =
16593 (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
16594 return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
16595 }
16596
numVectorEltsOrZero(EVT T)16597 static inline ElementCount numVectorEltsOrZero(EVT T) {
16598 return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(0);
16599 }
16600
getTruncatedStoreValue(StoreSDNode * ST,SDValue & Val)16601 bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
16602 Val = ST->getValue();
16603 EVT STType = Val.getValueType();
16604 EVT STMemType = ST->getMemoryVT();
16605 if (STType == STMemType)
16606 return true;
16607 if (isTypeLegal(STMemType))
16608 return false; // fail.
16609 if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
16610 TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
16611 Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
16612 return true;
16613 }
16614 if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
16615 STType.isInteger() && STMemType.isInteger()) {
16616 Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
16617 return true;
16618 }
16619 if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
16620 Val = DAG.getBitcast(STMemType, Val);
16621 return true;
16622 }
16623 return false; // fail.
16624 }
16625
extendLoadedValueToExtension(LoadSDNode * LD,SDValue & Val)16626 bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
16627 EVT LDMemType = LD->getMemoryVT();
16628 EVT LDType = LD->getValueType(0);
16629 assert(Val.getValueType() == LDMemType &&
16630 "Attempting to extend value of non-matching type");
16631 if (LDType == LDMemType)
16632 return true;
16633 if (LDMemType.isInteger() && LDType.isInteger()) {
16634 switch (LD->getExtensionType()) {
16635 case ISD::NON_EXTLOAD:
16636 Val = DAG.getBitcast(LDType, Val);
16637 return true;
16638 case ISD::EXTLOAD:
16639 Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
16640 return true;
16641 case ISD::SEXTLOAD:
16642 Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
16643 return true;
16644 case ISD::ZEXTLOAD:
16645 Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
16646 return true;
16647 }
16648 }
16649 return false;
16650 }
16651
ForwardStoreValueToDirectLoad(LoadSDNode * LD)16652 SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
16653 if (OptLevel == CodeGenOpt::None || !LD->isSimple())
16654 return SDValue();
16655 SDValue Chain = LD->getOperand(0);
16656 StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode());
16657 // TODO: Relax this restriction for unordered atomics (see D66309)
16658 if (!ST || !ST->isSimple())
16659 return SDValue();
16660
16661 EVT LDType = LD->getValueType(0);
16662 EVT LDMemType = LD->getMemoryVT();
16663 EVT STMemType = ST->getMemoryVT();
16664 EVT STType = ST->getValue().getValueType();
16665
16666 // There are two cases to consider here:
16667 // 1. The store is fixed width and the load is scalable. In this case we
16668 // don't know at compile time if the store completely envelops the load
16669 // so we abandon the optimisation.
16670 // 2. The store is scalable and the load is fixed width. We could
16671 // potentially support a limited number of cases here, but there has been
16672 // no cost-benefit analysis to prove it's worth it.
16673 bool LdStScalable = LDMemType.isScalableVector();
16674 if (LdStScalable != STMemType.isScalableVector())
16675 return SDValue();
16676
16677 // If we are dealing with scalable vectors on a big endian platform the
16678 // calculation of offsets below becomes trickier, since we do not know at
16679 // compile time the absolute size of the vector. Until we've done more
16680 // analysis on big-endian platforms it seems better to bail out for now.
16681 if (LdStScalable && DAG.getDataLayout().isBigEndian())
16682 return SDValue();
16683
16684 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
16685 BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG);
16686 int64_t Offset;
16687 if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
16688 return SDValue();
16689
16690 // Normalize for Endianness. After this Offset=0 will denote that the least
16691 // significant bit in the loaded value maps to the least significant bit in
16692 // the stored value). With Offset=n (for n > 0) the loaded value starts at the
16693 // n:th least significant byte of the stored value.
16694 if (DAG.getDataLayout().isBigEndian())
16695 Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedSize() -
16696 (int64_t)LDMemType.getStoreSizeInBits().getFixedSize()) /
16697 8 -
16698 Offset;
16699
16700 // Check that the stored value cover all bits that are loaded.
16701 bool STCoversLD;
16702
16703 TypeSize LdMemSize = LDMemType.getSizeInBits();
16704 TypeSize StMemSize = STMemType.getSizeInBits();
16705 if (LdStScalable)
16706 STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
16707 else
16708 STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedSize() <=
16709 StMemSize.getFixedSize());
16710
16711 auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
16712 if (LD->isIndexed()) {
16713 // Cannot handle opaque target constants and we must respect the user's
16714 // request not to split indexes from loads.
16715 if (!canSplitIdx(LD))
16716 return SDValue();
16717 SDValue Idx = SplitIndexingFromLoad(LD);
16718 SDValue Ops[] = {Val, Idx, Chain};
16719 return CombineTo(LD, Ops, 3);
16720 }
16721 return CombineTo(LD, Val, Chain);
16722 };
16723
16724 if (!STCoversLD)
16725 return SDValue();
16726
16727 // Memory as copy space (potentially masked).
16728 if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
16729 // Simple case: Direct non-truncating forwarding
16730 if (LDType.getSizeInBits() == LdMemSize)
16731 return ReplaceLd(LD, ST->getValue(), Chain);
16732 // Can we model the truncate and extension with an and mask?
16733 if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
16734 !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
16735 // Mask to size of LDMemType
16736 auto Mask =
16737 DAG.getConstant(APInt::getLowBitsSet(STType.getFixedSizeInBits(),
16738 StMemSize.getFixedSize()),
16739 SDLoc(ST), STType);
16740 auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
16741 return ReplaceLd(LD, Val, Chain);
16742 }
16743 }
16744
16745 // TODO: Deal with nonzero offset.
16746 if (LD->getBasePtr().isUndef() || Offset != 0)
16747 return SDValue();
16748 // Model necessary truncations / extenstions.
16749 SDValue Val;
16750 // Truncate Value To Stored Memory Size.
16751 do {
16752 if (!getTruncatedStoreValue(ST, Val))
16753 continue;
16754 if (!isTypeLegal(LDMemType))
16755 continue;
16756 if (STMemType != LDMemType) {
16757 // TODO: Support vectors? This requires extract_subvector/bitcast.
16758 if (!STMemType.isVector() && !LDMemType.isVector() &&
16759 STMemType.isInteger() && LDMemType.isInteger())
16760 Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
16761 else
16762 continue;
16763 }
16764 if (!extendLoadedValueToExtension(LD, Val))
16765 continue;
16766 return ReplaceLd(LD, Val, Chain);
16767 } while (false);
16768
16769 // On failure, cleanup dead nodes we may have created.
16770 if (Val->use_empty())
16771 deleteAndRecombine(Val.getNode());
16772 return SDValue();
16773 }
16774
visitLOAD(SDNode * N)16775 SDValue DAGCombiner::visitLOAD(SDNode *N) {
16776 LoadSDNode *LD = cast<LoadSDNode>(N);
16777 SDValue Chain = LD->getChain();
16778 SDValue Ptr = LD->getBasePtr();
16779
16780 // If load is not volatile and there are no uses of the loaded value (and
16781 // the updated indexed value in case of indexed loads), change uses of the
16782 // chain value into uses of the chain input (i.e. delete the dead load).
16783 // TODO: Allow this for unordered atomics (see D66309)
16784 if (LD->isSimple()) {
16785 if (N->getValueType(1) == MVT::Other) {
16786 // Unindexed loads.
16787 if (!N->hasAnyUseOfValue(0)) {
16788 // It's not safe to use the two value CombineTo variant here. e.g.
16789 // v1, chain2 = load chain1, loc
16790 // v2, chain3 = load chain2, loc
16791 // v3 = add v2, c
16792 // Now we replace use of chain2 with chain1. This makes the second load
16793 // isomorphic to the one we are deleting, and thus makes this load live.
16794 LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
16795 dbgs() << "\nWith chain: "; Chain.dump(&DAG);
16796 dbgs() << "\n");
16797 WorklistRemover DeadNodes(*this);
16798 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
16799 AddUsersToWorklist(Chain.getNode());
16800 if (N->use_empty())
16801 deleteAndRecombine(N);
16802
16803 return SDValue(N, 0); // Return N so it doesn't get rechecked!
16804 }
16805 } else {
16806 // Indexed loads.
16807 assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
16808
16809 // If this load has an opaque TargetConstant offset, then we cannot split
16810 // the indexing into an add/sub directly (that TargetConstant may not be
16811 // valid for a different type of node, and we cannot convert an opaque
16812 // target constant into a regular constant).
16813 bool CanSplitIdx = canSplitIdx(LD);
16814
16815 if (!N->hasAnyUseOfValue(0) && (CanSplitIdx || !N->hasAnyUseOfValue(1))) {
16816 SDValue Undef = DAG.getUNDEF(N->getValueType(0));
16817 SDValue Index;
16818 if (N->hasAnyUseOfValue(1) && CanSplitIdx) {
16819 Index = SplitIndexingFromLoad(LD);
16820 // Try to fold the base pointer arithmetic into subsequent loads and
16821 // stores.
16822 AddUsersToWorklist(N);
16823 } else
16824 Index = DAG.getUNDEF(N->getValueType(1));
16825 LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
16826 dbgs() << "\nWith: "; Undef.dump(&DAG);
16827 dbgs() << " and 2 other values\n");
16828 WorklistRemover DeadNodes(*this);
16829 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
16830 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
16831 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 2), Chain);
16832 deleteAndRecombine(N);
16833 return SDValue(N, 0); // Return N so it doesn't get rechecked!
16834 }
16835 }
16836 }
16837
16838 // If this load is directly stored, replace the load value with the stored
16839 // value.
16840 if (auto V = ForwardStoreValueToDirectLoad(LD))
16841 return V;
16842
16843 // Try to infer better alignment information than the load already has.
16844 if (OptLevel != CodeGenOpt::None && LD->isUnindexed() && !LD->isAtomic()) {
16845 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
16846 if (*Alignment > LD->getAlign() &&
16847 isAligned(*Alignment, LD->getSrcValueOffset())) {
16848 SDValue NewLoad = DAG.getExtLoad(
16849 LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
16850 LD->getPointerInfo(), LD->getMemoryVT(), *Alignment,
16851 LD->getMemOperand()->getFlags(), LD->getAAInfo());
16852 // NewLoad will always be N as we are only refining the alignment
16853 assert(NewLoad.getNode() == N);
16854 (void)NewLoad;
16855 }
16856 }
16857 }
16858
16859 if (LD->isUnindexed()) {
16860 // Walk up chain skipping non-aliasing memory nodes.
16861 SDValue BetterChain = FindBetterChain(LD, Chain);
16862
16863 // If there is a better chain.
16864 if (Chain != BetterChain) {
16865 SDValue ReplLoad;
16866
16867 // Replace the chain to void dependency.
16868 if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
16869 ReplLoad = DAG.getLoad(N->getValueType(0), SDLoc(LD),
16870 BetterChain, Ptr, LD->getMemOperand());
16871 } else {
16872 ReplLoad = DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD),
16873 LD->getValueType(0),
16874 BetterChain, Ptr, LD->getMemoryVT(),
16875 LD->getMemOperand());
16876 }
16877
16878 // Create token factor to keep old chain connected.
16879 SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
16880 MVT::Other, Chain, ReplLoad.getValue(1));
16881
16882 // Replace uses with load result and token factor
16883 return CombineTo(N, ReplLoad.getValue(0), Token);
16884 }
16885 }
16886
16887 // Try transforming N to an indexed load.
16888 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
16889 return SDValue(N, 0);
16890
16891 // Try to slice up N to more direct loads if the slices are mapped to
16892 // different register banks or pairing can take place.
16893 if (SliceUpLoad(N))
16894 return SDValue(N, 0);
16895
16896 return SDValue();
16897 }
16898
16899 namespace {
16900
16901 /// Helper structure used to slice a load in smaller loads.
16902 /// Basically a slice is obtained from the following sequence:
16903 /// Origin = load Ty1, Base
16904 /// Shift = srl Ty1 Origin, CstTy Amount
16905 /// Inst = trunc Shift to Ty2
16906 ///
16907 /// Then, it will be rewritten into:
16908 /// Slice = load SliceTy, Base + SliceOffset
16909 /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
16910 ///
16911 /// SliceTy is deduced from the number of bits that are actually used to
16912 /// build Inst.
16913 struct LoadedSlice {
16914 /// Helper structure used to compute the cost of a slice.
16915 struct Cost {
16916 /// Are we optimizing for code size.
16917 bool ForCodeSize = false;
16918
16919 /// Various cost.
16920 unsigned Loads = 0;
16921 unsigned Truncates = 0;
16922 unsigned CrossRegisterBanksCopies = 0;
16923 unsigned ZExts = 0;
16924 unsigned Shift = 0;
16925
Cost__anon54f00e403b11::LoadedSlice::Cost16926 explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
16927
16928 /// Get the cost of one isolated slice.
Cost__anon54f00e403b11::LoadedSlice::Cost16929 Cost(const LoadedSlice &LS, bool ForCodeSize)
16930 : ForCodeSize(ForCodeSize), Loads(1) {
16931 EVT TruncType = LS.Inst->getValueType(0);
16932 EVT LoadedType = LS.getLoadedType();
16933 if (TruncType != LoadedType &&
16934 !LS.DAG->getTargetLoweringInfo().isZExtFree(LoadedType, TruncType))
16935 ZExts = 1;
16936 }
16937
16938 /// Account for slicing gain in the current cost.
16939 /// Slicing provide a few gains like removing a shift or a
16940 /// truncate. This method allows to grow the cost of the original
16941 /// load with the gain from this slice.
addSliceGain__anon54f00e403b11::LoadedSlice::Cost16942 void addSliceGain(const LoadedSlice &LS) {
16943 // Each slice saves a truncate.
16944 const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
16945 if (!TLI.isTruncateFree(LS.Inst->getOperand(0).getValueType(),
16946 LS.Inst->getValueType(0)))
16947 ++Truncates;
16948 // If there is a shift amount, this slice gets rid of it.
16949 if (LS.Shift)
16950 ++Shift;
16951 // If this slice can merge a cross register bank copy, account for it.
16952 if (LS.canMergeExpensiveCrossRegisterBankCopy())
16953 ++CrossRegisterBanksCopies;
16954 }
16955
operator +=__anon54f00e403b11::LoadedSlice::Cost16956 Cost &operator+=(const Cost &RHS) {
16957 Loads += RHS.Loads;
16958 Truncates += RHS.Truncates;
16959 CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
16960 ZExts += RHS.ZExts;
16961 Shift += RHS.Shift;
16962 return *this;
16963 }
16964
operator ==__anon54f00e403b11::LoadedSlice::Cost16965 bool operator==(const Cost &RHS) const {
16966 return Loads == RHS.Loads && Truncates == RHS.Truncates &&
16967 CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
16968 ZExts == RHS.ZExts && Shift == RHS.Shift;
16969 }
16970
operator !=__anon54f00e403b11::LoadedSlice::Cost16971 bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
16972
operator <__anon54f00e403b11::LoadedSlice::Cost16973 bool operator<(const Cost &RHS) const {
16974 // Assume cross register banks copies are as expensive as loads.
16975 // FIXME: Do we want some more target hooks?
16976 unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
16977 unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
16978 // Unless we are optimizing for code size, consider the
16979 // expensive operation first.
16980 if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
16981 return ExpensiveOpsLHS < ExpensiveOpsRHS;
16982 return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
16983 (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
16984 }
16985
operator >__anon54f00e403b11::LoadedSlice::Cost16986 bool operator>(const Cost &RHS) const { return RHS < *this; }
16987
operator <=__anon54f00e403b11::LoadedSlice::Cost16988 bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
16989
operator >=__anon54f00e403b11::LoadedSlice::Cost16990 bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
16991 };
16992
16993 // The last instruction that represent the slice. This should be a
16994 // truncate instruction.
16995 SDNode *Inst;
16996
16997 // The original load instruction.
16998 LoadSDNode *Origin;
16999
17000 // The right shift amount in bits from the original load.
17001 unsigned Shift;
17002
17003 // The DAG from which Origin came from.
17004 // This is used to get some contextual information about legal types, etc.
17005 SelectionDAG *DAG;
17006
LoadedSlice__anon54f00e403b11::LoadedSlice17007 LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
17008 unsigned Shift = 0, SelectionDAG *DAG = nullptr)
17009 : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
17010
17011 /// Get the bits used in a chunk of bits \p BitWidth large.
17012 /// \return Result is \p BitWidth and has used bits set to 1 and
17013 /// not used bits set to 0.
getUsedBits__anon54f00e403b11::LoadedSlice17014 APInt getUsedBits() const {
17015 // Reproduce the trunc(lshr) sequence:
17016 // - Start from the truncated value.
17017 // - Zero extend to the desired bit width.
17018 // - Shift left.
17019 assert(Origin && "No original load to compare against.");
17020 unsigned BitWidth = Origin->getValueSizeInBits(0);
17021 assert(Inst && "This slice is not bound to an instruction");
17022 assert(Inst->getValueSizeInBits(0) <= BitWidth &&
17023 "Extracted slice is bigger than the whole type!");
17024 APInt UsedBits(Inst->getValueSizeInBits(0), 0);
17025 UsedBits.setAllBits();
17026 UsedBits = UsedBits.zext(BitWidth);
17027 UsedBits <<= Shift;
17028 return UsedBits;
17029 }
17030
17031 /// Get the size of the slice to be loaded in bytes.
getLoadedSize__anon54f00e403b11::LoadedSlice17032 unsigned getLoadedSize() const {
17033 unsigned SliceSize = getUsedBits().countPopulation();
17034 assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
17035 return SliceSize / 8;
17036 }
17037
17038 /// Get the type that will be loaded for this slice.
17039 /// Note: This may not be the final type for the slice.
getLoadedType__anon54f00e403b11::LoadedSlice17040 EVT getLoadedType() const {
17041 assert(DAG && "Missing context");
17042 LLVMContext &Ctxt = *DAG->getContext();
17043 return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
17044 }
17045
17046 /// Get the alignment of the load used for this slice.
getAlign__anon54f00e403b11::LoadedSlice17047 Align getAlign() const {
17048 Align Alignment = Origin->getAlign();
17049 uint64_t Offset = getOffsetFromBase();
17050 if (Offset != 0)
17051 Alignment = commonAlignment(Alignment, Alignment.value() + Offset);
17052 return Alignment;
17053 }
17054
17055 /// Check if this slice can be rewritten with legal operations.
isLegal__anon54f00e403b11::LoadedSlice17056 bool isLegal() const {
17057 // An invalid slice is not legal.
17058 if (!Origin || !Inst || !DAG)
17059 return false;
17060
17061 // Offsets are for indexed load only, we do not handle that.
17062 if (!Origin->getOffset().isUndef())
17063 return false;
17064
17065 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
17066
17067 // Check that the type is legal.
17068 EVT SliceType = getLoadedType();
17069 if (!TLI.isTypeLegal(SliceType))
17070 return false;
17071
17072 // Check that the load is legal for this type.
17073 if (!TLI.isOperationLegal(ISD::LOAD, SliceType))
17074 return false;
17075
17076 // Check that the offset can be computed.
17077 // 1. Check its type.
17078 EVT PtrType = Origin->getBasePtr().getValueType();
17079 if (PtrType == MVT::Untyped || PtrType.isExtended())
17080 return false;
17081
17082 // 2. Check that it fits in the immediate.
17083 if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
17084 return false;
17085
17086 // 3. Check that the computation is legal.
17087 if (!TLI.isOperationLegal(ISD::ADD, PtrType))
17088 return false;
17089
17090 // Check that the zext is legal if it needs one.
17091 EVT TruncateType = Inst->getValueType(0);
17092 if (TruncateType != SliceType &&
17093 !TLI.isOperationLegal(ISD::ZERO_EXTEND, TruncateType))
17094 return false;
17095
17096 return true;
17097 }
17098
17099 /// Get the offset in bytes of this slice in the original chunk of
17100 /// bits.
17101 /// \pre DAG != nullptr.
getOffsetFromBase__anon54f00e403b11::LoadedSlice17102 uint64_t getOffsetFromBase() const {
17103 assert(DAG && "Missing context.");
17104 bool IsBigEndian = DAG->getDataLayout().isBigEndian();
17105 assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
17106 uint64_t Offset = Shift / 8;
17107 unsigned TySizeInBytes = Origin->getValueSizeInBits(0) / 8;
17108 assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
17109 "The size of the original loaded type is not a multiple of a"
17110 " byte.");
17111 // If Offset is bigger than TySizeInBytes, it means we are loading all
17112 // zeros. This should have been optimized before in the process.
17113 assert(TySizeInBytes > Offset &&
17114 "Invalid shift amount for given loaded size");
17115 if (IsBigEndian)
17116 Offset = TySizeInBytes - Offset - getLoadedSize();
17117 return Offset;
17118 }
17119
17120 /// Generate the sequence of instructions to load the slice
17121 /// represented by this object and redirect the uses of this slice to
17122 /// this new sequence of instructions.
17123 /// \pre this->Inst && this->Origin are valid Instructions and this
17124 /// object passed the legal check: LoadedSlice::isLegal returned true.
17125 /// \return The last instruction of the sequence used to load the slice.
loadSlice__anon54f00e403b11::LoadedSlice17126 SDValue loadSlice() const {
17127 assert(Inst && Origin && "Unable to replace a non-existing slice.");
17128 const SDValue &OldBaseAddr = Origin->getBasePtr();
17129 SDValue BaseAddr = OldBaseAddr;
17130 // Get the offset in that chunk of bytes w.r.t. the endianness.
17131 int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
17132 assert(Offset >= 0 && "Offset too big to fit in int64_t!");
17133 if (Offset) {
17134 // BaseAddr = BaseAddr + Offset.
17135 EVT ArithType = BaseAddr.getValueType();
17136 SDLoc DL(Origin);
17137 BaseAddr = DAG->getNode(ISD::ADD, DL, ArithType, BaseAddr,
17138 DAG->getConstant(Offset, DL, ArithType));
17139 }
17140
17141 // Create the type of the loaded slice according to its size.
17142 EVT SliceType = getLoadedType();
17143
17144 // Create the load for the slice.
17145 SDValue LastInst =
17146 DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
17147 Origin->getPointerInfo().getWithOffset(Offset), getAlign(),
17148 Origin->getMemOperand()->getFlags());
17149 // If the final type is not the same as the loaded type, this means that
17150 // we have to pad with zero. Create a zero extend for that.
17151 EVT FinalType = Inst->getValueType(0);
17152 if (SliceType != FinalType)
17153 LastInst =
17154 DAG->getNode(ISD::ZERO_EXTEND, SDLoc(LastInst), FinalType, LastInst);
17155 return LastInst;
17156 }
17157
17158 /// Check if this slice can be merged with an expensive cross register
17159 /// bank copy. E.g.,
17160 /// i = load i32
17161 /// f = bitcast i32 i to float
canMergeExpensiveCrossRegisterBankCopy__anon54f00e403b11::LoadedSlice17162 bool canMergeExpensiveCrossRegisterBankCopy() const {
17163 if (!Inst || !Inst->hasOneUse())
17164 return false;
17165 SDNode *Use = *Inst->use_begin();
17166 if (Use->getOpcode() != ISD::BITCAST)
17167 return false;
17168 assert(DAG && "Missing context");
17169 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
17170 EVT ResVT = Use->getValueType(0);
17171 const TargetRegisterClass *ResRC =
17172 TLI.getRegClassFor(ResVT.getSimpleVT(), Use->isDivergent());
17173 const TargetRegisterClass *ArgRC =
17174 TLI.getRegClassFor(Use->getOperand(0).getValueType().getSimpleVT(),
17175 Use->getOperand(0)->isDivergent());
17176 if (ArgRC == ResRC || !TLI.isOperationLegal(ISD::LOAD, ResVT))
17177 return false;
17178
17179 // At this point, we know that we perform a cross-register-bank copy.
17180 // Check if it is expensive.
17181 const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
17182 // Assume bitcasts are cheap, unless both register classes do not
17183 // explicitly share a common sub class.
17184 if (!TRI || TRI->getCommonSubClass(ArgRC, ResRC))
17185 return false;
17186
17187 // Check if it will be merged with the load.
17188 // 1. Check the alignment / fast memory access constraint.
17189 bool IsFast = false;
17190 if (!TLI.allowsMemoryAccess(*DAG->getContext(), DAG->getDataLayout(), ResVT,
17191 Origin->getAddressSpace(), getAlign(),
17192 Origin->getMemOperand()->getFlags(), &IsFast) ||
17193 !IsFast)
17194 return false;
17195
17196 // 2. Check that the load is a legal operation for that type.
17197 if (!TLI.isOperationLegal(ISD::LOAD, ResVT))
17198 return false;
17199
17200 // 3. Check that we do not have a zext in the way.
17201 if (Inst->getValueType(0) != getLoadedType())
17202 return false;
17203
17204 return true;
17205 }
17206 };
17207
17208 } // end anonymous namespace
17209
17210 /// Check that all bits set in \p UsedBits form a dense region, i.e.,
17211 /// \p UsedBits looks like 0..0 1..1 0..0.
areUsedBitsDense(const APInt & UsedBits)17212 static bool areUsedBitsDense(const APInt &UsedBits) {
17213 // If all the bits are one, this is dense!
17214 if (UsedBits.isAllOnes())
17215 return true;
17216
17217 // Get rid of the unused bits on the right.
17218 APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countTrailingZeros());
17219 // Get rid of the unused bits on the left.
17220 if (NarrowedUsedBits.countLeadingZeros())
17221 NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits());
17222 // Check that the chunk of bits is completely used.
17223 return NarrowedUsedBits.isAllOnes();
17224 }
17225
17226 /// Check whether or not \p First and \p Second are next to each other
17227 /// in memory. This means that there is no hole between the bits loaded
17228 /// by \p First and the bits loaded by \p Second.
areSlicesNextToEachOther(const LoadedSlice & First,const LoadedSlice & Second)17229 static bool areSlicesNextToEachOther(const LoadedSlice &First,
17230 const LoadedSlice &Second) {
17231 assert(First.Origin == Second.Origin && First.Origin &&
17232 "Unable to match different memory origins.");
17233 APInt UsedBits = First.getUsedBits();
17234 assert((UsedBits & Second.getUsedBits()) == 0 &&
17235 "Slices are not supposed to overlap.");
17236 UsedBits |= Second.getUsedBits();
17237 return areUsedBitsDense(UsedBits);
17238 }
17239
17240 /// Adjust the \p GlobalLSCost according to the target
17241 /// paring capabilities and the layout of the slices.
17242 /// \pre \p GlobalLSCost should account for at least as many loads as
17243 /// there is in the slices in \p LoadedSlices.
adjustCostForPairing(SmallVectorImpl<LoadedSlice> & LoadedSlices,LoadedSlice::Cost & GlobalLSCost)17244 static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
17245 LoadedSlice::Cost &GlobalLSCost) {
17246 unsigned NumberOfSlices = LoadedSlices.size();
17247 // If there is less than 2 elements, no pairing is possible.
17248 if (NumberOfSlices < 2)
17249 return;
17250
17251 // Sort the slices so that elements that are likely to be next to each
17252 // other in memory are next to each other in the list.
17253 llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
17254 assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
17255 return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
17256 });
17257 const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
17258 // First (resp. Second) is the first (resp. Second) potentially candidate
17259 // to be placed in a paired load.
17260 const LoadedSlice *First = nullptr;
17261 const LoadedSlice *Second = nullptr;
17262 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
17263 // Set the beginning of the pair.
17264 First = Second) {
17265 Second = &LoadedSlices[CurrSlice];
17266
17267 // If First is NULL, it means we start a new pair.
17268 // Get to the next slice.
17269 if (!First)
17270 continue;
17271
17272 EVT LoadedType = First->getLoadedType();
17273
17274 // If the types of the slices are different, we cannot pair them.
17275 if (LoadedType != Second->getLoadedType())
17276 continue;
17277
17278 // Check if the target supplies paired loads for this type.
17279 Align RequiredAlignment;
17280 if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
17281 // move to the next pair, this type is hopeless.
17282 Second = nullptr;
17283 continue;
17284 }
17285 // Check if we meet the alignment requirement.
17286 if (First->getAlign() < RequiredAlignment)
17287 continue;
17288
17289 // Check that both loads are next to each other in memory.
17290 if (!areSlicesNextToEachOther(*First, *Second))
17291 continue;
17292
17293 assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
17294 --GlobalLSCost.Loads;
17295 // Move to the next pair.
17296 Second = nullptr;
17297 }
17298 }
17299
17300 /// Check the profitability of all involved LoadedSlice.
17301 /// Currently, it is considered profitable if there is exactly two
17302 /// involved slices (1) which are (2) next to each other in memory, and
17303 /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
17304 ///
17305 /// Note: The order of the elements in \p LoadedSlices may be modified, but not
17306 /// the elements themselves.
17307 ///
17308 /// FIXME: When the cost model will be mature enough, we can relax
17309 /// constraints (1) and (2).
isSlicingProfitable(SmallVectorImpl<LoadedSlice> & LoadedSlices,const APInt & UsedBits,bool ForCodeSize)17310 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
17311 const APInt &UsedBits, bool ForCodeSize) {
17312 unsigned NumberOfSlices = LoadedSlices.size();
17313 if (StressLoadSlicing)
17314 return NumberOfSlices > 1;
17315
17316 // Check (1).
17317 if (NumberOfSlices != 2)
17318 return false;
17319
17320 // Check (2).
17321 if (!areUsedBitsDense(UsedBits))
17322 return false;
17323
17324 // Check (3).
17325 LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
17326 // The original code has one big load.
17327 OrigCost.Loads = 1;
17328 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
17329 const LoadedSlice &LS = LoadedSlices[CurrSlice];
17330 // Accumulate the cost of all the slices.
17331 LoadedSlice::Cost SliceCost(LS, ForCodeSize);
17332 GlobalSlicingCost += SliceCost;
17333
17334 // Account as cost in the original configuration the gain obtained
17335 // with the current slices.
17336 OrigCost.addSliceGain(LS);
17337 }
17338
17339 // If the target supports paired load, adjust the cost accordingly.
17340 adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
17341 return OrigCost > GlobalSlicingCost;
17342 }
17343
17344 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
17345 /// operations, split it in the various pieces being extracted.
17346 ///
17347 /// This sort of thing is introduced by SROA.
17348 /// This slicing takes care not to insert overlapping loads.
17349 /// \pre LI is a simple load (i.e., not an atomic or volatile load).
SliceUpLoad(SDNode * N)17350 bool DAGCombiner::SliceUpLoad(SDNode *N) {
17351 if (Level < AfterLegalizeDAG)
17352 return false;
17353
17354 LoadSDNode *LD = cast<LoadSDNode>(N);
17355 if (!LD->isSimple() || !ISD::isNormalLoad(LD) ||
17356 !LD->getValueType(0).isInteger())
17357 return false;
17358
17359 // The algorithm to split up a load of a scalable vector into individual
17360 // elements currently requires knowing the length of the loaded type,
17361 // so will need adjusting to work on scalable vectors.
17362 if (LD->getValueType(0).isScalableVector())
17363 return false;
17364
17365 // Keep track of already used bits to detect overlapping values.
17366 // In that case, we will just abort the transformation.
17367 APInt UsedBits(LD->getValueSizeInBits(0), 0);
17368
17369 SmallVector<LoadedSlice, 4> LoadedSlices;
17370
17371 // Check if this load is used as several smaller chunks of bits.
17372 // Basically, look for uses in trunc or trunc(lshr) and record a new chain
17373 // of computation for each trunc.
17374 for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
17375 UI != UIEnd; ++UI) {
17376 // Skip the uses of the chain.
17377 if (UI.getUse().getResNo() != 0)
17378 continue;
17379
17380 SDNode *User = *UI;
17381 unsigned Shift = 0;
17382
17383 // Check if this is a trunc(lshr).
17384 if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
17385 isa<ConstantSDNode>(User->getOperand(1))) {
17386 Shift = User->getConstantOperandVal(1);
17387 User = *User->use_begin();
17388 }
17389
17390 // At this point, User is a Truncate, iff we encountered, trunc or
17391 // trunc(lshr).
17392 if (User->getOpcode() != ISD::TRUNCATE)
17393 return false;
17394
17395 // The width of the type must be a power of 2 and greater than 8-bits.
17396 // Otherwise the load cannot be represented in LLVM IR.
17397 // Moreover, if we shifted with a non-8-bits multiple, the slice
17398 // will be across several bytes. We do not support that.
17399 unsigned Width = User->getValueSizeInBits(0);
17400 if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
17401 return false;
17402
17403 // Build the slice for this chain of computations.
17404 LoadedSlice LS(User, LD, Shift, &DAG);
17405 APInt CurrentUsedBits = LS.getUsedBits();
17406
17407 // Check if this slice overlaps with another.
17408 if ((CurrentUsedBits & UsedBits) != 0)
17409 return false;
17410 // Update the bits used globally.
17411 UsedBits |= CurrentUsedBits;
17412
17413 // Check if the new slice would be legal.
17414 if (!LS.isLegal())
17415 return false;
17416
17417 // Record the slice.
17418 LoadedSlices.push_back(LS);
17419 }
17420
17421 // Abort slicing if it does not seem to be profitable.
17422 if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
17423 return false;
17424
17425 ++SlicedLoads;
17426
17427 // Rewrite each chain to use an independent load.
17428 // By construction, each chain can be represented by a unique load.
17429
17430 // Prepare the argument for the new token factor for all the slices.
17431 SmallVector<SDValue, 8> ArgChains;
17432 for (const LoadedSlice &LS : LoadedSlices) {
17433 SDValue SliceInst = LS.loadSlice();
17434 CombineTo(LS.Inst, SliceInst, true);
17435 if (SliceInst.getOpcode() != ISD::LOAD)
17436 SliceInst = SliceInst.getOperand(0);
17437 assert(SliceInst->getOpcode() == ISD::LOAD &&
17438 "It takes more than a zext to get to the loaded slice!!");
17439 ArgChains.push_back(SliceInst.getValue(1));
17440 }
17441
17442 SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
17443 ArgChains);
17444 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
17445 AddToWorklist(Chain.getNode());
17446 return true;
17447 }
17448
17449 /// Check to see if V is (and load (ptr), imm), where the load is having
17450 /// specific bytes cleared out. If so, return the byte size being masked out
17451 /// and the shift amount.
17452 static std::pair<unsigned, unsigned>
CheckForMaskedLoad(SDValue V,SDValue Ptr,SDValue Chain)17453 CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
17454 std::pair<unsigned, unsigned> Result(0, 0);
17455
17456 // Check for the structure we're looking for.
17457 if (V->getOpcode() != ISD::AND ||
17458 !isa<ConstantSDNode>(V->getOperand(1)) ||
17459 !ISD::isNormalLoad(V->getOperand(0).getNode()))
17460 return Result;
17461
17462 // Check the chain and pointer.
17463 LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
17464 if (LD->getBasePtr() != Ptr) return Result; // Not from same pointer.
17465
17466 // This only handles simple types.
17467 if (V.getValueType() != MVT::i16 &&
17468 V.getValueType() != MVT::i32 &&
17469 V.getValueType() != MVT::i64)
17470 return Result;
17471
17472 // Check the constant mask. Invert it so that the bits being masked out are
17473 // 0 and the bits being kept are 1. Use getSExtValue so that leading bits
17474 // follow the sign bit for uniformity.
17475 uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue();
17476 unsigned NotMaskLZ = countLeadingZeros(NotMask);
17477 if (NotMaskLZ & 7) return Result; // Must be multiple of a byte.
17478 unsigned NotMaskTZ = countTrailingZeros(NotMask);
17479 if (NotMaskTZ & 7) return Result; // Must be multiple of a byte.
17480 if (NotMaskLZ == 64) return Result; // All zero mask.
17481
17482 // See if we have a continuous run of bits. If so, we have 0*1+0*
17483 if (countTrailingOnes(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
17484 return Result;
17485
17486 // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
17487 if (V.getValueType() != MVT::i64 && NotMaskLZ)
17488 NotMaskLZ -= 64-V.getValueSizeInBits();
17489
17490 unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
17491 switch (MaskedBytes) {
17492 case 1:
17493 case 2:
17494 case 4: break;
17495 default: return Result; // All one mask, or 5-byte mask.
17496 }
17497
17498 // Verify that the first bit starts at a multiple of mask so that the access
17499 // is aligned the same as the access width.
17500 if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
17501
17502 // For narrowing to be valid, it must be the case that the load the
17503 // immediately preceding memory operation before the store.
17504 if (LD == Chain.getNode())
17505 ; // ok.
17506 else if (Chain->getOpcode() == ISD::TokenFactor &&
17507 SDValue(LD, 1).hasOneUse()) {
17508 // LD has only 1 chain use so they are no indirect dependencies.
17509 if (!LD->isOperandOf(Chain.getNode()))
17510 return Result;
17511 } else
17512 return Result; // Fail.
17513
17514 Result.first = MaskedBytes;
17515 Result.second = NotMaskTZ/8;
17516 return Result;
17517 }
17518
17519 /// Check to see if IVal is something that provides a value as specified by
17520 /// MaskInfo. If so, replace the specified store with a narrower store of
17521 /// truncated IVal.
17522 static SDValue
ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned,unsigned> & MaskInfo,SDValue IVal,StoreSDNode * St,DAGCombiner * DC)17523 ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
17524 SDValue IVal, StoreSDNode *St,
17525 DAGCombiner *DC) {
17526 unsigned NumBytes = MaskInfo.first;
17527 unsigned ByteShift = MaskInfo.second;
17528 SelectionDAG &DAG = DC->getDAG();
17529
17530 // Check to see if IVal is all zeros in the part being masked in by the 'or'
17531 // that uses this. If not, this is not a replacement.
17532 APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(),
17533 ByteShift*8, (ByteShift+NumBytes)*8);
17534 if (!DAG.MaskedValueIsZero(IVal, Mask)) return SDValue();
17535
17536 // Check that it is legal on the target to do this. It is legal if the new
17537 // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
17538 // legalization. If the source type is legal, but the store type isn't, see
17539 // if we can use a truncating store.
17540 MVT VT = MVT::getIntegerVT(NumBytes * 8);
17541 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
17542 bool UseTruncStore;
17543 if (DC->isTypeLegal(VT))
17544 UseTruncStore = false;
17545 else if (TLI.isTypeLegal(IVal.getValueType()) &&
17546 TLI.isTruncStoreLegal(IVal.getValueType(), VT))
17547 UseTruncStore = true;
17548 else
17549 return SDValue();
17550 // Check that the target doesn't think this is a bad idea.
17551 if (St->getMemOperand() &&
17552 !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
17553 *St->getMemOperand()))
17554 return SDValue();
17555
17556 // Okay, we can do this! Replace the 'St' store with a store of IVal that is
17557 // shifted by ByteShift and truncated down to NumBytes.
17558 if (ByteShift) {
17559 SDLoc DL(IVal);
17560 IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
17561 DAG.getConstant(ByteShift*8, DL,
17562 DC->getShiftAmountTy(IVal.getValueType())));
17563 }
17564
17565 // Figure out the offset for the store and the alignment of the access.
17566 unsigned StOffset;
17567 if (DAG.getDataLayout().isLittleEndian())
17568 StOffset = ByteShift;
17569 else
17570 StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
17571
17572 SDValue Ptr = St->getBasePtr();
17573 if (StOffset) {
17574 SDLoc DL(IVal);
17575 Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(StOffset), DL);
17576 }
17577
17578 ++OpsNarrowed;
17579 if (UseTruncStore)
17580 return DAG.getTruncStore(St->getChain(), SDLoc(St), IVal, Ptr,
17581 St->getPointerInfo().getWithOffset(StOffset),
17582 VT, St->getOriginalAlign());
17583
17584 // Truncate down to the new size.
17585 IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
17586
17587 return DAG
17588 .getStore(St->getChain(), SDLoc(St), IVal, Ptr,
17589 St->getPointerInfo().getWithOffset(StOffset),
17590 St->getOriginalAlign());
17591 }
17592
17593 /// Look for sequence of load / op / store where op is one of 'or', 'xor', and
17594 /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
17595 /// narrowing the load and store if it would end up being a win for performance
17596 /// or code size.
ReduceLoadOpStoreWidth(SDNode * N)17597 SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
17598 StoreSDNode *ST = cast<StoreSDNode>(N);
17599 if (!ST->isSimple())
17600 return SDValue();
17601
17602 SDValue Chain = ST->getChain();
17603 SDValue Value = ST->getValue();
17604 SDValue Ptr = ST->getBasePtr();
17605 EVT VT = Value.getValueType();
17606
17607 if (ST->isTruncatingStore() || VT.isVector())
17608 return SDValue();
17609
17610 unsigned Opc = Value.getOpcode();
17611
17612 if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
17613 !Value.hasOneUse())
17614 return SDValue();
17615
17616 // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
17617 // is a byte mask indicating a consecutive number of bytes, check to see if
17618 // Y is known to provide just those bytes. If so, we try to replace the
17619 // load + replace + store sequence with a single (narrower) store, which makes
17620 // the load dead.
17621 if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
17622 std::pair<unsigned, unsigned> MaskedLoad;
17623 MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain);
17624 if (MaskedLoad.first)
17625 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
17626 Value.getOperand(1), ST,this))
17627 return NewST;
17628
17629 // Or is commutative, so try swapping X and Y.
17630 MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain);
17631 if (MaskedLoad.first)
17632 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
17633 Value.getOperand(0), ST,this))
17634 return NewST;
17635 }
17636
17637 if (!EnableReduceLoadOpStoreWidth)
17638 return SDValue();
17639
17640 if (Value.getOperand(1).getOpcode() != ISD::Constant)
17641 return SDValue();
17642
17643 SDValue N0 = Value.getOperand(0);
17644 if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
17645 Chain == SDValue(N0.getNode(), 1)) {
17646 LoadSDNode *LD = cast<LoadSDNode>(N0);
17647 if (LD->getBasePtr() != Ptr ||
17648 LD->getPointerInfo().getAddrSpace() !=
17649 ST->getPointerInfo().getAddrSpace())
17650 return SDValue();
17651
17652 // Find the type to narrow it the load / op / store to.
17653 SDValue N1 = Value.getOperand(1);
17654 unsigned BitWidth = N1.getValueSizeInBits();
17655 APInt Imm = cast<ConstantSDNode>(N1)->getAPIntValue();
17656 if (Opc == ISD::AND)
17657 Imm ^= APInt::getAllOnes(BitWidth);
17658 if (Imm == 0 || Imm.isAllOnes())
17659 return SDValue();
17660 unsigned ShAmt = Imm.countTrailingZeros();
17661 unsigned MSB = BitWidth - Imm.countLeadingZeros() - 1;
17662 unsigned NewBW = NextPowerOf2(MSB - ShAmt);
17663 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
17664 // The narrowing should be profitable, the load/store operation should be
17665 // legal (or custom) and the store size should be equal to the NewVT width.
17666 while (NewBW < BitWidth &&
17667 (NewVT.getStoreSizeInBits() != NewBW ||
17668 !TLI.isOperationLegalOrCustom(Opc, NewVT) ||
17669 !TLI.isNarrowingProfitable(VT, NewVT))) {
17670 NewBW = NextPowerOf2(NewBW);
17671 NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
17672 }
17673 if (NewBW >= BitWidth)
17674 return SDValue();
17675
17676 // If the lsb changed does not start at the type bitwidth boundary,
17677 // start at the previous one.
17678 if (ShAmt % NewBW)
17679 ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
17680 APInt Mask = APInt::getBitsSet(BitWidth, ShAmt,
17681 std::min(BitWidth, ShAmt + NewBW));
17682 if ((Imm & Mask) == Imm) {
17683 APInt NewImm = (Imm & Mask).lshr(ShAmt).trunc(NewBW);
17684 if (Opc == ISD::AND)
17685 NewImm ^= APInt::getAllOnes(NewBW);
17686 uint64_t PtrOff = ShAmt / 8;
17687 // For big endian targets, we need to adjust the offset to the pointer to
17688 // load the correct bytes.
17689 if (DAG.getDataLayout().isBigEndian())
17690 PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
17691
17692 bool IsFast = false;
17693 Align NewAlign = commonAlignment(LD->getAlign(), PtrOff);
17694 if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), NewVT,
17695 LD->getAddressSpace(), NewAlign,
17696 LD->getMemOperand()->getFlags(), &IsFast) ||
17697 !IsFast)
17698 return SDValue();
17699
17700 SDValue NewPtr =
17701 DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(PtrOff), SDLoc(LD));
17702 SDValue NewLD =
17703 DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
17704 LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
17705 LD->getMemOperand()->getFlags(), LD->getAAInfo());
17706 SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
17707 DAG.getConstant(NewImm, SDLoc(Value),
17708 NewVT));
17709 SDValue NewST =
17710 DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
17711 ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
17712
17713 AddToWorklist(NewPtr.getNode());
17714 AddToWorklist(NewLD.getNode());
17715 AddToWorklist(NewVal.getNode());
17716 WorklistRemover DeadNodes(*this);
17717 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLD.getValue(1));
17718 ++OpsNarrowed;
17719 return NewST;
17720 }
17721 }
17722
17723 return SDValue();
17724 }
17725
17726 /// For a given floating point load / store pair, if the load value isn't used
17727 /// by any other operations, then consider transforming the pair to integer
17728 /// load / store operations if the target deems the transformation profitable.
TransformFPLoadStorePair(SDNode * N)17729 SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
17730 StoreSDNode *ST = cast<StoreSDNode>(N);
17731 SDValue Value = ST->getValue();
17732 if (ISD::isNormalStore(ST) && ISD::isNormalLoad(Value.getNode()) &&
17733 Value.hasOneUse()) {
17734 LoadSDNode *LD = cast<LoadSDNode>(Value);
17735 EVT VT = LD->getMemoryVT();
17736 if (!VT.isFloatingPoint() ||
17737 VT != ST->getMemoryVT() ||
17738 LD->isNonTemporal() ||
17739 ST->isNonTemporal() ||
17740 LD->getPointerInfo().getAddrSpace() != 0 ||
17741 ST->getPointerInfo().getAddrSpace() != 0)
17742 return SDValue();
17743
17744 TypeSize VTSize = VT.getSizeInBits();
17745
17746 // We don't know the size of scalable types at compile time so we cannot
17747 // create an integer of the equivalent size.
17748 if (VTSize.isScalable())
17749 return SDValue();
17750
17751 bool FastLD = false, FastST = false;
17752 EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VTSize.getFixedSize());
17753 if (!TLI.isOperationLegal(ISD::LOAD, IntVT) ||
17754 !TLI.isOperationLegal(ISD::STORE, IntVT) ||
17755 !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
17756 !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
17757 !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
17758 *LD->getMemOperand(), &FastLD) ||
17759 !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
17760 *ST->getMemOperand(), &FastST) ||
17761 !FastLD || !FastST)
17762 return SDValue();
17763
17764 SDValue NewLD =
17765 DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(), LD->getBasePtr(),
17766 LD->getPointerInfo(), LD->getAlign());
17767
17768 SDValue NewST =
17769 DAG.getStore(ST->getChain(), SDLoc(N), NewLD, ST->getBasePtr(),
17770 ST->getPointerInfo(), ST->getAlign());
17771
17772 AddToWorklist(NewLD.getNode());
17773 AddToWorklist(NewST.getNode());
17774 WorklistRemover DeadNodes(*this);
17775 DAG.ReplaceAllUsesOfValueWith(Value.getValue(1), NewLD.getValue(1));
17776 ++LdStFP2Int;
17777 return NewST;
17778 }
17779
17780 return SDValue();
17781 }
17782
17783 // This is a helper function for visitMUL to check the profitability
17784 // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
17785 // MulNode is the original multiply, AddNode is (add x, c1),
17786 // and ConstNode is c2.
17787 //
17788 // If the (add x, c1) has multiple uses, we could increase
17789 // the number of adds if we make this transformation.
17790 // It would only be worth doing this if we can remove a
17791 // multiply in the process. Check for that here.
17792 // To illustrate:
17793 // (A + c1) * c3
17794 // (A + c2) * c3
17795 // We're checking for cases where we have common "c3 * A" expressions.
isMulAddWithConstProfitable(SDNode * MulNode,SDValue AddNode,SDValue ConstNode)17796 bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
17797 SDValue ConstNode) {
17798 APInt Val;
17799
17800 // If the add only has one use, and the target thinks the folding is
17801 // profitable or does not lead to worse code, this would be OK to do.
17802 if (AddNode->hasOneUse() &&
17803 TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
17804 return true;
17805
17806 // Walk all the users of the constant with which we're multiplying.
17807 for (SDNode *Use : ConstNode->uses()) {
17808 if (Use == MulNode) // This use is the one we're on right now. Skip it.
17809 continue;
17810
17811 if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
17812 SDNode *OtherOp;
17813 SDNode *MulVar = AddNode.getOperand(0).getNode();
17814
17815 // OtherOp is what we're multiplying against the constant.
17816 if (Use->getOperand(0) == ConstNode)
17817 OtherOp = Use->getOperand(1).getNode();
17818 else
17819 OtherOp = Use->getOperand(0).getNode();
17820
17821 // Check to see if multiply is with the same operand of our "add".
17822 //
17823 // ConstNode = CONST
17824 // Use = ConstNode * A <-- visiting Use. OtherOp is A.
17825 // ...
17826 // AddNode = (A + c1) <-- MulVar is A.
17827 // = AddNode * ConstNode <-- current visiting instruction.
17828 //
17829 // If we make this transformation, we will have a common
17830 // multiply (ConstNode * A) that we can save.
17831 if (OtherOp == MulVar)
17832 return true;
17833
17834 // Now check to see if a future expansion will give us a common
17835 // multiply.
17836 //
17837 // ConstNode = CONST
17838 // AddNode = (A + c1)
17839 // ... = AddNode * ConstNode <-- current visiting instruction.
17840 // ...
17841 // OtherOp = (A + c2)
17842 // Use = OtherOp * ConstNode <-- visiting Use.
17843 //
17844 // If we make this transformation, we will have a common
17845 // multiply (CONST * A) after we also do the same transformation
17846 // to the "t2" instruction.
17847 if (OtherOp->getOpcode() == ISD::ADD &&
17848 DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
17849 OtherOp->getOperand(0).getNode() == MulVar)
17850 return true;
17851 }
17852 }
17853
17854 // Didn't find a case where this would be profitable.
17855 return false;
17856 }
17857
getMergeStoreChains(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores)17858 SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
17859 unsigned NumStores) {
17860 SmallVector<SDValue, 8> Chains;
17861 SmallPtrSet<const SDNode *, 8> Visited;
17862 SDLoc StoreDL(StoreNodes[0].MemNode);
17863
17864 for (unsigned i = 0; i < NumStores; ++i) {
17865 Visited.insert(StoreNodes[i].MemNode);
17866 }
17867
17868 // don't include nodes that are children or repeated nodes.
17869 for (unsigned i = 0; i < NumStores; ++i) {
17870 if (Visited.insert(StoreNodes[i].MemNode->getChain().getNode()).second)
17871 Chains.push_back(StoreNodes[i].MemNode->getChain());
17872 }
17873
17874 assert(Chains.size() > 0 && "Chain should have generated a chain");
17875 return DAG.getTokenFactor(StoreDL, Chains);
17876 }
17877
mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> & StoreNodes,EVT MemVT,unsigned NumStores,bool IsConstantSrc,bool UseVector,bool UseTrunc)17878 bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
17879 SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
17880 bool IsConstantSrc, bool UseVector, bool UseTrunc) {
17881 // Make sure we have something to merge.
17882 if (NumStores < 2)
17883 return false;
17884
17885 assert((!UseTrunc || !UseVector) &&
17886 "This optimization cannot emit a vector truncating store");
17887
17888 // The latest Node in the DAG.
17889 SDLoc DL(StoreNodes[0].MemNode);
17890
17891 TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
17892 unsigned SizeInBits = NumStores * ElementSizeBits;
17893 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
17894
17895 Optional<MachineMemOperand::Flags> Flags;
17896 AAMDNodes AAInfo;
17897 for (unsigned I = 0; I != NumStores; ++I) {
17898 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
17899 if (!Flags) {
17900 Flags = St->getMemOperand()->getFlags();
17901 AAInfo = St->getAAInfo();
17902 continue;
17903 }
17904 // Skip merging if there's an inconsistent flag.
17905 if (Flags != St->getMemOperand()->getFlags())
17906 return false;
17907 // Concatenate AA metadata.
17908 AAInfo = AAInfo.concat(St->getAAInfo());
17909 }
17910
17911 EVT StoreTy;
17912 if (UseVector) {
17913 unsigned Elts = NumStores * NumMemElts;
17914 // Get the type for the merged vector store.
17915 StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
17916 } else
17917 StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
17918
17919 SDValue StoredVal;
17920 if (UseVector) {
17921 if (IsConstantSrc) {
17922 SmallVector<SDValue, 8> BuildVector;
17923 for (unsigned I = 0; I != NumStores; ++I) {
17924 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
17925 SDValue Val = St->getValue();
17926 // If constant is of the wrong type, convert it now.
17927 if (MemVT != Val.getValueType()) {
17928 Val = peekThroughBitcasts(Val);
17929 // Deal with constants of wrong size.
17930 if (ElementSizeBits != Val.getValueSizeInBits()) {
17931 EVT IntMemVT =
17932 EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
17933 if (isa<ConstantFPSDNode>(Val)) {
17934 // Not clear how to truncate FP values.
17935 return false;
17936 }
17937
17938 if (auto *C = dyn_cast<ConstantSDNode>(Val))
17939 Val = DAG.getConstant(C->getAPIntValue()
17940 .zextOrTrunc(Val.getValueSizeInBits())
17941 .zextOrTrunc(ElementSizeBits),
17942 SDLoc(C), IntMemVT);
17943 }
17944 // Make sure correctly size type is the correct type.
17945 Val = DAG.getBitcast(MemVT, Val);
17946 }
17947 BuildVector.push_back(Val);
17948 }
17949 StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
17950 : ISD::BUILD_VECTOR,
17951 DL, StoreTy, BuildVector);
17952 } else {
17953 SmallVector<SDValue, 8> Ops;
17954 for (unsigned i = 0; i < NumStores; ++i) {
17955 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
17956 SDValue Val = peekThroughBitcasts(St->getValue());
17957 // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
17958 // type MemVT. If the underlying value is not the correct
17959 // type, but it is an extraction of an appropriate vector we
17960 // can recast Val to be of the correct type. This may require
17961 // converting between EXTRACT_VECTOR_ELT and
17962 // EXTRACT_SUBVECTOR.
17963 if ((MemVT != Val.getValueType()) &&
17964 (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
17965 Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
17966 EVT MemVTScalarTy = MemVT.getScalarType();
17967 // We may need to add a bitcast here to get types to line up.
17968 if (MemVTScalarTy != Val.getValueType().getScalarType()) {
17969 Val = DAG.getBitcast(MemVT, Val);
17970 } else {
17971 unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
17972 : ISD::EXTRACT_VECTOR_ELT;
17973 SDValue Vec = Val.getOperand(0);
17974 SDValue Idx = Val.getOperand(1);
17975 Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
17976 }
17977 }
17978 Ops.push_back(Val);
17979 }
17980
17981 // Build the extracted vector elements back into a vector.
17982 StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
17983 : ISD::BUILD_VECTOR,
17984 DL, StoreTy, Ops);
17985 }
17986 } else {
17987 // We should always use a vector store when merging extracted vector
17988 // elements, so this path implies a store of constants.
17989 assert(IsConstantSrc && "Merged vector elements should use vector store");
17990
17991 APInt StoreInt(SizeInBits, 0);
17992
17993 // Construct a single integer constant which is made of the smaller
17994 // constant inputs.
17995 bool IsLE = DAG.getDataLayout().isLittleEndian();
17996 for (unsigned i = 0; i < NumStores; ++i) {
17997 unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
17998 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
17999
18000 SDValue Val = St->getValue();
18001 Val = peekThroughBitcasts(Val);
18002 StoreInt <<= ElementSizeBits;
18003 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
18004 StoreInt |= C->getAPIntValue()
18005 .zextOrTrunc(ElementSizeBits)
18006 .zextOrTrunc(SizeInBits);
18007 } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
18008 StoreInt |= C->getValueAPF()
18009 .bitcastToAPInt()
18010 .zextOrTrunc(ElementSizeBits)
18011 .zextOrTrunc(SizeInBits);
18012 // If fp truncation is necessary give up for now.
18013 if (MemVT.getSizeInBits() != ElementSizeBits)
18014 return false;
18015 } else {
18016 llvm_unreachable("Invalid constant element type");
18017 }
18018 }
18019
18020 // Create the new Load and Store operations.
18021 StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
18022 }
18023
18024 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
18025 SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
18026
18027 // make sure we use trunc store if it's necessary to be legal.
18028 SDValue NewStore;
18029 if (!UseTrunc) {
18030 NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
18031 FirstInChain->getPointerInfo(),
18032 FirstInChain->getAlign(), *Flags, AAInfo);
18033 } else { // Must be realized as a trunc store
18034 EVT LegalizedStoredValTy =
18035 TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
18036 unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
18037 ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
18038 SDValue ExtendedStoreVal =
18039 DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
18040 LegalizedStoredValTy);
18041 NewStore = DAG.getTruncStore(
18042 NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
18043 FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/,
18044 FirstInChain->getAlign(), *Flags, AAInfo);
18045 }
18046
18047 // Replace all merged stores with the new store.
18048 for (unsigned i = 0; i < NumStores; ++i)
18049 CombineTo(StoreNodes[i].MemNode, NewStore);
18050
18051 AddToWorklist(NewChain.getNode());
18052 return true;
18053 }
18054
getStoreMergeCandidates(StoreSDNode * St,SmallVectorImpl<MemOpLink> & StoreNodes,SDNode * & RootNode)18055 void DAGCombiner::getStoreMergeCandidates(
18056 StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes,
18057 SDNode *&RootNode) {
18058 // This holds the base pointer, index, and the offset in bytes from the base
18059 // pointer. We must have a base and an offset. Do not handle stores to undef
18060 // base pointers.
18061 BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
18062 if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
18063 return;
18064
18065 SDValue Val = peekThroughBitcasts(St->getValue());
18066 StoreSource StoreSrc = getStoreSource(Val);
18067 assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
18068
18069 // Match on loadbaseptr if relevant.
18070 EVT MemVT = St->getMemoryVT();
18071 BaseIndexOffset LBasePtr;
18072 EVT LoadVT;
18073 if (StoreSrc == StoreSource::Load) {
18074 auto *Ld = cast<LoadSDNode>(Val);
18075 LBasePtr = BaseIndexOffset::match(Ld, DAG);
18076 LoadVT = Ld->getMemoryVT();
18077 // Load and store should be the same type.
18078 if (MemVT != LoadVT)
18079 return;
18080 // Loads must only have one use.
18081 if (!Ld->hasNUsesOfValue(1, 0))
18082 return;
18083 // The memory operands must not be volatile/indexed/atomic.
18084 // TODO: May be able to relax for unordered atomics (see D66309)
18085 if (!Ld->isSimple() || Ld->isIndexed())
18086 return;
18087 }
18088 auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
18089 int64_t &Offset) -> bool {
18090 // The memory operands must not be volatile/indexed/atomic.
18091 // TODO: May be able to relax for unordered atomics (see D66309)
18092 if (!Other->isSimple() || Other->isIndexed())
18093 return false;
18094 // Don't mix temporal stores with non-temporal stores.
18095 if (St->isNonTemporal() != Other->isNonTemporal())
18096 return false;
18097 SDValue OtherBC = peekThroughBitcasts(Other->getValue());
18098 // Allow merging constants of different types as integers.
18099 bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
18100 : Other->getMemoryVT() != MemVT;
18101 switch (StoreSrc) {
18102 case StoreSource::Load: {
18103 if (NoTypeMatch)
18104 return false;
18105 // The Load's Base Ptr must also match.
18106 auto *OtherLd = dyn_cast<LoadSDNode>(OtherBC);
18107 if (!OtherLd)
18108 return false;
18109 BaseIndexOffset LPtr = BaseIndexOffset::match(OtherLd, DAG);
18110 if (LoadVT != OtherLd->getMemoryVT())
18111 return false;
18112 // Loads must only have one use.
18113 if (!OtherLd->hasNUsesOfValue(1, 0))
18114 return false;
18115 // The memory operands must not be volatile/indexed/atomic.
18116 // TODO: May be able to relax for unordered atomics (see D66309)
18117 if (!OtherLd->isSimple() || OtherLd->isIndexed())
18118 return false;
18119 // Don't mix temporal loads with non-temporal loads.
18120 if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
18121 return false;
18122 if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
18123 return false;
18124 break;
18125 }
18126 case StoreSource::Constant:
18127 if (NoTypeMatch)
18128 return false;
18129 if (!isIntOrFPConstant(OtherBC))
18130 return false;
18131 break;
18132 case StoreSource::Extract:
18133 // Do not merge truncated stores here.
18134 if (Other->isTruncatingStore())
18135 return false;
18136 if (!MemVT.bitsEq(OtherBC.getValueType()))
18137 return false;
18138 if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
18139 OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
18140 return false;
18141 break;
18142 default:
18143 llvm_unreachable("Unhandled store source for merging");
18144 }
18145 Ptr = BaseIndexOffset::match(Other, DAG);
18146 return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
18147 };
18148
18149 // Check if the pair of StoreNode and the RootNode already bail out many
18150 // times which is over the limit in dependence check.
18151 auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
18152 SDNode *RootNode) -> bool {
18153 auto RootCount = StoreRootCountMap.find(StoreNode);
18154 return RootCount != StoreRootCountMap.end() &&
18155 RootCount->second.first == RootNode &&
18156 RootCount->second.second > StoreMergeDependenceLimit;
18157 };
18158
18159 auto TryToAddCandidate = [&](SDNode::use_iterator UseIter) {
18160 // This must be a chain use.
18161 if (UseIter.getOperandNo() != 0)
18162 return;
18163 if (auto *OtherStore = dyn_cast<StoreSDNode>(*UseIter)) {
18164 BaseIndexOffset Ptr;
18165 int64_t PtrDiff;
18166 if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
18167 !OverLimitInDependenceCheck(OtherStore, RootNode))
18168 StoreNodes.push_back(MemOpLink(OtherStore, PtrDiff));
18169 }
18170 };
18171
18172 // We looking for a root node which is an ancestor to all mergable
18173 // stores. We search up through a load, to our root and then down
18174 // through all children. For instance we will find Store{1,2,3} if
18175 // St is Store1, Store2. or Store3 where the root is not a load
18176 // which always true for nonvolatile ops. TODO: Expand
18177 // the search to find all valid candidates through multiple layers of loads.
18178 //
18179 // Root
18180 // |-------|-------|
18181 // Load Load Store3
18182 // | |
18183 // Store1 Store2
18184 //
18185 // FIXME: We should be able to climb and
18186 // descend TokenFactors to find candidates as well.
18187
18188 RootNode = St->getChain().getNode();
18189
18190 unsigned NumNodesExplored = 0;
18191 const unsigned MaxSearchNodes = 1024;
18192 if (auto *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
18193 RootNode = Ldn->getChain().getNode();
18194 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
18195 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
18196 if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) { // walk down chain
18197 for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
18198 TryToAddCandidate(I2);
18199 }
18200 // Check stores that depend on the root (e.g. Store 3 in the chart above).
18201 if (I.getOperandNo() == 0 && isa<StoreSDNode>(*I)) {
18202 TryToAddCandidate(I);
18203 }
18204 }
18205 } else {
18206 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
18207 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
18208 TryToAddCandidate(I);
18209 }
18210 }
18211
18212 // We need to check that merging these stores does not cause a loop in the
18213 // DAG. Any store candidate may depend on another candidate indirectly through
18214 // its operands. Check in parallel by searching up from operands of candidates.
checkMergeStoreCandidatesForDependencies(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores,SDNode * RootNode)18215 bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
18216 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
18217 SDNode *RootNode) {
18218 // FIXME: We should be able to truncate a full search of
18219 // predecessors by doing a BFS and keeping tabs the originating
18220 // stores from which worklist nodes come from in a similar way to
18221 // TokenFactor simplfication.
18222
18223 SmallPtrSet<const SDNode *, 32> Visited;
18224 SmallVector<const SDNode *, 8> Worklist;
18225
18226 // RootNode is a predecessor to all candidates so we need not search
18227 // past it. Add RootNode (peeking through TokenFactors). Do not count
18228 // these towards size check.
18229
18230 Worklist.push_back(RootNode);
18231 while (!Worklist.empty()) {
18232 auto N = Worklist.pop_back_val();
18233 if (!Visited.insert(N).second)
18234 continue; // Already present in Visited.
18235 if (N->getOpcode() == ISD::TokenFactor) {
18236 for (SDValue Op : N->ops())
18237 Worklist.push_back(Op.getNode());
18238 }
18239 }
18240
18241 // Don't count pruning nodes towards max.
18242 unsigned int Max = 1024 + Visited.size();
18243 // Search Ops of store candidates.
18244 for (unsigned i = 0; i < NumStores; ++i) {
18245 SDNode *N = StoreNodes[i].MemNode;
18246 // Of the 4 Store Operands:
18247 // * Chain (Op 0) -> We have already considered these
18248 // in candidate selection, but only by following the
18249 // chain dependencies. We could still have a chain
18250 // dependency to a load, that has a non-chain dep to
18251 // another load, that depends on a store, etc. So it is
18252 // possible to have dependencies that consist of a mix
18253 // of chain and non-chain deps, and we need to include
18254 // chain operands in the analysis here..
18255 // * Value (Op 1) -> Cycles may happen (e.g. through load chains)
18256 // * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
18257 // but aren't necessarily fromt the same base node, so
18258 // cycles possible (e.g. via indexed store).
18259 // * (Op 3) -> Represents the pre or post-indexing offset (or undef for
18260 // non-indexed stores). Not constant on all targets (e.g. ARM)
18261 // and so can participate in a cycle.
18262 for (unsigned j = 0; j < N->getNumOperands(); ++j)
18263 Worklist.push_back(N->getOperand(j).getNode());
18264 }
18265 // Search through DAG. We can stop early if we find a store node.
18266 for (unsigned i = 0; i < NumStores; ++i)
18267 if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
18268 Max)) {
18269 // If the searching bail out, record the StoreNode and RootNode in the
18270 // StoreRootCountMap. If we have seen the pair many times over a limit,
18271 // we won't add the StoreNode into StoreNodes set again.
18272 if (Visited.size() >= Max) {
18273 auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
18274 if (RootCount.first == RootNode)
18275 RootCount.second++;
18276 else
18277 RootCount = {RootNode, 1};
18278 }
18279 return false;
18280 }
18281 return true;
18282 }
18283
18284 unsigned
getConsecutiveStores(SmallVectorImpl<MemOpLink> & StoreNodes,int64_t ElementSizeBytes) const18285 DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
18286 int64_t ElementSizeBytes) const {
18287 while (true) {
18288 // Find a store past the width of the first store.
18289 size_t StartIdx = 0;
18290 while ((StartIdx + 1 < StoreNodes.size()) &&
18291 StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
18292 StoreNodes[StartIdx + 1].OffsetFromBase)
18293 ++StartIdx;
18294
18295 // Bail if we don't have enough candidates to merge.
18296 if (StartIdx + 1 >= StoreNodes.size())
18297 return 0;
18298
18299 // Trim stores that overlapped with the first store.
18300 if (StartIdx)
18301 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
18302
18303 // Scan the memory operations on the chain and find the first
18304 // non-consecutive store memory address.
18305 unsigned NumConsecutiveStores = 1;
18306 int64_t StartAddress = StoreNodes[0].OffsetFromBase;
18307 // Check that the addresses are consecutive starting from the second
18308 // element in the list of stores.
18309 for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
18310 int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
18311 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
18312 break;
18313 NumConsecutiveStores = i + 1;
18314 }
18315 if (NumConsecutiveStores > 1)
18316 return NumConsecutiveStores;
18317
18318 // There are no consecutive stores at the start of the list.
18319 // Remove the first store and try again.
18320 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
18321 }
18322 }
18323
tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors)18324 bool DAGCombiner::tryStoreMergeOfConstants(
18325 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
18326 EVT MemVT, SDNode *RootNode, bool AllowVectors) {
18327 LLVMContext &Context = *DAG.getContext();
18328 const DataLayout &DL = DAG.getDataLayout();
18329 int64_t ElementSizeBytes = MemVT.getStoreSize();
18330 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
18331 bool MadeChange = false;
18332
18333 // Store the constants into memory as one consecutive store.
18334 while (NumConsecutiveStores >= 2) {
18335 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
18336 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
18337 Align FirstStoreAlign = FirstInChain->getAlign();
18338 unsigned LastLegalType = 1;
18339 unsigned LastLegalVectorType = 1;
18340 bool LastIntegerTrunc = false;
18341 bool NonZero = false;
18342 unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
18343 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
18344 StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
18345 SDValue StoredVal = ST->getValue();
18346 bool IsElementZero = false;
18347 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
18348 IsElementZero = C->isZero();
18349 else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
18350 IsElementZero = C->getConstantFPValue()->isNullValue();
18351 if (IsElementZero) {
18352 if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
18353 FirstZeroAfterNonZero = i;
18354 }
18355 NonZero |= !IsElementZero;
18356
18357 // Find a legal type for the constant store.
18358 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
18359 EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
18360 bool IsFast = false;
18361
18362 // Break early when size is too large to be legal.
18363 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
18364 break;
18365
18366 if (TLI.isTypeLegal(StoreTy) &&
18367 TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
18368 DAG.getMachineFunction()) &&
18369 TLI.allowsMemoryAccess(Context, DL, StoreTy,
18370 *FirstInChain->getMemOperand(), &IsFast) &&
18371 IsFast) {
18372 LastIntegerTrunc = false;
18373 LastLegalType = i + 1;
18374 // Or check whether a truncstore is legal.
18375 } else if (TLI.getTypeAction(Context, StoreTy) ==
18376 TargetLowering::TypePromoteInteger) {
18377 EVT LegalizedStoredValTy =
18378 TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
18379 if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
18380 TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
18381 DAG.getMachineFunction()) &&
18382 TLI.allowsMemoryAccess(Context, DL, StoreTy,
18383 *FirstInChain->getMemOperand(), &IsFast) &&
18384 IsFast) {
18385 LastIntegerTrunc = true;
18386 LastLegalType = i + 1;
18387 }
18388 }
18389
18390 // We only use vectors if the constant is known to be zero or the
18391 // target allows it and the function is not marked with the
18392 // noimplicitfloat attribute.
18393 if ((!NonZero ||
18394 TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) &&
18395 AllowVectors) {
18396 // Find a legal type for the vector store.
18397 unsigned Elts = (i + 1) * NumMemElts;
18398 EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
18399 if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
18400 TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
18401 TLI.allowsMemoryAccess(Context, DL, Ty,
18402 *FirstInChain->getMemOperand(), &IsFast) &&
18403 IsFast)
18404 LastLegalVectorType = i + 1;
18405 }
18406 }
18407
18408 bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
18409 unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
18410 bool UseTrunc = LastIntegerTrunc && !UseVector;
18411
18412 // Check if we found a legal integer type that creates a meaningful
18413 // merge.
18414 if (NumElem < 2) {
18415 // We know that candidate stores are in order and of correct
18416 // shape. While there is no mergeable sequence from the
18417 // beginning one may start later in the sequence. The only
18418 // reason a merge of size N could have failed where another of
18419 // the same size would not have, is if the alignment has
18420 // improved or we've dropped a non-zero value. Drop as many
18421 // candidates as we can here.
18422 unsigned NumSkip = 1;
18423 while ((NumSkip < NumConsecutiveStores) &&
18424 (NumSkip < FirstZeroAfterNonZero) &&
18425 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
18426 NumSkip++;
18427
18428 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
18429 NumConsecutiveStores -= NumSkip;
18430 continue;
18431 }
18432
18433 // Check that we can merge these candidates without causing a cycle.
18434 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
18435 RootNode)) {
18436 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
18437 NumConsecutiveStores -= NumElem;
18438 continue;
18439 }
18440
18441 MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem,
18442 /*IsConstantSrc*/ true,
18443 UseVector, UseTrunc);
18444
18445 // Remove merged stores for next iteration.
18446 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
18447 NumConsecutiveStores -= NumElem;
18448 }
18449 return MadeChange;
18450 }
18451
tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode)18452 bool DAGCombiner::tryStoreMergeOfExtracts(
18453 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
18454 EVT MemVT, SDNode *RootNode) {
18455 LLVMContext &Context = *DAG.getContext();
18456 const DataLayout &DL = DAG.getDataLayout();
18457 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
18458 bool MadeChange = false;
18459
18460 // Loop on Consecutive Stores on success.
18461 while (NumConsecutiveStores >= 2) {
18462 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
18463 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
18464 Align FirstStoreAlign = FirstInChain->getAlign();
18465 unsigned NumStoresToMerge = 1;
18466 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
18467 // Find a legal type for the vector store.
18468 unsigned Elts = (i + 1) * NumMemElts;
18469 EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
18470 bool IsFast = false;
18471
18472 // Break early when size is too large to be legal.
18473 if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
18474 break;
18475
18476 if (TLI.isTypeLegal(Ty) &&
18477 TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
18478 TLI.allowsMemoryAccess(Context, DL, Ty,
18479 *FirstInChain->getMemOperand(), &IsFast) &&
18480 IsFast)
18481 NumStoresToMerge = i + 1;
18482 }
18483
18484 // Check if we found a legal integer type creating a meaningful
18485 // merge.
18486 if (NumStoresToMerge < 2) {
18487 // We know that candidate stores are in order and of correct
18488 // shape. While there is no mergeable sequence from the
18489 // beginning one may start later in the sequence. The only
18490 // reason a merge of size N could have failed where another of
18491 // the same size would not have, is if the alignment has
18492 // improved. Drop as many candidates as we can here.
18493 unsigned NumSkip = 1;
18494 while ((NumSkip < NumConsecutiveStores) &&
18495 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
18496 NumSkip++;
18497
18498 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
18499 NumConsecutiveStores -= NumSkip;
18500 continue;
18501 }
18502
18503 // Check that we can merge these candidates without causing a cycle.
18504 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStoresToMerge,
18505 RootNode)) {
18506 StoreNodes.erase(StoreNodes.begin(),
18507 StoreNodes.begin() + NumStoresToMerge);
18508 NumConsecutiveStores -= NumStoresToMerge;
18509 continue;
18510 }
18511
18512 MadeChange |= mergeStoresOfConstantsOrVecElts(
18513 StoreNodes, MemVT, NumStoresToMerge, /*IsConstantSrc*/ false,
18514 /*UseVector*/ true, /*UseTrunc*/ false);
18515
18516 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumStoresToMerge);
18517 NumConsecutiveStores -= NumStoresToMerge;
18518 }
18519 return MadeChange;
18520 }
18521
tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors,bool IsNonTemporalStore,bool IsNonTemporalLoad)18522 bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
18523 unsigned NumConsecutiveStores, EVT MemVT,
18524 SDNode *RootNode, bool AllowVectors,
18525 bool IsNonTemporalStore,
18526 bool IsNonTemporalLoad) {
18527 LLVMContext &Context = *DAG.getContext();
18528 const DataLayout &DL = DAG.getDataLayout();
18529 int64_t ElementSizeBytes = MemVT.getStoreSize();
18530 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
18531 bool MadeChange = false;
18532
18533 // Look for load nodes which are used by the stored values.
18534 SmallVector<MemOpLink, 8> LoadNodes;
18535
18536 // Find acceptable loads. Loads need to have the same chain (token factor),
18537 // must not be zext, volatile, indexed, and they must be consecutive.
18538 BaseIndexOffset LdBasePtr;
18539
18540 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
18541 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
18542 SDValue Val = peekThroughBitcasts(St->getValue());
18543 LoadSDNode *Ld = cast<LoadSDNode>(Val);
18544
18545 BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
18546 // If this is not the first ptr that we check.
18547 int64_t LdOffset = 0;
18548 if (LdBasePtr.getBase().getNode()) {
18549 // The base ptr must be the same.
18550 if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
18551 break;
18552 } else {
18553 // Check that all other base pointers are the same as this one.
18554 LdBasePtr = LdPtr;
18555 }
18556
18557 // We found a potential memory operand to merge.
18558 LoadNodes.push_back(MemOpLink(Ld, LdOffset));
18559 }
18560
18561 while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
18562 Align RequiredAlignment;
18563 bool NeedRotate = false;
18564 if (LoadNodes.size() == 2) {
18565 // If we have load/store pair instructions and we only have two values,
18566 // don't bother merging.
18567 if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
18568 StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
18569 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
18570 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
18571 break;
18572 }
18573 // If the loads are reversed, see if we can rotate the halves into place.
18574 int64_t Offset0 = LoadNodes[0].OffsetFromBase;
18575 int64_t Offset1 = LoadNodes[1].OffsetFromBase;
18576 EVT PairVT = EVT::getIntegerVT(Context, ElementSizeBytes * 8 * 2);
18577 if (Offset0 - Offset1 == ElementSizeBytes &&
18578 (hasOperation(ISD::ROTL, PairVT) ||
18579 hasOperation(ISD::ROTR, PairVT))) {
18580 std::swap(LoadNodes[0], LoadNodes[1]);
18581 NeedRotate = true;
18582 }
18583 }
18584 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
18585 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
18586 Align FirstStoreAlign = FirstInChain->getAlign();
18587 LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
18588
18589 // Scan the memory operations on the chain and find the first
18590 // non-consecutive load memory address. These variables hold the index in
18591 // the store node array.
18592
18593 unsigned LastConsecutiveLoad = 1;
18594
18595 // This variable refers to the size and not index in the array.
18596 unsigned LastLegalVectorType = 1;
18597 unsigned LastLegalIntegerType = 1;
18598 bool isDereferenceable = true;
18599 bool DoIntegerTruncate = false;
18600 int64_t StartAddress = LoadNodes[0].OffsetFromBase;
18601 SDValue LoadChain = FirstLoad->getChain();
18602 for (unsigned i = 1; i < LoadNodes.size(); ++i) {
18603 // All loads must share the same chain.
18604 if (LoadNodes[i].MemNode->getChain() != LoadChain)
18605 break;
18606
18607 int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
18608 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
18609 break;
18610 LastConsecutiveLoad = i;
18611
18612 if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
18613 isDereferenceable = false;
18614
18615 // Find a legal type for the vector store.
18616 unsigned Elts = (i + 1) * NumMemElts;
18617 EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
18618
18619 // Break early when size is too large to be legal.
18620 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
18621 break;
18622
18623 bool IsFastSt = false;
18624 bool IsFastLd = false;
18625 // Don't try vector types if we need a rotate. We may still fail the
18626 // legality checks for the integer type, but we can't handle the rotate
18627 // case with vectors.
18628 // FIXME: We could use a shuffle in place of the rotate.
18629 if (!NeedRotate && TLI.isTypeLegal(StoreTy) &&
18630 TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
18631 DAG.getMachineFunction()) &&
18632 TLI.allowsMemoryAccess(Context, DL, StoreTy,
18633 *FirstInChain->getMemOperand(), &IsFastSt) &&
18634 IsFastSt &&
18635 TLI.allowsMemoryAccess(Context, DL, StoreTy,
18636 *FirstLoad->getMemOperand(), &IsFastLd) &&
18637 IsFastLd) {
18638 LastLegalVectorType = i + 1;
18639 }
18640
18641 // Find a legal type for the integer store.
18642 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
18643 StoreTy = EVT::getIntegerVT(Context, SizeInBits);
18644 if (TLI.isTypeLegal(StoreTy) &&
18645 TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
18646 DAG.getMachineFunction()) &&
18647 TLI.allowsMemoryAccess(Context, DL, StoreTy,
18648 *FirstInChain->getMemOperand(), &IsFastSt) &&
18649 IsFastSt &&
18650 TLI.allowsMemoryAccess(Context, DL, StoreTy,
18651 *FirstLoad->getMemOperand(), &IsFastLd) &&
18652 IsFastLd) {
18653 LastLegalIntegerType = i + 1;
18654 DoIntegerTruncate = false;
18655 // Or check whether a truncstore and extload is legal.
18656 } else if (TLI.getTypeAction(Context, StoreTy) ==
18657 TargetLowering::TypePromoteInteger) {
18658 EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
18659 if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
18660 TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
18661 DAG.getMachineFunction()) &&
18662 TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) &&
18663 TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) &&
18664 TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
18665 TLI.allowsMemoryAccess(Context, DL, StoreTy,
18666 *FirstInChain->getMemOperand(), &IsFastSt) &&
18667 IsFastSt &&
18668 TLI.allowsMemoryAccess(Context, DL, StoreTy,
18669 *FirstLoad->getMemOperand(), &IsFastLd) &&
18670 IsFastLd) {
18671 LastLegalIntegerType = i + 1;
18672 DoIntegerTruncate = true;
18673 }
18674 }
18675 }
18676
18677 // Only use vector types if the vector type is larger than the integer
18678 // type. If they are the same, use integers.
18679 bool UseVectorTy =
18680 LastLegalVectorType > LastLegalIntegerType && AllowVectors;
18681 unsigned LastLegalType =
18682 std::max(LastLegalVectorType, LastLegalIntegerType);
18683
18684 // We add +1 here because the LastXXX variables refer to location while
18685 // the NumElem refers to array/index size.
18686 unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
18687 NumElem = std::min(LastLegalType, NumElem);
18688 Align FirstLoadAlign = FirstLoad->getAlign();
18689
18690 if (NumElem < 2) {
18691 // We know that candidate stores are in order and of correct
18692 // shape. While there is no mergeable sequence from the
18693 // beginning one may start later in the sequence. The only
18694 // reason a merge of size N could have failed where another of
18695 // the same size would not have is if the alignment or either
18696 // the load or store has improved. Drop as many candidates as we
18697 // can here.
18698 unsigned NumSkip = 1;
18699 while ((NumSkip < LoadNodes.size()) &&
18700 (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
18701 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
18702 NumSkip++;
18703 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
18704 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
18705 NumConsecutiveStores -= NumSkip;
18706 continue;
18707 }
18708
18709 // Check that we can merge these candidates without causing a cycle.
18710 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
18711 RootNode)) {
18712 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
18713 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
18714 NumConsecutiveStores -= NumElem;
18715 continue;
18716 }
18717
18718 // Find if it is better to use vectors or integers to load and store
18719 // to memory.
18720 EVT JointMemOpVT;
18721 if (UseVectorTy) {
18722 // Find a legal type for the vector store.
18723 unsigned Elts = NumElem * NumMemElts;
18724 JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
18725 } else {
18726 unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
18727 JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
18728 }
18729
18730 SDLoc LoadDL(LoadNodes[0].MemNode);
18731 SDLoc StoreDL(StoreNodes[0].MemNode);
18732
18733 // The merged loads are required to have the same incoming chain, so
18734 // using the first's chain is acceptable.
18735
18736 SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
18737 AddToWorklist(NewStoreChain.getNode());
18738
18739 MachineMemOperand::Flags LdMMOFlags =
18740 isDereferenceable ? MachineMemOperand::MODereferenceable
18741 : MachineMemOperand::MONone;
18742 if (IsNonTemporalLoad)
18743 LdMMOFlags |= MachineMemOperand::MONonTemporal;
18744
18745 MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
18746 ? MachineMemOperand::MONonTemporal
18747 : MachineMemOperand::MONone;
18748
18749 SDValue NewLoad, NewStore;
18750 if (UseVectorTy || !DoIntegerTruncate) {
18751 NewLoad = DAG.getLoad(
18752 JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(),
18753 FirstLoad->getPointerInfo(), FirstLoadAlign, LdMMOFlags);
18754 SDValue StoreOp = NewLoad;
18755 if (NeedRotate) {
18756 unsigned LoadWidth = ElementSizeBytes * 8 * 2;
18757 assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
18758 "Unexpected type for rotate-able load pair");
18759 SDValue RotAmt =
18760 DAG.getShiftAmountConstant(LoadWidth / 2, JointMemOpVT, LoadDL);
18761 // Target can convert to the identical ROTR if it does not have ROTL.
18762 StoreOp = DAG.getNode(ISD::ROTL, LoadDL, JointMemOpVT, NewLoad, RotAmt);
18763 }
18764 NewStore = DAG.getStore(
18765 NewStoreChain, StoreDL, StoreOp, FirstInChain->getBasePtr(),
18766 FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags);
18767 } else { // This must be the truncstore/extload case
18768 EVT ExtendedTy =
18769 TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
18770 NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
18771 FirstLoad->getChain(), FirstLoad->getBasePtr(),
18772 FirstLoad->getPointerInfo(), JointMemOpVT,
18773 FirstLoadAlign, LdMMOFlags);
18774 NewStore = DAG.getTruncStore(
18775 NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
18776 FirstInChain->getPointerInfo(), JointMemOpVT,
18777 FirstInChain->getAlign(), FirstInChain->getMemOperand()->getFlags());
18778 }
18779
18780 // Transfer chain users from old loads to the new load.
18781 for (unsigned i = 0; i < NumElem; ++i) {
18782 LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
18783 DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
18784 SDValue(NewLoad.getNode(), 1));
18785 }
18786
18787 // Replace all stores with the new store. Recursively remove corresponding
18788 // values if they are no longer used.
18789 for (unsigned i = 0; i < NumElem; ++i) {
18790 SDValue Val = StoreNodes[i].MemNode->getOperand(1);
18791 CombineTo(StoreNodes[i].MemNode, NewStore);
18792 if (Val->use_empty())
18793 recursivelyDeleteUnusedNodes(Val.getNode());
18794 }
18795
18796 MadeChange = true;
18797 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
18798 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
18799 NumConsecutiveStores -= NumElem;
18800 }
18801 return MadeChange;
18802 }
18803
mergeConsecutiveStores(StoreSDNode * St)18804 bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
18805 if (OptLevel == CodeGenOpt::None || !EnableStoreMerging)
18806 return false;
18807
18808 // TODO: Extend this function to merge stores of scalable vectors.
18809 // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
18810 // store since we know <vscale x 16 x i8> is exactly twice as large as
18811 // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
18812 EVT MemVT = St->getMemoryVT();
18813 if (MemVT.isScalableVector())
18814 return false;
18815 if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
18816 return false;
18817
18818 // This function cannot currently deal with non-byte-sized memory sizes.
18819 int64_t ElementSizeBytes = MemVT.getStoreSize();
18820 if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
18821 return false;
18822
18823 // Do not bother looking at stored values that are not constants, loads, or
18824 // extracted vector elements.
18825 SDValue StoredVal = peekThroughBitcasts(St->getValue());
18826 const StoreSource StoreSrc = getStoreSource(StoredVal);
18827 if (StoreSrc == StoreSource::Unknown)
18828 return false;
18829
18830 SmallVector<MemOpLink, 8> StoreNodes;
18831 SDNode *RootNode;
18832 // Find potential store merge candidates by searching through chain sub-DAG
18833 getStoreMergeCandidates(St, StoreNodes, RootNode);
18834
18835 // Check if there is anything to merge.
18836 if (StoreNodes.size() < 2)
18837 return false;
18838
18839 // Sort the memory operands according to their distance from the
18840 // base pointer.
18841 llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
18842 return LHS.OffsetFromBase < RHS.OffsetFromBase;
18843 });
18844
18845 bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
18846 Attribute::NoImplicitFloat);
18847 bool IsNonTemporalStore = St->isNonTemporal();
18848 bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
18849 cast<LoadSDNode>(StoredVal)->isNonTemporal();
18850
18851 // Store Merge attempts to merge the lowest stores. This generally
18852 // works out as if successful, as the remaining stores are checked
18853 // after the first collection of stores is merged. However, in the
18854 // case that a non-mergeable store is found first, e.g., {p[-2],
18855 // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
18856 // mergeable cases. To prevent this, we prune such stores from the
18857 // front of StoreNodes here.
18858 bool MadeChange = false;
18859 while (StoreNodes.size() > 1) {
18860 unsigned NumConsecutiveStores =
18861 getConsecutiveStores(StoreNodes, ElementSizeBytes);
18862 // There are no more stores in the list to examine.
18863 if (NumConsecutiveStores == 0)
18864 return MadeChange;
18865
18866 // We have at least 2 consecutive stores. Try to merge them.
18867 assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
18868 switch (StoreSrc) {
18869 case StoreSource::Constant:
18870 MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
18871 MemVT, RootNode, AllowVectors);
18872 break;
18873
18874 case StoreSource::Extract:
18875 MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
18876 MemVT, RootNode);
18877 break;
18878
18879 case StoreSource::Load:
18880 MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
18881 MemVT, RootNode, AllowVectors,
18882 IsNonTemporalStore, IsNonTemporalLoad);
18883 break;
18884
18885 default:
18886 llvm_unreachable("Unhandled store source type");
18887 }
18888 }
18889 return MadeChange;
18890 }
18891
replaceStoreChain(StoreSDNode * ST,SDValue BetterChain)18892 SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
18893 SDLoc SL(ST);
18894 SDValue ReplStore;
18895
18896 // Replace the chain to avoid dependency.
18897 if (ST->isTruncatingStore()) {
18898 ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(),
18899 ST->getBasePtr(), ST->getMemoryVT(),
18900 ST->getMemOperand());
18901 } else {
18902 ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(),
18903 ST->getMemOperand());
18904 }
18905
18906 // Create token to keep both nodes around.
18907 SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
18908 MVT::Other, ST->getChain(), ReplStore);
18909
18910 // Make sure the new and old chains are cleaned up.
18911 AddToWorklist(Token.getNode());
18912
18913 // Don't add users to work list.
18914 return CombineTo(ST, Token, false);
18915 }
18916
replaceStoreOfFPConstant(StoreSDNode * ST)18917 SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
18918 SDValue Value = ST->getValue();
18919 if (Value.getOpcode() == ISD::TargetConstantFP)
18920 return SDValue();
18921
18922 if (!ISD::isNormalStore(ST))
18923 return SDValue();
18924
18925 SDLoc DL(ST);
18926
18927 SDValue Chain = ST->getChain();
18928 SDValue Ptr = ST->getBasePtr();
18929
18930 const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
18931
18932 // NOTE: If the original store is volatile, this transform must not increase
18933 // the number of stores. For example, on x86-32 an f64 can be stored in one
18934 // processor operation but an i64 (which is not legal) requires two. So the
18935 // transform should not be done in this case.
18936
18937 SDValue Tmp;
18938 switch (CFP->getSimpleValueType(0).SimpleTy) {
18939 default:
18940 llvm_unreachable("Unknown FP type");
18941 case MVT::f16: // We don't do this for these yet.
18942 case MVT::bf16:
18943 case MVT::f80:
18944 case MVT::f128:
18945 case MVT::ppcf128:
18946 return SDValue();
18947 case MVT::f32:
18948 if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
18949 TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
18950 Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
18951 bitcastToAPInt().getZExtValue(), SDLoc(CFP),
18952 MVT::i32);
18953 return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
18954 }
18955
18956 return SDValue();
18957 case MVT::f64:
18958 if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
18959 ST->isSimple()) ||
18960 TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
18961 Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
18962 getZExtValue(), SDLoc(CFP), MVT::i64);
18963 return DAG.getStore(Chain, DL, Tmp,
18964 Ptr, ST->getMemOperand());
18965 }
18966
18967 if (ST->isSimple() &&
18968 TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
18969 // Many FP stores are not made apparent until after legalize, e.g. for
18970 // argument passing. Since this is so common, custom legalize the
18971 // 64-bit integer store into two 32-bit stores.
18972 uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
18973 SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
18974 SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
18975 if (DAG.getDataLayout().isBigEndian())
18976 std::swap(Lo, Hi);
18977
18978 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
18979 AAMDNodes AAInfo = ST->getAAInfo();
18980
18981 SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
18982 ST->getOriginalAlign(), MMOFlags, AAInfo);
18983 Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(4), DL);
18984 SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
18985 ST->getPointerInfo().getWithOffset(4),
18986 ST->getOriginalAlign(), MMOFlags, AAInfo);
18987 return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
18988 St0, St1);
18989 }
18990
18991 return SDValue();
18992 }
18993 }
18994
visitSTORE(SDNode * N)18995 SDValue DAGCombiner::visitSTORE(SDNode *N) {
18996 StoreSDNode *ST = cast<StoreSDNode>(N);
18997 SDValue Chain = ST->getChain();
18998 SDValue Value = ST->getValue();
18999 SDValue Ptr = ST->getBasePtr();
19000
19001 // If this is a store of a bit convert, store the input value if the
19002 // resultant store does not need a higher alignment than the original.
19003 if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
19004 ST->isUnindexed()) {
19005 EVT SVT = Value.getOperand(0).getValueType();
19006 // If the store is volatile, we only want to change the store type if the
19007 // resulting store is legal. Otherwise we might increase the number of
19008 // memory accesses. We don't care if the original type was legal or not
19009 // as we assume software couldn't rely on the number of accesses of an
19010 // illegal type.
19011 // TODO: May be able to relax for unordered atomics (see D66309)
19012 if (((!LegalOperations && ST->isSimple()) ||
19013 TLI.isOperationLegal(ISD::STORE, SVT)) &&
19014 TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT,
19015 DAG, *ST->getMemOperand())) {
19016 return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
19017 ST->getMemOperand());
19018 }
19019 }
19020
19021 // Turn 'store undef, Ptr' -> nothing.
19022 if (Value.isUndef() && ST->isUnindexed())
19023 return Chain;
19024
19025 // Try to infer better alignment information than the store already has.
19026 if (OptLevel != CodeGenOpt::None && ST->isUnindexed() && !ST->isAtomic()) {
19027 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
19028 if (*Alignment > ST->getAlign() &&
19029 isAligned(*Alignment, ST->getSrcValueOffset())) {
19030 SDValue NewStore =
19031 DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
19032 ST->getMemoryVT(), *Alignment,
19033 ST->getMemOperand()->getFlags(), ST->getAAInfo());
19034 // NewStore will always be N as we are only refining the alignment
19035 assert(NewStore.getNode() == N);
19036 (void)NewStore;
19037 }
19038 }
19039 }
19040
19041 // Try transforming a pair floating point load / store ops to integer
19042 // load / store ops.
19043 if (SDValue NewST = TransformFPLoadStorePair(N))
19044 return NewST;
19045
19046 // Try transforming several stores into STORE (BSWAP).
19047 if (SDValue Store = mergeTruncStores(ST))
19048 return Store;
19049
19050 if (ST->isUnindexed()) {
19051 // Walk up chain skipping non-aliasing memory nodes, on this store and any
19052 // adjacent stores.
19053 if (findBetterNeighborChains(ST)) {
19054 // replaceStoreChain uses CombineTo, which handled all of the worklist
19055 // manipulation. Return the original node to not do anything else.
19056 return SDValue(ST, 0);
19057 }
19058 Chain = ST->getChain();
19059 }
19060
19061 // FIXME: is there such a thing as a truncating indexed store?
19062 if (ST->isTruncatingStore() && ST->isUnindexed() &&
19063 Value.getValueType().isInteger() &&
19064 (!isa<ConstantSDNode>(Value) ||
19065 !cast<ConstantSDNode>(Value)->isOpaque())) {
19066 // Convert a truncating store of a extension into a standard store.
19067 if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
19068 Value.getOpcode() == ISD::SIGN_EXTEND ||
19069 Value.getOpcode() == ISD::ANY_EXTEND) &&
19070 Value.getOperand(0).getValueType() == ST->getMemoryVT() &&
19071 TLI.isOperationLegalOrCustom(ISD::STORE, ST->getMemoryVT()))
19072 return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
19073 ST->getMemOperand());
19074
19075 APInt TruncDemandedBits =
19076 APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
19077 ST->getMemoryVT().getScalarSizeInBits());
19078
19079 // See if we can simplify the input to this truncstore with knowledge that
19080 // only the low bits are being used. For example:
19081 // "truncstore (or (shl x, 8), y), i8" -> "truncstore y, i8"
19082 AddToWorklist(Value.getNode());
19083 if (SDValue Shorter = DAG.GetDemandedBits(Value, TruncDemandedBits))
19084 return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr, ST->getMemoryVT(),
19085 ST->getMemOperand());
19086
19087 // Otherwise, see if we can simplify the operation with
19088 // SimplifyDemandedBits, which only works if the value has a single use.
19089 if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
19090 // Re-visit the store if anything changed and the store hasn't been merged
19091 // with another node (N is deleted) SimplifyDemandedBits will add Value's
19092 // node back to the worklist if necessary, but we also need to re-visit
19093 // the Store node itself.
19094 if (N->getOpcode() != ISD::DELETED_NODE)
19095 AddToWorklist(N);
19096 return SDValue(N, 0);
19097 }
19098 }
19099
19100 // If this is a load followed by a store to the same location, then the store
19101 // is dead/noop.
19102 // TODO: Can relax for unordered atomics (see D66309)
19103 if (LoadSDNode *Ld = dyn_cast<LoadSDNode>(Value)) {
19104 if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
19105 ST->isUnindexed() && ST->isSimple() &&
19106 Ld->getAddressSpace() == ST->getAddressSpace() &&
19107 // There can't be any side effects between the load and store, such as
19108 // a call or store.
19109 Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) {
19110 // The store is dead, remove it.
19111 return Chain;
19112 }
19113 }
19114
19115 // TODO: Can relax for unordered atomics (see D66309)
19116 if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
19117 if (ST->isUnindexed() && ST->isSimple() &&
19118 ST1->isUnindexed() && ST1->isSimple()) {
19119 if (OptLevel != CodeGenOpt::None && ST1->getBasePtr() == Ptr &&
19120 ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
19121 ST->getAddressSpace() == ST1->getAddressSpace()) {
19122 // If this is a store followed by a store with the same value to the
19123 // same location, then the store is dead/noop.
19124 return Chain;
19125 }
19126
19127 if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() &&
19128 !ST1->getBasePtr().isUndef() &&
19129 // BaseIndexOffset and the code below requires knowing the size
19130 // of a vector, so bail out if MemoryVT is scalable.
19131 !ST->getMemoryVT().isScalableVector() &&
19132 !ST1->getMemoryVT().isScalableVector() &&
19133 ST->getAddressSpace() == ST1->getAddressSpace()) {
19134 const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG);
19135 const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG);
19136 unsigned STBitSize = ST->getMemoryVT().getFixedSizeInBits();
19137 unsigned ChainBitSize = ST1->getMemoryVT().getFixedSizeInBits();
19138 // If this is a store who's preceding store to a subset of the current
19139 // location and no one other node is chained to that store we can
19140 // effectively drop the store. Do not remove stores to undef as they may
19141 // be used as data sinks.
19142 if (STBase.contains(DAG, STBitSize, ChainBase, ChainBitSize)) {
19143 CombineTo(ST1, ST1->getChain());
19144 return SDValue();
19145 }
19146 }
19147 }
19148 }
19149
19150 // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
19151 // truncating store. We can do this even if this is already a truncstore.
19152 if ((Value.getOpcode() == ISD::FP_ROUND ||
19153 Value.getOpcode() == ISD::TRUNCATE) &&
19154 Value->hasOneUse() && ST->isUnindexed() &&
19155 TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
19156 ST->getMemoryVT(), LegalOperations)) {
19157 return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0),
19158 Ptr, ST->getMemoryVT(), ST->getMemOperand());
19159 }
19160
19161 // Always perform this optimization before types are legal. If the target
19162 // prefers, also try this after legalization to catch stores that were created
19163 // by intrinsics or other nodes.
19164 if (!LegalTypes || (TLI.mergeStoresAfterLegalization(ST->getMemoryVT()))) {
19165 while (true) {
19166 // There can be multiple store sequences on the same chain.
19167 // Keep trying to merge store sequences until we are unable to do so
19168 // or until we merge the last store on the chain.
19169 bool Changed = mergeConsecutiveStores(ST);
19170 if (!Changed) break;
19171 // Return N as merge only uses CombineTo and no worklist clean
19172 // up is necessary.
19173 if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
19174 return SDValue(N, 0);
19175 }
19176 }
19177
19178 // Try transforming N to an indexed store.
19179 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
19180 return SDValue(N, 0);
19181
19182 // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
19183 //
19184 // Make sure to do this only after attempting to merge stores in order to
19185 // avoid changing the types of some subset of stores due to visit order,
19186 // preventing their merging.
19187 if (isa<ConstantFPSDNode>(ST->getValue())) {
19188 if (SDValue NewSt = replaceStoreOfFPConstant(ST))
19189 return NewSt;
19190 }
19191
19192 if (SDValue NewSt = splitMergedValStore(ST))
19193 return NewSt;
19194
19195 return ReduceLoadOpStoreWidth(N);
19196 }
19197
visitLIFETIME_END(SDNode * N)19198 SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
19199 const auto *LifetimeEnd = cast<LifetimeSDNode>(N);
19200 if (!LifetimeEnd->hasOffset())
19201 return SDValue();
19202
19203 const BaseIndexOffset LifetimeEndBase(N->getOperand(1), SDValue(),
19204 LifetimeEnd->getOffset(), false);
19205
19206 // We walk up the chains to find stores.
19207 SmallVector<SDValue, 8> Chains = {N->getOperand(0)};
19208 while (!Chains.empty()) {
19209 SDValue Chain = Chains.pop_back_val();
19210 if (!Chain.hasOneUse())
19211 continue;
19212 switch (Chain.getOpcode()) {
19213 case ISD::TokenFactor:
19214 for (unsigned Nops = Chain.getNumOperands(); Nops;)
19215 Chains.push_back(Chain.getOperand(--Nops));
19216 break;
19217 case ISD::LIFETIME_START:
19218 case ISD::LIFETIME_END:
19219 // We can forward past any lifetime start/end that can be proven not to
19220 // alias the node.
19221 if (!mayAlias(Chain.getNode(), N))
19222 Chains.push_back(Chain.getOperand(0));
19223 break;
19224 case ISD::STORE: {
19225 StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain);
19226 // TODO: Can relax for unordered atomics (see D66309)
19227 if (!ST->isSimple() || ST->isIndexed())
19228 continue;
19229 const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
19230 // The bounds of a scalable store are not known until runtime, so this
19231 // store cannot be elided.
19232 if (StoreSize.isScalable())
19233 continue;
19234 const BaseIndexOffset StoreBase = BaseIndexOffset::match(ST, DAG);
19235 // If we store purely within object bounds just before its lifetime ends,
19236 // we can remove the store.
19237 if (LifetimeEndBase.contains(DAG, LifetimeEnd->getSize() * 8, StoreBase,
19238 StoreSize.getFixedSize() * 8)) {
19239 LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
19240 dbgs() << "\nwithin LIFETIME_END of : ";
19241 LifetimeEndBase.dump(); dbgs() << "\n");
19242 CombineTo(ST, ST->getChain());
19243 return SDValue(N, 0);
19244 }
19245 }
19246 }
19247 }
19248 return SDValue();
19249 }
19250
19251 /// For the instruction sequence of store below, F and I values
19252 /// are bundled together as an i64 value before being stored into memory.
19253 /// Sometimes it is more efficent to generate separate stores for F and I,
19254 /// which can remove the bitwise instructions or sink them to colder places.
19255 ///
19256 /// (store (or (zext (bitcast F to i32) to i64),
19257 /// (shl (zext I to i64), 32)), addr) -->
19258 /// (store F, addr) and (store I, addr+4)
19259 ///
19260 /// Similarly, splitting for other merged store can also be beneficial, like:
19261 /// For pair of {i32, i32}, i64 store --> two i32 stores.
19262 /// For pair of {i32, i16}, i64 store --> two i32 stores.
19263 /// For pair of {i16, i16}, i32 store --> two i16 stores.
19264 /// For pair of {i16, i8}, i32 store --> two i16 stores.
19265 /// For pair of {i8, i8}, i16 store --> two i8 stores.
19266 ///
19267 /// We allow each target to determine specifically which kind of splitting is
19268 /// supported.
19269 ///
19270 /// The store patterns are commonly seen from the simple code snippet below
19271 /// if only std::make_pair(...) is sroa transformed before inlined into hoo.
19272 /// void goo(const std::pair<int, float> &);
19273 /// hoo() {
19274 /// ...
19275 /// goo(std::make_pair(tmp, ftmp));
19276 /// ...
19277 /// }
19278 ///
splitMergedValStore(StoreSDNode * ST)19279 SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
19280 if (OptLevel == CodeGenOpt::None)
19281 return SDValue();
19282
19283 // Can't change the number of memory accesses for a volatile store or break
19284 // atomicity for an atomic one.
19285 if (!ST->isSimple())
19286 return SDValue();
19287
19288 SDValue Val = ST->getValue();
19289 SDLoc DL(ST);
19290
19291 // Match OR operand.
19292 if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
19293 return SDValue();
19294
19295 // Match SHL operand and get Lower and Higher parts of Val.
19296 SDValue Op1 = Val.getOperand(0);
19297 SDValue Op2 = Val.getOperand(1);
19298 SDValue Lo, Hi;
19299 if (Op1.getOpcode() != ISD::SHL) {
19300 std::swap(Op1, Op2);
19301 if (Op1.getOpcode() != ISD::SHL)
19302 return SDValue();
19303 }
19304 Lo = Op2;
19305 Hi = Op1.getOperand(0);
19306 if (!Op1.hasOneUse())
19307 return SDValue();
19308
19309 // Match shift amount to HalfValBitSize.
19310 unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
19311 ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
19312 if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
19313 return SDValue();
19314
19315 // Lo and Hi are zero-extended from int with size less equal than 32
19316 // to i64.
19317 if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
19318 !Lo.getOperand(0).getValueType().isScalarInteger() ||
19319 Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
19320 Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
19321 !Hi.getOperand(0).getValueType().isScalarInteger() ||
19322 Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
19323 return SDValue();
19324
19325 // Use the EVT of low and high parts before bitcast as the input
19326 // of target query.
19327 EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
19328 ? Lo.getOperand(0).getValueType()
19329 : Lo.getValueType();
19330 EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
19331 ? Hi.getOperand(0).getValueType()
19332 : Hi.getValueType();
19333 if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
19334 return SDValue();
19335
19336 // Start to split store.
19337 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
19338 AAMDNodes AAInfo = ST->getAAInfo();
19339
19340 // Change the sizes of Lo and Hi's value types to HalfValBitSize.
19341 EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
19342 Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
19343 Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
19344
19345 SDValue Chain = ST->getChain();
19346 SDValue Ptr = ST->getBasePtr();
19347 // Lower value store.
19348 SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
19349 ST->getOriginalAlign(), MMOFlags, AAInfo);
19350 Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(HalfValBitSize / 8), DL);
19351 // Higher value store.
19352 SDValue St1 = DAG.getStore(
19353 St0, DL, Hi, Ptr, ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
19354 ST->getOriginalAlign(), MMOFlags, AAInfo);
19355 return St1;
19356 }
19357
19358 /// Convert a disguised subvector insertion into a shuffle:
combineInsertEltToShuffle(SDNode * N,unsigned InsIndex)19359 SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
19360 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
19361 "Expected extract_vector_elt");
19362 SDValue InsertVal = N->getOperand(1);
19363 SDValue Vec = N->getOperand(0);
19364
19365 // (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
19366 // InsIndex)
19367 // --> (vector_shuffle X, Y) and variations where shuffle operands may be
19368 // CONCAT_VECTORS.
19369 if (Vec.getOpcode() == ISD::VECTOR_SHUFFLE && Vec.hasOneUse() &&
19370 InsertVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
19371 isa<ConstantSDNode>(InsertVal.getOperand(1))) {
19372 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Vec.getNode());
19373 ArrayRef<int> Mask = SVN->getMask();
19374
19375 SDValue X = Vec.getOperand(0);
19376 SDValue Y = Vec.getOperand(1);
19377
19378 // Vec's operand 0 is using indices from 0 to N-1 and
19379 // operand 1 from N to 2N - 1, where N is the number of
19380 // elements in the vectors.
19381 SDValue InsertVal0 = InsertVal.getOperand(0);
19382 int ElementOffset = -1;
19383
19384 // We explore the inputs of the shuffle in order to see if we find the
19385 // source of the extract_vector_elt. If so, we can use it to modify the
19386 // shuffle rather than perform an insert_vector_elt.
19387 SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
19388 ArgWorkList.emplace_back(Mask.size(), Y);
19389 ArgWorkList.emplace_back(0, X);
19390
19391 while (!ArgWorkList.empty()) {
19392 int ArgOffset;
19393 SDValue ArgVal;
19394 std::tie(ArgOffset, ArgVal) = ArgWorkList.pop_back_val();
19395
19396 if (ArgVal == InsertVal0) {
19397 ElementOffset = ArgOffset;
19398 break;
19399 }
19400
19401 // Peek through concat_vector.
19402 if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
19403 int CurrentArgOffset =
19404 ArgOffset + ArgVal.getValueType().getVectorNumElements();
19405 int Step = ArgVal.getOperand(0).getValueType().getVectorNumElements();
19406 for (SDValue Op : reverse(ArgVal->ops())) {
19407 CurrentArgOffset -= Step;
19408 ArgWorkList.emplace_back(CurrentArgOffset, Op);
19409 }
19410
19411 // Make sure we went through all the elements and did not screw up index
19412 // computation.
19413 assert(CurrentArgOffset == ArgOffset);
19414 }
19415 }
19416
19417 // If we failed to find a match, see if we can replace an UNDEF shuffle
19418 // operand.
19419 if (ElementOffset == -1 && Y.isUndef() &&
19420 InsertVal0.getValueType() == Y.getValueType()) {
19421 ElementOffset = Mask.size();
19422 Y = InsertVal0;
19423 }
19424
19425 if (ElementOffset != -1) {
19426 SmallVector<int, 16> NewMask(Mask.begin(), Mask.end());
19427
19428 auto *ExtrIndex = cast<ConstantSDNode>(InsertVal.getOperand(1));
19429 NewMask[InsIndex] = ElementOffset + ExtrIndex->getZExtValue();
19430 assert(NewMask[InsIndex] <
19431 (int)(2 * Vec.getValueType().getVectorNumElements()) &&
19432 NewMask[InsIndex] >= 0 && "NewMask[InsIndex] is out of bound");
19433
19434 SDValue LegalShuffle =
19435 TLI.buildLegalVectorShuffle(Vec.getValueType(), SDLoc(N), X,
19436 Y, NewMask, DAG);
19437 if (LegalShuffle)
19438 return LegalShuffle;
19439 }
19440 }
19441
19442 // insert_vector_elt V, (bitcast X from vector type), IdxC -->
19443 // bitcast(shuffle (bitcast V), (extended X), Mask)
19444 // Note: We do not use an insert_subvector node because that requires a
19445 // legal subvector type.
19446 if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
19447 !InsertVal.getOperand(0).getValueType().isVector())
19448 return SDValue();
19449
19450 SDValue SubVec = InsertVal.getOperand(0);
19451 SDValue DestVec = N->getOperand(0);
19452 EVT SubVecVT = SubVec.getValueType();
19453 EVT VT = DestVec.getValueType();
19454 unsigned NumSrcElts = SubVecVT.getVectorNumElements();
19455 // If the source only has a single vector element, the cost of creating adding
19456 // it to a vector is likely to exceed the cost of a insert_vector_elt.
19457 if (NumSrcElts == 1)
19458 return SDValue();
19459 unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
19460 unsigned NumMaskVals = ExtendRatio * NumSrcElts;
19461
19462 // Step 1: Create a shuffle mask that implements this insert operation. The
19463 // vector that we are inserting into will be operand 0 of the shuffle, so
19464 // those elements are just 'i'. The inserted subvector is in the first
19465 // positions of operand 1 of the shuffle. Example:
19466 // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
19467 SmallVector<int, 16> Mask(NumMaskVals);
19468 for (unsigned i = 0; i != NumMaskVals; ++i) {
19469 if (i / NumSrcElts == InsIndex)
19470 Mask[i] = (i % NumSrcElts) + NumMaskVals;
19471 else
19472 Mask[i] = i;
19473 }
19474
19475 // Bail out if the target can not handle the shuffle we want to create.
19476 EVT SubVecEltVT = SubVecVT.getVectorElementType();
19477 EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
19478 if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
19479 return SDValue();
19480
19481 // Step 2: Create a wide vector from the inserted source vector by appending
19482 // undefined elements. This is the same size as our destination vector.
19483 SDLoc DL(N);
19484 SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
19485 ConcatOps[0] = SubVec;
19486 SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
19487
19488 // Step 3: Shuffle in the padded subvector.
19489 SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
19490 SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
19491 AddToWorklist(PaddedSubV.getNode());
19492 AddToWorklist(DestVecBC.getNode());
19493 AddToWorklist(Shuf.getNode());
19494 return DAG.getBitcast(VT, Shuf);
19495 }
19496
visitINSERT_VECTOR_ELT(SDNode * N)19497 SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
19498 SDValue InVec = N->getOperand(0);
19499 SDValue InVal = N->getOperand(1);
19500 SDValue EltNo = N->getOperand(2);
19501 SDLoc DL(N);
19502
19503 EVT VT = InVec.getValueType();
19504 auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
19505
19506 // Insert into out-of-bounds element is undefined.
19507 if (IndexC && VT.isFixedLengthVector() &&
19508 IndexC->getZExtValue() >= VT.getVectorNumElements())
19509 return DAG.getUNDEF(VT);
19510
19511 // Remove redundant insertions:
19512 // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
19513 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
19514 InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
19515 return InVec;
19516
19517 if (!IndexC) {
19518 // If this is variable insert to undef vector, it might be better to splat:
19519 // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
19520 if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) {
19521 if (VT.isScalableVector())
19522 return DAG.getSplatVector(VT, DL, InVal);
19523
19524 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), InVal);
19525 return DAG.getBuildVector(VT, DL, Ops);
19526 }
19527 return SDValue();
19528 }
19529
19530 if (VT.isScalableVector())
19531 return SDValue();
19532
19533 unsigned NumElts = VT.getVectorNumElements();
19534
19535 // We must know which element is being inserted for folds below here.
19536 unsigned Elt = IndexC->getZExtValue();
19537
19538 if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
19539 return Shuf;
19540
19541 // Handle <1 x ???> vector insertion special cases.
19542 if (NumElts == 1) {
19543 // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
19544 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
19545 InVal.getOperand(0).getValueType() == VT &&
19546 isNullConstant(InVal.getOperand(1)))
19547 return InVal.getOperand(0);
19548 }
19549
19550 // Canonicalize insert_vector_elt dag nodes.
19551 // Example:
19552 // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
19553 // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
19554 //
19555 // Do this only if the child insert_vector node has one use; also
19556 // do this only if indices are both constants and Idx1 < Idx0.
19557 if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
19558 && isa<ConstantSDNode>(InVec.getOperand(2))) {
19559 unsigned OtherElt = InVec.getConstantOperandVal(2);
19560 if (Elt < OtherElt) {
19561 // Swap nodes.
19562 SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
19563 InVec.getOperand(0), InVal, EltNo);
19564 AddToWorklist(NewOp.getNode());
19565 return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
19566 VT, NewOp, InVec.getOperand(1), InVec.getOperand(2));
19567 }
19568 }
19569
19570 // Attempt to convert an insert_vector_elt chain into a legal build_vector.
19571 if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) {
19572 // vXi1 vector - we don't need to recurse.
19573 if (NumElts == 1)
19574 return DAG.getBuildVector(VT, DL, {InVal});
19575
19576 // If we haven't already collected the element, insert into the op list.
19577 EVT MaxEltVT = InVal.getValueType();
19578 auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
19579 unsigned Idx) {
19580 if (!Ops[Idx]) {
19581 Ops[Idx] = Elt;
19582 if (VT.isInteger()) {
19583 EVT EltVT = Elt.getValueType();
19584 MaxEltVT = MaxEltVT.bitsGE(EltVT) ? MaxEltVT : EltVT;
19585 }
19586 }
19587 };
19588
19589 // Ensure all the operands are the same value type, fill any missing
19590 // operands with UNDEF and create the BUILD_VECTOR.
19591 auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops) {
19592 assert(Ops.size() == NumElts && "Unexpected vector size");
19593 for (SDValue &Op : Ops) {
19594 if (Op)
19595 Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, MaxEltVT) : Op;
19596 else
19597 Op = DAG.getUNDEF(MaxEltVT);
19598 }
19599 return DAG.getBuildVector(VT, DL, Ops);
19600 };
19601
19602 SmallVector<SDValue, 8> Ops(NumElts, SDValue());
19603 Ops[Elt] = InVal;
19604
19605 // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
19606 for (SDValue CurVec = InVec; CurVec;) {
19607 // UNDEF - build new BUILD_VECTOR from already inserted operands.
19608 if (CurVec.isUndef())
19609 return CanonicalizeBuildVector(Ops);
19610
19611 // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
19612 if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
19613 for (unsigned I = 0; I != NumElts; ++I)
19614 AddBuildVectorOp(Ops, CurVec.getOperand(I), I);
19615 return CanonicalizeBuildVector(Ops);
19616 }
19617
19618 // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
19619 if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
19620 AddBuildVectorOp(Ops, CurVec.getOperand(0), 0);
19621 return CanonicalizeBuildVector(Ops);
19622 }
19623
19624 // INSERT_VECTOR_ELT - insert operand and continue up the chain.
19625 if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
19626 if (auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2)))
19627 if (CurIdx->getAPIntValue().ult(NumElts)) {
19628 unsigned Idx = CurIdx->getZExtValue();
19629 AddBuildVectorOp(Ops, CurVec.getOperand(1), Idx);
19630
19631 // Found entire BUILD_VECTOR.
19632 if (all_of(Ops, [](SDValue Op) { return !!Op; }))
19633 return CanonicalizeBuildVector(Ops);
19634
19635 CurVec = CurVec->getOperand(0);
19636 continue;
19637 }
19638
19639 // Failed to find a match in the chain - bail.
19640 break;
19641 }
19642 }
19643
19644 return SDValue();
19645 }
19646
scalarizeExtractedVectorLoad(SDNode * EVE,EVT InVecVT,SDValue EltNo,LoadSDNode * OriginalLoad)19647 SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
19648 SDValue EltNo,
19649 LoadSDNode *OriginalLoad) {
19650 assert(OriginalLoad->isSimple());
19651
19652 EVT ResultVT = EVE->getValueType(0);
19653 EVT VecEltVT = InVecVT.getVectorElementType();
19654
19655 // If the vector element type is not a multiple of a byte then we are unable
19656 // to correctly compute an address to load only the extracted element as a
19657 // scalar.
19658 if (!VecEltVT.isByteSized())
19659 return SDValue();
19660
19661 ISD::LoadExtType ExtTy =
19662 ResultVT.bitsGT(VecEltVT) ? ISD::NON_EXTLOAD : ISD::EXTLOAD;
19663 if (!TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
19664 !TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
19665 return SDValue();
19666
19667 Align Alignment = OriginalLoad->getAlign();
19668 MachinePointerInfo MPI;
19669 SDLoc DL(EVE);
19670 if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
19671 int Elt = ConstEltNo->getZExtValue();
19672 unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
19673 MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
19674 Alignment = commonAlignment(Alignment, PtrOff);
19675 } else {
19676 // Discard the pointer info except the address space because the memory
19677 // operand can't represent this new access since the offset is variable.
19678 MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
19679 Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
19680 }
19681
19682 bool IsFast = false;
19683 if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
19684 OriginalLoad->getAddressSpace(), Alignment,
19685 OriginalLoad->getMemOperand()->getFlags(),
19686 &IsFast) ||
19687 !IsFast)
19688 return SDValue();
19689
19690 SDValue NewPtr = TLI.getVectorElementPointer(DAG, OriginalLoad->getBasePtr(),
19691 InVecVT, EltNo);
19692
19693 // We are replacing a vector load with a scalar load. The new load must have
19694 // identical memory op ordering to the original.
19695 SDValue Load;
19696 if (ResultVT.bitsGT(VecEltVT)) {
19697 // If the result type of vextract is wider than the load, then issue an
19698 // extending load instead.
19699 ISD::LoadExtType ExtType =
19700 TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) ? ISD::ZEXTLOAD
19701 : ISD::EXTLOAD;
19702 Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
19703 NewPtr, MPI, VecEltVT, Alignment,
19704 OriginalLoad->getMemOperand()->getFlags(),
19705 OriginalLoad->getAAInfo());
19706 DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
19707 } else {
19708 // The result type is narrower or the same width as the vector element
19709 Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
19710 Alignment, OriginalLoad->getMemOperand()->getFlags(),
19711 OriginalLoad->getAAInfo());
19712 DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
19713 if (ResultVT.bitsLT(VecEltVT))
19714 Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
19715 else
19716 Load = DAG.getBitcast(ResultVT, Load);
19717 }
19718 ++OpsNarrowed;
19719 return Load;
19720 }
19721
19722 /// Transform a vector binary operation into a scalar binary operation by moving
19723 /// the math/logic after an extract element of a vector.
scalarizeExtractedBinop(SDNode * ExtElt,SelectionDAG & DAG,bool LegalOperations)19724 static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
19725 bool LegalOperations) {
19726 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19727 SDValue Vec = ExtElt->getOperand(0);
19728 SDValue Index = ExtElt->getOperand(1);
19729 auto *IndexC = dyn_cast<ConstantSDNode>(Index);
19730 if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
19731 Vec->getNumValues() != 1)
19732 return SDValue();
19733
19734 // Targets may want to avoid this to prevent an expensive register transfer.
19735 if (!TLI.shouldScalarizeBinop(Vec))
19736 return SDValue();
19737
19738 // Extracting an element of a vector constant is constant-folded, so this
19739 // transform is just replacing a vector op with a scalar op while moving the
19740 // extract.
19741 SDValue Op0 = Vec.getOperand(0);
19742 SDValue Op1 = Vec.getOperand(1);
19743 APInt SplatVal;
19744 if (isAnyConstantBuildVector(Op0, true) ||
19745 ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
19746 isAnyConstantBuildVector(Op1, true) ||
19747 ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
19748 // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
19749 // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
19750 SDLoc DL(ExtElt);
19751 EVT VT = ExtElt->getValueType(0);
19752 SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
19753 SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
19754 return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
19755 }
19756
19757 return SDValue();
19758 }
19759
visitEXTRACT_VECTOR_ELT(SDNode * N)19760 SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
19761 SDValue VecOp = N->getOperand(0);
19762 SDValue Index = N->getOperand(1);
19763 EVT ScalarVT = N->getValueType(0);
19764 EVT VecVT = VecOp.getValueType();
19765 if (VecOp.isUndef())
19766 return DAG.getUNDEF(ScalarVT);
19767
19768 // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
19769 //
19770 // This only really matters if the index is non-constant since other combines
19771 // on the constant elements already work.
19772 SDLoc DL(N);
19773 if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
19774 Index == VecOp.getOperand(2)) {
19775 SDValue Elt = VecOp.getOperand(1);
19776 return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
19777 }
19778
19779 // (vextract (scalar_to_vector val, 0) -> val
19780 if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
19781 // Only 0'th element of SCALAR_TO_VECTOR is defined.
19782 if (DAG.isKnownNeverZero(Index))
19783 return DAG.getUNDEF(ScalarVT);
19784
19785 // Check if the result type doesn't match the inserted element type. A
19786 // SCALAR_TO_VECTOR may truncate the inserted element and the
19787 // EXTRACT_VECTOR_ELT may widen the extracted vector.
19788 SDValue InOp = VecOp.getOperand(0);
19789 if (InOp.getValueType() != ScalarVT) {
19790 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger() &&
19791 InOp.getValueType().bitsGT(ScalarVT));
19792 return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, InOp);
19793 }
19794 return InOp;
19795 }
19796
19797 // extract_vector_elt of out-of-bounds element -> UNDEF
19798 auto *IndexC = dyn_cast<ConstantSDNode>(Index);
19799 if (IndexC && VecVT.isFixedLengthVector() &&
19800 IndexC->getAPIntValue().uge(VecVT.getVectorNumElements()))
19801 return DAG.getUNDEF(ScalarVT);
19802
19803 // extract_vector_elt (build_vector x, y), 1 -> y
19804 if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
19805 VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
19806 TLI.isTypeLegal(VecVT) &&
19807 (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) {
19808 assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
19809 VecVT.isFixedLengthVector()) &&
19810 "BUILD_VECTOR used for scalable vectors");
19811 unsigned IndexVal =
19812 VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
19813 SDValue Elt = VecOp.getOperand(IndexVal);
19814 EVT InEltVT = Elt.getValueType();
19815
19816 // Sometimes build_vector's scalar input types do not match result type.
19817 if (ScalarVT == InEltVT)
19818 return Elt;
19819
19820 // TODO: It may be useful to truncate if free if the build_vector implicitly
19821 // converts.
19822 }
19823
19824 if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations))
19825 return BO;
19826
19827 if (VecVT.isScalableVector())
19828 return SDValue();
19829
19830 // All the code from this point onwards assumes fixed width vectors, but it's
19831 // possible that some of the combinations could be made to work for scalable
19832 // vectors too.
19833 unsigned NumElts = VecVT.getVectorNumElements();
19834 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
19835
19836 // TODO: These transforms should not require the 'hasOneUse' restriction, but
19837 // there are regressions on multiple targets without it. We can end up with a
19838 // mess of scalar and vector code if we reduce only part of the DAG to scalar.
19839 if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
19840 VecOp.hasOneUse()) {
19841 // The vector index of the LSBs of the source depend on the endian-ness.
19842 bool IsLE = DAG.getDataLayout().isLittleEndian();
19843 unsigned ExtractIndex = IndexC->getZExtValue();
19844 // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
19845 unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
19846 SDValue BCSrc = VecOp.getOperand(0);
19847 if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
19848 return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, BCSrc);
19849
19850 if (LegalTypes && BCSrc.getValueType().isInteger() &&
19851 BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
19852 // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
19853 // trunc i64 X to i32
19854 SDValue X = BCSrc.getOperand(0);
19855 assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
19856 "Extract element and scalar to vector can't change element type "
19857 "from FP to integer.");
19858 unsigned XBitWidth = X.getValueSizeInBits();
19859 BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
19860
19861 // An extract element return value type can be wider than its vector
19862 // operand element type. In that case, the high bits are undefined, so
19863 // it's possible that we may need to extend rather than truncate.
19864 if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
19865 assert(XBitWidth % VecEltBitWidth == 0 &&
19866 "Scalar bitwidth must be a multiple of vector element bitwidth");
19867 return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
19868 }
19869 }
19870 }
19871
19872 // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
19873 // We only perform this optimization before the op legalization phase because
19874 // we may introduce new vector instructions which are not backed by TD
19875 // patterns. For example on AVX, extracting elements from a wide vector
19876 // without using extract_subvector. However, if we can find an underlying
19877 // scalar value, then we can always use that.
19878 if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
19879 auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
19880 // Find the new index to extract from.
19881 int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
19882
19883 // Extracting an undef index is undef.
19884 if (OrigElt == -1)
19885 return DAG.getUNDEF(ScalarVT);
19886
19887 // Select the right vector half to extract from.
19888 SDValue SVInVec;
19889 if (OrigElt < (int)NumElts) {
19890 SVInVec = VecOp.getOperand(0);
19891 } else {
19892 SVInVec = VecOp.getOperand(1);
19893 OrigElt -= NumElts;
19894 }
19895
19896 if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
19897 SDValue InOp = SVInVec.getOperand(OrigElt);
19898 if (InOp.getValueType() != ScalarVT) {
19899 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
19900 InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
19901 }
19902
19903 return InOp;
19904 }
19905
19906 // FIXME: We should handle recursing on other vector shuffles and
19907 // scalar_to_vector here as well.
19908
19909 if (!LegalOperations ||
19910 // FIXME: Should really be just isOperationLegalOrCustom.
19911 TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
19912 TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
19913 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
19914 DAG.getVectorIdxConstant(OrigElt, DL));
19915 }
19916 }
19917
19918 // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
19919 // simplify it based on the (valid) extraction indices.
19920 if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) {
19921 return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
19922 Use->getOperand(0) == VecOp &&
19923 isa<ConstantSDNode>(Use->getOperand(1));
19924 })) {
19925 APInt DemandedElts = APInt::getZero(NumElts);
19926 for (SDNode *Use : VecOp->uses()) {
19927 auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1));
19928 if (CstElt->getAPIntValue().ult(NumElts))
19929 DemandedElts.setBit(CstElt->getZExtValue());
19930 }
19931 if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
19932 // We simplified the vector operand of this extract element. If this
19933 // extract is not dead, visit it again so it is folded properly.
19934 if (N->getOpcode() != ISD::DELETED_NODE)
19935 AddToWorklist(N);
19936 return SDValue(N, 0);
19937 }
19938 APInt DemandedBits = APInt::getAllOnes(VecEltBitWidth);
19939 if (SimplifyDemandedBits(VecOp, DemandedBits, DemandedElts, true)) {
19940 // We simplified the vector operand of this extract element. If this
19941 // extract is not dead, visit it again so it is folded properly.
19942 if (N->getOpcode() != ISD::DELETED_NODE)
19943 AddToWorklist(N);
19944 return SDValue(N, 0);
19945 }
19946 }
19947
19948 // Everything under here is trying to match an extract of a loaded value.
19949 // If the result of load has to be truncated, then it's not necessarily
19950 // profitable.
19951 bool BCNumEltsChanged = false;
19952 EVT ExtVT = VecVT.getVectorElementType();
19953 EVT LVT = ExtVT;
19954 if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
19955 return SDValue();
19956
19957 if (VecOp.getOpcode() == ISD::BITCAST) {
19958 // Don't duplicate a load with other uses.
19959 if (!VecOp.hasOneUse())
19960 return SDValue();
19961
19962 EVT BCVT = VecOp.getOperand(0).getValueType();
19963 if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
19964 return SDValue();
19965 if (NumElts != BCVT.getVectorNumElements())
19966 BCNumEltsChanged = true;
19967 VecOp = VecOp.getOperand(0);
19968 ExtVT = BCVT.getVectorElementType();
19969 }
19970
19971 // extract (vector load $addr), i --> load $addr + i * size
19972 if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
19973 ISD::isNormalLoad(VecOp.getNode()) &&
19974 !Index->hasPredecessor(VecOp.getNode())) {
19975 auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
19976 if (VecLoad && VecLoad->isSimple())
19977 return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
19978 }
19979
19980 // Perform only after legalization to ensure build_vector / vector_shuffle
19981 // optimizations have already been done.
19982 if (!LegalOperations || !IndexC)
19983 return SDValue();
19984
19985 // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
19986 // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
19987 // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
19988 int Elt = IndexC->getZExtValue();
19989 LoadSDNode *LN0 = nullptr;
19990 if (ISD::isNormalLoad(VecOp.getNode())) {
19991 LN0 = cast<LoadSDNode>(VecOp);
19992 } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
19993 VecOp.getOperand(0).getValueType() == ExtVT &&
19994 ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
19995 // Don't duplicate a load with other uses.
19996 if (!VecOp.hasOneUse())
19997 return SDValue();
19998
19999 LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
20000 }
20001 if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
20002 // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
20003 // =>
20004 // (load $addr+1*size)
20005
20006 // Don't duplicate a load with other uses.
20007 if (!VecOp.hasOneUse())
20008 return SDValue();
20009
20010 // If the bit convert changed the number of elements, it is unsafe
20011 // to examine the mask.
20012 if (BCNumEltsChanged)
20013 return SDValue();
20014
20015 // Select the input vector, guarding against out of range extract vector.
20016 int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
20017 VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
20018
20019 if (VecOp.getOpcode() == ISD::BITCAST) {
20020 // Don't duplicate a load with other uses.
20021 if (!VecOp.hasOneUse())
20022 return SDValue();
20023
20024 VecOp = VecOp.getOperand(0);
20025 }
20026 if (ISD::isNormalLoad(VecOp.getNode())) {
20027 LN0 = cast<LoadSDNode>(VecOp);
20028 Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
20029 Index = DAG.getConstant(Elt, DL, Index.getValueType());
20030 }
20031 } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
20032 VecVT.getVectorElementType() == ScalarVT &&
20033 (!LegalTypes ||
20034 TLI.isTypeLegal(
20035 VecOp.getOperand(0).getValueType().getVectorElementType()))) {
20036 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
20037 // -> extract_vector_elt a, 0
20038 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
20039 // -> extract_vector_elt a, 1
20040 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
20041 // -> extract_vector_elt b, 0
20042 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
20043 // -> extract_vector_elt b, 1
20044 SDLoc SL(N);
20045 EVT ConcatVT = VecOp.getOperand(0).getValueType();
20046 unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
20047 SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, SL,
20048 Index.getValueType());
20049
20050 SDValue ConcatOp = VecOp.getOperand(Elt / ConcatNumElts);
20051 SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL,
20052 ConcatVT.getVectorElementType(),
20053 ConcatOp, NewIdx);
20054 return DAG.getNode(ISD::BITCAST, SL, ScalarVT, Elt);
20055 }
20056
20057 // Make sure we found a non-volatile load and the extractelement is
20058 // the only use.
20059 if (!LN0 || !LN0->hasNUsesOfValue(1,0) || !LN0->isSimple())
20060 return SDValue();
20061
20062 // If Idx was -1 above, Elt is going to be -1, so just return undef.
20063 if (Elt == -1)
20064 return DAG.getUNDEF(LVT);
20065
20066 return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
20067 }
20068
20069 // Simplify (build_vec (ext )) to (bitcast (build_vec ))
reduceBuildVecExtToExtBuildVec(SDNode * N)20070 SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
20071 // We perform this optimization post type-legalization because
20072 // the type-legalizer often scalarizes integer-promoted vectors.
20073 // Performing this optimization before may create bit-casts which
20074 // will be type-legalized to complex code sequences.
20075 // We perform this optimization only before the operation legalizer because we
20076 // may introduce illegal operations.
20077 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
20078 return SDValue();
20079
20080 unsigned NumInScalars = N->getNumOperands();
20081 SDLoc DL(N);
20082 EVT VT = N->getValueType(0);
20083
20084 // Check to see if this is a BUILD_VECTOR of a bunch of values
20085 // which come from any_extend or zero_extend nodes. If so, we can create
20086 // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
20087 // optimizations. We do not handle sign-extend because we can't fill the sign
20088 // using shuffles.
20089 EVT SourceType = MVT::Other;
20090 bool AllAnyExt = true;
20091
20092 for (unsigned i = 0; i != NumInScalars; ++i) {
20093 SDValue In = N->getOperand(i);
20094 // Ignore undef inputs.
20095 if (In.isUndef()) continue;
20096
20097 bool AnyExt = In.getOpcode() == ISD::ANY_EXTEND;
20098 bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
20099
20100 // Abort if the element is not an extension.
20101 if (!ZeroExt && !AnyExt) {
20102 SourceType = MVT::Other;
20103 break;
20104 }
20105
20106 // The input is a ZeroExt or AnyExt. Check the original type.
20107 EVT InTy = In.getOperand(0).getValueType();
20108
20109 // Check that all of the widened source types are the same.
20110 if (SourceType == MVT::Other)
20111 // First time.
20112 SourceType = InTy;
20113 else if (InTy != SourceType) {
20114 // Multiple income types. Abort.
20115 SourceType = MVT::Other;
20116 break;
20117 }
20118
20119 // Check if all of the extends are ANY_EXTENDs.
20120 AllAnyExt &= AnyExt;
20121 }
20122
20123 // In order to have valid types, all of the inputs must be extended from the
20124 // same source type and all of the inputs must be any or zero extend.
20125 // Scalar sizes must be a power of two.
20126 EVT OutScalarTy = VT.getScalarType();
20127 bool ValidTypes = SourceType != MVT::Other &&
20128 isPowerOf2_32(OutScalarTy.getSizeInBits()) &&
20129 isPowerOf2_32(SourceType.getSizeInBits());
20130
20131 // Create a new simpler BUILD_VECTOR sequence which other optimizations can
20132 // turn into a single shuffle instruction.
20133 if (!ValidTypes)
20134 return SDValue();
20135
20136 // If we already have a splat buildvector, then don't fold it if it means
20137 // introducing zeros.
20138 if (!AllAnyExt && DAG.isSplatValue(SDValue(N, 0), /*AllowUndefs*/ true))
20139 return SDValue();
20140
20141 bool isLE = DAG.getDataLayout().isLittleEndian();
20142 unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
20143 assert(ElemRatio > 1 && "Invalid element size ratio");
20144 SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
20145 DAG.getConstant(0, DL, SourceType);
20146
20147 unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
20148 SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
20149
20150 // Populate the new build_vector
20151 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
20152 SDValue Cast = N->getOperand(i);
20153 assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
20154 Cast.getOpcode() == ISD::ZERO_EXTEND ||
20155 Cast.isUndef()) && "Invalid cast opcode");
20156 SDValue In;
20157 if (Cast.isUndef())
20158 In = DAG.getUNDEF(SourceType);
20159 else
20160 In = Cast->getOperand(0);
20161 unsigned Index = isLE ? (i * ElemRatio) :
20162 (i * ElemRatio + (ElemRatio - 1));
20163
20164 assert(Index < Ops.size() && "Invalid index");
20165 Ops[Index] = In;
20166 }
20167
20168 // The type of the new BUILD_VECTOR node.
20169 EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SourceType, NewBVElems);
20170 assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
20171 "Invalid vector size");
20172 // Check if the new vector type is legal.
20173 if (!isTypeLegal(VecVT) ||
20174 (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
20175 TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
20176 return SDValue();
20177
20178 // Make the new BUILD_VECTOR.
20179 SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
20180
20181 // The new BUILD_VECTOR node has the potential to be further optimized.
20182 AddToWorklist(BV.getNode());
20183 // Bitcast to the desired type.
20184 return DAG.getBitcast(VT, BV);
20185 }
20186
20187 // Simplify (build_vec (trunc $1)
20188 // (trunc (srl $1 half-width))
20189 // (trunc (srl $1 (2 * half-width))) …)
20190 // to (bitcast $1)
reduceBuildVecTruncToBitCast(SDNode * N)20191 SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
20192 assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
20193
20194 // Only for little endian
20195 if (!DAG.getDataLayout().isLittleEndian())
20196 return SDValue();
20197
20198 SDLoc DL(N);
20199 EVT VT = N->getValueType(0);
20200 EVT OutScalarTy = VT.getScalarType();
20201 uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
20202
20203 // Only for power of two types to be sure that bitcast works well
20204 if (!isPowerOf2_64(ScalarTypeBitsize))
20205 return SDValue();
20206
20207 unsigned NumInScalars = N->getNumOperands();
20208
20209 // Look through bitcasts
20210 auto PeekThroughBitcast = [](SDValue Op) {
20211 if (Op.getOpcode() == ISD::BITCAST)
20212 return Op.getOperand(0);
20213 return Op;
20214 };
20215
20216 // The source value where all the parts are extracted.
20217 SDValue Src;
20218 for (unsigned i = 0; i != NumInScalars; ++i) {
20219 SDValue In = PeekThroughBitcast(N->getOperand(i));
20220 // Ignore undef inputs.
20221 if (In.isUndef()) continue;
20222
20223 if (In.getOpcode() != ISD::TRUNCATE)
20224 return SDValue();
20225
20226 In = PeekThroughBitcast(In.getOperand(0));
20227
20228 if (In.getOpcode() != ISD::SRL) {
20229 // For now only build_vec without shuffling, handle shifts here in the
20230 // future.
20231 if (i != 0)
20232 return SDValue();
20233
20234 Src = In;
20235 } else {
20236 // In is SRL
20237 SDValue part = PeekThroughBitcast(In.getOperand(0));
20238
20239 if (!Src) {
20240 Src = part;
20241 } else if (Src != part) {
20242 // Vector parts do not stem from the same variable
20243 return SDValue();
20244 }
20245
20246 SDValue ShiftAmtVal = In.getOperand(1);
20247 if (!isa<ConstantSDNode>(ShiftAmtVal))
20248 return SDValue();
20249
20250 uint64_t ShiftAmt = In.getConstantOperandVal(1);
20251
20252 // The extracted value is not extracted at the right position
20253 if (ShiftAmt != i * ScalarTypeBitsize)
20254 return SDValue();
20255 }
20256 }
20257
20258 // Only cast if the size is the same
20259 if (Src.getValueType().getSizeInBits() != VT.getSizeInBits())
20260 return SDValue();
20261
20262 return DAG.getBitcast(VT, Src);
20263 }
20264
createBuildVecShuffle(const SDLoc & DL,SDNode * N,ArrayRef<int> VectorMask,SDValue VecIn1,SDValue VecIn2,unsigned LeftIdx,bool DidSplitVec)20265 SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
20266 ArrayRef<int> VectorMask,
20267 SDValue VecIn1, SDValue VecIn2,
20268 unsigned LeftIdx, bool DidSplitVec) {
20269 SDValue ZeroIdx = DAG.getVectorIdxConstant(0, DL);
20270
20271 EVT VT = N->getValueType(0);
20272 EVT InVT1 = VecIn1.getValueType();
20273 EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
20274
20275 unsigned NumElems = VT.getVectorNumElements();
20276 unsigned ShuffleNumElems = NumElems;
20277
20278 // If we artificially split a vector in two already, then the offsets in the
20279 // operands will all be based off of VecIn1, even those in VecIn2.
20280 unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
20281
20282 uint64_t VTSize = VT.getFixedSizeInBits();
20283 uint64_t InVT1Size = InVT1.getFixedSizeInBits();
20284 uint64_t InVT2Size = InVT2.getFixedSizeInBits();
20285
20286 assert(InVT2Size <= InVT1Size &&
20287 "Inputs must be sorted to be in non-increasing vector size order.");
20288
20289 // We can't generate a shuffle node with mismatched input and output types.
20290 // Try to make the types match the type of the output.
20291 if (InVT1 != VT || InVT2 != VT) {
20292 if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
20293 // If the output vector length is a multiple of both input lengths,
20294 // we can concatenate them and pad the rest with undefs.
20295 unsigned NumConcats = VTSize / InVT1Size;
20296 assert(NumConcats >= 2 && "Concat needs at least two inputs!");
20297 SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
20298 ConcatOps[0] = VecIn1;
20299 ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
20300 VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
20301 VecIn2 = SDValue();
20302 } else if (InVT1Size == VTSize * 2) {
20303 if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
20304 return SDValue();
20305
20306 if (!VecIn2.getNode()) {
20307 // If we only have one input vector, and it's twice the size of the
20308 // output, split it in two.
20309 VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
20310 DAG.getVectorIdxConstant(NumElems, DL));
20311 VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx);
20312 // Since we now have shorter input vectors, adjust the offset of the
20313 // second vector's start.
20314 Vec2Offset = NumElems;
20315 } else {
20316 assert(InVT2Size <= InVT1Size &&
20317 "Second input is not going to be larger than the first one.");
20318
20319 // VecIn1 is wider than the output, and we have another, possibly
20320 // smaller input. Pad the smaller input with undefs, shuffle at the
20321 // input vector width, and extract the output.
20322 // The shuffle type is different than VT, so check legality again.
20323 if (LegalOperations &&
20324 !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
20325 return SDValue();
20326
20327 // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
20328 // lower it back into a BUILD_VECTOR. So if the inserted type is
20329 // illegal, don't even try.
20330 if (InVT1 != InVT2) {
20331 if (!TLI.isTypeLegal(InVT2))
20332 return SDValue();
20333 VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
20334 DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
20335 }
20336 ShuffleNumElems = NumElems * 2;
20337 }
20338 } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
20339 SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
20340 ConcatOps[0] = VecIn2;
20341 VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
20342 } else {
20343 // TODO: Support cases where the length mismatch isn't exactly by a
20344 // factor of 2.
20345 // TODO: Move this check upwards, so that if we have bad type
20346 // mismatches, we don't create any DAG nodes.
20347 return SDValue();
20348 }
20349 }
20350
20351 // Initialize mask to undef.
20352 SmallVector<int, 8> Mask(ShuffleNumElems, -1);
20353
20354 // Only need to run up to the number of elements actually used, not the
20355 // total number of elements in the shuffle - if we are shuffling a wider
20356 // vector, the high lanes should be set to undef.
20357 for (unsigned i = 0; i != NumElems; ++i) {
20358 if (VectorMask[i] <= 0)
20359 continue;
20360
20361 unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
20362 if (VectorMask[i] == (int)LeftIdx) {
20363 Mask[i] = ExtIndex;
20364 } else if (VectorMask[i] == (int)LeftIdx + 1) {
20365 Mask[i] = Vec2Offset + ExtIndex;
20366 }
20367 }
20368
20369 // The type the input vectors may have changed above.
20370 InVT1 = VecIn1.getValueType();
20371
20372 // If we already have a VecIn2, it should have the same type as VecIn1.
20373 // If we don't, get an undef/zero vector of the appropriate type.
20374 VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
20375 assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
20376
20377 SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
20378 if (ShuffleNumElems > NumElems)
20379 Shuffle = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuffle, ZeroIdx);
20380
20381 return Shuffle;
20382 }
20383
reduceBuildVecToShuffleWithZero(SDNode * BV,SelectionDAG & DAG)20384 static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
20385 assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
20386
20387 // First, determine where the build vector is not undef.
20388 // TODO: We could extend this to handle zero elements as well as undefs.
20389 int NumBVOps = BV->getNumOperands();
20390 int ZextElt = -1;
20391 for (int i = 0; i != NumBVOps; ++i) {
20392 SDValue Op = BV->getOperand(i);
20393 if (Op.isUndef())
20394 continue;
20395 if (ZextElt == -1)
20396 ZextElt = i;
20397 else
20398 return SDValue();
20399 }
20400 // Bail out if there's no non-undef element.
20401 if (ZextElt == -1)
20402 return SDValue();
20403
20404 // The build vector contains some number of undef elements and exactly
20405 // one other element. That other element must be a zero-extended scalar
20406 // extracted from a vector at a constant index to turn this into a shuffle.
20407 // Also, require that the build vector does not implicitly truncate/extend
20408 // its elements.
20409 // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
20410 EVT VT = BV->getValueType(0);
20411 SDValue Zext = BV->getOperand(ZextElt);
20412 if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
20413 Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
20414 !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
20415 Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
20416 return SDValue();
20417
20418 // The zero-extend must be a multiple of the source size, and we must be
20419 // building a vector of the same size as the source of the extract element.
20420 SDValue Extract = Zext.getOperand(0);
20421 unsigned DestSize = Zext.getValueSizeInBits();
20422 unsigned SrcSize = Extract.getValueSizeInBits();
20423 if (DestSize % SrcSize != 0 ||
20424 Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
20425 return SDValue();
20426
20427 // Create a shuffle mask that will combine the extracted element with zeros
20428 // and undefs.
20429 int ZextRatio = DestSize / SrcSize;
20430 int NumMaskElts = NumBVOps * ZextRatio;
20431 SmallVector<int, 32> ShufMask(NumMaskElts, -1);
20432 for (int i = 0; i != NumMaskElts; ++i) {
20433 if (i / ZextRatio == ZextElt) {
20434 // The low bits of the (potentially translated) extracted element map to
20435 // the source vector. The high bits map to zero. We will use a zero vector
20436 // as the 2nd source operand of the shuffle, so use the 1st element of
20437 // that vector (mask value is number-of-elements) for the high bits.
20438 if (i % ZextRatio == 0)
20439 ShufMask[i] = Extract.getConstantOperandVal(1);
20440 else
20441 ShufMask[i] = NumMaskElts;
20442 }
20443
20444 // Undef elements of the build vector remain undef because we initialize
20445 // the shuffle mask with -1.
20446 }
20447
20448 // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
20449 // bitcast (shuffle V, ZeroVec, VectorMask)
20450 SDLoc DL(BV);
20451 EVT VecVT = Extract.getOperand(0).getValueType();
20452 SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
20453 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20454 SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
20455 ZeroVec, ShufMask, DAG);
20456 if (!Shuf)
20457 return SDValue();
20458 return DAG.getBitcast(VT, Shuf);
20459 }
20460
20461 // FIXME: promote to STLExtras.
20462 template <typename R, typename T>
getFirstIndexOf(R && Range,const T & Val)20463 static auto getFirstIndexOf(R &&Range, const T &Val) {
20464 auto I = find(Range, Val);
20465 if (I == Range.end())
20466 return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
20467 return std::distance(Range.begin(), I);
20468 }
20469
20470 // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
20471 // operations. If the types of the vectors we're extracting from allow it,
20472 // turn this into a vector_shuffle node.
reduceBuildVecToShuffle(SDNode * N)20473 SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
20474 SDLoc DL(N);
20475 EVT VT = N->getValueType(0);
20476
20477 // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
20478 if (!isTypeLegal(VT))
20479 return SDValue();
20480
20481 if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
20482 return V;
20483
20484 // May only combine to shuffle after legalize if shuffle is legal.
20485 if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
20486 return SDValue();
20487
20488 bool UsesZeroVector = false;
20489 unsigned NumElems = N->getNumOperands();
20490
20491 // Record, for each element of the newly built vector, which input vector
20492 // that element comes from. -1 stands for undef, 0 for the zero vector,
20493 // and positive values for the input vectors.
20494 // VectorMask maps each element to its vector number, and VecIn maps vector
20495 // numbers to their initial SDValues.
20496
20497 SmallVector<int, 8> VectorMask(NumElems, -1);
20498 SmallVector<SDValue, 8> VecIn;
20499 VecIn.push_back(SDValue());
20500
20501 for (unsigned i = 0; i != NumElems; ++i) {
20502 SDValue Op = N->getOperand(i);
20503
20504 if (Op.isUndef())
20505 continue;
20506
20507 // See if we can use a blend with a zero vector.
20508 // TODO: Should we generalize this to a blend with an arbitrary constant
20509 // vector?
20510 if (isNullConstant(Op) || isNullFPConstant(Op)) {
20511 UsesZeroVector = true;
20512 VectorMask[i] = 0;
20513 continue;
20514 }
20515
20516 // Not an undef or zero. If the input is something other than an
20517 // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
20518 if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
20519 !isa<ConstantSDNode>(Op.getOperand(1)))
20520 return SDValue();
20521 SDValue ExtractedFromVec = Op.getOperand(0);
20522
20523 if (ExtractedFromVec.getValueType().isScalableVector())
20524 return SDValue();
20525
20526 const APInt &ExtractIdx = Op.getConstantOperandAPInt(1);
20527 if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements()))
20528 return SDValue();
20529
20530 // All inputs must have the same element type as the output.
20531 if (VT.getVectorElementType() !=
20532 ExtractedFromVec.getValueType().getVectorElementType())
20533 return SDValue();
20534
20535 // Have we seen this input vector before?
20536 // The vectors are expected to be tiny (usually 1 or 2 elements), so using
20537 // a map back from SDValues to numbers isn't worth it.
20538 int Idx = getFirstIndexOf(VecIn, ExtractedFromVec);
20539 if (Idx == -1) { // A new source vector?
20540 Idx = VecIn.size();
20541 VecIn.push_back(ExtractedFromVec);
20542 }
20543
20544 VectorMask[i] = Idx;
20545 }
20546
20547 // If we didn't find at least one input vector, bail out.
20548 if (VecIn.size() < 2)
20549 return SDValue();
20550
20551 // If all the Operands of BUILD_VECTOR extract from same
20552 // vector, then split the vector efficiently based on the maximum
20553 // vector access index and adjust the VectorMask and
20554 // VecIn accordingly.
20555 bool DidSplitVec = false;
20556 if (VecIn.size() == 2) {
20557 unsigned MaxIndex = 0;
20558 unsigned NearestPow2 = 0;
20559 SDValue Vec = VecIn.back();
20560 EVT InVT = Vec.getValueType();
20561 SmallVector<unsigned, 8> IndexVec(NumElems, 0);
20562
20563 for (unsigned i = 0; i < NumElems; i++) {
20564 if (VectorMask[i] <= 0)
20565 continue;
20566 unsigned Index = N->getOperand(i).getConstantOperandVal(1);
20567 IndexVec[i] = Index;
20568 MaxIndex = std::max(MaxIndex, Index);
20569 }
20570
20571 NearestPow2 = PowerOf2Ceil(MaxIndex);
20572 if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
20573 NumElems * 2 < NearestPow2) {
20574 unsigned SplitSize = NearestPow2 / 2;
20575 EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
20576 InVT.getVectorElementType(), SplitSize);
20577 if (TLI.isTypeLegal(SplitVT) &&
20578 SplitSize + SplitVT.getVectorNumElements() <=
20579 InVT.getVectorNumElements()) {
20580 SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
20581 DAG.getVectorIdxConstant(SplitSize, DL));
20582 SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
20583 DAG.getVectorIdxConstant(0, DL));
20584 VecIn.pop_back();
20585 VecIn.push_back(VecIn1);
20586 VecIn.push_back(VecIn2);
20587 DidSplitVec = true;
20588
20589 for (unsigned i = 0; i < NumElems; i++) {
20590 if (VectorMask[i] <= 0)
20591 continue;
20592 VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
20593 }
20594 }
20595 }
20596 }
20597
20598 // Sort input vectors by decreasing vector element count,
20599 // while preserving the relative order of equally-sized vectors.
20600 // Note that we keep the first "implicit zero vector as-is.
20601 SmallVector<SDValue, 8> SortedVecIn(VecIn);
20602 llvm::stable_sort(MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
20603 [](const SDValue &a, const SDValue &b) {
20604 return a.getValueType().getVectorNumElements() >
20605 b.getValueType().getVectorNumElements();
20606 });
20607
20608 // We now also need to rebuild the VectorMask, because it referenced element
20609 // order in VecIn, and we just sorted them.
20610 for (int &SourceVectorIndex : VectorMask) {
20611 if (SourceVectorIndex <= 0)
20612 continue;
20613 unsigned Idx = getFirstIndexOf(SortedVecIn, VecIn[SourceVectorIndex]);
20614 assert(Idx > 0 && Idx < SortedVecIn.size() &&
20615 VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
20616 SourceVectorIndex = Idx;
20617 }
20618
20619 VecIn = std::move(SortedVecIn);
20620
20621 // TODO: Should this fire if some of the input vectors has illegal type (like
20622 // it does now), or should we let legalization run its course first?
20623
20624 // Shuffle phase:
20625 // Take pairs of vectors, and shuffle them so that the result has elements
20626 // from these vectors in the correct places.
20627 // For example, given:
20628 // t10: i32 = extract_vector_elt t1, Constant:i64<0>
20629 // t11: i32 = extract_vector_elt t2, Constant:i64<0>
20630 // t12: i32 = extract_vector_elt t3, Constant:i64<0>
20631 // t13: i32 = extract_vector_elt t1, Constant:i64<1>
20632 // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
20633 // We will generate:
20634 // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
20635 // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
20636 SmallVector<SDValue, 4> Shuffles;
20637 for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
20638 unsigned LeftIdx = 2 * In + 1;
20639 SDValue VecLeft = VecIn[LeftIdx];
20640 SDValue VecRight =
20641 (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
20642
20643 if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
20644 VecRight, LeftIdx, DidSplitVec))
20645 Shuffles.push_back(Shuffle);
20646 else
20647 return SDValue();
20648 }
20649
20650 // If we need the zero vector as an "ingredient" in the blend tree, add it
20651 // to the list of shuffles.
20652 if (UsesZeroVector)
20653 Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
20654 : DAG.getConstantFP(0.0, DL, VT));
20655
20656 // If we only have one shuffle, we're done.
20657 if (Shuffles.size() == 1)
20658 return Shuffles[0];
20659
20660 // Update the vector mask to point to the post-shuffle vectors.
20661 for (int &Vec : VectorMask)
20662 if (Vec == 0)
20663 Vec = Shuffles.size() - 1;
20664 else
20665 Vec = (Vec - 1) / 2;
20666
20667 // More than one shuffle. Generate a binary tree of blends, e.g. if from
20668 // the previous step we got the set of shuffles t10, t11, t12, t13, we will
20669 // generate:
20670 // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
20671 // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
20672 // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
20673 // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
20674 // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
20675 // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
20676 // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
20677
20678 // Make sure the initial size of the shuffle list is even.
20679 if (Shuffles.size() % 2)
20680 Shuffles.push_back(DAG.getUNDEF(VT));
20681
20682 for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
20683 if (CurSize % 2) {
20684 Shuffles[CurSize] = DAG.getUNDEF(VT);
20685 CurSize++;
20686 }
20687 for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
20688 int Left = 2 * In;
20689 int Right = 2 * In + 1;
20690 SmallVector<int, 8> Mask(NumElems, -1);
20691 SDValue L = Shuffles[Left];
20692 ArrayRef<int> LMask;
20693 bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
20694 L.use_empty() && L.getOperand(1).isUndef() &&
20695 L.getOperand(0).getValueType() == L.getValueType();
20696 if (IsLeftShuffle) {
20697 LMask = cast<ShuffleVectorSDNode>(L.getNode())->getMask();
20698 L = L.getOperand(0);
20699 }
20700 SDValue R = Shuffles[Right];
20701 ArrayRef<int> RMask;
20702 bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
20703 R.use_empty() && R.getOperand(1).isUndef() &&
20704 R.getOperand(0).getValueType() == R.getValueType();
20705 if (IsRightShuffle) {
20706 RMask = cast<ShuffleVectorSDNode>(R.getNode())->getMask();
20707 R = R.getOperand(0);
20708 }
20709 for (unsigned I = 0; I != NumElems; ++I) {
20710 if (VectorMask[I] == Left) {
20711 Mask[I] = I;
20712 if (IsLeftShuffle)
20713 Mask[I] = LMask[I];
20714 VectorMask[I] = In;
20715 } else if (VectorMask[I] == Right) {
20716 Mask[I] = I + NumElems;
20717 if (IsRightShuffle)
20718 Mask[I] = RMask[I] + NumElems;
20719 VectorMask[I] = In;
20720 }
20721 }
20722
20723 Shuffles[In] = DAG.getVectorShuffle(VT, DL, L, R, Mask);
20724 }
20725 }
20726 return Shuffles[0];
20727 }
20728
20729 // Try to turn a build vector of zero extends of extract vector elts into a
20730 // a vector zero extend and possibly an extract subvector.
20731 // TODO: Support sign extend?
20732 // TODO: Allow undef elements?
convertBuildVecZextToZext(SDNode * N)20733 SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
20734 if (LegalOperations)
20735 return SDValue();
20736
20737 EVT VT = N->getValueType(0);
20738
20739 bool FoundZeroExtend = false;
20740 SDValue Op0 = N->getOperand(0);
20741 auto checkElem = [&](SDValue Op) -> int64_t {
20742 unsigned Opc = Op.getOpcode();
20743 FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
20744 if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
20745 Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20746 Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
20747 if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
20748 return C->getZExtValue();
20749 return -1;
20750 };
20751
20752 // Make sure the first element matches
20753 // (zext (extract_vector_elt X, C))
20754 // Offset must be a constant multiple of the
20755 // known-minimum vector length of the result type.
20756 int64_t Offset = checkElem(Op0);
20757 if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
20758 return SDValue();
20759
20760 unsigned NumElems = N->getNumOperands();
20761 SDValue In = Op0.getOperand(0).getOperand(0);
20762 EVT InSVT = In.getValueType().getScalarType();
20763 EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
20764
20765 // Don't create an illegal input type after type legalization.
20766 if (LegalTypes && !TLI.isTypeLegal(InVT))
20767 return SDValue();
20768
20769 // Ensure all the elements come from the same vector and are adjacent.
20770 for (unsigned i = 1; i != NumElems; ++i) {
20771 if ((Offset + i) != checkElem(N->getOperand(i)))
20772 return SDValue();
20773 }
20774
20775 SDLoc DL(N);
20776 In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
20777 Op0.getOperand(0).getOperand(1));
20778 return DAG.getNode(FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
20779 VT, In);
20780 }
20781
visitBUILD_VECTOR(SDNode * N)20782 SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
20783 EVT VT = N->getValueType(0);
20784
20785 // A vector built entirely of undefs is undef.
20786 if (ISD::allOperandsUndef(N))
20787 return DAG.getUNDEF(VT);
20788
20789 // If this is a splat of a bitcast from another vector, change to a
20790 // concat_vector.
20791 // For example:
20792 // (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
20793 // (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
20794 //
20795 // If X is a build_vector itself, the concat can become a larger build_vector.
20796 // TODO: Maybe this is useful for non-splat too?
20797 if (!LegalOperations) {
20798 if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) {
20799 Splat = peekThroughBitcasts(Splat);
20800 EVT SrcVT = Splat.getValueType();
20801 if (SrcVT.isVector()) {
20802 unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
20803 EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
20804 SrcVT.getVectorElementType(), NumElts);
20805 if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
20806 SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
20807 SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N),
20808 NewVT, Ops);
20809 return DAG.getBitcast(VT, Concat);
20810 }
20811 }
20812 }
20813 }
20814
20815 // Check if we can express BUILD VECTOR via subvector extract.
20816 if (!LegalTypes && (N->getNumOperands() > 1)) {
20817 SDValue Op0 = N->getOperand(0);
20818 auto checkElem = [&](SDValue Op) -> uint64_t {
20819 if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
20820 (Op0.getOperand(0) == Op.getOperand(0)))
20821 if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
20822 return CNode->getZExtValue();
20823 return -1;
20824 };
20825
20826 int Offset = checkElem(Op0);
20827 for (unsigned i = 0; i < N->getNumOperands(); ++i) {
20828 if (Offset + i != checkElem(N->getOperand(i))) {
20829 Offset = -1;
20830 break;
20831 }
20832 }
20833
20834 if ((Offset == 0) &&
20835 (Op0.getOperand(0).getValueType() == N->getValueType(0)))
20836 return Op0.getOperand(0);
20837 if ((Offset != -1) &&
20838 ((Offset % N->getValueType(0).getVectorNumElements()) ==
20839 0)) // IDX must be multiple of output size.
20840 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
20841 Op0.getOperand(0), Op0.getOperand(1));
20842 }
20843
20844 if (SDValue V = convertBuildVecZextToZext(N))
20845 return V;
20846
20847 if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
20848 return V;
20849
20850 if (SDValue V = reduceBuildVecTruncToBitCast(N))
20851 return V;
20852
20853 if (SDValue V = reduceBuildVecToShuffle(N))
20854 return V;
20855
20856 // A splat of a single element is a SPLAT_VECTOR if supported on the target.
20857 // Do this late as some of the above may replace the splat.
20858 if (TLI.getOperationAction(ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
20859 if (SDValue V = cast<BuildVectorSDNode>(N)->getSplatValue()) {
20860 assert(!V.isUndef() && "Splat of undef should have been handled earlier");
20861 return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V);
20862 }
20863
20864 return SDValue();
20865 }
20866
combineConcatVectorOfScalars(SDNode * N,SelectionDAG & DAG)20867 static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
20868 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20869 EVT OpVT = N->getOperand(0).getValueType();
20870
20871 // If the operands are legal vectors, leave them alone.
20872 if (TLI.isTypeLegal(OpVT))
20873 return SDValue();
20874
20875 SDLoc DL(N);
20876 EVT VT = N->getValueType(0);
20877 SmallVector<SDValue, 8> Ops;
20878
20879 EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
20880 SDValue ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
20881
20882 // Keep track of what we encounter.
20883 bool AnyInteger = false;
20884 bool AnyFP = false;
20885 for (const SDValue &Op : N->ops()) {
20886 if (ISD::BITCAST == Op.getOpcode() &&
20887 !Op.getOperand(0).getValueType().isVector())
20888 Ops.push_back(Op.getOperand(0));
20889 else if (ISD::UNDEF == Op.getOpcode())
20890 Ops.push_back(ScalarUndef);
20891 else
20892 return SDValue();
20893
20894 // Note whether we encounter an integer or floating point scalar.
20895 // If it's neither, bail out, it could be something weird like x86mmx.
20896 EVT LastOpVT = Ops.back().getValueType();
20897 if (LastOpVT.isFloatingPoint())
20898 AnyFP = true;
20899 else if (LastOpVT.isInteger())
20900 AnyInteger = true;
20901 else
20902 return SDValue();
20903 }
20904
20905 // If any of the operands is a floating point scalar bitcast to a vector,
20906 // use floating point types throughout, and bitcast everything.
20907 // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
20908 if (AnyFP) {
20909 SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits());
20910 ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
20911 if (AnyInteger) {
20912 for (SDValue &Op : Ops) {
20913 if (Op.getValueType() == SVT)
20914 continue;
20915 if (Op.isUndef())
20916 Op = ScalarUndef;
20917 else
20918 Op = DAG.getBitcast(SVT, Op);
20919 }
20920 }
20921 }
20922
20923 EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
20924 VT.getSizeInBits() / SVT.getSizeInBits());
20925 return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
20926 }
20927
20928 // Attempt to merge nested concat_vectors/undefs.
20929 // Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
20930 // --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
combineConcatVectorOfConcatVectors(SDNode * N,SelectionDAG & DAG)20931 static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
20932 SelectionDAG &DAG) {
20933 EVT VT = N->getValueType(0);
20934
20935 // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
20936 EVT SubVT;
20937 SDValue FirstConcat;
20938 for (const SDValue &Op : N->ops()) {
20939 if (Op.isUndef())
20940 continue;
20941 if (Op.getOpcode() != ISD::CONCAT_VECTORS)
20942 return SDValue();
20943 if (!FirstConcat) {
20944 SubVT = Op.getOperand(0).getValueType();
20945 if (!DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
20946 return SDValue();
20947 FirstConcat = Op;
20948 continue;
20949 }
20950 if (SubVT != Op.getOperand(0).getValueType())
20951 return SDValue();
20952 }
20953 assert(FirstConcat && "Concat of all-undefs found");
20954
20955 SmallVector<SDValue> ConcatOps;
20956 for (const SDValue &Op : N->ops()) {
20957 if (Op.isUndef()) {
20958 ConcatOps.append(FirstConcat->getNumOperands(), DAG.getUNDEF(SubVT));
20959 continue;
20960 }
20961 ConcatOps.append(Op->op_begin(), Op->op_end());
20962 }
20963 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, ConcatOps);
20964 }
20965
20966 // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
20967 // operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
20968 // most two distinct vectors the same size as the result, attempt to turn this
20969 // into a legal shuffle.
combineConcatVectorOfExtracts(SDNode * N,SelectionDAG & DAG)20970 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
20971 EVT VT = N->getValueType(0);
20972 EVT OpVT = N->getOperand(0).getValueType();
20973
20974 // We currently can't generate an appropriate shuffle for a scalable vector.
20975 if (VT.isScalableVector())
20976 return SDValue();
20977
20978 int NumElts = VT.getVectorNumElements();
20979 int NumOpElts = OpVT.getVectorNumElements();
20980
20981 SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
20982 SmallVector<int, 8> Mask;
20983
20984 for (SDValue Op : N->ops()) {
20985 Op = peekThroughBitcasts(Op);
20986
20987 // UNDEF nodes convert to UNDEF shuffle mask values.
20988 if (Op.isUndef()) {
20989 Mask.append((unsigned)NumOpElts, -1);
20990 continue;
20991 }
20992
20993 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
20994 return SDValue();
20995
20996 // What vector are we extracting the subvector from and at what index?
20997 SDValue ExtVec = Op.getOperand(0);
20998 int ExtIdx = Op.getConstantOperandVal(1);
20999
21000 // We want the EVT of the original extraction to correctly scale the
21001 // extraction index.
21002 EVT ExtVT = ExtVec.getValueType();
21003 ExtVec = peekThroughBitcasts(ExtVec);
21004
21005 // UNDEF nodes convert to UNDEF shuffle mask values.
21006 if (ExtVec.isUndef()) {
21007 Mask.append((unsigned)NumOpElts, -1);
21008 continue;
21009 }
21010
21011 // Ensure that we are extracting a subvector from a vector the same
21012 // size as the result.
21013 if (ExtVT.getSizeInBits() != VT.getSizeInBits())
21014 return SDValue();
21015
21016 // Scale the subvector index to account for any bitcast.
21017 int NumExtElts = ExtVT.getVectorNumElements();
21018 if (0 == (NumExtElts % NumElts))
21019 ExtIdx /= (NumExtElts / NumElts);
21020 else if (0 == (NumElts % NumExtElts))
21021 ExtIdx *= (NumElts / NumExtElts);
21022 else
21023 return SDValue();
21024
21025 // At most we can reference 2 inputs in the final shuffle.
21026 if (SV0.isUndef() || SV0 == ExtVec) {
21027 SV0 = ExtVec;
21028 for (int i = 0; i != NumOpElts; ++i)
21029 Mask.push_back(i + ExtIdx);
21030 } else if (SV1.isUndef() || SV1 == ExtVec) {
21031 SV1 = ExtVec;
21032 for (int i = 0; i != NumOpElts; ++i)
21033 Mask.push_back(i + ExtIdx + NumElts);
21034 } else {
21035 return SDValue();
21036 }
21037 }
21038
21039 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21040 return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
21041 DAG.getBitcast(VT, SV1), Mask, DAG);
21042 }
21043
combineConcatVectorOfCasts(SDNode * N,SelectionDAG & DAG)21044 static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
21045 unsigned CastOpcode = N->getOperand(0).getOpcode();
21046 switch (CastOpcode) {
21047 case ISD::SINT_TO_FP:
21048 case ISD::UINT_TO_FP:
21049 case ISD::FP_TO_SINT:
21050 case ISD::FP_TO_UINT:
21051 // TODO: Allow more opcodes?
21052 // case ISD::BITCAST:
21053 // case ISD::TRUNCATE:
21054 // case ISD::ZERO_EXTEND:
21055 // case ISD::SIGN_EXTEND:
21056 // case ISD::FP_EXTEND:
21057 break;
21058 default:
21059 return SDValue();
21060 }
21061
21062 EVT SrcVT = N->getOperand(0).getOperand(0).getValueType();
21063 if (!SrcVT.isVector())
21064 return SDValue();
21065
21066 // All operands of the concat must be the same kind of cast from the same
21067 // source type.
21068 SmallVector<SDValue, 4> SrcOps;
21069 for (SDValue Op : N->ops()) {
21070 if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
21071 Op.getOperand(0).getValueType() != SrcVT)
21072 return SDValue();
21073 SrcOps.push_back(Op.getOperand(0));
21074 }
21075
21076 // The wider cast must be supported by the target. This is unusual because
21077 // the operation support type parameter depends on the opcode. In addition,
21078 // check the other type in the cast to make sure this is really legal.
21079 EVT VT = N->getValueType(0);
21080 EVT SrcEltVT = SrcVT.getVectorElementType();
21081 ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
21082 EVT ConcatSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcEltVT, NumElts);
21083 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21084 switch (CastOpcode) {
21085 case ISD::SINT_TO_FP:
21086 case ISD::UINT_TO_FP:
21087 if (!TLI.isOperationLegalOrCustom(CastOpcode, ConcatSrcVT) ||
21088 !TLI.isTypeLegal(VT))
21089 return SDValue();
21090 break;
21091 case ISD::FP_TO_SINT:
21092 case ISD::FP_TO_UINT:
21093 if (!TLI.isOperationLegalOrCustom(CastOpcode, VT) ||
21094 !TLI.isTypeLegal(ConcatSrcVT))
21095 return SDValue();
21096 break;
21097 default:
21098 llvm_unreachable("Unexpected cast opcode");
21099 }
21100
21101 // concat (cast X), (cast Y)... -> cast (concat X, Y...)
21102 SDLoc DL(N);
21103 SDValue NewConcat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatSrcVT, SrcOps);
21104 return DAG.getNode(CastOpcode, DL, VT, NewConcat);
21105 }
21106
visitCONCAT_VECTORS(SDNode * N)21107 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
21108 // If we only have one input vector, we don't need to do any concatenation.
21109 if (N->getNumOperands() == 1)
21110 return N->getOperand(0);
21111
21112 // Check if all of the operands are undefs.
21113 EVT VT = N->getValueType(0);
21114 if (ISD::allOperandsUndef(N))
21115 return DAG.getUNDEF(VT);
21116
21117 // Optimize concat_vectors where all but the first of the vectors are undef.
21118 if (all_of(drop_begin(N->ops()),
21119 [](const SDValue &Op) { return Op.isUndef(); })) {
21120 SDValue In = N->getOperand(0);
21121 assert(In.getValueType().isVector() && "Must concat vectors");
21122
21123 // If the input is a concat_vectors, just make a larger concat by padding
21124 // with smaller undefs.
21125 if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse()) {
21126 unsigned NumOps = N->getNumOperands() * In.getNumOperands();
21127 SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
21128 Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType()));
21129 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
21130 }
21131
21132 SDValue Scalar = peekThroughOneUseBitcasts(In);
21133
21134 // concat_vectors(scalar_to_vector(scalar), undef) ->
21135 // scalar_to_vector(scalar)
21136 if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
21137 Scalar.hasOneUse()) {
21138 EVT SVT = Scalar.getValueType().getVectorElementType();
21139 if (SVT == Scalar.getOperand(0).getValueType())
21140 Scalar = Scalar.getOperand(0);
21141 }
21142
21143 // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
21144 if (!Scalar.getValueType().isVector()) {
21145 // If the bitcast type isn't legal, it might be a trunc of a legal type;
21146 // look through the trunc so we can still do the transform:
21147 // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
21148 if (Scalar->getOpcode() == ISD::TRUNCATE &&
21149 !TLI.isTypeLegal(Scalar.getValueType()) &&
21150 TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
21151 Scalar = Scalar->getOperand(0);
21152
21153 EVT SclTy = Scalar.getValueType();
21154
21155 if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
21156 return SDValue();
21157
21158 // Bail out if the vector size is not a multiple of the scalar size.
21159 if (VT.getSizeInBits() % SclTy.getSizeInBits())
21160 return SDValue();
21161
21162 unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
21163 if (VNTNumElms < 2)
21164 return SDValue();
21165
21166 EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
21167 if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
21168 return SDValue();
21169
21170 SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
21171 return DAG.getBitcast(VT, Res);
21172 }
21173 }
21174
21175 // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
21176 // We have already tested above for an UNDEF only concatenation.
21177 // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
21178 // -> (BUILD_VECTOR A, B, ..., C, D, ...)
21179 auto IsBuildVectorOrUndef = [](const SDValue &Op) {
21180 return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
21181 };
21182 if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
21183 SmallVector<SDValue, 8> Opnds;
21184 EVT SVT = VT.getScalarType();
21185
21186 EVT MinVT = SVT;
21187 if (!SVT.isFloatingPoint()) {
21188 // If BUILD_VECTOR are from built from integer, they may have different
21189 // operand types. Get the smallest type and truncate all operands to it.
21190 bool FoundMinVT = false;
21191 for (const SDValue &Op : N->ops())
21192 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
21193 EVT OpSVT = Op.getOperand(0).getValueType();
21194 MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
21195 FoundMinVT = true;
21196 }
21197 assert(FoundMinVT && "Concat vector type mismatch");
21198 }
21199
21200 for (const SDValue &Op : N->ops()) {
21201 EVT OpVT = Op.getValueType();
21202 unsigned NumElts = OpVT.getVectorNumElements();
21203
21204 if (ISD::UNDEF == Op.getOpcode())
21205 Opnds.append(NumElts, DAG.getUNDEF(MinVT));
21206
21207 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
21208 if (SVT.isFloatingPoint()) {
21209 assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
21210 Opnds.append(Op->op_begin(), Op->op_begin() + NumElts);
21211 } else {
21212 for (unsigned i = 0; i != NumElts; ++i)
21213 Opnds.push_back(
21214 DAG.getNode(ISD::TRUNCATE, SDLoc(N), MinVT, Op.getOperand(i)));
21215 }
21216 }
21217 }
21218
21219 assert(VT.getVectorNumElements() == Opnds.size() &&
21220 "Concat vector type mismatch");
21221 return DAG.getBuildVector(VT, SDLoc(N), Opnds);
21222 }
21223
21224 // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
21225 // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
21226 if (SDValue V = combineConcatVectorOfScalars(N, DAG))
21227 return V;
21228
21229 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
21230 // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
21231 if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
21232 return V;
21233
21234 // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
21235 if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
21236 return V;
21237 }
21238
21239 if (SDValue V = combineConcatVectorOfCasts(N, DAG))
21240 return V;
21241
21242 // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
21243 // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
21244 // operands and look for a CONCAT operations that place the incoming vectors
21245 // at the exact same location.
21246 //
21247 // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
21248 SDValue SingleSource = SDValue();
21249 unsigned PartNumElem =
21250 N->getOperand(0).getValueType().getVectorMinNumElements();
21251
21252 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
21253 SDValue Op = N->getOperand(i);
21254
21255 if (Op.isUndef())
21256 continue;
21257
21258 // Check if this is the identity extract:
21259 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
21260 return SDValue();
21261
21262 // Find the single incoming vector for the extract_subvector.
21263 if (SingleSource.getNode()) {
21264 if (Op.getOperand(0) != SingleSource)
21265 return SDValue();
21266 } else {
21267 SingleSource = Op.getOperand(0);
21268
21269 // Check the source type is the same as the type of the result.
21270 // If not, this concat may extend the vector, so we can not
21271 // optimize it away.
21272 if (SingleSource.getValueType() != N->getValueType(0))
21273 return SDValue();
21274 }
21275
21276 // Check that we are reading from the identity index.
21277 unsigned IdentityIndex = i * PartNumElem;
21278 if (Op.getConstantOperandAPInt(1) != IdentityIndex)
21279 return SDValue();
21280 }
21281
21282 if (SingleSource.getNode())
21283 return SingleSource;
21284
21285 return SDValue();
21286 }
21287
21288 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
21289 // if the subvector can be sourced for free.
getSubVectorSrc(SDValue V,SDValue Index,EVT SubVT)21290 static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
21291 if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
21292 V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
21293 return V.getOperand(1);
21294 }
21295 auto *IndexC = dyn_cast<ConstantSDNode>(Index);
21296 if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
21297 V.getOperand(0).getValueType() == SubVT &&
21298 (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
21299 uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
21300 return V.getOperand(SubIdx);
21301 }
21302 return SDValue();
21303 }
21304
narrowInsertExtractVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)21305 static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
21306 SelectionDAG &DAG,
21307 bool LegalOperations) {
21308 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21309 SDValue BinOp = Extract->getOperand(0);
21310 unsigned BinOpcode = BinOp.getOpcode();
21311 if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1)
21312 return SDValue();
21313
21314 EVT VecVT = BinOp.getValueType();
21315 SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
21316 if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
21317 return SDValue();
21318
21319 SDValue Index = Extract->getOperand(1);
21320 EVT SubVT = Extract->getValueType(0);
21321 if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
21322 return SDValue();
21323
21324 SDValue Sub0 = getSubVectorSrc(Bop0, Index, SubVT);
21325 SDValue Sub1 = getSubVectorSrc(Bop1, Index, SubVT);
21326
21327 // TODO: We could handle the case where only 1 operand is being inserted by
21328 // creating an extract of the other operand, but that requires checking
21329 // number of uses and/or costs.
21330 if (!Sub0 || !Sub1)
21331 return SDValue();
21332
21333 // We are inserting both operands of the wide binop only to extract back
21334 // to the narrow vector size. Eliminate all of the insert/extract:
21335 // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
21336 return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
21337 BinOp->getFlags());
21338 }
21339
21340 /// If we are extracting a subvector produced by a wide binary operator try
21341 /// to use a narrow binary operator and/or avoid concatenation and extraction.
narrowExtractedVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)21342 static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
21343 bool LegalOperations) {
21344 // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
21345 // some of these bailouts with other transforms.
21346
21347 if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
21348 return V;
21349
21350 // The extract index must be a constant, so we can map it to a concat operand.
21351 auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
21352 if (!ExtractIndexC)
21353 return SDValue();
21354
21355 // We are looking for an optionally bitcasted wide vector binary operator
21356 // feeding an extract subvector.
21357 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21358 SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
21359 unsigned BOpcode = BinOp.getOpcode();
21360 if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1)
21361 return SDValue();
21362
21363 // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
21364 // reduced to the unary fneg when it is visited, and we probably want to deal
21365 // with fneg in a target-specific way.
21366 if (BOpcode == ISD::FSUB) {
21367 auto *C = isConstOrConstSplatFP(BinOp.getOperand(0), /*AllowUndefs*/ true);
21368 if (C && C->getValueAPF().isNegZero())
21369 return SDValue();
21370 }
21371
21372 // The binop must be a vector type, so we can extract some fraction of it.
21373 EVT WideBVT = BinOp.getValueType();
21374 // The optimisations below currently assume we are dealing with fixed length
21375 // vectors. It is possible to add support for scalable vectors, but at the
21376 // moment we've done no analysis to prove whether they are profitable or not.
21377 if (!WideBVT.isFixedLengthVector())
21378 return SDValue();
21379
21380 EVT VT = Extract->getValueType(0);
21381 unsigned ExtractIndex = ExtractIndexC->getZExtValue();
21382 assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
21383 "Extract index is not a multiple of the vector length.");
21384
21385 // Bail out if this is not a proper multiple width extraction.
21386 unsigned WideWidth = WideBVT.getSizeInBits();
21387 unsigned NarrowWidth = VT.getSizeInBits();
21388 if (WideWidth % NarrowWidth != 0)
21389 return SDValue();
21390
21391 // Bail out if we are extracting a fraction of a single operation. This can
21392 // occur because we potentially looked through a bitcast of the binop.
21393 unsigned NarrowingRatio = WideWidth / NarrowWidth;
21394 unsigned WideNumElts = WideBVT.getVectorNumElements();
21395 if (WideNumElts % NarrowingRatio != 0)
21396 return SDValue();
21397
21398 // Bail out if the target does not support a narrower version of the binop.
21399 EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
21400 WideNumElts / NarrowingRatio);
21401 if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT))
21402 return SDValue();
21403
21404 // If extraction is cheap, we don't need to look at the binop operands
21405 // for concat ops. The narrow binop alone makes this transform profitable.
21406 // We can't just reuse the original extract index operand because we may have
21407 // bitcasted.
21408 unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
21409 unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
21410 if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
21411 BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
21412 // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
21413 SDLoc DL(Extract);
21414 SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
21415 SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
21416 BinOp.getOperand(0), NewExtIndex);
21417 SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
21418 BinOp.getOperand(1), NewExtIndex);
21419 SDValue NarrowBinOp =
21420 DAG.getNode(BOpcode, DL, NarrowBVT, X, Y, BinOp->getFlags());
21421 return DAG.getBitcast(VT, NarrowBinOp);
21422 }
21423
21424 // Only handle the case where we are doubling and then halving. A larger ratio
21425 // may require more than two narrow binops to replace the wide binop.
21426 if (NarrowingRatio != 2)
21427 return SDValue();
21428
21429 // TODO: The motivating case for this transform is an x86 AVX1 target. That
21430 // target has temptingly almost legal versions of bitwise logic ops in 256-bit
21431 // flavors, but no other 256-bit integer support. This could be extended to
21432 // handle any binop, but that may require fixing/adding other folds to avoid
21433 // codegen regressions.
21434 if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
21435 return SDValue();
21436
21437 // We need at least one concatenation operation of a binop operand to make
21438 // this transform worthwhile. The concat must double the input vector sizes.
21439 auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
21440 if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
21441 return V.getOperand(ConcatOpNum);
21442 return SDValue();
21443 };
21444 SDValue SubVecL = GetSubVector(peekThroughBitcasts(BinOp.getOperand(0)));
21445 SDValue SubVecR = GetSubVector(peekThroughBitcasts(BinOp.getOperand(1)));
21446
21447 if (SubVecL || SubVecR) {
21448 // If a binop operand was not the result of a concat, we must extract a
21449 // half-sized operand for our new narrow binop:
21450 // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
21451 // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
21452 // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
21453 SDLoc DL(Extract);
21454 SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
21455 SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
21456 : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
21457 BinOp.getOperand(0), IndexC);
21458
21459 SDValue Y = SubVecR ? DAG.getBitcast(NarrowBVT, SubVecR)
21460 : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
21461 BinOp.getOperand(1), IndexC);
21462
21463 SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
21464 return DAG.getBitcast(VT, NarrowBinOp);
21465 }
21466
21467 return SDValue();
21468 }
21469
21470 /// If we are extracting a subvector from a wide vector load, convert to a
21471 /// narrow load to eliminate the extraction:
21472 /// (extract_subvector (load wide vector)) --> (load narrow vector)
narrowExtractedVectorLoad(SDNode * Extract,SelectionDAG & DAG)21473 static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
21474 // TODO: Add support for big-endian. The offset calculation must be adjusted.
21475 if (DAG.getDataLayout().isBigEndian())
21476 return SDValue();
21477
21478 auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
21479 if (!Ld || Ld->getExtensionType() || !Ld->isSimple())
21480 return SDValue();
21481
21482 // Allow targets to opt-out.
21483 EVT VT = Extract->getValueType(0);
21484
21485 // We can only create byte sized loads.
21486 if (!VT.isByteSized())
21487 return SDValue();
21488
21489 unsigned Index = Extract->getConstantOperandVal(1);
21490 unsigned NumElts = VT.getVectorMinNumElements();
21491
21492 // The definition of EXTRACT_SUBVECTOR states that the index must be a
21493 // multiple of the minimum number of elements in the result type.
21494 assert(Index % NumElts == 0 && "The extract subvector index is not a "
21495 "multiple of the result's element count");
21496
21497 // It's fine to use TypeSize here as we know the offset will not be negative.
21498 TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
21499
21500 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21501 if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT))
21502 return SDValue();
21503
21504 // The narrow load will be offset from the base address of the old load if
21505 // we are extracting from something besides index 0 (little-endian).
21506 SDLoc DL(Extract);
21507
21508 // TODO: Use "BaseIndexOffset" to make this more effective.
21509 SDValue NewAddr = DAG.getMemBasePlusOffset(Ld->getBasePtr(), Offset, DL);
21510
21511 uint64_t StoreSize = MemoryLocation::getSizeOrUnknown(VT.getStoreSize());
21512 MachineFunction &MF = DAG.getMachineFunction();
21513 MachineMemOperand *MMO;
21514 if (Offset.isScalable()) {
21515 MachinePointerInfo MPI =
21516 MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
21517 MMO = MF.getMachineMemOperand(Ld->getMemOperand(), MPI, StoreSize);
21518 } else
21519 MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset.getFixedSize(),
21520 StoreSize);
21521
21522 SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
21523 DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
21524 return NewLd;
21525 }
21526
21527 /// Given EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
21528 /// try to produce VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
21529 /// EXTRACT_SUBVECTOR(Op?, ?),
21530 /// Mask'))
21531 /// iff it is legal and profitable to do so. Notably, the trimmed mask
21532 /// (containing only the elements that are extracted)
21533 /// must reference at most two subvectors.
foldExtractSubvectorFromShuffleVector(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)21534 static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
21535 SelectionDAG &DAG,
21536 const TargetLowering &TLI,
21537 bool LegalOperations) {
21538 assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
21539 "Must only be called on EXTRACT_SUBVECTOR's");
21540
21541 SDValue N0 = N->getOperand(0);
21542
21543 // Only deal with non-scalable vectors.
21544 EVT NarrowVT = N->getValueType(0);
21545 EVT WideVT = N0.getValueType();
21546 if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
21547 return SDValue();
21548
21549 // The operand must be a shufflevector.
21550 auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(N0);
21551 if (!WideShuffleVector)
21552 return SDValue();
21553
21554 // The old shuffleneeds to go away.
21555 if (!WideShuffleVector->hasOneUse())
21556 return SDValue();
21557
21558 // And the narrow shufflevector that we'll form must be legal.
21559 if (LegalOperations &&
21560 !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, NarrowVT))
21561 return SDValue();
21562
21563 uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(1);
21564 int NumEltsExtracted = NarrowVT.getVectorNumElements();
21565 assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 &&
21566 "Extract index is not a multiple of the output vector length.");
21567
21568 int WideNumElts = WideVT.getVectorNumElements();
21569
21570 SmallVector<int, 16> NewMask;
21571 NewMask.reserve(NumEltsExtracted);
21572 SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
21573 DemandedSubvectors;
21574
21575 // Try to decode the wide mask into narrow mask from at most two subvectors.
21576 for (int M : WideShuffleVector->getMask().slice(FirstExtractedEltIdx,
21577 NumEltsExtracted)) {
21578 assert((M >= -1) && (M < (2 * WideNumElts)) &&
21579 "Out-of-bounds shuffle mask?");
21580
21581 if (M < 0) {
21582 // Does not depend on operands, does not require adjustment.
21583 NewMask.emplace_back(M);
21584 continue;
21585 }
21586
21587 // From which operand of the shuffle does this shuffle mask element pick?
21588 int WideShufOpIdx = M / WideNumElts;
21589 // Which element of that operand is picked?
21590 int OpEltIdx = M % WideNumElts;
21591
21592 assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
21593 "Shuffle mask vector decomposition failure.");
21594
21595 // And which NumEltsExtracted-sized subvector of that operand is that?
21596 int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
21597 // And which element within that subvector of that operand is that?
21598 int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
21599
21600 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
21601 "Shuffle mask subvector decomposition failure.");
21602
21603 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
21604 WideShufOpIdx * WideNumElts) == M &&
21605 "Shuffle mask full decomposition failure.");
21606
21607 SDValue Op = WideShuffleVector->getOperand(WideShufOpIdx);
21608
21609 if (Op.isUndef()) {
21610 // Picking from an undef operand. Let's adjust mask instead.
21611 NewMask.emplace_back(-1);
21612 continue;
21613 }
21614
21615 // Profitability check: only deal with extractions from the first subvector.
21616 if (OpSubvecIdx != 0)
21617 return SDValue();
21618
21619 const std::pair<SDValue, int> DemandedSubvector =
21620 std::make_pair(Op, OpSubvecIdx);
21621
21622 if (DemandedSubvectors.insert(DemandedSubvector)) {
21623 if (DemandedSubvectors.size() > 2)
21624 return SDValue(); // We can't handle more than two subvectors.
21625 // How many elements into the WideVT does this subvector start?
21626 int Index = NumEltsExtracted * OpSubvecIdx;
21627 // Bail out if the extraction isn't going to be cheap.
21628 if (!TLI.isExtractSubvectorCheap(NarrowVT, WideVT, Index))
21629 return SDValue();
21630 }
21631
21632 // Ok, but from which operand of the new shuffle will this element pick?
21633 int NewOpIdx =
21634 getFirstIndexOf(DemandedSubvectors.getArrayRef(), DemandedSubvector);
21635 assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
21636
21637 int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
21638 NewMask.emplace_back(AdjM);
21639 }
21640 assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
21641 assert(DemandedSubvectors.size() <= 2 &&
21642 "Should have ended up demanding at most two subvectors.");
21643
21644 // Did we discover that the shuffle does not actually depend on operands?
21645 if (DemandedSubvectors.empty())
21646 return DAG.getUNDEF(NarrowVT);
21647
21648 // We still perform the exact same EXTRACT_SUBVECTOR, just on different
21649 // operand[s]/index[es], so there is no point in checking for it's legality.
21650
21651 // Do not turn a legal shuffle into an illegal one.
21652 if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
21653 !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
21654 return SDValue();
21655
21656 SDLoc DL(N);
21657
21658 SmallVector<SDValue, 2> NewOps;
21659 for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
21660 &DemandedSubvector : DemandedSubvectors) {
21661 // How many elements into the WideVT does this subvector start?
21662 int Index = NumEltsExtracted * DemandedSubvector.second;
21663 SDValue IndexC = DAG.getVectorIdxConstant(Index, DL);
21664 NewOps.emplace_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowVT,
21665 DemandedSubvector.first, IndexC));
21666 }
21667 assert((NewOps.size() == 1 || NewOps.size() == 2) &&
21668 "Should end up with either one or two ops");
21669
21670 // If we ended up with only one operand, pad with an undef.
21671 if (NewOps.size() == 1)
21672 NewOps.emplace_back(DAG.getUNDEF(NarrowVT));
21673
21674 return DAG.getVectorShuffle(NarrowVT, DL, NewOps[0], NewOps[1], NewMask);
21675 }
21676
visitEXTRACT_SUBVECTOR(SDNode * N)21677 SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
21678 EVT NVT = N->getValueType(0);
21679 SDValue V = N->getOperand(0);
21680 uint64_t ExtIdx = N->getConstantOperandVal(1);
21681
21682 // Extract from UNDEF is UNDEF.
21683 if (V.isUndef())
21684 return DAG.getUNDEF(NVT);
21685
21686 if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
21687 if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG))
21688 return NarrowLoad;
21689
21690 // Combine an extract of an extract into a single extract_subvector.
21691 // ext (ext X, C), 0 --> ext X, C
21692 if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
21693 if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
21694 V.getConstantOperandVal(1)) &&
21695 TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
21696 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, V.getOperand(0),
21697 V.getOperand(1));
21698 }
21699 }
21700
21701 // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
21702 if (V.getOpcode() == ISD::SPLAT_VECTOR)
21703 if (DAG.isConstantValueOfAnyType(V.getOperand(0)) || V.hasOneUse())
21704 if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT))
21705 return DAG.getSplatVector(NVT, SDLoc(N), V.getOperand(0));
21706
21707 // Try to move vector bitcast after extract_subv by scaling extraction index:
21708 // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
21709 if (V.getOpcode() == ISD::BITCAST &&
21710 V.getOperand(0).getValueType().isVector() &&
21711 (!LegalOperations || TLI.isOperationLegal(ISD::BITCAST, NVT))) {
21712 SDValue SrcOp = V.getOperand(0);
21713 EVT SrcVT = SrcOp.getValueType();
21714 unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
21715 unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
21716 if ((SrcNumElts % DestNumElts) == 0) {
21717 unsigned SrcDestRatio = SrcNumElts / DestNumElts;
21718 ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
21719 EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(),
21720 NewExtEC);
21721 if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
21722 SDLoc DL(N);
21723 SDValue NewIndex = DAG.getVectorIdxConstant(ExtIdx * SrcDestRatio, DL);
21724 SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
21725 V.getOperand(0), NewIndex);
21726 return DAG.getBitcast(NVT, NewExtract);
21727 }
21728 }
21729 if ((DestNumElts % SrcNumElts) == 0) {
21730 unsigned DestSrcRatio = DestNumElts / SrcNumElts;
21731 if (NVT.getVectorElementCount().isKnownMultipleOf(DestSrcRatio)) {
21732 ElementCount NewExtEC =
21733 NVT.getVectorElementCount().divideCoefficientBy(DestSrcRatio);
21734 EVT ScalarVT = SrcVT.getScalarType();
21735 if ((ExtIdx % DestSrcRatio) == 0) {
21736 SDLoc DL(N);
21737 unsigned IndexValScaled = ExtIdx / DestSrcRatio;
21738 EVT NewExtVT =
21739 EVT::getVectorVT(*DAG.getContext(), ScalarVT, NewExtEC);
21740 if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
21741 SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
21742 SDValue NewExtract =
21743 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
21744 V.getOperand(0), NewIndex);
21745 return DAG.getBitcast(NVT, NewExtract);
21746 }
21747 if (NewExtEC.isScalar() &&
21748 TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, ScalarVT)) {
21749 SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
21750 SDValue NewExtract =
21751 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
21752 V.getOperand(0), NewIndex);
21753 return DAG.getBitcast(NVT, NewExtract);
21754 }
21755 }
21756 }
21757 }
21758 }
21759
21760 if (V.getOpcode() == ISD::CONCAT_VECTORS) {
21761 unsigned ExtNumElts = NVT.getVectorMinNumElements();
21762 EVT ConcatSrcVT = V.getOperand(0).getValueType();
21763 assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
21764 "Concat and extract subvector do not change element type");
21765 assert((ExtIdx % ExtNumElts) == 0 &&
21766 "Extract index is not a multiple of the input vector length.");
21767
21768 unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
21769 unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
21770
21771 // If the concatenated source types match this extract, it's a direct
21772 // simplification:
21773 // extract_subvec (concat V1, V2, ...), i --> Vi
21774 if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
21775 return V.getOperand(ConcatOpIdx);
21776
21777 // If the concatenated source vectors are a multiple length of this extract,
21778 // then extract a fraction of one of those source vectors directly from a
21779 // concat operand. Example:
21780 // v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
21781 // v2i8 extract_subvec v8i8 Y, 6
21782 if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
21783 ConcatSrcNumElts % ExtNumElts == 0) {
21784 SDLoc DL(N);
21785 unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
21786 assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
21787 "Trying to extract from >1 concat operand?");
21788 assert(NewExtIdx % ExtNumElts == 0 &&
21789 "Extract index is not a multiple of the input vector length.");
21790 SDValue NewIndexC = DAG.getVectorIdxConstant(NewExtIdx, DL);
21791 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT,
21792 V.getOperand(ConcatOpIdx), NewIndexC);
21793 }
21794 }
21795
21796 if (SDValue V =
21797 foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations))
21798 return V;
21799
21800 V = peekThroughBitcasts(V);
21801
21802 // If the input is a build vector. Try to make a smaller build vector.
21803 if (V.getOpcode() == ISD::BUILD_VECTOR) {
21804 EVT InVT = V.getValueType();
21805 unsigned ExtractSize = NVT.getSizeInBits();
21806 unsigned EltSize = InVT.getScalarSizeInBits();
21807 // Only do this if we won't split any elements.
21808 if (ExtractSize % EltSize == 0) {
21809 unsigned NumElems = ExtractSize / EltSize;
21810 EVT EltVT = InVT.getVectorElementType();
21811 EVT ExtractVT =
21812 NumElems == 1 ? EltVT
21813 : EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems);
21814 if ((Level < AfterLegalizeDAG ||
21815 (NumElems == 1 ||
21816 TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
21817 (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
21818 unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
21819
21820 if (NumElems == 1) {
21821 SDValue Src = V->getOperand(IdxVal);
21822 if (EltVT != Src.getValueType())
21823 Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src);
21824 return DAG.getBitcast(NVT, Src);
21825 }
21826
21827 // Extract the pieces from the original build_vector.
21828 SDValue BuildVec = DAG.getBuildVector(ExtractVT, SDLoc(N),
21829 V->ops().slice(IdxVal, NumElems));
21830 return DAG.getBitcast(NVT, BuildVec);
21831 }
21832 }
21833 }
21834
21835 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
21836 // Handle only simple case where vector being inserted and vector
21837 // being extracted are of same size.
21838 EVT SmallVT = V.getOperand(1).getValueType();
21839 if (!NVT.bitsEq(SmallVT))
21840 return SDValue();
21841
21842 // Combine:
21843 // (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
21844 // Into:
21845 // indices are equal or bit offsets are equal => V1
21846 // otherwise => (extract_subvec V1, ExtIdx)
21847 uint64_t InsIdx = V.getConstantOperandVal(2);
21848 if (InsIdx * SmallVT.getScalarSizeInBits() ==
21849 ExtIdx * NVT.getScalarSizeInBits()) {
21850 if (LegalOperations && !TLI.isOperationLegal(ISD::BITCAST, NVT))
21851 return SDValue();
21852
21853 return DAG.getBitcast(NVT, V.getOperand(1));
21854 }
21855 return DAG.getNode(
21856 ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT,
21857 DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
21858 N->getOperand(1));
21859 }
21860
21861 if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations))
21862 return NarrowBOp;
21863
21864 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
21865 return SDValue(N, 0);
21866
21867 return SDValue();
21868 }
21869
21870 /// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
21871 /// followed by concatenation. Narrow vector ops may have better performance
21872 /// than wide ops, and this can unlock further narrowing of other vector ops.
21873 /// Targets can invert this transform later if it is not profitable.
foldShuffleOfConcatUndefs(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)21874 static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
21875 SelectionDAG &DAG) {
21876 SDValue N0 = Shuf->getOperand(0), N1 = Shuf->getOperand(1);
21877 if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
21878 N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
21879 !N0.getOperand(1).isUndef() || !N1.getOperand(1).isUndef())
21880 return SDValue();
21881
21882 // Split the wide shuffle mask into halves. Any mask element that is accessing
21883 // operand 1 is offset down to account for narrowing of the vectors.
21884 ArrayRef<int> Mask = Shuf->getMask();
21885 EVT VT = Shuf->getValueType(0);
21886 unsigned NumElts = VT.getVectorNumElements();
21887 unsigned HalfNumElts = NumElts / 2;
21888 SmallVector<int, 16> Mask0(HalfNumElts, -1);
21889 SmallVector<int, 16> Mask1(HalfNumElts, -1);
21890 for (unsigned i = 0; i != NumElts; ++i) {
21891 if (Mask[i] == -1)
21892 continue;
21893 // If we reference the upper (undef) subvector then the element is undef.
21894 if ((Mask[i] % NumElts) >= HalfNumElts)
21895 continue;
21896 int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
21897 if (i < HalfNumElts)
21898 Mask0[i] = M;
21899 else
21900 Mask1[i - HalfNumElts] = M;
21901 }
21902
21903 // Ask the target if this is a valid transform.
21904 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21905 EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
21906 HalfNumElts);
21907 if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
21908 !TLI.isShuffleMaskLegal(Mask1, HalfVT))
21909 return SDValue();
21910
21911 // shuffle (concat X, undef), (concat Y, undef), Mask -->
21912 // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
21913 SDValue X = N0.getOperand(0), Y = N1.getOperand(0);
21914 SDLoc DL(Shuf);
21915 SDValue Shuf0 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask0);
21916 SDValue Shuf1 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask1);
21917 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Shuf0, Shuf1);
21918 }
21919
21920 // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
21921 // or turn a shuffle of a single concat into simpler shuffle then concat.
partitionShuffleOfConcats(SDNode * N,SelectionDAG & DAG)21922 static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
21923 EVT VT = N->getValueType(0);
21924 unsigned NumElts = VT.getVectorNumElements();
21925
21926 SDValue N0 = N->getOperand(0);
21927 SDValue N1 = N->getOperand(1);
21928 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
21929 ArrayRef<int> Mask = SVN->getMask();
21930
21931 SmallVector<SDValue, 4> Ops;
21932 EVT ConcatVT = N0.getOperand(0).getValueType();
21933 unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
21934 unsigned NumConcats = NumElts / NumElemsPerConcat;
21935
21936 auto IsUndefMaskElt = [](int i) { return i == -1; };
21937
21938 // Special case: shuffle(concat(A,B)) can be more efficiently represented
21939 // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
21940 // half vector elements.
21941 if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
21942 llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
21943 IsUndefMaskElt)) {
21944 N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
21945 N0.getOperand(1),
21946 Mask.slice(0, NumElemsPerConcat));
21947 N1 = DAG.getUNDEF(ConcatVT);
21948 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
21949 }
21950
21951 // Look at every vector that's inserted. We're looking for exact
21952 // subvector-sized copies from a concatenated vector
21953 for (unsigned I = 0; I != NumConcats; ++I) {
21954 unsigned Begin = I * NumElemsPerConcat;
21955 ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
21956
21957 // Make sure we're dealing with a copy.
21958 if (llvm::all_of(SubMask, IsUndefMaskElt)) {
21959 Ops.push_back(DAG.getUNDEF(ConcatVT));
21960 continue;
21961 }
21962
21963 int OpIdx = -1;
21964 for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
21965 if (IsUndefMaskElt(SubMask[i]))
21966 continue;
21967 if ((SubMask[i] % (int)NumElemsPerConcat) != i)
21968 return SDValue();
21969 int EltOpIdx = SubMask[i] / NumElemsPerConcat;
21970 if (0 <= OpIdx && EltOpIdx != OpIdx)
21971 return SDValue();
21972 OpIdx = EltOpIdx;
21973 }
21974 assert(0 <= OpIdx && "Unknown concat_vectors op");
21975
21976 if (OpIdx < (int)N0.getNumOperands())
21977 Ops.push_back(N0.getOperand(OpIdx));
21978 else
21979 Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
21980 }
21981
21982 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
21983 }
21984
21985 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
21986 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
21987 //
21988 // SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
21989 // a simplification in some sense, but it isn't appropriate in general: some
21990 // BUILD_VECTORs are substantially cheaper than others. The general case
21991 // of a BUILD_VECTOR requires inserting each element individually (or
21992 // performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
21993 // all constants is a single constant pool load. A BUILD_VECTOR where each
21994 // element is identical is a splat. A BUILD_VECTOR where most of the operands
21995 // are undef lowers to a small number of element insertions.
21996 //
21997 // To deal with this, we currently use a bunch of mostly arbitrary heuristics.
21998 // We don't fold shuffles where one side is a non-zero constant, and we don't
21999 // fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
22000 // non-constant operands. This seems to work out reasonably well in practice.
combineShuffleOfScalars(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI)22001 static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
22002 SelectionDAG &DAG,
22003 const TargetLowering &TLI) {
22004 EVT VT = SVN->getValueType(0);
22005 unsigned NumElts = VT.getVectorNumElements();
22006 SDValue N0 = SVN->getOperand(0);
22007 SDValue N1 = SVN->getOperand(1);
22008
22009 if (!N0->hasOneUse())
22010 return SDValue();
22011
22012 // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
22013 // discussed above.
22014 if (!N1.isUndef()) {
22015 if (!N1->hasOneUse())
22016 return SDValue();
22017
22018 bool N0AnyConst = isAnyConstantBuildVector(N0);
22019 bool N1AnyConst = isAnyConstantBuildVector(N1);
22020 if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
22021 return SDValue();
22022 if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
22023 return SDValue();
22024 }
22025
22026 // If both inputs are splats of the same value then we can safely merge this
22027 // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
22028 bool IsSplat = false;
22029 auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
22030 auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
22031 if (BV0 && BV1)
22032 if (SDValue Splat0 = BV0->getSplatValue())
22033 IsSplat = (Splat0 == BV1->getSplatValue());
22034
22035 SmallVector<SDValue, 8> Ops;
22036 SmallSet<SDValue, 16> DuplicateOps;
22037 for (int M : SVN->getMask()) {
22038 SDValue Op = DAG.getUNDEF(VT.getScalarType());
22039 if (M >= 0) {
22040 int Idx = M < (int)NumElts ? M : M - NumElts;
22041 SDValue &S = (M < (int)NumElts ? N0 : N1);
22042 if (S.getOpcode() == ISD::BUILD_VECTOR) {
22043 Op = S.getOperand(Idx);
22044 } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
22045 SDValue Op0 = S.getOperand(0);
22046 Op = Idx == 0 ? Op0 : DAG.getUNDEF(Op0.getValueType());
22047 } else {
22048 // Operand can't be combined - bail out.
22049 return SDValue();
22050 }
22051 }
22052
22053 // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
22054 // generating a splat; semantically, this is fine, but it's likely to
22055 // generate low-quality code if the target can't reconstruct an appropriate
22056 // shuffle.
22057 if (!Op.isUndef() && !isIntOrFPConstant(Op))
22058 if (!IsSplat && !DuplicateOps.insert(Op).second)
22059 return SDValue();
22060
22061 Ops.push_back(Op);
22062 }
22063
22064 // BUILD_VECTOR requires all inputs to be of the same type, find the
22065 // maximum type and extend them all.
22066 EVT SVT = VT.getScalarType();
22067 if (SVT.isInteger())
22068 for (SDValue &Op : Ops)
22069 SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
22070 if (SVT != VT.getScalarType())
22071 for (SDValue &Op : Ops)
22072 Op = Op.isUndef() ? DAG.getUNDEF(SVT)
22073 : (TLI.isZExtFree(Op.getValueType(), SVT)
22074 ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
22075 : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT));
22076 return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
22077 }
22078
22079 // Match shuffles that can be converted to any_vector_extend_in_reg.
22080 // This is often generated during legalization.
22081 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
22082 // TODO Add support for ZERO_EXTEND_VECTOR_INREG when we have a test case.
combineShuffleToVectorExtend(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)22083 static SDValue combineShuffleToVectorExtend(ShuffleVectorSDNode *SVN,
22084 SelectionDAG &DAG,
22085 const TargetLowering &TLI,
22086 bool LegalOperations) {
22087 EVT VT = SVN->getValueType(0);
22088 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
22089
22090 // TODO Add support for big-endian when we have a test case.
22091 if (!VT.isInteger() || IsBigEndian)
22092 return SDValue();
22093
22094 unsigned NumElts = VT.getVectorNumElements();
22095 unsigned EltSizeInBits = VT.getScalarSizeInBits();
22096 ArrayRef<int> Mask = SVN->getMask();
22097 SDValue N0 = SVN->getOperand(0);
22098
22099 // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
22100 auto isAnyExtend = [&Mask, &NumElts](unsigned Scale) {
22101 for (unsigned i = 0; i != NumElts; ++i) {
22102 if (Mask[i] < 0)
22103 continue;
22104 if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
22105 continue;
22106 return false;
22107 }
22108 return true;
22109 };
22110
22111 // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
22112 // power-of-2 extensions as they are the most likely.
22113 for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
22114 // Check for non power of 2 vector sizes
22115 if (NumElts % Scale != 0)
22116 continue;
22117 if (!isAnyExtend(Scale))
22118 continue;
22119
22120 EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
22121 EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
22122 // Never create an illegal type. Only create unsupported operations if we
22123 // are pre-legalization.
22124 if (TLI.isTypeLegal(OutVT))
22125 if (!LegalOperations ||
22126 TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND_VECTOR_INREG, OutVT))
22127 return DAG.getBitcast(VT,
22128 DAG.getNode(ISD::ANY_EXTEND_VECTOR_INREG,
22129 SDLoc(SVN), OutVT, N0));
22130 }
22131
22132 return SDValue();
22133 }
22134
22135 // Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
22136 // each source element of a large type into the lowest elements of a smaller
22137 // destination type. This is often generated during legalization.
22138 // If the source node itself was a '*_extend_vector_inreg' node then we should
22139 // then be able to remove it.
combineTruncationShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG)22140 static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
22141 SelectionDAG &DAG) {
22142 EVT VT = SVN->getValueType(0);
22143 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
22144
22145 // TODO Add support for big-endian when we have a test case.
22146 if (!VT.isInteger() || IsBigEndian)
22147 return SDValue();
22148
22149 SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
22150
22151 unsigned Opcode = N0.getOpcode();
22152 if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG &&
22153 Opcode != ISD::SIGN_EXTEND_VECTOR_INREG &&
22154 Opcode != ISD::ZERO_EXTEND_VECTOR_INREG)
22155 return SDValue();
22156
22157 SDValue N00 = N0.getOperand(0);
22158 ArrayRef<int> Mask = SVN->getMask();
22159 unsigned NumElts = VT.getVectorNumElements();
22160 unsigned EltSizeInBits = VT.getScalarSizeInBits();
22161 unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
22162 unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
22163
22164 if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
22165 return SDValue();
22166 unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
22167
22168 // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
22169 // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
22170 // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
22171 auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
22172 for (unsigned i = 0; i != NumElts; ++i) {
22173 if (Mask[i] < 0)
22174 continue;
22175 if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
22176 continue;
22177 return false;
22178 }
22179 return true;
22180 };
22181
22182 // At the moment we just handle the case where we've truncated back to the
22183 // same size as before the extension.
22184 // TODO: handle more extension/truncation cases as cases arise.
22185 if (EltSizeInBits != ExtSrcSizeInBits)
22186 return SDValue();
22187
22188 // We can remove *extend_vector_inreg only if the truncation happens at
22189 // the same scale as the extension.
22190 if (isTruncate(ExtScale))
22191 return DAG.getBitcast(VT, N00);
22192
22193 return SDValue();
22194 }
22195
22196 // Combine shuffles of splat-shuffles of the form:
22197 // shuffle (shuffle V, undef, splat-mask), undef, M
22198 // If splat-mask contains undef elements, we need to be careful about
22199 // introducing undef's in the folded mask which are not the result of composing
22200 // the masks of the shuffles.
combineShuffleOfSplatVal(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)22201 static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
22202 SelectionDAG &DAG) {
22203 if (!Shuf->getOperand(1).isUndef())
22204 return SDValue();
22205
22206 // If the inner operand is a known splat with no undefs, just return that directly.
22207 // TODO: Create DemandedElts mask from Shuf's mask.
22208 // TODO: Allow undef elements and merge with the shuffle code below.
22209 if (DAG.isSplatValue(Shuf->getOperand(0), /*AllowUndefs*/ false))
22210 return Shuf->getOperand(0);
22211
22212 auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
22213 if (!Splat || !Splat->isSplat())
22214 return SDValue();
22215
22216 ArrayRef<int> ShufMask = Shuf->getMask();
22217 ArrayRef<int> SplatMask = Splat->getMask();
22218 assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
22219
22220 // Prefer simplifying to the splat-shuffle, if possible. This is legal if
22221 // every undef mask element in the splat-shuffle has a corresponding undef
22222 // element in the user-shuffle's mask or if the composition of mask elements
22223 // would result in undef.
22224 // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
22225 // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
22226 // In this case it is not legal to simplify to the splat-shuffle because we
22227 // may be exposing the users of the shuffle an undef element at index 1
22228 // which was not there before the combine.
22229 // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
22230 // In this case the composition of masks yields SplatMask, so it's ok to
22231 // simplify to the splat-shuffle.
22232 // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
22233 // In this case the composed mask includes all undef elements of SplatMask
22234 // and in addition sets element zero to undef. It is safe to simplify to
22235 // the splat-shuffle.
22236 auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
22237 ArrayRef<int> SplatMask) {
22238 for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
22239 if (UserMask[i] != -1 && SplatMask[i] == -1 &&
22240 SplatMask[UserMask[i]] != -1)
22241 return false;
22242 return true;
22243 };
22244 if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
22245 return Shuf->getOperand(0);
22246
22247 // Create a new shuffle with a mask that is composed of the two shuffles'
22248 // masks.
22249 SmallVector<int, 32> NewMask;
22250 for (int Idx : ShufMask)
22251 NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
22252
22253 return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
22254 Splat->getOperand(0), Splat->getOperand(1),
22255 NewMask);
22256 }
22257
22258 // Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
22259 // the mask can be treated as a larger type.
combineShuffleOfBitcast(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)22260 static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
22261 SelectionDAG &DAG,
22262 const TargetLowering &TLI,
22263 bool LegalOperations) {
22264 SDValue Op0 = SVN->getOperand(0);
22265 SDValue Op1 = SVN->getOperand(1);
22266 EVT VT = SVN->getValueType(0);
22267 if (Op0.getOpcode() != ISD::BITCAST)
22268 return SDValue();
22269 EVT InVT = Op0.getOperand(0).getValueType();
22270 if (!InVT.isVector() ||
22271 (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
22272 Op1.getOperand(0).getValueType() != InVT)))
22273 return SDValue();
22274 if (isAnyConstantBuildVector(Op0.getOperand(0)) &&
22275 (Op1.isUndef() || isAnyConstantBuildVector(Op1.getOperand(0))))
22276 return SDValue();
22277
22278 int VTLanes = VT.getVectorNumElements();
22279 int InLanes = InVT.getVectorNumElements();
22280 if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
22281 (LegalOperations &&
22282 !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, InVT)))
22283 return SDValue();
22284 int Factor = VTLanes / InLanes;
22285
22286 // Check that each group of lanes in the mask are either undef or make a valid
22287 // mask for the wider lane type.
22288 ArrayRef<int> Mask = SVN->getMask();
22289 SmallVector<int> NewMask;
22290 if (!widenShuffleMaskElts(Factor, Mask, NewMask))
22291 return SDValue();
22292
22293 if (!TLI.isShuffleMaskLegal(NewMask, InVT))
22294 return SDValue();
22295
22296 // Create the new shuffle with the new mask and bitcast it back to the
22297 // original type.
22298 SDLoc DL(SVN);
22299 Op0 = Op0.getOperand(0);
22300 Op1 = Op1.isUndef() ? DAG.getUNDEF(InVT) : Op1.getOperand(0);
22301 SDValue NewShuf = DAG.getVectorShuffle(InVT, DL, Op0, Op1, NewMask);
22302 return DAG.getBitcast(VT, NewShuf);
22303 }
22304
22305 /// Combine shuffle of shuffle of the form:
22306 /// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
formSplatFromShuffles(ShuffleVectorSDNode * OuterShuf,SelectionDAG & DAG)22307 static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
22308 SelectionDAG &DAG) {
22309 if (!OuterShuf->getOperand(1).isUndef())
22310 return SDValue();
22311 auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
22312 if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
22313 return SDValue();
22314
22315 ArrayRef<int> OuterMask = OuterShuf->getMask();
22316 ArrayRef<int> InnerMask = InnerShuf->getMask();
22317 unsigned NumElts = OuterMask.size();
22318 assert(NumElts == InnerMask.size() && "Mask length mismatch");
22319 SmallVector<int, 32> CombinedMask(NumElts, -1);
22320 int SplatIndex = -1;
22321 for (unsigned i = 0; i != NumElts; ++i) {
22322 // Undef lanes remain undef.
22323 int OuterMaskElt = OuterMask[i];
22324 if (OuterMaskElt == -1)
22325 continue;
22326
22327 // Peek through the shuffle masks to get the underlying source element.
22328 int InnerMaskElt = InnerMask[OuterMaskElt];
22329 if (InnerMaskElt == -1)
22330 continue;
22331
22332 // Initialize the splatted element.
22333 if (SplatIndex == -1)
22334 SplatIndex = InnerMaskElt;
22335
22336 // Non-matching index - this is not a splat.
22337 if (SplatIndex != InnerMaskElt)
22338 return SDValue();
22339
22340 CombinedMask[i] = InnerMaskElt;
22341 }
22342 assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
22343 getSplatIndex(CombinedMask) != -1) &&
22344 "Expected a splat mask");
22345
22346 // TODO: The transform may be a win even if the mask is not legal.
22347 EVT VT = OuterShuf->getValueType(0);
22348 assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
22349 if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
22350 return SDValue();
22351
22352 return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
22353 InnerShuf->getOperand(1), CombinedMask);
22354 }
22355
22356 /// If the shuffle mask is taking exactly one element from the first vector
22357 /// operand and passing through all other elements from the second vector
22358 /// operand, return the index of the mask element that is choosing an element
22359 /// from the first operand. Otherwise, return -1.
getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask)22360 static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
22361 int MaskSize = Mask.size();
22362 int EltFromOp0 = -1;
22363 // TODO: This does not match if there are undef elements in the shuffle mask.
22364 // Should we ignore undefs in the shuffle mask instead? The trade-off is
22365 // removing an instruction (a shuffle), but losing the knowledge that some
22366 // vector lanes are not needed.
22367 for (int i = 0; i != MaskSize; ++i) {
22368 if (Mask[i] >= 0 && Mask[i] < MaskSize) {
22369 // We're looking for a shuffle of exactly one element from operand 0.
22370 if (EltFromOp0 != -1)
22371 return -1;
22372 EltFromOp0 = i;
22373 } else if (Mask[i] != i + MaskSize) {
22374 // Nothing from operand 1 can change lanes.
22375 return -1;
22376 }
22377 }
22378 return EltFromOp0;
22379 }
22380
22381 /// If a shuffle inserts exactly one element from a source vector operand into
22382 /// another vector operand and we can access the specified element as a scalar,
22383 /// then we can eliminate the shuffle.
replaceShuffleOfInsert(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)22384 static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
22385 SelectionDAG &DAG) {
22386 // First, check if we are taking one element of a vector and shuffling that
22387 // element into another vector.
22388 ArrayRef<int> Mask = Shuf->getMask();
22389 SmallVector<int, 16> CommutedMask(Mask.begin(), Mask.end());
22390 SDValue Op0 = Shuf->getOperand(0);
22391 SDValue Op1 = Shuf->getOperand(1);
22392 int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
22393 if (ShufOp0Index == -1) {
22394 // Commute mask and check again.
22395 ShuffleVectorSDNode::commuteMask(CommutedMask);
22396 ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
22397 if (ShufOp0Index == -1)
22398 return SDValue();
22399 // Commute operands to match the commuted shuffle mask.
22400 std::swap(Op0, Op1);
22401 Mask = CommutedMask;
22402 }
22403
22404 // The shuffle inserts exactly one element from operand 0 into operand 1.
22405 // Now see if we can access that element as a scalar via a real insert element
22406 // instruction.
22407 // TODO: We can try harder to locate the element as a scalar. Examples: it
22408 // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
22409 assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
22410 "Shuffle mask value must be from operand 0");
22411 if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
22412 return SDValue();
22413
22414 auto *InsIndexC = dyn_cast<ConstantSDNode>(Op0.getOperand(2));
22415 if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
22416 return SDValue();
22417
22418 // There's an existing insertelement with constant insertion index, so we
22419 // don't need to check the legality/profitability of a replacement operation
22420 // that differs at most in the constant value. The target should be able to
22421 // lower any of those in a similar way. If not, legalization will expand this
22422 // to a scalar-to-vector plus shuffle.
22423 //
22424 // Note that the shuffle may move the scalar from the position that the insert
22425 // element used. Therefore, our new insert element occurs at the shuffle's
22426 // mask index value, not the insert's index value.
22427 // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
22428 SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf));
22429 return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
22430 Op1, Op0.getOperand(1), NewInsIndex);
22431 }
22432
22433 /// If we have a unary shuffle of a shuffle, see if it can be folded away
22434 /// completely. This has the potential to lose undef knowledge because the first
22435 /// shuffle may not have an undef mask element where the second one does. So
22436 /// only call this after doing simplifications based on demanded elements.
simplifyShuffleOfShuffle(ShuffleVectorSDNode * Shuf)22437 static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
22438 // shuf (shuf0 X, Y, Mask0), undef, Mask
22439 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
22440 if (!Shuf0 || !Shuf->getOperand(1).isUndef())
22441 return SDValue();
22442
22443 ArrayRef<int> Mask = Shuf->getMask();
22444 ArrayRef<int> Mask0 = Shuf0->getMask();
22445 for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
22446 // Ignore undef elements.
22447 if (Mask[i] == -1)
22448 continue;
22449 assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
22450
22451 // Is the element of the shuffle operand chosen by this shuffle the same as
22452 // the element chosen by the shuffle operand itself?
22453 if (Mask0[Mask[i]] != Mask0[i])
22454 return SDValue();
22455 }
22456 // Every element of this shuffle is identical to the result of the previous
22457 // shuffle, so we can replace this value.
22458 return Shuf->getOperand(0);
22459 }
22460
visitVECTOR_SHUFFLE(SDNode * N)22461 SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
22462 EVT VT = N->getValueType(0);
22463 unsigned NumElts = VT.getVectorNumElements();
22464
22465 SDValue N0 = N->getOperand(0);
22466 SDValue N1 = N->getOperand(1);
22467
22468 assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
22469
22470 // Canonicalize shuffle undef, undef -> undef
22471 if (N0.isUndef() && N1.isUndef())
22472 return DAG.getUNDEF(VT);
22473
22474 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
22475
22476 // Canonicalize shuffle v, v -> v, undef
22477 if (N0 == N1)
22478 return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT),
22479 createUnaryMask(SVN->getMask(), NumElts));
22480
22481 // Canonicalize shuffle undef, v -> v, undef. Commute the shuffle mask.
22482 if (N0.isUndef())
22483 return DAG.getCommutedVectorShuffle(*SVN);
22484
22485 // Remove references to rhs if it is undef
22486 if (N1.isUndef()) {
22487 bool Changed = false;
22488 SmallVector<int, 8> NewMask;
22489 for (unsigned i = 0; i != NumElts; ++i) {
22490 int Idx = SVN->getMaskElt(i);
22491 if (Idx >= (int)NumElts) {
22492 Idx = -1;
22493 Changed = true;
22494 }
22495 NewMask.push_back(Idx);
22496 }
22497 if (Changed)
22498 return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
22499 }
22500
22501 if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG))
22502 return InsElt;
22503
22504 // A shuffle of a single vector that is a splatted value can always be folded.
22505 if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
22506 return V;
22507
22508 if (SDValue V = formSplatFromShuffles(SVN, DAG))
22509 return V;
22510
22511 // If it is a splat, check if the argument vector is another splat or a
22512 // build_vector.
22513 if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
22514 int SplatIndex = SVN->getSplatIndex();
22515 if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) &&
22516 TLI.isBinOp(N0.getOpcode()) && N0->getNumValues() == 1) {
22517 // splat (vector_bo L, R), Index -->
22518 // splat (scalar_bo (extelt L, Index), (extelt R, Index))
22519 SDValue L = N0.getOperand(0), R = N0.getOperand(1);
22520 SDLoc DL(N);
22521 EVT EltVT = VT.getScalarType();
22522 SDValue Index = DAG.getVectorIdxConstant(SplatIndex, DL);
22523 SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index);
22524 SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index);
22525 SDValue NewBO =
22526 DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR, N0->getFlags());
22527 SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO);
22528 SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
22529 return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask);
22530 }
22531
22532 // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
22533 // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
22534 if ((!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) &&
22535 N0.hasOneUse()) {
22536 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
22537 return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(0));
22538
22539 if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
22540 if (auto *Idx = dyn_cast<ConstantSDNode>(N0.getOperand(2)))
22541 if (Idx->getAPIntValue() == SplatIndex)
22542 return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(1));
22543 }
22544
22545 // If this is a bit convert that changes the element type of the vector but
22546 // not the number of vector elements, look through it. Be careful not to
22547 // look though conversions that change things like v4f32 to v2f64.
22548 SDNode *V = N0.getNode();
22549 if (V->getOpcode() == ISD::BITCAST) {
22550 SDValue ConvInput = V->getOperand(0);
22551 if (ConvInput.getValueType().isVector() &&
22552 ConvInput.getValueType().getVectorNumElements() == NumElts)
22553 V = ConvInput.getNode();
22554 }
22555
22556 if (V->getOpcode() == ISD::BUILD_VECTOR) {
22557 assert(V->getNumOperands() == NumElts &&
22558 "BUILD_VECTOR has wrong number of operands");
22559 SDValue Base;
22560 bool AllSame = true;
22561 for (unsigned i = 0; i != NumElts; ++i) {
22562 if (!V->getOperand(i).isUndef()) {
22563 Base = V->getOperand(i);
22564 break;
22565 }
22566 }
22567 // Splat of <u, u, u, u>, return <u, u, u, u>
22568 if (!Base.getNode())
22569 return N0;
22570 for (unsigned i = 0; i != NumElts; ++i) {
22571 if (V->getOperand(i) != Base) {
22572 AllSame = false;
22573 break;
22574 }
22575 }
22576 // Splat of <x, x, x, x>, return <x, x, x, x>
22577 if (AllSame)
22578 return N0;
22579
22580 // Canonicalize any other splat as a build_vector.
22581 SDValue Splatted = V->getOperand(SplatIndex);
22582 SmallVector<SDValue, 8> Ops(NumElts, Splatted);
22583 SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
22584
22585 // We may have jumped through bitcasts, so the type of the
22586 // BUILD_VECTOR may not match the type of the shuffle.
22587 if (V->getValueType(0) != VT)
22588 NewBV = DAG.getBitcast(VT, NewBV);
22589 return NewBV;
22590 }
22591 }
22592
22593 // Simplify source operands based on shuffle mask.
22594 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
22595 return SDValue(N, 0);
22596
22597 // This is intentionally placed after demanded elements simplification because
22598 // it could eliminate knowledge of undef elements created by this shuffle.
22599 if (SDValue ShufOp = simplifyShuffleOfShuffle(SVN))
22600 return ShufOp;
22601
22602 // Match shuffles that can be converted to any_vector_extend_in_reg.
22603 if (SDValue V = combineShuffleToVectorExtend(SVN, DAG, TLI, LegalOperations))
22604 return V;
22605
22606 // Combine "truncate_vector_in_reg" style shuffles.
22607 if (SDValue V = combineTruncationShuffle(SVN, DAG))
22608 return V;
22609
22610 if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
22611 Level < AfterLegalizeVectorOps &&
22612 (N1.isUndef() ||
22613 (N1.getOpcode() == ISD::CONCAT_VECTORS &&
22614 N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
22615 if (SDValue V = partitionShuffleOfConcats(N, DAG))
22616 return V;
22617 }
22618
22619 // A shuffle of a concat of the same narrow vector can be reduced to use
22620 // only low-half elements of a concat with undef:
22621 // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
22622 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
22623 N0.getNumOperands() == 2 &&
22624 N0.getOperand(0) == N0.getOperand(1)) {
22625 int HalfNumElts = (int)NumElts / 2;
22626 SmallVector<int, 8> NewMask;
22627 for (unsigned i = 0; i != NumElts; ++i) {
22628 int Idx = SVN->getMaskElt(i);
22629 if (Idx >= HalfNumElts) {
22630 assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
22631 Idx -= HalfNumElts;
22632 }
22633 NewMask.push_back(Idx);
22634 }
22635 if (TLI.isShuffleMaskLegal(NewMask, VT)) {
22636 SDValue UndefVec = DAG.getUNDEF(N0.getOperand(0).getValueType());
22637 SDValue NewCat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
22638 N0.getOperand(0), UndefVec);
22639 return DAG.getVectorShuffle(VT, SDLoc(N), NewCat, N1, NewMask);
22640 }
22641 }
22642
22643 // See if we can replace a shuffle with an insert_subvector.
22644 // e.g. v2i32 into v8i32:
22645 // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
22646 // --> insert_subvector(lhs,rhs1,4).
22647 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
22648 TLI.isOperationLegalOrCustom(ISD::INSERT_SUBVECTOR, VT)) {
22649 auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
22650 // Ensure RHS subvectors are legal.
22651 assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
22652 EVT SubVT = RHS.getOperand(0).getValueType();
22653 int NumSubVecs = RHS.getNumOperands();
22654 int NumSubElts = SubVT.getVectorNumElements();
22655 assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
22656 if (!TLI.isTypeLegal(SubVT))
22657 return SDValue();
22658
22659 // Don't bother if we have an unary shuffle (matches undef + LHS elts).
22660 if (all_of(Mask, [NumElts](int M) { return M < (int)NumElts; }))
22661 return SDValue();
22662
22663 // Search [NumSubElts] spans for RHS sequence.
22664 // TODO: Can we avoid nested loops to increase performance?
22665 SmallVector<int> InsertionMask(NumElts);
22666 for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
22667 for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
22668 // Reset mask to identity.
22669 std::iota(InsertionMask.begin(), InsertionMask.end(), 0);
22670
22671 // Add subvector insertion.
22672 std::iota(InsertionMask.begin() + SubIdx,
22673 InsertionMask.begin() + SubIdx + NumSubElts,
22674 NumElts + (SubVec * NumSubElts));
22675
22676 // See if the shuffle mask matches the reference insertion mask.
22677 bool MatchingShuffle = true;
22678 for (int i = 0; i != (int)NumElts; ++i) {
22679 int ExpectIdx = InsertionMask[i];
22680 int ActualIdx = Mask[i];
22681 if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
22682 MatchingShuffle = false;
22683 break;
22684 }
22685 }
22686
22687 if (MatchingShuffle)
22688 return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, LHS,
22689 RHS.getOperand(SubVec),
22690 DAG.getVectorIdxConstant(SubIdx, SDLoc(N)));
22691 }
22692 }
22693 return SDValue();
22694 };
22695 ArrayRef<int> Mask = SVN->getMask();
22696 if (N1.getOpcode() == ISD::CONCAT_VECTORS)
22697 if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
22698 return InsertN1;
22699 if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
22700 SmallVector<int> CommuteMask(Mask.begin(), Mask.end());
22701 ShuffleVectorSDNode::commuteMask(CommuteMask);
22702 if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
22703 return InsertN0;
22704 }
22705 }
22706
22707 // If we're not performing a select/blend shuffle, see if we can convert the
22708 // shuffle into a AND node, with all the out-of-lane elements are known zero.
22709 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
22710 bool IsInLaneMask = true;
22711 ArrayRef<int> Mask = SVN->getMask();
22712 SmallVector<int, 16> ClearMask(NumElts, -1);
22713 APInt DemandedLHS = APInt::getNullValue(NumElts);
22714 APInt DemandedRHS = APInt::getNullValue(NumElts);
22715 for (int I = 0; I != (int)NumElts; ++I) {
22716 int M = Mask[I];
22717 if (M < 0)
22718 continue;
22719 ClearMask[I] = M == I ? I : (I + NumElts);
22720 IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
22721 if (M != I) {
22722 APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
22723 Demanded.setBit(M % NumElts);
22724 }
22725 }
22726 // TODO: Should we try to mask with N1 as well?
22727 if (!IsInLaneMask &&
22728 (!DemandedLHS.isNullValue() || !DemandedRHS.isNullValue()) &&
22729 (DemandedLHS.isNullValue() ||
22730 DAG.MaskedVectorIsZero(N0, DemandedLHS)) &&
22731 (DemandedRHS.isNullValue() ||
22732 DAG.MaskedVectorIsZero(N1, DemandedRHS))) {
22733 SDLoc DL(N);
22734 EVT IntVT = VT.changeVectorElementTypeToInteger();
22735 EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
22736 // Transform the type to a legal type so that the buildvector constant
22737 // elements are not illegal. Make sure that the result is larger than the
22738 // original type, incase the value is split into two (eg i64->i32).
22739 if (!TLI.isTypeLegal(IntSVT) && LegalTypes)
22740 IntSVT = TLI.getTypeToTransformTo(*DAG.getContext(), IntSVT);
22741 if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
22742 SDValue ZeroElt = DAG.getConstant(0, DL, IntSVT);
22743 SDValue AllOnesElt = DAG.getAllOnesConstant(DL, IntSVT);
22744 SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(IntSVT));
22745 for (int I = 0; I != (int)NumElts; ++I)
22746 if (0 <= Mask[I])
22747 AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
22748
22749 // See if a clear mask is legal instead of going via
22750 // XformToShuffleWithZero which loses UNDEF mask elements.
22751 if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
22752 return DAG.getBitcast(
22753 VT, DAG.getVectorShuffle(IntVT, DL, DAG.getBitcast(IntVT, N0),
22754 DAG.getConstant(0, DL, IntVT), ClearMask));
22755
22756 if (TLI.isOperationLegalOrCustom(ISD::AND, IntVT))
22757 return DAG.getBitcast(
22758 VT, DAG.getNode(ISD::AND, DL, IntVT, DAG.getBitcast(IntVT, N0),
22759 DAG.getBuildVector(IntVT, DL, AndMask)));
22760 }
22761 }
22762 }
22763
22764 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
22765 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
22766 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
22767 if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
22768 return Res;
22769
22770 // If this shuffle only has a single input that is a bitcasted shuffle,
22771 // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
22772 // back to their original types.
22773 if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
22774 N1.isUndef() && Level < AfterLegalizeVectorOps &&
22775 TLI.isTypeLegal(VT)) {
22776
22777 SDValue BC0 = peekThroughOneUseBitcasts(N0);
22778 if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
22779 EVT SVT = VT.getScalarType();
22780 EVT InnerVT = BC0->getValueType(0);
22781 EVT InnerSVT = InnerVT.getScalarType();
22782
22783 // Determine which shuffle works with the smaller scalar type.
22784 EVT ScaleVT = SVT.bitsLT(InnerSVT) ? VT : InnerVT;
22785 EVT ScaleSVT = ScaleVT.getScalarType();
22786
22787 if (TLI.isTypeLegal(ScaleVT) &&
22788 0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
22789 0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
22790 int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
22791 int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
22792
22793 // Scale the shuffle masks to the smaller scalar type.
22794 ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0);
22795 SmallVector<int, 8> InnerMask;
22796 SmallVector<int, 8> OuterMask;
22797 narrowShuffleMaskElts(InnerScale, InnerSVN->getMask(), InnerMask);
22798 narrowShuffleMaskElts(OuterScale, SVN->getMask(), OuterMask);
22799
22800 // Merge the shuffle masks.
22801 SmallVector<int, 8> NewMask;
22802 for (int M : OuterMask)
22803 NewMask.push_back(M < 0 ? -1 : InnerMask[M]);
22804
22805 // Test for shuffle mask legality over both commutations.
22806 SDValue SV0 = BC0->getOperand(0);
22807 SDValue SV1 = BC0->getOperand(1);
22808 bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
22809 if (!LegalMask) {
22810 std::swap(SV0, SV1);
22811 ShuffleVectorSDNode::commuteMask(NewMask);
22812 LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
22813 }
22814
22815 if (LegalMask) {
22816 SV0 = DAG.getBitcast(ScaleVT, SV0);
22817 SV1 = DAG.getBitcast(ScaleVT, SV1);
22818 return DAG.getBitcast(
22819 VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
22820 }
22821 }
22822 }
22823 }
22824
22825 // Match shuffles of bitcasts, so long as the mask can be treated as the
22826 // larger type.
22827 if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
22828 return V;
22829
22830 // Compute the combined shuffle mask for a shuffle with SV0 as the first
22831 // operand, and SV1 as the second operand.
22832 // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
22833 // Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
22834 auto MergeInnerShuffle =
22835 [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
22836 ShuffleVectorSDNode *OtherSVN, SDValue N1,
22837 const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
22838 SmallVectorImpl<int> &Mask) -> bool {
22839 // Don't try to fold splats; they're likely to simplify somehow, or they
22840 // might be free.
22841 if (OtherSVN->isSplat())
22842 return false;
22843
22844 SV0 = SV1 = SDValue();
22845 Mask.clear();
22846
22847 for (unsigned i = 0; i != NumElts; ++i) {
22848 int Idx = SVN->getMaskElt(i);
22849 if (Idx < 0) {
22850 // Propagate Undef.
22851 Mask.push_back(Idx);
22852 continue;
22853 }
22854
22855 if (Commute)
22856 Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
22857
22858 SDValue CurrentVec;
22859 if (Idx < (int)NumElts) {
22860 // This shuffle index refers to the inner shuffle N0. Lookup the inner
22861 // shuffle mask to identify which vector is actually referenced.
22862 Idx = OtherSVN->getMaskElt(Idx);
22863 if (Idx < 0) {
22864 // Propagate Undef.
22865 Mask.push_back(Idx);
22866 continue;
22867 }
22868 CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(0)
22869 : OtherSVN->getOperand(1);
22870 } else {
22871 // This shuffle index references an element within N1.
22872 CurrentVec = N1;
22873 }
22874
22875 // Simple case where 'CurrentVec' is UNDEF.
22876 if (CurrentVec.isUndef()) {
22877 Mask.push_back(-1);
22878 continue;
22879 }
22880
22881 // Canonicalize the shuffle index. We don't know yet if CurrentVec
22882 // will be the first or second operand of the combined shuffle.
22883 Idx = Idx % NumElts;
22884 if (!SV0.getNode() || SV0 == CurrentVec) {
22885 // Ok. CurrentVec is the left hand side.
22886 // Update the mask accordingly.
22887 SV0 = CurrentVec;
22888 Mask.push_back(Idx);
22889 continue;
22890 }
22891 if (!SV1.getNode() || SV1 == CurrentVec) {
22892 // Ok. CurrentVec is the right hand side.
22893 // Update the mask accordingly.
22894 SV1 = CurrentVec;
22895 Mask.push_back(Idx + NumElts);
22896 continue;
22897 }
22898
22899 // Last chance - see if the vector is another shuffle and if it
22900 // uses one of the existing candidate shuffle ops.
22901 if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(CurrentVec)) {
22902 int InnerIdx = CurrentSVN->getMaskElt(Idx);
22903 if (InnerIdx < 0) {
22904 Mask.push_back(-1);
22905 continue;
22906 }
22907 SDValue InnerVec = (InnerIdx < (int)NumElts)
22908 ? CurrentSVN->getOperand(0)
22909 : CurrentSVN->getOperand(1);
22910 if (InnerVec.isUndef()) {
22911 Mask.push_back(-1);
22912 continue;
22913 }
22914 InnerIdx %= NumElts;
22915 if (InnerVec == SV0) {
22916 Mask.push_back(InnerIdx);
22917 continue;
22918 }
22919 if (InnerVec == SV1) {
22920 Mask.push_back(InnerIdx + NumElts);
22921 continue;
22922 }
22923 }
22924
22925 // Bail out if we cannot convert the shuffle pair into a single shuffle.
22926 return false;
22927 }
22928
22929 if (llvm::all_of(Mask, [](int M) { return M < 0; }))
22930 return true;
22931
22932 // Avoid introducing shuffles with illegal mask.
22933 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
22934 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
22935 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
22936 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
22937 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
22938 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
22939 if (TLI.isShuffleMaskLegal(Mask, VT))
22940 return true;
22941
22942 std::swap(SV0, SV1);
22943 ShuffleVectorSDNode::commuteMask(Mask);
22944 return TLI.isShuffleMaskLegal(Mask, VT);
22945 };
22946
22947 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
22948 // Canonicalize shuffles according to rules:
22949 // shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
22950 // shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
22951 // shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
22952 if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
22953 N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
22954 // The incoming shuffle must be of the same type as the result of the
22955 // current shuffle.
22956 assert(N1->getOperand(0).getValueType() == VT &&
22957 "Shuffle types don't match");
22958
22959 SDValue SV0 = N1->getOperand(0);
22960 SDValue SV1 = N1->getOperand(1);
22961 bool HasSameOp0 = N0 == SV0;
22962 bool IsSV1Undef = SV1.isUndef();
22963 if (HasSameOp0 || IsSV1Undef || N0 == SV1)
22964 // Commute the operands of this shuffle so merging below will trigger.
22965 return DAG.getCommutedVectorShuffle(*SVN);
22966 }
22967
22968 // Canonicalize splat shuffles to the RHS to improve merging below.
22969 // shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
22970 if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
22971 N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
22972 cast<ShuffleVectorSDNode>(N0)->isSplat() &&
22973 !cast<ShuffleVectorSDNode>(N1)->isSplat()) {
22974 return DAG.getCommutedVectorShuffle(*SVN);
22975 }
22976
22977 // Try to fold according to rules:
22978 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
22979 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
22980 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
22981 // Don't try to fold shuffles with illegal type.
22982 // Only fold if this shuffle is the only user of the other shuffle.
22983 // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
22984 for (int i = 0; i != 2; ++i) {
22985 if (N->getOperand(i).getOpcode() == ISD::VECTOR_SHUFFLE &&
22986 N->isOnlyUserOf(N->getOperand(i).getNode())) {
22987 // The incoming shuffle must be of the same type as the result of the
22988 // current shuffle.
22989 auto *OtherSV = cast<ShuffleVectorSDNode>(N->getOperand(i));
22990 assert(OtherSV->getOperand(0).getValueType() == VT &&
22991 "Shuffle types don't match");
22992
22993 SDValue SV0, SV1;
22994 SmallVector<int, 4> Mask;
22995 if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(1 - i), TLI,
22996 SV0, SV1, Mask)) {
22997 // Check if all indices in Mask are Undef. In case, propagate Undef.
22998 if (llvm::all_of(Mask, [](int M) { return M < 0; }))
22999 return DAG.getUNDEF(VT);
23000
23001 return DAG.getVectorShuffle(VT, SDLoc(N),
23002 SV0 ? SV0 : DAG.getUNDEF(VT),
23003 SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
23004 }
23005 }
23006 }
23007
23008 // Merge shuffles through binops if we are able to merge it with at least
23009 // one other shuffles.
23010 // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
23011 // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
23012 unsigned SrcOpcode = N0.getOpcode();
23013 if (TLI.isBinOp(SrcOpcode) && N->isOnlyUserOf(N0.getNode()) &&
23014 (N1.isUndef() ||
23015 (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N1.getNode())))) {
23016 // Get binop source ops, or just pass on the undef.
23017 SDValue Op00 = N0.getOperand(0);
23018 SDValue Op01 = N0.getOperand(1);
23019 SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(0);
23020 SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(1);
23021 // TODO: We might be able to relax the VT check but we don't currently
23022 // have any isBinOp() that has different result/ops VTs so play safe until
23023 // we have test coverage.
23024 if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
23025 Op01.getValueType() == VT && Op11.getValueType() == VT &&
23026 (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
23027 Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
23028 Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
23029 Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
23030 auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
23031 SmallVectorImpl<int> &Mask, bool LeftOp,
23032 bool Commute) {
23033 SDValue InnerN = Commute ? N1 : N0;
23034 SDValue Op0 = LeftOp ? Op00 : Op01;
23035 SDValue Op1 = LeftOp ? Op10 : Op11;
23036 if (Commute)
23037 std::swap(Op0, Op1);
23038 // Only accept the merged shuffle if we don't introduce undef elements,
23039 // or the inner shuffle already contained undef elements.
23040 auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Op0);
23041 return SVN0 && InnerN->isOnlyUserOf(SVN0) &&
23042 MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
23043 Mask) &&
23044 (llvm::any_of(SVN0->getMask(), [](int M) { return M < 0; }) ||
23045 llvm::none_of(Mask, [](int M) { return M < 0; }));
23046 };
23047
23048 // Ensure we don't increase the number of shuffles - we must merge a
23049 // shuffle from at least one of the LHS and RHS ops.
23050 bool MergedLeft = false;
23051 SDValue LeftSV0, LeftSV1;
23052 SmallVector<int, 4> LeftMask;
23053 if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
23054 CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
23055 MergedLeft = true;
23056 } else {
23057 LeftMask.assign(SVN->getMask().begin(), SVN->getMask().end());
23058 LeftSV0 = Op00, LeftSV1 = Op10;
23059 }
23060
23061 bool MergedRight = false;
23062 SDValue RightSV0, RightSV1;
23063 SmallVector<int, 4> RightMask;
23064 if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
23065 CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
23066 MergedRight = true;
23067 } else {
23068 RightMask.assign(SVN->getMask().begin(), SVN->getMask().end());
23069 RightSV0 = Op01, RightSV1 = Op11;
23070 }
23071
23072 if (MergedLeft || MergedRight) {
23073 SDLoc DL(N);
23074 SDValue LHS = DAG.getVectorShuffle(
23075 VT, DL, LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
23076 LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), LeftMask);
23077 SDValue RHS = DAG.getVectorShuffle(
23078 VT, DL, RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
23079 RightSV1 ? RightSV1 : DAG.getUNDEF(VT), RightMask);
23080 return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
23081 }
23082 }
23083 }
23084 }
23085
23086 if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
23087 return V;
23088
23089 return SDValue();
23090 }
23091
visitSCALAR_TO_VECTOR(SDNode * N)23092 SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
23093 SDValue InVal = N->getOperand(0);
23094 EVT VT = N->getValueType(0);
23095
23096 // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
23097 // with a VECTOR_SHUFFLE and possible truncate.
23098 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23099 VT.isFixedLengthVector() &&
23100 InVal->getOperand(0).getValueType().isFixedLengthVector()) {
23101 SDValue InVec = InVal->getOperand(0);
23102 SDValue EltNo = InVal->getOperand(1);
23103 auto InVecT = InVec.getValueType();
23104 if (ConstantSDNode *C0 = dyn_cast<ConstantSDNode>(EltNo)) {
23105 SmallVector<int, 8> NewMask(InVecT.getVectorNumElements(), -1);
23106 int Elt = C0->getZExtValue();
23107 NewMask[0] = Elt;
23108 // If we have an implict truncate do truncate here as long as it's legal.
23109 // if it's not legal, this should
23110 if (VT.getScalarType() != InVal.getValueType() &&
23111 InVal.getValueType().isScalarInteger() &&
23112 isTypeLegal(VT.getScalarType())) {
23113 SDValue Val =
23114 DAG.getNode(ISD::TRUNCATE, SDLoc(InVal), VT.getScalarType(), InVal);
23115 return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
23116 }
23117 if (VT.getScalarType() == InVecT.getScalarType() &&
23118 VT.getVectorNumElements() <= InVecT.getVectorNumElements()) {
23119 SDValue LegalShuffle =
23120 TLI.buildLegalVectorShuffle(InVecT, SDLoc(N), InVec,
23121 DAG.getUNDEF(InVecT), NewMask, DAG);
23122 if (LegalShuffle) {
23123 // If the initial vector is the correct size this shuffle is a
23124 // valid result.
23125 if (VT == InVecT)
23126 return LegalShuffle;
23127 // If not we must truncate the vector.
23128 if (VT.getVectorNumElements() != InVecT.getVectorNumElements()) {
23129 SDValue ZeroIdx = DAG.getVectorIdxConstant(0, SDLoc(N));
23130 EVT SubVT = EVT::getVectorVT(*DAG.getContext(),
23131 InVecT.getVectorElementType(),
23132 VT.getVectorNumElements());
23133 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT,
23134 LegalShuffle, ZeroIdx);
23135 }
23136 }
23137 }
23138 }
23139 }
23140
23141 return SDValue();
23142 }
23143
visitINSERT_SUBVECTOR(SDNode * N)23144 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
23145 EVT VT = N->getValueType(0);
23146 SDValue N0 = N->getOperand(0);
23147 SDValue N1 = N->getOperand(1);
23148 SDValue N2 = N->getOperand(2);
23149 uint64_t InsIdx = N->getConstantOperandVal(2);
23150
23151 // If inserting an UNDEF, just return the original vector.
23152 if (N1.isUndef())
23153 return N0;
23154
23155 // If this is an insert of an extracted vector into an undef vector, we can
23156 // just use the input to the extract.
23157 if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
23158 N1.getOperand(1) == N2 && N1.getOperand(0).getValueType() == VT)
23159 return N1.getOperand(0);
23160
23161 // Simplify scalar inserts into an undef vector:
23162 // insert_subvector undef, (splat X), N2 -> splat X
23163 if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
23164 return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, N1.getOperand(0));
23165
23166 // If we are inserting a bitcast value into an undef, with the same
23167 // number of elements, just use the bitcast input of the extract.
23168 // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
23169 // BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
23170 if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
23171 N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
23172 N1.getOperand(0).getOperand(1) == N2 &&
23173 N1.getOperand(0).getOperand(0).getValueType().getVectorElementCount() ==
23174 VT.getVectorElementCount() &&
23175 N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
23176 VT.getSizeInBits()) {
23177 return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
23178 }
23179
23180 // If both N1 and N2 are bitcast values on which insert_subvector
23181 // would makes sense, pull the bitcast through.
23182 // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
23183 // BITCAST (INSERT_SUBVECTOR N0 N1 N2)
23184 if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
23185 SDValue CN0 = N0.getOperand(0);
23186 SDValue CN1 = N1.getOperand(0);
23187 EVT CN0VT = CN0.getValueType();
23188 EVT CN1VT = CN1.getValueType();
23189 if (CN0VT.isVector() && CN1VT.isVector() &&
23190 CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
23191 CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
23192 SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
23193 CN0.getValueType(), CN0, CN1, N2);
23194 return DAG.getBitcast(VT, NewINSERT);
23195 }
23196 }
23197
23198 // Combine INSERT_SUBVECTORs where we are inserting to the same index.
23199 // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
23200 // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
23201 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
23202 N0.getOperand(1).getValueType() == N1.getValueType() &&
23203 N0.getOperand(2) == N2)
23204 return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
23205 N1, N2);
23206
23207 // Eliminate an intermediate insert into an undef vector:
23208 // insert_subvector undef, (insert_subvector undef, X, 0), N2 -->
23209 // insert_subvector undef, X, N2
23210 if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
23211 N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)))
23212 return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
23213 N1.getOperand(1), N2);
23214
23215 // Push subvector bitcasts to the output, adjusting the index as we go.
23216 // insert_subvector(bitcast(v), bitcast(s), c1)
23217 // -> bitcast(insert_subvector(v, s, c2))
23218 if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
23219 N1.getOpcode() == ISD::BITCAST) {
23220 SDValue N0Src = peekThroughBitcasts(N0);
23221 SDValue N1Src = peekThroughBitcasts(N1);
23222 EVT N0SrcSVT = N0Src.getValueType().getScalarType();
23223 EVT N1SrcSVT = N1Src.getValueType().getScalarType();
23224 if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
23225 N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
23226 EVT NewVT;
23227 SDLoc DL(N);
23228 SDValue NewIdx;
23229 LLVMContext &Ctx = *DAG.getContext();
23230 ElementCount NumElts = VT.getVectorElementCount();
23231 unsigned EltSizeInBits = VT.getScalarSizeInBits();
23232 if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
23233 unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
23234 NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale);
23235 NewIdx = DAG.getVectorIdxConstant(InsIdx * Scale, DL);
23236 } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
23237 unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
23238 if (NumElts.isKnownMultipleOf(Scale) && (InsIdx % Scale) == 0) {
23239 NewVT = EVT::getVectorVT(Ctx, N1SrcSVT,
23240 NumElts.divideCoefficientBy(Scale));
23241 NewIdx = DAG.getVectorIdxConstant(InsIdx / Scale, DL);
23242 }
23243 }
23244 if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) {
23245 SDValue Res = DAG.getBitcast(NewVT, N0Src);
23246 Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, Res, N1Src, NewIdx);
23247 return DAG.getBitcast(VT, Res);
23248 }
23249 }
23250 }
23251
23252 // Canonicalize insert_subvector dag nodes.
23253 // Example:
23254 // (insert_subvector (insert_subvector A, Idx0), Idx1)
23255 // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
23256 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
23257 N1.getValueType() == N0.getOperand(1).getValueType()) {
23258 unsigned OtherIdx = N0.getConstantOperandVal(2);
23259 if (InsIdx < OtherIdx) {
23260 // Swap nodes.
23261 SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
23262 N0.getOperand(0), N1, N2);
23263 AddToWorklist(NewOp.getNode());
23264 return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
23265 VT, NewOp, N0.getOperand(1), N0.getOperand(2));
23266 }
23267 }
23268
23269 // If the input vector is a concatenation, and the insert replaces
23270 // one of the pieces, we can optimize into a single concat_vectors.
23271 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
23272 N0.getOperand(0).getValueType() == N1.getValueType() &&
23273 N0.getOperand(0).getValueType().isScalableVector() ==
23274 N1.getValueType().isScalableVector()) {
23275 unsigned Factor = N1.getValueType().getVectorMinNumElements();
23276 SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
23277 Ops[InsIdx / Factor] = N1;
23278 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
23279 }
23280
23281 // Simplify source operands based on insertion.
23282 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
23283 return SDValue(N, 0);
23284
23285 return SDValue();
23286 }
23287
visitFP_TO_FP16(SDNode * N)23288 SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
23289 SDValue N0 = N->getOperand(0);
23290
23291 // fold (fp_to_fp16 (fp16_to_fp op)) -> op
23292 if (N0->getOpcode() == ISD::FP16_TO_FP)
23293 return N0->getOperand(0);
23294
23295 return SDValue();
23296 }
23297
visitFP16_TO_FP(SDNode * N)23298 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
23299 SDValue N0 = N->getOperand(0);
23300
23301 // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op)
23302 if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
23303 ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
23304 if (AndConst && AndConst->getAPIntValue() == 0xffff) {
23305 return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0),
23306 N0.getOperand(0));
23307 }
23308 }
23309
23310 return SDValue();
23311 }
23312
visitFP_TO_BF16(SDNode * N)23313 SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
23314 SDValue N0 = N->getOperand(0);
23315
23316 // fold (fp_to_bf16 (bf16_to_fp op)) -> op
23317 if (N0->getOpcode() == ISD::BF16_TO_FP)
23318 return N0->getOperand(0);
23319
23320 return SDValue();
23321 }
23322
visitVECREDUCE(SDNode * N)23323 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
23324 SDValue N0 = N->getOperand(0);
23325 EVT VT = N0.getValueType();
23326 unsigned Opcode = N->getOpcode();
23327
23328 // VECREDUCE over 1-element vector is just an extract.
23329 if (VT.getVectorElementCount().isScalar()) {
23330 SDLoc dl(N);
23331 SDValue Res =
23332 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
23333 DAG.getVectorIdxConstant(0, dl));
23334 if (Res.getValueType() != N->getValueType(0))
23335 Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
23336 return Res;
23337 }
23338
23339 // On an boolean vector an and/or reduction is the same as a umin/umax
23340 // reduction. Convert them if the latter is legal while the former isn't.
23341 if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
23342 unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
23343 ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
23344 if (!TLI.isOperationLegalOrCustom(Opcode, VT) &&
23345 TLI.isOperationLegalOrCustom(NewOpcode, VT) &&
23346 DAG.ComputeNumSignBits(N0) == VT.getScalarSizeInBits())
23347 return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0);
23348 }
23349
23350 // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
23351 // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
23352 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
23353 TLI.isTypeLegal(N0.getOperand(1).getValueType())) {
23354 SDValue Vec = N0.getOperand(0);
23355 SDValue Subvec = N0.getOperand(1);
23356 if ((Opcode == ISD::VECREDUCE_OR &&
23357 (N0.getOperand(0).isUndef() || isNullOrNullSplat(Vec))) ||
23358 (Opcode == ISD::VECREDUCE_AND &&
23359 (N0.getOperand(0).isUndef() || isAllOnesOrAllOnesSplat(Vec))))
23360 return DAG.getNode(Opcode, SDLoc(N), N->getValueType(0), Subvec);
23361 }
23362
23363 return SDValue();
23364 }
23365
visitVPOp(SDNode * N)23366 SDValue DAGCombiner::visitVPOp(SDNode *N) {
23367 // VP operations in which all vector elements are disabled - either by
23368 // determining that the mask is all false or that the EVL is 0 - can be
23369 // eliminated.
23370 bool AreAllEltsDisabled = false;
23371 if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(N->getOpcode()))
23372 AreAllEltsDisabled |= isNullConstant(N->getOperand(*EVLIdx));
23373 if (auto MaskIdx = ISD::getVPMaskIdx(N->getOpcode()))
23374 AreAllEltsDisabled |=
23375 ISD::isConstantSplatVectorAllZeros(N->getOperand(*MaskIdx).getNode());
23376
23377 // This is the only generic VP combine we support for now.
23378 if (!AreAllEltsDisabled)
23379 return SDValue();
23380
23381 // Binary operations can be replaced by UNDEF.
23382 if (ISD::isVPBinaryOp(N->getOpcode()))
23383 return DAG.getUNDEF(N->getValueType(0));
23384
23385 // VP Memory operations can be replaced by either the chain (stores) or the
23386 // chain + undef (loads).
23387 if (const auto *MemSD = dyn_cast<MemSDNode>(N)) {
23388 if (MemSD->writeMem())
23389 return MemSD->getChain();
23390 return CombineTo(N, DAG.getUNDEF(N->getValueType(0)), MemSD->getChain());
23391 }
23392
23393 // Reduction operations return the start operand when no elements are active.
23394 if (ISD::isVPReduction(N->getOpcode()))
23395 return N->getOperand(0);
23396
23397 return SDValue();
23398 }
23399
23400 /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
23401 /// with the destination vector and a zero vector.
23402 /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
23403 /// vector_shuffle V, Zero, <0, 4, 2, 4>
XformToShuffleWithZero(SDNode * N)23404 SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
23405 assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
23406
23407 EVT VT = N->getValueType(0);
23408 SDValue LHS = N->getOperand(0);
23409 SDValue RHS = peekThroughBitcasts(N->getOperand(1));
23410 SDLoc DL(N);
23411
23412 // Make sure we're not running after operation legalization where it
23413 // may have custom lowered the vector shuffles.
23414 if (LegalOperations)
23415 return SDValue();
23416
23417 if (RHS.getOpcode() != ISD::BUILD_VECTOR)
23418 return SDValue();
23419
23420 EVT RVT = RHS.getValueType();
23421 unsigned NumElts = RHS.getNumOperands();
23422
23423 // Attempt to create a valid clear mask, splitting the mask into
23424 // sub elements and checking to see if each is
23425 // all zeros or all ones - suitable for shuffle masking.
23426 auto BuildClearMask = [&](int Split) {
23427 int NumSubElts = NumElts * Split;
23428 int NumSubBits = RVT.getScalarSizeInBits() / Split;
23429
23430 SmallVector<int, 8> Indices;
23431 for (int i = 0; i != NumSubElts; ++i) {
23432 int EltIdx = i / Split;
23433 int SubIdx = i % Split;
23434 SDValue Elt = RHS.getOperand(EltIdx);
23435 // X & undef --> 0 (not undef). So this lane must be converted to choose
23436 // from the zero constant vector (same as if the element had all 0-bits).
23437 if (Elt.isUndef()) {
23438 Indices.push_back(i + NumSubElts);
23439 continue;
23440 }
23441
23442 APInt Bits;
23443 if (isa<ConstantSDNode>(Elt))
23444 Bits = cast<ConstantSDNode>(Elt)->getAPIntValue();
23445 else if (isa<ConstantFPSDNode>(Elt))
23446 Bits = cast<ConstantFPSDNode>(Elt)->getValueAPF().bitcastToAPInt();
23447 else
23448 return SDValue();
23449
23450 // Extract the sub element from the constant bit mask.
23451 if (DAG.getDataLayout().isBigEndian())
23452 Bits = Bits.extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits);
23453 else
23454 Bits = Bits.extractBits(NumSubBits, SubIdx * NumSubBits);
23455
23456 if (Bits.isAllOnes())
23457 Indices.push_back(i);
23458 else if (Bits == 0)
23459 Indices.push_back(i + NumSubElts);
23460 else
23461 return SDValue();
23462 }
23463
23464 // Let's see if the target supports this vector_shuffle.
23465 EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits);
23466 EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts);
23467 if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
23468 return SDValue();
23469
23470 SDValue Zero = DAG.getConstant(0, DL, ClearVT);
23471 return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
23472 DAG.getBitcast(ClearVT, LHS),
23473 Zero, Indices));
23474 };
23475
23476 // Determine maximum split level (byte level masking).
23477 int MaxSplit = 1;
23478 if (RVT.getScalarSizeInBits() % 8 == 0)
23479 MaxSplit = RVT.getScalarSizeInBits() / 8;
23480
23481 for (int Split = 1; Split <= MaxSplit; ++Split)
23482 if (RVT.getScalarSizeInBits() % Split == 0)
23483 if (SDValue S = BuildClearMask(Split))
23484 return S;
23485
23486 return SDValue();
23487 }
23488
23489 /// If a vector binop is performed on splat values, it may be profitable to
23490 /// extract, scalarize, and insert/splat.
scalarizeBinOpOfSplats(SDNode * N,SelectionDAG & DAG,const SDLoc & DL)23491 static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
23492 const SDLoc &DL) {
23493 SDValue N0 = N->getOperand(0);
23494 SDValue N1 = N->getOperand(1);
23495 unsigned Opcode = N->getOpcode();
23496 EVT VT = N->getValueType(0);
23497 EVT EltVT = VT.getVectorElementType();
23498 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23499
23500 // TODO: Remove/replace the extract cost check? If the elements are available
23501 // as scalars, then there may be no extract cost. Should we ask if
23502 // inserting a scalar back into a vector is cheap instead?
23503 int Index0, Index1;
23504 SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
23505 SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
23506 // Extract element from splat_vector should be free.
23507 // TODO: use DAG.isSplatValue instead?
23508 bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
23509 N1.getOpcode() == ISD::SPLAT_VECTOR;
23510 if (!Src0 || !Src1 || Index0 != Index1 ||
23511 Src0.getValueType().getVectorElementType() != EltVT ||
23512 Src1.getValueType().getVectorElementType() != EltVT ||
23513 !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index0)) ||
23514 !TLI.isOperationLegalOrCustom(Opcode, EltVT))
23515 return SDValue();
23516
23517 SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
23518 SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src0, IndexC);
23519 SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src1, IndexC);
23520 SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
23521
23522 // If all lanes but 1 are undefined, no need to splat the scalar result.
23523 // TODO: Keep track of undefs and use that info in the general case.
23524 if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
23525 count_if(N0->ops(), [](SDValue V) { return !V.isUndef(); }) == 1 &&
23526 count_if(N1->ops(), [](SDValue V) { return !V.isUndef(); }) == 1) {
23527 // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
23528 // build_vec ..undef, (bo X, Y), undef...
23529 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(EltVT));
23530 Ops[Index0] = ScalarBO;
23531 return DAG.getBuildVector(VT, DL, Ops);
23532 }
23533
23534 // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
23535 if (VT.isScalableVector())
23536 return DAG.getSplatVector(VT, DL, ScalarBO);
23537 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
23538 return DAG.getBuildVector(VT, DL, Ops);
23539 }
23540
23541 /// Visit a binary vector operation, like ADD.
SimplifyVBinOp(SDNode * N,const SDLoc & DL)23542 SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
23543 EVT VT = N->getValueType(0);
23544 assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
23545
23546 SDValue LHS = N->getOperand(0);
23547 SDValue RHS = N->getOperand(1);
23548 unsigned Opcode = N->getOpcode();
23549 SDNodeFlags Flags = N->getFlags();
23550
23551 // Move unary shuffles with identical masks after a vector binop:
23552 // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
23553 // --> shuffle (VBinOp A, B), Undef, Mask
23554 // This does not require type legality checks because we are creating the
23555 // same types of operations that are in the original sequence. We do have to
23556 // restrict ops like integer div that have immediate UB (eg, div-by-zero)
23557 // though. This code is adapted from the identical transform in instcombine.
23558 if (Opcode != ISD::UDIV && Opcode != ISD::SDIV &&
23559 Opcode != ISD::UREM && Opcode != ISD::SREM &&
23560 Opcode != ISD::UDIVREM && Opcode != ISD::SDIVREM) {
23561 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
23562 auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
23563 if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
23564 LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
23565 (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
23566 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0),
23567 RHS.getOperand(0), Flags);
23568 SDValue UndefV = LHS.getOperand(1);
23569 return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
23570 }
23571
23572 // Try to sink a splat shuffle after a binop with a uniform constant.
23573 // This is limited to cases where neither the shuffle nor the constant have
23574 // undefined elements because that could be poison-unsafe or inhibit
23575 // demanded elements analysis. It is further limited to not change a splat
23576 // of an inserted scalar because that may be optimized better by
23577 // load-folding or other target-specific behaviors.
23578 if (isConstOrConstSplat(RHS) && Shuf0 && is_splat(Shuf0->getMask()) &&
23579 Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
23580 Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
23581 // binop (splat X), (splat C) --> splat (binop X, C)
23582 SDValue X = Shuf0->getOperand(0);
23583 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, X, RHS, Flags);
23584 return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
23585 Shuf0->getMask());
23586 }
23587 if (isConstOrConstSplat(LHS) && Shuf1 && is_splat(Shuf1->getMask()) &&
23588 Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
23589 Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
23590 // binop (splat C), (splat X) --> splat (binop C, X)
23591 SDValue X = Shuf1->getOperand(0);
23592 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS, X, Flags);
23593 return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
23594 Shuf1->getMask());
23595 }
23596 }
23597
23598 // The following pattern is likely to emerge with vector reduction ops. Moving
23599 // the binary operation ahead of insertion may allow using a narrower vector
23600 // instruction that has better performance than the wide version of the op:
23601 // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
23602 if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(0).isUndef() &&
23603 RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(0).isUndef() &&
23604 LHS.getOperand(2) == RHS.getOperand(2) &&
23605 (LHS.hasOneUse() || RHS.hasOneUse())) {
23606 SDValue X = LHS.getOperand(1);
23607 SDValue Y = RHS.getOperand(1);
23608 SDValue Z = LHS.getOperand(2);
23609 EVT NarrowVT = X.getValueType();
23610 if (NarrowVT == Y.getValueType() &&
23611 TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT,
23612 LegalOperations)) {
23613 // (binop undef, undef) may not return undef, so compute that result.
23614 SDValue VecC =
23615 DAG.getNode(Opcode, DL, VT, DAG.getUNDEF(VT), DAG.getUNDEF(VT));
23616 SDValue NarrowBO = DAG.getNode(Opcode, DL, NarrowVT, X, Y);
23617 return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, VecC, NarrowBO, Z);
23618 }
23619 }
23620
23621 // Make sure all but the first op are undef or constant.
23622 auto ConcatWithConstantOrUndef = [](SDValue Concat) {
23623 return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
23624 all_of(drop_begin(Concat->ops()), [](const SDValue &Op) {
23625 return Op.isUndef() ||
23626 ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
23627 });
23628 };
23629
23630 // The following pattern is likely to emerge with vector reduction ops. Moving
23631 // the binary operation ahead of the concat may allow using a narrower vector
23632 // instruction that has better performance than the wide version of the op:
23633 // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
23634 // concat (VBinOp X, Y), VecC
23635 if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
23636 (LHS.hasOneUse() || RHS.hasOneUse())) {
23637 EVT NarrowVT = LHS.getOperand(0).getValueType();
23638 if (NarrowVT == RHS.getOperand(0).getValueType() &&
23639 TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
23640 unsigned NumOperands = LHS.getNumOperands();
23641 SmallVector<SDValue, 4> ConcatOps;
23642 for (unsigned i = 0; i != NumOperands; ++i) {
23643 // This constant fold for operands 1 and up.
23644 ConcatOps.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i),
23645 RHS.getOperand(i)));
23646 }
23647
23648 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
23649 }
23650 }
23651
23652 if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL))
23653 return V;
23654
23655 return SDValue();
23656 }
23657
SimplifySelect(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2)23658 SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
23659 SDValue N2) {
23660 assert(N0.getOpcode() ==ISD::SETCC && "First argument must be a SetCC node!");
23661
23662 SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
23663 cast<CondCodeSDNode>(N0.getOperand(2))->get());
23664
23665 // If we got a simplified select_cc node back from SimplifySelectCC, then
23666 // break it down into a new SETCC node, and a new SELECT node, and then return
23667 // the SELECT node, since we were called with a SELECT node.
23668 if (SCC.getNode()) {
23669 // Check to see if we got a select_cc back (to turn into setcc/select).
23670 // Otherwise, just return whatever node we got back, like fabs.
23671 if (SCC.getOpcode() == ISD::SELECT_CC) {
23672 const SDNodeFlags Flags = N0->getFlags();
23673 SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0),
23674 N0.getValueType(),
23675 SCC.getOperand(0), SCC.getOperand(1),
23676 SCC.getOperand(4), Flags);
23677 AddToWorklist(SETCC.getNode());
23678 SDValue SelectNode = DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC,
23679 SCC.getOperand(2), SCC.getOperand(3));
23680 SelectNode->setFlags(Flags);
23681 return SelectNode;
23682 }
23683
23684 return SCC;
23685 }
23686 return SDValue();
23687 }
23688
23689 /// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
23690 /// being selected between, see if we can simplify the select. Callers of this
23691 /// should assume that TheSelect is deleted if this returns true. As such, they
23692 /// should return the appropriate thing (e.g. the node) back to the top-level of
23693 /// the DAG combiner loop to avoid it being looked at.
SimplifySelectOps(SDNode * TheSelect,SDValue LHS,SDValue RHS)23694 bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
23695 SDValue RHS) {
23696 // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
23697 // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
23698 if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
23699 if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
23700 // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
23701 SDValue Sqrt = RHS;
23702 ISD::CondCode CC;
23703 SDValue CmpLHS;
23704 const ConstantFPSDNode *Zero = nullptr;
23705
23706 if (TheSelect->getOpcode() == ISD::SELECT_CC) {
23707 CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
23708 CmpLHS = TheSelect->getOperand(0);
23709 Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
23710 } else {
23711 // SELECT or VSELECT
23712 SDValue Cmp = TheSelect->getOperand(0);
23713 if (Cmp.getOpcode() == ISD::SETCC) {
23714 CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
23715 CmpLHS = Cmp.getOperand(0);
23716 Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
23717 }
23718 }
23719 if (Zero && Zero->isZero() &&
23720 Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
23721 CC == ISD::SETULT || CC == ISD::SETLT)) {
23722 // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
23723 CombineTo(TheSelect, Sqrt);
23724 return true;
23725 }
23726 }
23727 }
23728 // Cannot simplify select with vector condition
23729 if (TheSelect->getOperand(0).getValueType().isVector()) return false;
23730
23731 // If this is a select from two identical things, try to pull the operation
23732 // through the select.
23733 if (LHS.getOpcode() != RHS.getOpcode() ||
23734 !LHS.hasOneUse() || !RHS.hasOneUse())
23735 return false;
23736
23737 // If this is a load and the token chain is identical, replace the select
23738 // of two loads with a load through a select of the address to load from.
23739 // This triggers in things like "select bool X, 10.0, 123.0" after the FP
23740 // constants have been dropped into the constant pool.
23741 if (LHS.getOpcode() == ISD::LOAD) {
23742 LoadSDNode *LLD = cast<LoadSDNode>(LHS);
23743 LoadSDNode *RLD = cast<LoadSDNode>(RHS);
23744
23745 // Token chains must be identical.
23746 if (LHS.getOperand(0) != RHS.getOperand(0) ||
23747 // Do not let this transformation reduce the number of volatile loads.
23748 // Be conservative for atomics for the moment
23749 // TODO: This does appear to be legal for unordered atomics (see D66309)
23750 !LLD->isSimple() || !RLD->isSimple() ||
23751 // FIXME: If either is a pre/post inc/dec load,
23752 // we'd need to split out the address adjustment.
23753 LLD->isIndexed() || RLD->isIndexed() ||
23754 // If this is an EXTLOAD, the VT's must match.
23755 LLD->getMemoryVT() != RLD->getMemoryVT() ||
23756 // If this is an EXTLOAD, the kind of extension must match.
23757 (LLD->getExtensionType() != RLD->getExtensionType() &&
23758 // The only exception is if one of the extensions is anyext.
23759 LLD->getExtensionType() != ISD::EXTLOAD &&
23760 RLD->getExtensionType() != ISD::EXTLOAD) ||
23761 // FIXME: this discards src value information. This is
23762 // over-conservative. It would be beneficial to be able to remember
23763 // both potential memory locations. Since we are discarding
23764 // src value info, don't do the transformation if the memory
23765 // locations are not in the default address space.
23766 LLD->getPointerInfo().getAddrSpace() != 0 ||
23767 RLD->getPointerInfo().getAddrSpace() != 0 ||
23768 // We can't produce a CMOV of a TargetFrameIndex since we won't
23769 // generate the address generation required.
23770 LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
23771 RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
23772 !TLI.isOperationLegalOrCustom(TheSelect->getOpcode(),
23773 LLD->getBasePtr().getValueType()))
23774 return false;
23775
23776 // The loads must not depend on one another.
23777 if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
23778 return false;
23779
23780 // Check that the select condition doesn't reach either load. If so,
23781 // folding this will induce a cycle into the DAG. If not, this is safe to
23782 // xform, so create a select of the addresses.
23783
23784 SmallPtrSet<const SDNode *, 32> Visited;
23785 SmallVector<const SDNode *, 16> Worklist;
23786
23787 // Always fail if LLD and RLD are not independent. TheSelect is a
23788 // predecessor to all Nodes in question so we need not search past it.
23789
23790 Visited.insert(TheSelect);
23791 Worklist.push_back(LLD);
23792 Worklist.push_back(RLD);
23793
23794 if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
23795 SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
23796 return false;
23797
23798 SDValue Addr;
23799 if (TheSelect->getOpcode() == ISD::SELECT) {
23800 // We cannot do this optimization if any pair of {RLD, LLD} is a
23801 // predecessor to {RLD, LLD, CondNode}. As we've already compared the
23802 // Loads, we only need to check if CondNode is a successor to one of the
23803 // loads. We can further avoid this if there's no use of their chain
23804 // value.
23805 SDNode *CondNode = TheSelect->getOperand(0).getNode();
23806 Worklist.push_back(CondNode);
23807
23808 if ((LLD->hasAnyUseOfValue(1) &&
23809 SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
23810 (RLD->hasAnyUseOfValue(1) &&
23811 SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
23812 return false;
23813
23814 Addr = DAG.getSelect(SDLoc(TheSelect),
23815 LLD->getBasePtr().getValueType(),
23816 TheSelect->getOperand(0), LLD->getBasePtr(),
23817 RLD->getBasePtr());
23818 } else { // Otherwise SELECT_CC
23819 // We cannot do this optimization if any pair of {RLD, LLD} is a
23820 // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
23821 // the Loads, we only need to check if CondLHS/CondRHS is a successor to
23822 // one of the loads. We can further avoid this if there's no use of their
23823 // chain value.
23824
23825 SDNode *CondLHS = TheSelect->getOperand(0).getNode();
23826 SDNode *CondRHS = TheSelect->getOperand(1).getNode();
23827 Worklist.push_back(CondLHS);
23828 Worklist.push_back(CondRHS);
23829
23830 if ((LLD->hasAnyUseOfValue(1) &&
23831 SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
23832 (RLD->hasAnyUseOfValue(1) &&
23833 SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
23834 return false;
23835
23836 Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
23837 LLD->getBasePtr().getValueType(),
23838 TheSelect->getOperand(0),
23839 TheSelect->getOperand(1),
23840 LLD->getBasePtr(), RLD->getBasePtr(),
23841 TheSelect->getOperand(4));
23842 }
23843
23844 SDValue Load;
23845 // It is safe to replace the two loads if they have different alignments,
23846 // but the new load must be the minimum (most restrictive) alignment of the
23847 // inputs.
23848 Align Alignment = std::min(LLD->getAlign(), RLD->getAlign());
23849 MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
23850 if (!RLD->isInvariant())
23851 MMOFlags &= ~MachineMemOperand::MOInvariant;
23852 if (!RLD->isDereferenceable())
23853 MMOFlags &= ~MachineMemOperand::MODereferenceable;
23854 if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
23855 // FIXME: Discards pointer and AA info.
23856 Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
23857 LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
23858 MMOFlags);
23859 } else {
23860 // FIXME: Discards pointer and AA info.
23861 Load = DAG.getExtLoad(
23862 LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
23863 : LLD->getExtensionType(),
23864 SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
23865 MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
23866 }
23867
23868 // Users of the select now use the result of the load.
23869 CombineTo(TheSelect, Load);
23870
23871 // Users of the old loads now use the new load's chain. We know the
23872 // old-load value is dead now.
23873 CombineTo(LHS.getNode(), Load.getValue(0), Load.getValue(1));
23874 CombineTo(RHS.getNode(), Load.getValue(0), Load.getValue(1));
23875 return true;
23876 }
23877
23878 return false;
23879 }
23880
23881 /// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
23882 /// bitwise 'and'.
foldSelectCCToShiftAnd(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)23883 SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
23884 SDValue N1, SDValue N2, SDValue N3,
23885 ISD::CondCode CC) {
23886 // If this is a select where the false operand is zero and the compare is a
23887 // check of the sign bit, see if we can perform the "gzip trick":
23888 // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
23889 // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
23890 EVT XType = N0.getValueType();
23891 EVT AType = N2.getValueType();
23892 if (!isNullConstant(N3) || !XType.bitsGE(AType))
23893 return SDValue();
23894
23895 // If the comparison is testing for a positive value, we have to invert
23896 // the sign bit mask, so only do that transform if the target has a bitwise
23897 // 'and not' instruction (the invert is free).
23898 if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
23899 // (X > -1) ? A : 0
23900 // (X > 0) ? X : 0 <-- This is canonical signed max.
23901 if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
23902 return SDValue();
23903 } else if (CC == ISD::SETLT) {
23904 // (X < 0) ? A : 0
23905 // (X < 1) ? X : 0 <-- This is un-canonicalized signed min.
23906 if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
23907 return SDValue();
23908 } else {
23909 return SDValue();
23910 }
23911
23912 // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
23913 // constant.
23914 EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
23915 auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
23916 if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
23917 unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
23918 if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
23919 SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
23920 SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
23921 AddToWorklist(Shift.getNode());
23922
23923 if (XType.bitsGT(AType)) {
23924 Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
23925 AddToWorklist(Shift.getNode());
23926 }
23927
23928 if (CC == ISD::SETGT)
23929 Shift = DAG.getNOT(DL, Shift, AType);
23930
23931 return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
23932 }
23933 }
23934
23935 unsigned ShCt = XType.getSizeInBits() - 1;
23936 if (TLI.shouldAvoidTransformToShift(XType, ShCt))
23937 return SDValue();
23938
23939 SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
23940 SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
23941 AddToWorklist(Shift.getNode());
23942
23943 if (XType.bitsGT(AType)) {
23944 Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
23945 AddToWorklist(Shift.getNode());
23946 }
23947
23948 if (CC == ISD::SETGT)
23949 Shift = DAG.getNOT(DL, Shift, AType);
23950
23951 return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
23952 }
23953
23954 // Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
foldSelectOfBinops(SDNode * N)23955 SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
23956 SDValue N0 = N->getOperand(0);
23957 SDValue N1 = N->getOperand(1);
23958 SDValue N2 = N->getOperand(2);
23959 EVT VT = N->getValueType(0);
23960 SDLoc DL(N);
23961
23962 unsigned BinOpc = N1.getOpcode();
23963 if (!TLI.isBinOp(BinOpc) || (N2.getOpcode() != BinOpc))
23964 return SDValue();
23965
23966 // The use checks are intentionally on SDNode because we may be dealing
23967 // with opcodes that produce more than one SDValue.
23968 // TODO: Do we really need to check N0 (the condition operand of the select)?
23969 // But removing that clause could cause an infinite loop...
23970 if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
23971 return SDValue();
23972
23973 // Binops may include opcodes that return multiple values, so all values
23974 // must be created/propagated from the newly created binops below.
23975 SDVTList OpVTs = N1->getVTList();
23976
23977 // Fold select(cond, binop(x, y), binop(z, y))
23978 // --> binop(select(cond, x, z), y)
23979 if (N1.getOperand(1) == N2.getOperand(1)) {
23980 SDValue NewSel =
23981 DAG.getSelect(DL, VT, N0, N1.getOperand(0), N2.getOperand(0));
23982 SDValue NewBinOp = DAG.getNode(BinOpc, DL, OpVTs, NewSel, N1.getOperand(1));
23983 NewBinOp->setFlags(N1->getFlags());
23984 NewBinOp->intersectFlagsWith(N2->getFlags());
23985 return NewBinOp;
23986 }
23987
23988 // Fold select(cond, binop(x, y), binop(x, z))
23989 // --> binop(x, select(cond, y, z))
23990 // Second op VT might be different (e.g. shift amount type)
23991 if (N1.getOperand(0) == N2.getOperand(0) &&
23992 VT == N1.getOperand(1).getValueType() &&
23993 VT == N2.getOperand(1).getValueType()) {
23994 SDValue NewSel =
23995 DAG.getSelect(DL, VT, N0, N1.getOperand(1), N2.getOperand(1));
23996 SDValue NewBinOp = DAG.getNode(BinOpc, DL, OpVTs, N1.getOperand(0), NewSel);
23997 NewBinOp->setFlags(N1->getFlags());
23998 NewBinOp->intersectFlagsWith(N2->getFlags());
23999 return NewBinOp;
24000 }
24001
24002 // TODO: Handle isCommutativeBinOp patterns as well?
24003 return SDValue();
24004 }
24005
24006 // Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
foldSignChangeInBitcast(SDNode * N)24007 SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
24008 SDValue N0 = N->getOperand(0);
24009 EVT VT = N->getValueType(0);
24010 bool IsFabs = N->getOpcode() == ISD::FABS;
24011 bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
24012
24013 if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
24014 return SDValue();
24015
24016 SDValue Int = N0.getOperand(0);
24017 EVT IntVT = Int.getValueType();
24018
24019 // The operand to cast should be integer.
24020 if (!IntVT.isInteger() || IntVT.isVector())
24021 return SDValue();
24022
24023 // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
24024 // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
24025 APInt SignMask;
24026 if (N0.getValueType().isVector()) {
24027 // For vector, create a sign mask (0x80...) or its inverse (for fabs,
24028 // 0x7f...) per element and splat it.
24029 SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
24030 if (IsFabs)
24031 SignMask = ~SignMask;
24032 SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
24033 } else {
24034 // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
24035 SignMask = APInt::getSignMask(IntVT.getSizeInBits());
24036 if (IsFabs)
24037 SignMask = ~SignMask;
24038 }
24039 SDLoc DL(N0);
24040 Int = DAG.getNode(IsFabs ? ISD::AND : ISD::XOR, DL, IntVT, Int,
24041 DAG.getConstant(SignMask, DL, IntVT));
24042 AddToWorklist(Int.getNode());
24043 return DAG.getBitcast(VT, Int);
24044 }
24045
24046 /// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
24047 /// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
24048 /// in it. This may be a win when the constant is not otherwise available
24049 /// because it replaces two constant pool loads with one.
convertSelectOfFPConstantsToLoadOffset(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)24050 SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
24051 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
24052 ISD::CondCode CC) {
24053 if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType()))
24054 return SDValue();
24055
24056 // If we are before legalize types, we want the other legalization to happen
24057 // first (for example, to avoid messing with soft float).
24058 auto *TV = dyn_cast<ConstantFPSDNode>(N2);
24059 auto *FV = dyn_cast<ConstantFPSDNode>(N3);
24060 EVT VT = N2.getValueType();
24061 if (!TV || !FV || !TLI.isTypeLegal(VT))
24062 return SDValue();
24063
24064 // If a constant can be materialized without loads, this does not make sense.
24065 if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
24066 TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0), ForCodeSize) ||
24067 TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0), ForCodeSize))
24068 return SDValue();
24069
24070 // If both constants have multiple uses, then we won't need to do an extra
24071 // load. The values are likely around in registers for other users.
24072 if (!TV->hasOneUse() && !FV->hasOneUse())
24073 return SDValue();
24074
24075 Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
24076 const_cast<ConstantFP*>(TV->getConstantFPValue()) };
24077 Type *FPTy = Elts[0]->getType();
24078 const DataLayout &TD = DAG.getDataLayout();
24079
24080 // Create a ConstantArray of the two constants.
24081 Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
24082 SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
24083 TD.getPrefTypeAlign(FPTy));
24084 Align Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlign();
24085
24086 // Get offsets to the 0 and 1 elements of the array, so we can select between
24087 // them.
24088 SDValue Zero = DAG.getIntPtrConstant(0, DL);
24089 unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
24090 SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
24091 SDValue Cond =
24092 DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
24093 AddToWorklist(Cond.getNode());
24094 SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
24095 AddToWorklist(CstOffset.getNode());
24096 CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
24097 AddToWorklist(CPIdx.getNode());
24098 return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
24099 MachinePointerInfo::getConstantPool(
24100 DAG.getMachineFunction()), Alignment);
24101 }
24102
24103 /// Simplify an expression of the form (N0 cond N1) ? N2 : N3
24104 /// where 'cond' is the comparison specified by CC.
SimplifySelectCC(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,bool NotExtCompare)24105 SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
24106 SDValue N2, SDValue N3, ISD::CondCode CC,
24107 bool NotExtCompare) {
24108 // (x ? y : y) -> y.
24109 if (N2 == N3) return N2;
24110
24111 EVT CmpOpVT = N0.getValueType();
24112 EVT CmpResVT = getSetCCResultType(CmpOpVT);
24113 EVT VT = N2.getValueType();
24114 auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
24115 auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
24116 auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
24117
24118 // Determine if the condition we're dealing with is constant.
24119 if (SDValue SCC = DAG.FoldSetCC(CmpResVT, N0, N1, CC, DL)) {
24120 AddToWorklist(SCC.getNode());
24121 if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC)) {
24122 // fold select_cc true, x, y -> x
24123 // fold select_cc false, x, y -> y
24124 return !(SCCC->isZero()) ? N2 : N3;
24125 }
24126 }
24127
24128 if (SDValue V =
24129 convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
24130 return V;
24131
24132 if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
24133 return V;
24134
24135 // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (shr (shl x)) A)
24136 // where y is has a single bit set.
24137 // A plaintext description would be, we can turn the SELECT_CC into an AND
24138 // when the condition can be materialized as an all-ones register. Any
24139 // single bit-test can be materialized as an all-ones register with
24140 // shift-left and shift-right-arith.
24141 if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
24142 N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
24143 SDValue AndLHS = N0->getOperand(0);
24144 auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
24145 if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) {
24146 // Shift the tested bit over the sign bit.
24147 const APInt &AndMask = ConstAndRHS->getAPIntValue();
24148 unsigned ShCt = AndMask.getBitWidth() - 1;
24149 if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
24150 SDValue ShlAmt =
24151 DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS),
24152 getShiftAmountTy(AndLHS.getValueType()));
24153 SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
24154
24155 // Now arithmetic right shift it all the way over, so the result is
24156 // either all-ones, or zero.
24157 SDValue ShrAmt =
24158 DAG.getConstant(ShCt, SDLoc(Shl),
24159 getShiftAmountTy(Shl.getValueType()));
24160 SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
24161
24162 return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
24163 }
24164 }
24165 }
24166
24167 // fold select C, 16, 0 -> shl C, 4
24168 bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
24169 bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
24170
24171 if ((Fold || Swap) &&
24172 TLI.getBooleanContents(CmpOpVT) ==
24173 TargetLowering::ZeroOrOneBooleanContent &&
24174 (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) {
24175
24176 if (Swap) {
24177 CC = ISD::getSetCCInverse(CC, CmpOpVT);
24178 std::swap(N2C, N3C);
24179 }
24180
24181 // If the caller doesn't want us to simplify this into a zext of a compare,
24182 // don't do it.
24183 if (NotExtCompare && N2C->isOne())
24184 return SDValue();
24185
24186 SDValue Temp, SCC;
24187 // zext (setcc n0, n1)
24188 if (LegalTypes) {
24189 SCC = DAG.getSetCC(DL, CmpResVT, N0, N1, CC);
24190 if (VT.bitsLT(SCC.getValueType()))
24191 Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), VT);
24192 else
24193 Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
24194 } else {
24195 SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
24196 Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
24197 }
24198
24199 AddToWorklist(SCC.getNode());
24200 AddToWorklist(Temp.getNode());
24201
24202 if (N2C->isOne())
24203 return Temp;
24204
24205 unsigned ShCt = N2C->getAPIntValue().logBase2();
24206 if (TLI.shouldAvoidTransformToShift(VT, ShCt))
24207 return SDValue();
24208
24209 // shl setcc result by log2 n2c
24210 return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
24211 DAG.getConstant(ShCt, SDLoc(Temp),
24212 getShiftAmountTy(Temp.getValueType())));
24213 }
24214
24215 // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
24216 // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
24217 // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
24218 // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
24219 // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
24220 // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
24221 // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
24222 // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
24223 if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
24224 SDValue ValueOnZero = N2;
24225 SDValue Count = N3;
24226 // If the condition is NE instead of E, swap the operands.
24227 if (CC == ISD::SETNE)
24228 std::swap(ValueOnZero, Count);
24229 // Check if the value on zero is a constant equal to the bits in the type.
24230 if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
24231 if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
24232 // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
24233 // legal, combine to just cttz.
24234 if ((Count.getOpcode() == ISD::CTTZ ||
24235 Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
24236 N0 == Count.getOperand(0) &&
24237 (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
24238 return DAG.getNode(ISD::CTTZ, DL, VT, N0);
24239 // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
24240 // legal, combine to just ctlz.
24241 if ((Count.getOpcode() == ISD::CTLZ ||
24242 Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
24243 N0 == Count.getOperand(0) &&
24244 (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
24245 return DAG.getNode(ISD::CTLZ, DL, VT, N0);
24246 }
24247 }
24248 }
24249
24250 // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
24251 // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
24252 if (!NotExtCompare && N1C && N2C && N3C &&
24253 N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
24254 ((N1C->isAllOnes() && CC == ISD::SETGT) ||
24255 (N1C->isZero() && CC == ISD::SETLT)) &&
24256 !TLI.shouldAvoidTransformToShift(VT, CmpOpVT.getScalarSizeInBits() - 1)) {
24257 SDValue ASR = DAG.getNode(
24258 ISD::SRA, DL, CmpOpVT, N0,
24259 DAG.getConstant(CmpOpVT.getScalarSizeInBits() - 1, DL, CmpOpVT));
24260 return DAG.getNode(ISD::XOR, DL, VT, DAG.getSExtOrTrunc(ASR, DL, VT),
24261 DAG.getSExtOrTrunc(CC == ISD::SETLT ? N3 : N2, DL, VT));
24262 }
24263
24264 if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
24265 return S;
24266 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
24267 return S;
24268
24269 return SDValue();
24270 }
24271
24272 /// This is a stub for TargetLowering::SimplifySetCC.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,bool foldBooleans)24273 SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
24274 ISD::CondCode Cond, const SDLoc &DL,
24275 bool foldBooleans) {
24276 TargetLowering::DAGCombinerInfo
24277 DagCombineInfo(DAG, Level, false, this);
24278 return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
24279 }
24280
24281 /// Given an ISD::SDIV node expressing a divide by constant, return
24282 /// a DAG expression to select that will generate the same value by multiplying
24283 /// by a magic number.
24284 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N)24285 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
24286 // when optimising for minimum size, we don't want to expand a div to a mul
24287 // and a shift.
24288 if (DAG.getMachineFunction().getFunction().hasMinSize())
24289 return SDValue();
24290
24291 SmallVector<SDNode *, 8> Built;
24292 if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) {
24293 for (SDNode *N : Built)
24294 AddToWorklist(N);
24295 return S;
24296 }
24297
24298 return SDValue();
24299 }
24300
24301 /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
24302 /// DAG expression that will generate the same value by right shifting.
BuildSDIVPow2(SDNode * N)24303 SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
24304 ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
24305 if (!C)
24306 return SDValue();
24307
24308 // Avoid division by zero.
24309 if (C->isZero())
24310 return SDValue();
24311
24312 SmallVector<SDNode *, 8> Built;
24313 if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
24314 for (SDNode *N : Built)
24315 AddToWorklist(N);
24316 return S;
24317 }
24318
24319 return SDValue();
24320 }
24321
24322 /// Given an ISD::UDIV node expressing a divide by constant, return a DAG
24323 /// expression that will generate the same value by multiplying by a magic
24324 /// number.
24325 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N)24326 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
24327 // when optimising for minimum size, we don't want to expand a div to a mul
24328 // and a shift.
24329 if (DAG.getMachineFunction().getFunction().hasMinSize())
24330 return SDValue();
24331
24332 SmallVector<SDNode *, 8> Built;
24333 if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) {
24334 for (SDNode *N : Built)
24335 AddToWorklist(N);
24336 return S;
24337 }
24338
24339 return SDValue();
24340 }
24341
24342 /// Given an ISD::SREM node expressing a remainder by constant power of 2,
24343 /// return a DAG expression that will generate the same value.
BuildSREMPow2(SDNode * N)24344 SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
24345 ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
24346 if (!C)
24347 return SDValue();
24348
24349 // Avoid division by zero.
24350 if (C->isZero())
24351 return SDValue();
24352
24353 SmallVector<SDNode *, 8> Built;
24354 if (SDValue S = TLI.BuildSREMPow2(N, C->getAPIntValue(), DAG, Built)) {
24355 for (SDNode *N : Built)
24356 AddToWorklist(N);
24357 return S;
24358 }
24359
24360 return SDValue();
24361 }
24362
24363 /// Determines the LogBase2 value for a non-null input value using the
24364 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
BuildLogBase2(SDValue V,const SDLoc & DL)24365 SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) {
24366 EVT VT = V.getValueType();
24367 SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
24368 SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
24369 SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
24370 return LogBase2;
24371 }
24372
24373 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
24374 /// For the reciprocal, we need to find the zero of the function:
24375 /// F(X) = 1/X - A [which has a zero at X = 1/A]
24376 /// =>
24377 /// X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
24378 /// does not require additional intermediate precision]
24379 /// For the last iteration, put numerator N into it to gain more precision:
24380 /// Result = N X_i + X_i (N - N A X_i)
BuildDivEstimate(SDValue N,SDValue Op,SDNodeFlags Flags)24381 SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
24382 SDNodeFlags Flags) {
24383 if (LegalDAG)
24384 return SDValue();
24385
24386 // TODO: Handle extended types?
24387 EVT VT = Op.getValueType();
24388 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
24389 VT.getScalarType() != MVT::f64)
24390 return SDValue();
24391
24392 // If estimates are explicitly disabled for this function, we're done.
24393 MachineFunction &MF = DAG.getMachineFunction();
24394 int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
24395 if (Enabled == TLI.ReciprocalEstimate::Disabled)
24396 return SDValue();
24397
24398 // Estimates may be explicitly enabled for this type with a custom number of
24399 // refinement steps.
24400 int Iterations = TLI.getDivRefinementSteps(VT, MF);
24401 if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
24402 AddToWorklist(Est.getNode());
24403
24404 SDLoc DL(Op);
24405 if (Iterations) {
24406 SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
24407
24408 // Newton iterations: Est = Est + Est (N - Arg * Est)
24409 // If this is the last iteration, also multiply by the numerator.
24410 for (int i = 0; i < Iterations; ++i) {
24411 SDValue MulEst = Est;
24412
24413 if (i == Iterations - 1) {
24414 MulEst = DAG.getNode(ISD::FMUL, DL, VT, N, Est, Flags);
24415 AddToWorklist(MulEst.getNode());
24416 }
24417
24418 SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, MulEst, Flags);
24419 AddToWorklist(NewEst.getNode());
24420
24421 NewEst = DAG.getNode(ISD::FSUB, DL, VT,
24422 (i == Iterations - 1 ? N : FPOne), NewEst, Flags);
24423 AddToWorklist(NewEst.getNode());
24424
24425 NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
24426 AddToWorklist(NewEst.getNode());
24427
24428 Est = DAG.getNode(ISD::FADD, DL, VT, MulEst, NewEst, Flags);
24429 AddToWorklist(Est.getNode());
24430 }
24431 } else {
24432 // If no iterations are available, multiply with N.
24433 Est = DAG.getNode(ISD::FMUL, DL, VT, Est, N, Flags);
24434 AddToWorklist(Est.getNode());
24435 }
24436
24437 return Est;
24438 }
24439
24440 return SDValue();
24441 }
24442
24443 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
24444 /// For the reciprocal sqrt, we need to find the zero of the function:
24445 /// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
24446 /// =>
24447 /// X_{i+1} = X_i (1.5 - A X_i^2 / 2)
24448 /// As a result, we precompute A/2 prior to the iteration loop.
buildSqrtNROneConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)24449 SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
24450 unsigned Iterations,
24451 SDNodeFlags Flags, bool Reciprocal) {
24452 EVT VT = Arg.getValueType();
24453 SDLoc DL(Arg);
24454 SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
24455
24456 // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
24457 // this entire sequence requires only one FP constant.
24458 SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags);
24459 HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags);
24460
24461 // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
24462 for (unsigned i = 0; i < Iterations; ++i) {
24463 SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
24464 NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags);
24465 NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags);
24466 Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
24467 }
24468
24469 // If non-reciprocal square root is requested, multiply the result by Arg.
24470 if (!Reciprocal)
24471 Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
24472
24473 return Est;
24474 }
24475
24476 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
24477 /// For the reciprocal sqrt, we need to find the zero of the function:
24478 /// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
24479 /// =>
24480 /// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
buildSqrtNRTwoConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)24481 SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
24482 unsigned Iterations,
24483 SDNodeFlags Flags, bool Reciprocal) {
24484 EVT VT = Arg.getValueType();
24485 SDLoc DL(Arg);
24486 SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
24487 SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
24488
24489 // This routine must enter the loop below to work correctly
24490 // when (Reciprocal == false).
24491 assert(Iterations > 0);
24492
24493 // Newton iterations for reciprocal square root:
24494 // E = (E * -0.5) * ((A * E) * E + -3.0)
24495 for (unsigned i = 0; i < Iterations; ++i) {
24496 SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
24497 SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
24498 SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
24499
24500 // When calculating a square root at the last iteration build:
24501 // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
24502 // (notice a common subexpression)
24503 SDValue LHS;
24504 if (Reciprocal || (i + 1) < Iterations) {
24505 // RSQRT: LHS = (E * -0.5)
24506 LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
24507 } else {
24508 // SQRT: LHS = (A * E) * -0.5
24509 LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
24510 }
24511
24512 Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
24513 }
24514
24515 return Est;
24516 }
24517
24518 /// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
24519 /// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
24520 /// Op can be zero.
buildSqrtEstimateImpl(SDValue Op,SDNodeFlags Flags,bool Reciprocal)24521 SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
24522 bool Reciprocal) {
24523 if (LegalDAG)
24524 return SDValue();
24525
24526 // TODO: Handle extended types?
24527 EVT VT = Op.getValueType();
24528 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
24529 VT.getScalarType() != MVT::f64)
24530 return SDValue();
24531
24532 // If estimates are explicitly disabled for this function, we're done.
24533 MachineFunction &MF = DAG.getMachineFunction();
24534 int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
24535 if (Enabled == TLI.ReciprocalEstimate::Disabled)
24536 return SDValue();
24537
24538 // Estimates may be explicitly enabled for this type with a custom number of
24539 // refinement steps.
24540 int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
24541
24542 bool UseOneConstNR = false;
24543 if (SDValue Est =
24544 TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
24545 Reciprocal)) {
24546 AddToWorklist(Est.getNode());
24547
24548 if (Iterations)
24549 Est = UseOneConstNR
24550 ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
24551 : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
24552 if (!Reciprocal) {
24553 SDLoc DL(Op);
24554 // Try the target specific test first.
24555 SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
24556
24557 // The estimate is now completely wrong if the input was exactly 0.0 or
24558 // possibly a denormal. Force the answer to 0.0 or value provided by
24559 // target for those cases.
24560 Est = DAG.getNode(
24561 Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
24562 Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
24563 }
24564 return Est;
24565 }
24566
24567 return SDValue();
24568 }
24569
buildRsqrtEstimate(SDValue Op,SDNodeFlags Flags)24570 SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
24571 return buildSqrtEstimateImpl(Op, Flags, true);
24572 }
24573
buildSqrtEstimate(SDValue Op,SDNodeFlags Flags)24574 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
24575 return buildSqrtEstimateImpl(Op, Flags, false);
24576 }
24577
24578 /// Return true if there is any possibility that the two addresses overlap.
mayAlias(SDNode * Op0,SDNode * Op1) const24579 bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
24580
24581 struct MemUseCharacteristics {
24582 bool IsVolatile;
24583 bool IsAtomic;
24584 SDValue BasePtr;
24585 int64_t Offset;
24586 Optional<int64_t> NumBytes;
24587 MachineMemOperand *MMO;
24588 };
24589
24590 auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
24591 if (const auto *LSN = dyn_cast<LSBaseSDNode>(N)) {
24592 int64_t Offset = 0;
24593 if (auto *C = dyn_cast<ConstantSDNode>(LSN->getOffset()))
24594 Offset = (LSN->getAddressingMode() == ISD::PRE_INC)
24595 ? C->getSExtValue()
24596 : (LSN->getAddressingMode() == ISD::PRE_DEC)
24597 ? -1 * C->getSExtValue()
24598 : 0;
24599 uint64_t Size =
24600 MemoryLocation::getSizeOrUnknown(LSN->getMemoryVT().getStoreSize());
24601 return {LSN->isVolatile(), LSN->isAtomic(), LSN->getBasePtr(),
24602 Offset /*base offset*/,
24603 Optional<int64_t>(Size),
24604 LSN->getMemOperand()};
24605 }
24606 if (const auto *LN = cast<LifetimeSDNode>(N))
24607 return {false /*isVolatile*/, /*isAtomic*/ false, LN->getOperand(1),
24608 (LN->hasOffset()) ? LN->getOffset() : 0,
24609 (LN->hasOffset()) ? Optional<int64_t>(LN->getSize())
24610 : Optional<int64_t>(),
24611 (MachineMemOperand *)nullptr};
24612 // Default.
24613 return {false /*isvolatile*/, /*isAtomic*/ false, SDValue(),
24614 (int64_t)0 /*offset*/,
24615 Optional<int64_t>() /*size*/, (MachineMemOperand *)nullptr};
24616 };
24617
24618 MemUseCharacteristics MUC0 = getCharacteristics(Op0),
24619 MUC1 = getCharacteristics(Op1);
24620
24621 // If they are to the same address, then they must be aliases.
24622 if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
24623 MUC0.Offset == MUC1.Offset)
24624 return true;
24625
24626 // If they are both volatile then they cannot be reordered.
24627 if (MUC0.IsVolatile && MUC1.IsVolatile)
24628 return true;
24629
24630 // Be conservative about atomics for the moment
24631 // TODO: This is way overconservative for unordered atomics (see D66309)
24632 if (MUC0.IsAtomic && MUC1.IsAtomic)
24633 return true;
24634
24635 if (MUC0.MMO && MUC1.MMO) {
24636 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
24637 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
24638 return false;
24639 }
24640
24641 // Try to prove that there is aliasing, or that there is no aliasing. Either
24642 // way, we can return now. If nothing can be proved, proceed with more tests.
24643 bool IsAlias;
24644 if (BaseIndexOffset::computeAliasing(Op0, MUC0.NumBytes, Op1, MUC1.NumBytes,
24645 DAG, IsAlias))
24646 return IsAlias;
24647
24648 // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
24649 // either are not known.
24650 if (!MUC0.MMO || !MUC1.MMO)
24651 return true;
24652
24653 // If one operation reads from invariant memory, and the other may store, they
24654 // cannot alias. These should really be checking the equivalent of mayWrite,
24655 // but it only matters for memory nodes other than load /store.
24656 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
24657 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
24658 return false;
24659
24660 // If we know required SrcValue1 and SrcValue2 have relatively large
24661 // alignment compared to the size and offset of the access, we may be able
24662 // to prove they do not alias. This check is conservative for now to catch
24663 // cases created by splitting vector types, it only works when the offsets are
24664 // multiples of the size of the data.
24665 int64_t SrcValOffset0 = MUC0.MMO->getOffset();
24666 int64_t SrcValOffset1 = MUC1.MMO->getOffset();
24667 Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
24668 Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
24669 auto &Size0 = MUC0.NumBytes;
24670 auto &Size1 = MUC1.NumBytes;
24671 if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
24672 Size0.has_value() && Size1.has_value() && *Size0 == *Size1 &&
24673 OrigAlignment0 > *Size0 && SrcValOffset0 % *Size0 == 0 &&
24674 SrcValOffset1 % *Size1 == 0) {
24675 int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
24676 int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
24677
24678 // There is no overlap between these relatively aligned accesses of
24679 // similar size. Return no alias.
24680 if ((OffAlign0 + *Size0) <= OffAlign1 || (OffAlign1 + *Size1) <= OffAlign0)
24681 return false;
24682 }
24683
24684 bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
24685 ? CombinerGlobalAA
24686 : DAG.getSubtarget().useAA();
24687 #ifndef NDEBUG
24688 if (CombinerAAOnlyFunc.getNumOccurrences() &&
24689 CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
24690 UseAA = false;
24691 #endif
24692
24693 if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() && Size0 &&
24694 Size1) {
24695 // Use alias analysis information.
24696 int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
24697 int64_t Overlap0 = *Size0 + SrcValOffset0 - MinOffset;
24698 int64_t Overlap1 = *Size1 + SrcValOffset1 - MinOffset;
24699 if (AA->isNoAlias(
24700 MemoryLocation(MUC0.MMO->getValue(), Overlap0,
24701 UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
24702 MemoryLocation(MUC1.MMO->getValue(), Overlap1,
24703 UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
24704 return false;
24705 }
24706
24707 // Otherwise we have to assume they alias.
24708 return true;
24709 }
24710
24711 /// Walk up chain skipping non-aliasing memory nodes,
24712 /// looking for aliasing nodes and adding them to the Aliases vector.
GatherAllAliases(SDNode * N,SDValue OriginalChain,SmallVectorImpl<SDValue> & Aliases)24713 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
24714 SmallVectorImpl<SDValue> &Aliases) {
24715 SmallVector<SDValue, 8> Chains; // List of chains to visit.
24716 SmallPtrSet<SDNode *, 16> Visited; // Visited node set.
24717
24718 // Get alias information for node.
24719 // TODO: relax aliasing for unordered atomics (see D66309)
24720 const bool IsLoad = isa<LoadSDNode>(N) && cast<LoadSDNode>(N)->isSimple();
24721
24722 // Starting off.
24723 Chains.push_back(OriginalChain);
24724 unsigned Depth = 0;
24725
24726 // Attempt to improve chain by a single step
24727 auto ImproveChain = [&](SDValue &C) -> bool {
24728 switch (C.getOpcode()) {
24729 case ISD::EntryToken:
24730 // No need to mark EntryToken.
24731 C = SDValue();
24732 return true;
24733 case ISD::LOAD:
24734 case ISD::STORE: {
24735 // Get alias information for C.
24736 // TODO: Relax aliasing for unordered atomics (see D66309)
24737 bool IsOpLoad = isa<LoadSDNode>(C.getNode()) &&
24738 cast<LSBaseSDNode>(C.getNode())->isSimple();
24739 if ((IsLoad && IsOpLoad) || !mayAlias(N, C.getNode())) {
24740 // Look further up the chain.
24741 C = C.getOperand(0);
24742 return true;
24743 }
24744 // Alias, so stop here.
24745 return false;
24746 }
24747
24748 case ISD::CopyFromReg:
24749 // Always forward past past CopyFromReg.
24750 C = C.getOperand(0);
24751 return true;
24752
24753 case ISD::LIFETIME_START:
24754 case ISD::LIFETIME_END: {
24755 // We can forward past any lifetime start/end that can be proven not to
24756 // alias the memory access.
24757 if (!mayAlias(N, C.getNode())) {
24758 // Look further up the chain.
24759 C = C.getOperand(0);
24760 return true;
24761 }
24762 return false;
24763 }
24764 default:
24765 return false;
24766 }
24767 };
24768
24769 // Look at each chain and determine if it is an alias. If so, add it to the
24770 // aliases list. If not, then continue up the chain looking for the next
24771 // candidate.
24772 while (!Chains.empty()) {
24773 SDValue Chain = Chains.pop_back_val();
24774
24775 // Don't bother if we've seen Chain before.
24776 if (!Visited.insert(Chain.getNode()).second)
24777 continue;
24778
24779 // For TokenFactor nodes, look at each operand and only continue up the
24780 // chain until we reach the depth limit.
24781 //
24782 // FIXME: The depth check could be made to return the last non-aliasing
24783 // chain we found before we hit a tokenfactor rather than the original
24784 // chain.
24785 if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
24786 Aliases.clear();
24787 Aliases.push_back(OriginalChain);
24788 return;
24789 }
24790
24791 if (Chain.getOpcode() == ISD::TokenFactor) {
24792 // We have to check each of the operands of the token factor for "small"
24793 // token factors, so we queue them up. Adding the operands to the queue
24794 // (stack) in reverse order maintains the original order and increases the
24795 // likelihood that getNode will find a matching token factor (CSE.)
24796 if (Chain.getNumOperands() > 16) {
24797 Aliases.push_back(Chain);
24798 continue;
24799 }
24800 for (unsigned n = Chain.getNumOperands(); n;)
24801 Chains.push_back(Chain.getOperand(--n));
24802 ++Depth;
24803 continue;
24804 }
24805 // Everything else
24806 if (ImproveChain(Chain)) {
24807 // Updated Chain Found, Consider new chain if one exists.
24808 if (Chain.getNode())
24809 Chains.push_back(Chain);
24810 ++Depth;
24811 continue;
24812 }
24813 // No Improved Chain Possible, treat as Alias.
24814 Aliases.push_back(Chain);
24815 }
24816 }
24817
24818 /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
24819 /// (aliasing node.)
FindBetterChain(SDNode * N,SDValue OldChain)24820 SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
24821 if (OptLevel == CodeGenOpt::None)
24822 return OldChain;
24823
24824 // Ops for replacing token factor.
24825 SmallVector<SDValue, 8> Aliases;
24826
24827 // Accumulate all the aliases to this node.
24828 GatherAllAliases(N, OldChain, Aliases);
24829
24830 // If no operands then chain to entry token.
24831 if (Aliases.size() == 0)
24832 return DAG.getEntryNode();
24833
24834 // If a single operand then chain to it. We don't need to revisit it.
24835 if (Aliases.size() == 1)
24836 return Aliases[0];
24837
24838 // Construct a custom tailored token factor.
24839 return DAG.getTokenFactor(SDLoc(N), Aliases);
24840 }
24841
24842 namespace {
24843 // TODO: Replace with with std::monostate when we move to C++17.
24844 struct UnitT { } Unit;
operator ==(const UnitT &,const UnitT &)24845 bool operator==(const UnitT &, const UnitT &) { return true; }
operator !=(const UnitT &,const UnitT &)24846 bool operator!=(const UnitT &, const UnitT &) { return false; }
24847 } // namespace
24848
24849 // This function tries to collect a bunch of potentially interesting
24850 // nodes to improve the chains of, all at once. This might seem
24851 // redundant, as this function gets called when visiting every store
24852 // node, so why not let the work be done on each store as it's visited?
24853 //
24854 // I believe this is mainly important because mergeConsecutiveStores
24855 // is unable to deal with merging stores of different sizes, so unless
24856 // we improve the chains of all the potential candidates up-front
24857 // before running mergeConsecutiveStores, it might only see some of
24858 // the nodes that will eventually be candidates, and then not be able
24859 // to go from a partially-merged state to the desired final
24860 // fully-merged state.
24861
parallelizeChainedStores(StoreSDNode * St)24862 bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
24863 SmallVector<StoreSDNode *, 8> ChainedStores;
24864 StoreSDNode *STChain = St;
24865 // Intervals records which offsets from BaseIndex have been covered. In
24866 // the common case, every store writes to the immediately previous address
24867 // space and thus merged with the previous interval at insertion time.
24868
24869 using IMap =
24870 llvm::IntervalMap<int64_t, UnitT, 8, IntervalMapHalfOpenInfo<int64_t>>;
24871 IMap::Allocator A;
24872 IMap Intervals(A);
24873
24874 // This holds the base pointer, index, and the offset in bytes from the base
24875 // pointer.
24876 const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
24877
24878 // We must have a base and an offset.
24879 if (!BasePtr.getBase().getNode())
24880 return false;
24881
24882 // Do not handle stores to undef base pointers.
24883 if (BasePtr.getBase().isUndef())
24884 return false;
24885
24886 // Do not handle stores to opaque types
24887 if (St->getMemoryVT().isZeroSized())
24888 return false;
24889
24890 // BaseIndexOffset assumes that offsets are fixed-size, which
24891 // is not valid for scalable vectors where the offsets are
24892 // scaled by `vscale`, so bail out early.
24893 if (St->getMemoryVT().isScalableVector())
24894 return false;
24895
24896 // Add ST's interval.
24897 Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit);
24898
24899 while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
24900 if (Chain->getMemoryVT().isScalableVector())
24901 return false;
24902
24903 // If the chain has more than one use, then we can't reorder the mem ops.
24904 if (!SDValue(Chain, 0)->hasOneUse())
24905 break;
24906 // TODO: Relax for unordered atomics (see D66309)
24907 if (!Chain->isSimple() || Chain->isIndexed())
24908 break;
24909
24910 // Find the base pointer and offset for this memory node.
24911 const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
24912 // Check that the base pointer is the same as the original one.
24913 int64_t Offset;
24914 if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
24915 break;
24916 int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
24917 // Make sure we don't overlap with other intervals by checking the ones to
24918 // the left or right before inserting.
24919 auto I = Intervals.find(Offset);
24920 // If there's a next interval, we should end before it.
24921 if (I != Intervals.end() && I.start() < (Offset + Length))
24922 break;
24923 // If there's a previous interval, we should start after it.
24924 if (I != Intervals.begin() && (--I).stop() <= Offset)
24925 break;
24926 Intervals.insert(Offset, Offset + Length, Unit);
24927
24928 ChainedStores.push_back(Chain);
24929 STChain = Chain;
24930 }
24931
24932 // If we didn't find a chained store, exit.
24933 if (ChainedStores.size() == 0)
24934 return false;
24935
24936 // Improve all chained stores (St and ChainedStores members) starting from
24937 // where the store chain ended and return single TokenFactor.
24938 SDValue NewChain = STChain->getChain();
24939 SmallVector<SDValue, 8> TFOps;
24940 for (unsigned I = ChainedStores.size(); I;) {
24941 StoreSDNode *S = ChainedStores[--I];
24942 SDValue BetterChain = FindBetterChain(S, NewChain);
24943 S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
24944 S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
24945 TFOps.push_back(SDValue(S, 0));
24946 ChainedStores[I] = S;
24947 }
24948
24949 // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
24950 SDValue BetterChain = FindBetterChain(St, NewChain);
24951 SDValue NewST;
24952 if (St->isTruncatingStore())
24953 NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
24954 St->getBasePtr(), St->getMemoryVT(),
24955 St->getMemOperand());
24956 else
24957 NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
24958 St->getBasePtr(), St->getMemOperand());
24959
24960 TFOps.push_back(NewST);
24961
24962 // If we improved every element of TFOps, then we've lost the dependence on
24963 // NewChain to successors of St and we need to add it back to TFOps. Do so at
24964 // the beginning to keep relative order consistent with FindBetterChains.
24965 auto hasImprovedChain = [&](SDValue ST) -> bool {
24966 return ST->getOperand(0) != NewChain;
24967 };
24968 bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
24969 if (AddNewChain)
24970 TFOps.insert(TFOps.begin(), NewChain);
24971
24972 SDValue TF = DAG.getTokenFactor(SDLoc(STChain), TFOps);
24973 CombineTo(St, TF);
24974
24975 // Add TF and its operands to the worklist.
24976 AddToWorklist(TF.getNode());
24977 for (const SDValue &Op : TF->ops())
24978 AddToWorklist(Op.getNode());
24979 AddToWorklist(STChain);
24980 return true;
24981 }
24982
findBetterNeighborChains(StoreSDNode * St)24983 bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
24984 if (OptLevel == CodeGenOpt::None)
24985 return false;
24986
24987 const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
24988
24989 // We must have a base and an offset.
24990 if (!BasePtr.getBase().getNode())
24991 return false;
24992
24993 // Do not handle stores to undef base pointers.
24994 if (BasePtr.getBase().isUndef())
24995 return false;
24996
24997 // Directly improve a chain of disjoint stores starting at St.
24998 if (parallelizeChainedStores(St))
24999 return true;
25000
25001 // Improve St's Chain..
25002 SDValue BetterChain = FindBetterChain(St, St->getChain());
25003 if (St->getChain() != BetterChain) {
25004 replaceStoreChain(St, BetterChain);
25005 return true;
25006 }
25007 return false;
25008 }
25009
25010 /// This is the entry point for the file.
Combine(CombineLevel Level,AliasAnalysis * AA,CodeGenOpt::Level OptLevel)25011 void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
25012 CodeGenOpt::Level OptLevel) {
25013 /// This is the main entry point to this class.
25014 DAGCombiner(*this, AA, OptLevel).Run(Level);
25015 }
25016