1 //===- AMDGPULegalizerInfo.cpp -----------------------------------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements the targeting of the Machinelegalizer class for
10 /// AMDGPU.
11 /// \todo This should be generated by TableGen.
12 //===----------------------------------------------------------------------===//
13 
14 #include "AMDGPU.h"
15 #include "AMDGPULegalizerInfo.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "SIMachineFunctionInfo.h"
18 
19 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
20 #include "llvm/CodeGen/TargetOpcodes.h"
21 #include "llvm/CodeGen/ValueTypes.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Type.h"
24 #include "llvm/Support/Debug.h"
25 
26 #define DEBUG_TYPE "amdgpu-legalinfo"
27 
28 using namespace llvm;
29 using namespace LegalizeActions;
30 using namespace LegalizeMutations;
31 using namespace LegalityPredicates;
32 
33 
34 static LegalityPredicate isMultiple32(unsigned TypeIdx,
35                                       unsigned MaxSize = 512) {
36   return [=](const LegalityQuery &Query) {
37     const LLT Ty = Query.Types[TypeIdx];
38     const LLT EltTy = Ty.getScalarType();
39     return Ty.getSizeInBits() <= MaxSize && EltTy.getSizeInBits() % 32 == 0;
40   };
41 }
42 
43 static LegalityPredicate isSmallOddVector(unsigned TypeIdx) {
44   return [=](const LegalityQuery &Query) {
45     const LLT Ty = Query.Types[TypeIdx];
46     return Ty.isVector() &&
47            Ty.getNumElements() % 2 != 0 &&
48            Ty.getElementType().getSizeInBits() < 32;
49   };
50 }
51 
52 static LegalizeMutation oneMoreElement(unsigned TypeIdx) {
53   return [=](const LegalityQuery &Query) {
54     const LLT Ty = Query.Types[TypeIdx];
55     const LLT EltTy = Ty.getElementType();
56     return std::make_pair(TypeIdx, LLT::vector(Ty.getNumElements() + 1, EltTy));
57   };
58 }
59 
60 static LegalizeMutation fewerEltsToSize64Vector(unsigned TypeIdx) {
61   return [=](const LegalityQuery &Query) {
62     const LLT Ty = Query.Types[TypeIdx];
63     const LLT EltTy = Ty.getElementType();
64     unsigned Size = Ty.getSizeInBits();
65     unsigned Pieces = (Size + 63) / 64;
66     unsigned NewNumElts = (Ty.getNumElements() + 1) / Pieces;
67     return std::make_pair(TypeIdx, LLT::scalarOrVector(NewNumElts, EltTy));
68   };
69 }
70 
71 static LegalityPredicate vectorWiderThan(unsigned TypeIdx, unsigned Size) {
72   return [=](const LegalityQuery &Query) {
73     const LLT QueryTy = Query.Types[TypeIdx];
74     return QueryTy.isVector() && QueryTy.getSizeInBits() > Size;
75   };
76 }
77 
78 static LegalityPredicate numElementsNotEven(unsigned TypeIdx) {
79   return [=](const LegalityQuery &Query) {
80     const LLT QueryTy = Query.Types[TypeIdx];
81     return QueryTy.isVector() && QueryTy.getNumElements() % 2 != 0;
82   };
83 }
84 
85 AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
86                                          const GCNTargetMachine &TM)
87   :  ST(ST_) {
88   using namespace TargetOpcode;
89 
90   auto GetAddrSpacePtr = [&TM](unsigned AS) {
91     return LLT::pointer(AS, TM.getPointerSizeInBits(AS));
92   };
93 
94   const LLT S1 = LLT::scalar(1);
95   const LLT S8 = LLT::scalar(8);
96   const LLT S16 = LLT::scalar(16);
97   const LLT S32 = LLT::scalar(32);
98   const LLT S64 = LLT::scalar(64);
99   const LLT S128 = LLT::scalar(128);
100   const LLT S256 = LLT::scalar(256);
101   const LLT S512 = LLT::scalar(512);
102 
103   const LLT V2S16 = LLT::vector(2, 16);
104   const LLT V4S16 = LLT::vector(4, 16);
105   const LLT V8S16 = LLT::vector(8, 16);
106 
107   const LLT V2S32 = LLT::vector(2, 32);
108   const LLT V3S32 = LLT::vector(3, 32);
109   const LLT V4S32 = LLT::vector(4, 32);
110   const LLT V5S32 = LLT::vector(5, 32);
111   const LLT V6S32 = LLT::vector(6, 32);
112   const LLT V7S32 = LLT::vector(7, 32);
113   const LLT V8S32 = LLT::vector(8, 32);
114   const LLT V9S32 = LLT::vector(9, 32);
115   const LLT V10S32 = LLT::vector(10, 32);
116   const LLT V11S32 = LLT::vector(11, 32);
117   const LLT V12S32 = LLT::vector(12, 32);
118   const LLT V13S32 = LLT::vector(13, 32);
119   const LLT V14S32 = LLT::vector(14, 32);
120   const LLT V15S32 = LLT::vector(15, 32);
121   const LLT V16S32 = LLT::vector(16, 32);
122 
123   const LLT V2S64 = LLT::vector(2, 64);
124   const LLT V3S64 = LLT::vector(3, 64);
125   const LLT V4S64 = LLT::vector(4, 64);
126   const LLT V5S64 = LLT::vector(5, 64);
127   const LLT V6S64 = LLT::vector(6, 64);
128   const LLT V7S64 = LLT::vector(7, 64);
129   const LLT V8S64 = LLT::vector(8, 64);
130 
131   std::initializer_list<LLT> AllS32Vectors =
132     {V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32,
133      V9S32, V10S32, V11S32, V12S32, V13S32, V14S32, V15S32, V16S32};
134   std::initializer_list<LLT> AllS64Vectors =
135     {V2S64, V3S64, V4S64, V5S64, V6S64, V7S64, V8S64};
136 
137   const LLT GlobalPtr = GetAddrSpacePtr(AMDGPUAS::GLOBAL_ADDRESS);
138   const LLT ConstantPtr = GetAddrSpacePtr(AMDGPUAS::CONSTANT_ADDRESS);
139   const LLT LocalPtr = GetAddrSpacePtr(AMDGPUAS::LOCAL_ADDRESS);
140   const LLT FlatPtr = GetAddrSpacePtr(AMDGPUAS::FLAT_ADDRESS);
141   const LLT PrivatePtr = GetAddrSpacePtr(AMDGPUAS::PRIVATE_ADDRESS);
142 
143   const LLT CodePtr = FlatPtr;
144 
145   const std::initializer_list<LLT> AddrSpaces64 = {
146     GlobalPtr, ConstantPtr, FlatPtr
147   };
148 
149   const std::initializer_list<LLT> AddrSpaces32 = {
150     LocalPtr, PrivatePtr
151   };
152 
153   const std::initializer_list<LLT> FPTypesBase = {
154     S32, S64
155   };
156 
157   const std::initializer_list<LLT> FPTypes16 = {
158     S32, S64, S16
159   };
160 
161   setAction({G_BRCOND, S1}, Legal);
162 
163   // TODO: All multiples of 32, vectors of pointers, all v2s16 pairs, more
164   // elements for v3s16
165   getActionDefinitionsBuilder(G_PHI)
166     .legalFor({S32, S64, V2S16, V4S16, S1, S128, S256})
167     .legalFor(AllS32Vectors)
168     .legalFor(AllS64Vectors)
169     .legalFor(AddrSpaces64)
170     .legalFor(AddrSpaces32)
171     .clampScalar(0, S32, S256)
172     .widenScalarToNextPow2(0, 32)
173     .clampMaxNumElements(0, S32, 16)
174     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
175     .legalIf(isPointer(0));
176 
177   if (ST.has16BitInsts()) {
178     getActionDefinitionsBuilder({G_ADD, G_SUB, G_MUL})
179       .legalFor({S32, S16})
180       .clampScalar(0, S16, S32)
181       .scalarize(0);
182   } else {
183     getActionDefinitionsBuilder({G_ADD, G_SUB, G_MUL})
184       .legalFor({S32})
185       .clampScalar(0, S32, S32)
186       .scalarize(0);
187   }
188 
189   getActionDefinitionsBuilder({G_UMULH, G_SMULH})
190     .legalFor({S32})
191     .clampScalar(0, S32, S32)
192     .scalarize(0);
193 
194   // Report legal for any types we can handle anywhere. For the cases only legal
195   // on the SALU, RegBankSelect will be able to re-legalize.
196   getActionDefinitionsBuilder({G_AND, G_OR, G_XOR})
197     .legalFor({S32, S1, S64, V2S32, V2S16, V4S16})
198     .clampScalar(0, S32, S64)
199     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
200     .fewerElementsIf(vectorWiderThan(0, 32), fewerEltsToSize64Vector(0))
201     .widenScalarToNextPow2(0)
202     .scalarize(0);
203 
204   getActionDefinitionsBuilder({G_UADDO, G_SADDO, G_USUBO, G_SSUBO,
205                                G_UADDE, G_SADDE, G_USUBE, G_SSUBE})
206     .legalFor({{S32, S1}})
207     .clampScalar(0, S32, S32);
208 
209   getActionDefinitionsBuilder(G_BITCAST)
210     .legalForCartesianProduct({S32, V2S16})
211     .legalForCartesianProduct({S64, V2S32, V4S16})
212     .legalForCartesianProduct({V2S64, V4S32})
213     // Don't worry about the size constraint.
214     .legalIf(all(isPointer(0), isPointer(1)));
215 
216   if (ST.has16BitInsts()) {
217     getActionDefinitionsBuilder(G_FCONSTANT)
218       .legalFor({S32, S64, S16})
219       .clampScalar(0, S16, S64);
220   } else {
221     getActionDefinitionsBuilder(G_FCONSTANT)
222       .legalFor({S32, S64})
223       .clampScalar(0, S32, S64);
224   }
225 
226   getActionDefinitionsBuilder(G_IMPLICIT_DEF)
227     .legalFor({S1, S32, S64, V2S32, V4S32, V2S16, V4S16, GlobalPtr,
228                ConstantPtr, LocalPtr, FlatPtr, PrivatePtr})
229     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
230     .clampScalarOrElt(0, S32, S512)
231     .legalIf(isMultiple32(0))
232     .widenScalarToNextPow2(0, 32)
233     .clampMaxNumElements(0, S32, 16);
234 
235 
236   // FIXME: i1 operands to intrinsics should always be legal, but other i1
237   // values may not be legal.  We need to figure out how to distinguish
238   // between these two scenarios.
239   getActionDefinitionsBuilder(G_CONSTANT)
240     .legalFor({S1, S32, S64, GlobalPtr,
241                LocalPtr, ConstantPtr, PrivatePtr, FlatPtr })
242     .clampScalar(0, S32, S64)
243     .widenScalarToNextPow2(0)
244     .legalIf(isPointer(0));
245 
246   setAction({G_FRAME_INDEX, PrivatePtr}, Legal);
247 
248   auto &FPOpActions = getActionDefinitionsBuilder(
249     { G_FADD, G_FMUL, G_FNEG, G_FABS, G_FMA, G_FCANONICALIZE})
250     .legalFor({S32, S64});
251 
252   if (ST.has16BitInsts()) {
253     if (ST.hasVOP3PInsts())
254       FPOpActions.legalFor({S16, V2S16});
255     else
256       FPOpActions.legalFor({S16});
257   }
258 
259   if (ST.hasVOP3PInsts())
260     FPOpActions.clampMaxNumElements(0, S16, 2);
261   FPOpActions
262     .scalarize(0)
263     .clampScalar(0, ST.has16BitInsts() ? S16 : S32, S64);
264 
265   if (ST.has16BitInsts()) {
266     getActionDefinitionsBuilder(G_FSQRT)
267       .legalFor({S32, S64, S16})
268       .scalarize(0)
269       .clampScalar(0, S16, S64);
270   } else {
271     getActionDefinitionsBuilder(G_FSQRT)
272       .legalFor({S32, S64})
273       .scalarize(0)
274       .clampScalar(0, S32, S64);
275   }
276 
277   getActionDefinitionsBuilder(G_FPTRUNC)
278     .legalFor({{S32, S64}, {S16, S32}})
279     .scalarize(0);
280 
281   getActionDefinitionsBuilder(G_FPEXT)
282     .legalFor({{S64, S32}, {S32, S16}})
283     .lowerFor({{S64, S16}}) // FIXME: Implement
284     .scalarize(0);
285 
286   getActionDefinitionsBuilder(G_FCOPYSIGN)
287     .legalForCartesianProduct({S16, S32, S64}, {S16, S32, S64})
288     .scalarize(0);
289 
290   getActionDefinitionsBuilder(G_FSUB)
291       // Use actual fsub instruction
292       .legalFor({S32})
293       // Must use fadd + fneg
294       .lowerFor({S64, S16, V2S16})
295       .scalarize(0)
296       .clampScalar(0, S32, S64);
297 
298   getActionDefinitionsBuilder({G_SEXT, G_ZEXT, G_ANYEXT})
299     .legalFor({{S64, S32}, {S32, S16}, {S64, S16},
300                {S32, S1}, {S64, S1}, {S16, S1},
301                // FIXME: Hack
302                {S64, LLT::scalar(33)},
303                {S32, S8}, {S128, S32}, {S128, S64}, {S32, LLT::scalar(24)}})
304     .scalarize(0);
305 
306   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
307     .legalFor({{S32, S32}, {S64, S32}})
308     .lowerFor({{S32, S64}})
309     .customFor({{S64, S64}})
310     .scalarize(0);
311 
312   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
313     .legalFor({{S32, S32}, {S32, S64}})
314     .scalarize(0);
315 
316   getActionDefinitionsBuilder(G_INTRINSIC_ROUND)
317     .legalFor({S32, S64})
318     .scalarize(0);
319 
320   if (ST.getGeneration() >= AMDGPUSubtarget::SEA_ISLANDS) {
321     getActionDefinitionsBuilder({G_INTRINSIC_TRUNC, G_FCEIL, G_FRINT})
322       .legalFor({S32, S64})
323       .clampScalar(0, S32, S64)
324       .scalarize(0);
325   } else {
326     getActionDefinitionsBuilder({G_INTRINSIC_TRUNC, G_FCEIL, G_FRINT})
327       .legalFor({S32})
328       .customFor({S64})
329       .clampScalar(0, S32, S64)
330       .scalarize(0);
331   }
332 
333   getActionDefinitionsBuilder(G_GEP)
334     .legalForCartesianProduct(AddrSpaces64, {S64})
335     .legalForCartesianProduct(AddrSpaces32, {S32})
336     .scalarize(0);
337 
338   setAction({G_BLOCK_ADDR, CodePtr}, Legal);
339 
340   getActionDefinitionsBuilder(G_ICMP)
341     .legalForCartesianProduct(
342       {S1}, {S32, S64, GlobalPtr, LocalPtr, ConstantPtr, PrivatePtr, FlatPtr})
343     .legalFor({{S1, S32}, {S1, S64}})
344     .widenScalarToNextPow2(1)
345     .clampScalar(1, S32, S64)
346     .scalarize(0)
347     .legalIf(all(typeIs(0, S1), isPointer(1)));
348 
349   getActionDefinitionsBuilder(G_FCMP)
350     .legalForCartesianProduct({S1}, ST.has16BitInsts() ? FPTypes16 : FPTypesBase)
351     .widenScalarToNextPow2(1)
352     .clampScalar(1, S32, S64)
353     .scalarize(0);
354 
355   // FIXME: fexp, flog2, flog10 needs to be custom lowered.
356   getActionDefinitionsBuilder({G_FPOW, G_FEXP, G_FEXP2,
357                                G_FLOG, G_FLOG2, G_FLOG10})
358     .legalFor({S32})
359     .scalarize(0);
360 
361   // The 64-bit versions produce 32-bit results, but only on the SALU.
362   getActionDefinitionsBuilder({G_CTLZ, G_CTLZ_ZERO_UNDEF,
363                                G_CTTZ, G_CTTZ_ZERO_UNDEF,
364                                G_CTPOP})
365     .legalFor({{S32, S32}, {S32, S64}})
366     .clampScalar(0, S32, S32)
367     .clampScalar(1, S32, S64)
368     .scalarize(0)
369     .widenScalarToNextPow2(0, 32)
370     .widenScalarToNextPow2(1, 32);
371 
372   // TODO: Expand for > s32
373   getActionDefinitionsBuilder(G_BSWAP)
374     .legalFor({S32})
375     .clampScalar(0, S32, S32)
376     .scalarize(0);
377 
378   if (ST.has16BitInsts()) {
379     if (ST.hasVOP3PInsts()) {
380       getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
381         .legalFor({S32, S16, V2S16})
382         .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
383         .clampMaxNumElements(0, S16, 2)
384         .clampScalar(0, S16, S32)
385         .widenScalarToNextPow2(0)
386         .scalarize(0);
387     } else {
388       getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
389         .legalFor({S32, S16})
390         .widenScalarToNextPow2(0)
391         .clampScalar(0, S16, S32)
392         .scalarize(0);
393     }
394   } else {
395     getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
396       .legalFor({S32})
397       .clampScalar(0, S32, S32)
398       .widenScalarToNextPow2(0)
399       .scalarize(0);
400   }
401 
402   auto smallerThan = [](unsigned TypeIdx0, unsigned TypeIdx1) {
403     return [=](const LegalityQuery &Query) {
404       return Query.Types[TypeIdx0].getSizeInBits() <
405              Query.Types[TypeIdx1].getSizeInBits();
406     };
407   };
408 
409   auto greaterThan = [](unsigned TypeIdx0, unsigned TypeIdx1) {
410     return [=](const LegalityQuery &Query) {
411       return Query.Types[TypeIdx0].getSizeInBits() >
412              Query.Types[TypeIdx1].getSizeInBits();
413     };
414   };
415 
416   getActionDefinitionsBuilder(G_INTTOPTR)
417     // List the common cases
418     .legalForCartesianProduct(AddrSpaces64, {S64})
419     .legalForCartesianProduct(AddrSpaces32, {S32})
420     .scalarize(0)
421     // Accept any address space as long as the size matches
422     .legalIf(sameSize(0, 1))
423     .widenScalarIf(smallerThan(1, 0),
424       [](const LegalityQuery &Query) {
425         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
426       })
427     .narrowScalarIf(greaterThan(1, 0),
428       [](const LegalityQuery &Query) {
429         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
430       });
431 
432   getActionDefinitionsBuilder(G_PTRTOINT)
433     // List the common cases
434     .legalForCartesianProduct(AddrSpaces64, {S64})
435     .legalForCartesianProduct(AddrSpaces32, {S32})
436     .scalarize(0)
437     // Accept any address space as long as the size matches
438     .legalIf(sameSize(0, 1))
439     .widenScalarIf(smallerThan(0, 1),
440       [](const LegalityQuery &Query) {
441         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
442       })
443     .narrowScalarIf(
444       greaterThan(0, 1),
445       [](const LegalityQuery &Query) {
446         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
447       });
448 
449   if (ST.hasFlatAddressSpace()) {
450     getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
451       .scalarize(0)
452       .custom();
453   }
454 
455   getActionDefinitionsBuilder({G_LOAD, G_STORE})
456     .narrowScalarIf([](const LegalityQuery &Query) {
457         unsigned Size = Query.Types[0].getSizeInBits();
458         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
459         return (Size > 32 && MemSize < Size);
460       },
461       [](const LegalityQuery &Query) {
462         return std::make_pair(0, LLT::scalar(32));
463       })
464     .fewerElementsIf([=](const LegalityQuery &Query) {
465         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
466         return (MemSize == 96) &&
467                Query.Types[0].isVector() &&
468                !ST.hasDwordx3LoadStores();
469       },
470       [=](const LegalityQuery &Query) {
471         return std::make_pair(0, V2S32);
472       })
473     .legalIf([=](const LegalityQuery &Query) {
474         const LLT &Ty0 = Query.Types[0];
475 
476         unsigned Size = Ty0.getSizeInBits();
477         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
478         if (Size < 32 || (Size > 32 && MemSize < Size))
479           return false;
480 
481         if (Ty0.isVector() && Size != MemSize)
482           return false;
483 
484         // TODO: Decompose private loads into 4-byte components.
485         // TODO: Illegal flat loads on SI
486         switch (MemSize) {
487         case 8:
488         case 16:
489           return Size == 32;
490         case 32:
491         case 64:
492         case 128:
493           return true;
494 
495         case 96:
496           return ST.hasDwordx3LoadStores();
497 
498         case 256:
499         case 512:
500           // TODO: constant loads
501         default:
502           return false;
503         }
504       })
505     .clampScalar(0, S32, S64);
506 
507 
508   // FIXME: Handle alignment requirements.
509   auto &ExtLoads = getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD})
510     .legalForTypesWithMemDesc({
511         {S32, GlobalPtr, 8, 8},
512         {S32, GlobalPtr, 16, 8},
513         {S32, LocalPtr, 8, 8},
514         {S32, LocalPtr, 16, 8},
515         {S32, PrivatePtr, 8, 8},
516         {S32, PrivatePtr, 16, 8}});
517   if (ST.hasFlatAddressSpace()) {
518     ExtLoads.legalForTypesWithMemDesc({{S32, FlatPtr, 8, 8},
519                                        {S32, FlatPtr, 16, 8}});
520   }
521 
522   ExtLoads.clampScalar(0, S32, S32)
523           .widenScalarToNextPow2(0)
524           .unsupportedIfMemSizeNotPow2()
525           .lower();
526 
527   auto &Atomics = getActionDefinitionsBuilder(
528     {G_ATOMICRMW_XCHG, G_ATOMICRMW_ADD, G_ATOMICRMW_SUB,
529      G_ATOMICRMW_AND, G_ATOMICRMW_OR, G_ATOMICRMW_XOR,
530      G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_UMAX,
531      G_ATOMICRMW_UMIN, G_ATOMIC_CMPXCHG})
532     .legalFor({{S32, GlobalPtr}, {S32, LocalPtr},
533                {S64, GlobalPtr}, {S64, LocalPtr}});
534   if (ST.hasFlatAddressSpace()) {
535     Atomics.legalFor({{S32, FlatPtr}, {S64, FlatPtr}});
536   }
537 
538   // TODO: Pointer types, any 32-bit or 64-bit vector
539   getActionDefinitionsBuilder(G_SELECT)
540     .legalForCartesianProduct({S32, S64, S16, V2S32, V2S16, V4S16,
541           GlobalPtr, LocalPtr, FlatPtr, PrivatePtr,
542           LLT::vector(2, LocalPtr), LLT::vector(2, PrivatePtr)}, {S1})
543     .clampScalar(0, S16, S64)
544     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
545     .fewerElementsIf(numElementsNotEven(0), scalarize(0))
546     .scalarize(1)
547     .clampMaxNumElements(0, S32, 2)
548     .clampMaxNumElements(0, LocalPtr, 2)
549     .clampMaxNumElements(0, PrivatePtr, 2)
550     .scalarize(0)
551     .widenScalarToNextPow2(0)
552     .legalIf(all(isPointer(0), typeIs(1, S1)));
553 
554   // TODO: Only the low 4/5/6 bits of the shift amount are observed, so we can
555   // be more flexible with the shift amount type.
556   auto &Shifts = getActionDefinitionsBuilder({G_SHL, G_LSHR, G_ASHR})
557     .legalFor({{S32, S32}, {S64, S32}});
558   if (ST.has16BitInsts()) {
559     if (ST.hasVOP3PInsts()) {
560       Shifts.legalFor({{S16, S32}, {S16, S16}, {V2S16, V2S16}})
561             .clampMaxNumElements(0, S16, 2);
562     } else
563       Shifts.legalFor({{S16, S32}, {S16, S16}});
564 
565     Shifts.clampScalar(1, S16, S32);
566     Shifts.clampScalar(0, S16, S64);
567     Shifts.widenScalarToNextPow2(0, 16);
568   } else {
569     // Make sure we legalize the shift amount type first, as the general
570     // expansion for the shifted type will produce much worse code if it hasn't
571     // been truncated already.
572     Shifts.clampScalar(1, S32, S32);
573     Shifts.clampScalar(0, S32, S64);
574     Shifts.widenScalarToNextPow2(0, 32);
575   }
576   Shifts.scalarize(0);
577 
578   for (unsigned Op : {G_EXTRACT_VECTOR_ELT, G_INSERT_VECTOR_ELT}) {
579     unsigned VecTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 1 : 0;
580     unsigned EltTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 0 : 1;
581     unsigned IdxTypeIdx = 2;
582 
583     getActionDefinitionsBuilder(Op)
584       .legalIf([=](const LegalityQuery &Query) {
585           const LLT &VecTy = Query.Types[VecTypeIdx];
586           const LLT &IdxTy = Query.Types[IdxTypeIdx];
587           return VecTy.getSizeInBits() % 32 == 0 &&
588             VecTy.getSizeInBits() <= 512 &&
589             IdxTy.getSizeInBits() == 32;
590         })
591       .clampScalar(EltTypeIdx, S32, S64)
592       .clampScalar(VecTypeIdx, S32, S64)
593       .clampScalar(IdxTypeIdx, S32, S32);
594   }
595 
596   getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
597     .unsupportedIf([=](const LegalityQuery &Query) {
598         const LLT &EltTy = Query.Types[1].getElementType();
599         return Query.Types[0] != EltTy;
600       });
601 
602   for (unsigned Op : {G_EXTRACT, G_INSERT}) {
603     unsigned BigTyIdx = Op == G_EXTRACT ? 1 : 0;
604     unsigned LitTyIdx = Op == G_EXTRACT ? 0 : 1;
605 
606     // FIXME: Doesn't handle extract of illegal sizes.
607     getActionDefinitionsBuilder(Op)
608       .legalIf([=](const LegalityQuery &Query) {
609           const LLT BigTy = Query.Types[BigTyIdx];
610           const LLT LitTy = Query.Types[LitTyIdx];
611           return (BigTy.getSizeInBits() % 32 == 0) &&
612                  (LitTy.getSizeInBits() % 16 == 0);
613         })
614       .widenScalarIf(
615         [=](const LegalityQuery &Query) {
616           const LLT BigTy = Query.Types[BigTyIdx];
617           return (BigTy.getScalarSizeInBits() < 16);
618         },
619         LegalizeMutations::widenScalarOrEltToNextPow2(BigTyIdx, 16))
620       .widenScalarIf(
621         [=](const LegalityQuery &Query) {
622           const LLT LitTy = Query.Types[LitTyIdx];
623           return (LitTy.getScalarSizeInBits() < 16);
624         },
625         LegalizeMutations::widenScalarOrEltToNextPow2(LitTyIdx, 16))
626       .moreElementsIf(isSmallOddVector(BigTyIdx), oneMoreElement(BigTyIdx))
627       .widenScalarToNextPow2(BigTyIdx, 32);
628 
629   }
630 
631   // TODO: vectors of pointers
632   getActionDefinitionsBuilder(G_BUILD_VECTOR)
633       .legalForCartesianProduct(AllS32Vectors, {S32})
634       .legalForCartesianProduct(AllS64Vectors, {S64})
635       .clampNumElements(0, V16S32, V16S32)
636       .clampNumElements(0, V2S64, V8S64)
637       .minScalarSameAs(1, 0)
638       // FIXME: Sort of a hack to make progress on other legalizations.
639       .legalIf([=](const LegalityQuery &Query) {
640         return Query.Types[0].getScalarSizeInBits() <= 32 ||
641                Query.Types[0].getScalarSizeInBits() == 64;
642       });
643 
644   // TODO: Support any combination of v2s32
645   getActionDefinitionsBuilder(G_CONCAT_VECTORS)
646     .legalFor({{V4S32, V2S32},
647                {V8S32, V2S32},
648                {V8S32, V4S32},
649                {V4S64, V2S64},
650                {V4S16, V2S16},
651                {V8S16, V2S16},
652                {V8S16, V4S16},
653                {LLT::vector(4, LocalPtr), LLT::vector(2, LocalPtr)},
654                {LLT::vector(4, PrivatePtr), LLT::vector(2, PrivatePtr)}});
655 
656   // Merge/Unmerge
657   for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
658     unsigned BigTyIdx = Op == G_MERGE_VALUES ? 0 : 1;
659     unsigned LitTyIdx = Op == G_MERGE_VALUES ? 1 : 0;
660 
661     auto notValidElt = [=](const LegalityQuery &Query, unsigned TypeIdx) {
662       const LLT &Ty = Query.Types[TypeIdx];
663       if (Ty.isVector()) {
664         const LLT &EltTy = Ty.getElementType();
665         if (EltTy.getSizeInBits() < 8 || EltTy.getSizeInBits() > 64)
666           return true;
667         if (!isPowerOf2_32(EltTy.getSizeInBits()))
668           return true;
669       }
670       return false;
671     };
672 
673     getActionDefinitionsBuilder(Op)
674       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 16)
675       // Clamp the little scalar to s8-s256 and make it a power of 2. It's not
676       // worth considering the multiples of 64 since 2*192 and 2*384 are not
677       // valid.
678       .clampScalar(LitTyIdx, S16, S256)
679       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 32)
680 
681       // Break up vectors with weird elements into scalars
682       .fewerElementsIf(
683         [=](const LegalityQuery &Query) { return notValidElt(Query, 0); },
684         scalarize(0))
685       .fewerElementsIf(
686         [=](const LegalityQuery &Query) { return notValidElt(Query, 1); },
687         scalarize(1))
688       .clampScalar(BigTyIdx, S32, S512)
689       .widenScalarIf(
690         [=](const LegalityQuery &Query) {
691           const LLT &Ty = Query.Types[BigTyIdx];
692           return !isPowerOf2_32(Ty.getSizeInBits()) &&
693                  Ty.getSizeInBits() % 16 != 0;
694         },
695         [=](const LegalityQuery &Query) {
696           // Pick the next power of 2, or a multiple of 64 over 128.
697           // Whichever is smaller.
698           const LLT &Ty = Query.Types[BigTyIdx];
699           unsigned NewSizeInBits = 1 << Log2_32_Ceil(Ty.getSizeInBits() + 1);
700           if (NewSizeInBits >= 256) {
701             unsigned RoundedTo = alignTo<64>(Ty.getSizeInBits() + 1);
702             if (RoundedTo < NewSizeInBits)
703               NewSizeInBits = RoundedTo;
704           }
705           return std::make_pair(BigTyIdx, LLT::scalar(NewSizeInBits));
706         })
707       .legalIf([=](const LegalityQuery &Query) {
708           const LLT &BigTy = Query.Types[BigTyIdx];
709           const LLT &LitTy = Query.Types[LitTyIdx];
710 
711           if (BigTy.isVector() && BigTy.getSizeInBits() < 32)
712             return false;
713           if (LitTy.isVector() && LitTy.getSizeInBits() < 32)
714             return false;
715 
716           return BigTy.getSizeInBits() % 16 == 0 &&
717                  LitTy.getSizeInBits() % 16 == 0 &&
718                  BigTy.getSizeInBits() <= 512;
719         })
720       // Any vectors left are the wrong size. Scalarize them.
721       .scalarize(0)
722       .scalarize(1);
723   }
724 
725   computeTables();
726   verify(*ST.getInstrInfo());
727 }
728 
729 bool AMDGPULegalizerInfo::legalizeCustom(MachineInstr &MI,
730                                          MachineRegisterInfo &MRI,
731                                          MachineIRBuilder &MIRBuilder,
732                                          GISelChangeObserver &Observer) const {
733   switch (MI.getOpcode()) {
734   case TargetOpcode::G_ADDRSPACE_CAST:
735     return legalizeAddrSpaceCast(MI, MRI, MIRBuilder);
736   case TargetOpcode::G_FRINT:
737     return legalizeFrint(MI, MRI, MIRBuilder);
738   case TargetOpcode::G_FCEIL:
739     return legalizeFceil(MI, MRI, MIRBuilder);
740   case TargetOpcode::G_INTRINSIC_TRUNC:
741     return legalizeIntrinsicTrunc(MI, MRI, MIRBuilder);
742   case TargetOpcode::G_SITOFP:
743     return legalizeITOFP(MI, MRI, MIRBuilder, true);
744   case TargetOpcode::G_UITOFP:
745     return legalizeITOFP(MI, MRI, MIRBuilder, false);
746   default:
747     return false;
748   }
749 
750   llvm_unreachable("expected switch to return");
751 }
752 
753 Register AMDGPULegalizerInfo::getSegmentAperture(
754   unsigned AS,
755   MachineRegisterInfo &MRI,
756   MachineIRBuilder &MIRBuilder) const {
757   MachineFunction &MF = MIRBuilder.getMF();
758   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
759   const LLT S32 = LLT::scalar(32);
760 
761   if (ST.hasApertureRegs()) {
762     // FIXME: Use inline constants (src_{shared, private}_base) instead of
763     // getreg.
764     unsigned Offset = AS == AMDGPUAS::LOCAL_ADDRESS ?
765         AMDGPU::Hwreg::OFFSET_SRC_SHARED_BASE :
766         AMDGPU::Hwreg::OFFSET_SRC_PRIVATE_BASE;
767     unsigned WidthM1 = AS == AMDGPUAS::LOCAL_ADDRESS ?
768         AMDGPU::Hwreg::WIDTH_M1_SRC_SHARED_BASE :
769         AMDGPU::Hwreg::WIDTH_M1_SRC_PRIVATE_BASE;
770     unsigned Encoding =
771         AMDGPU::Hwreg::ID_MEM_BASES << AMDGPU::Hwreg::ID_SHIFT_ |
772         Offset << AMDGPU::Hwreg::OFFSET_SHIFT_ |
773         WidthM1 << AMDGPU::Hwreg::WIDTH_M1_SHIFT_;
774 
775     Register ApertureReg = MRI.createGenericVirtualRegister(S32);
776     Register GetReg = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
777 
778     MIRBuilder.buildInstr(AMDGPU::S_GETREG_B32)
779       .addDef(GetReg)
780       .addImm(Encoding);
781     MRI.setType(GetReg, S32);
782 
783     auto ShiftAmt = MIRBuilder.buildConstant(S32, WidthM1 + 1);
784     MIRBuilder.buildInstr(TargetOpcode::G_SHL)
785       .addDef(ApertureReg)
786       .addUse(GetReg)
787       .addUse(ShiftAmt.getReg(0));
788 
789     return ApertureReg;
790   }
791 
792   Register QueuePtr = MRI.createGenericVirtualRegister(
793     LLT::pointer(AMDGPUAS::CONSTANT_ADDRESS, 64));
794 
795   // FIXME: Placeholder until we can track the input registers.
796   MIRBuilder.buildConstant(QueuePtr, 0xdeadbeef);
797 
798   // Offset into amd_queue_t for group_segment_aperture_base_hi /
799   // private_segment_aperture_base_hi.
800   uint32_t StructOffset = (AS == AMDGPUAS::LOCAL_ADDRESS) ? 0x40 : 0x44;
801 
802   // FIXME: Don't use undef
803   Value *V = UndefValue::get(PointerType::get(
804                                Type::getInt8Ty(MF.getFunction().getContext()),
805                                AMDGPUAS::CONSTANT_ADDRESS));
806 
807   MachinePointerInfo PtrInfo(V, StructOffset);
808   MachineMemOperand *MMO = MF.getMachineMemOperand(
809     PtrInfo,
810     MachineMemOperand::MOLoad |
811     MachineMemOperand::MODereferenceable |
812     MachineMemOperand::MOInvariant,
813     4,
814     MinAlign(64, StructOffset));
815 
816   Register LoadResult = MRI.createGenericVirtualRegister(S32);
817   Register LoadAddr;
818 
819   MIRBuilder.materializeGEP(LoadAddr, QueuePtr, LLT::scalar(64), StructOffset);
820   MIRBuilder.buildLoad(LoadResult, LoadAddr, *MMO);
821   return LoadResult;
822 }
823 
824 bool AMDGPULegalizerInfo::legalizeAddrSpaceCast(
825   MachineInstr &MI, MachineRegisterInfo &MRI,
826   MachineIRBuilder &MIRBuilder) const {
827   MachineFunction &MF = MIRBuilder.getMF();
828 
829   MIRBuilder.setInstr(MI);
830 
831   Register Dst = MI.getOperand(0).getReg();
832   Register Src = MI.getOperand(1).getReg();
833 
834   LLT DstTy = MRI.getType(Dst);
835   LLT SrcTy = MRI.getType(Src);
836   unsigned DestAS = DstTy.getAddressSpace();
837   unsigned SrcAS = SrcTy.getAddressSpace();
838 
839   // TODO: Avoid reloading from the queue ptr for each cast, or at least each
840   // vector element.
841   assert(!DstTy.isVector());
842 
843   const AMDGPUTargetMachine &TM
844     = static_cast<const AMDGPUTargetMachine &>(MF.getTarget());
845 
846   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
847   if (ST.getTargetLowering()->isNoopAddrSpaceCast(SrcAS, DestAS)) {
848     MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::G_BITCAST));
849     return true;
850   }
851 
852   if (SrcAS == AMDGPUAS::FLAT_ADDRESS) {
853     assert(DestAS == AMDGPUAS::LOCAL_ADDRESS ||
854            DestAS == AMDGPUAS::PRIVATE_ADDRESS);
855     unsigned NullVal = TM.getNullPointerValue(DestAS);
856 
857     auto SegmentNull = MIRBuilder.buildConstant(DstTy, NullVal);
858     auto FlatNull = MIRBuilder.buildConstant(SrcTy, 0);
859 
860     Register PtrLo32 = MRI.createGenericVirtualRegister(DstTy);
861 
862     // Extract low 32-bits of the pointer.
863     MIRBuilder.buildExtract(PtrLo32, Src, 0);
864 
865     Register CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1));
866     MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, FlatNull.getReg(0));
867     MIRBuilder.buildSelect(Dst, CmpRes, PtrLo32, SegmentNull.getReg(0));
868 
869     MI.eraseFromParent();
870     return true;
871   }
872 
873   assert(SrcAS == AMDGPUAS::LOCAL_ADDRESS ||
874          SrcAS == AMDGPUAS::PRIVATE_ADDRESS);
875 
876   auto SegmentNull =
877       MIRBuilder.buildConstant(SrcTy, TM.getNullPointerValue(SrcAS));
878   auto FlatNull =
879       MIRBuilder.buildConstant(DstTy, TM.getNullPointerValue(DestAS));
880 
881   Register ApertureReg = getSegmentAperture(DestAS, MRI, MIRBuilder);
882 
883   Register CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1));
884   MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, SegmentNull.getReg(0));
885 
886   Register BuildPtr = MRI.createGenericVirtualRegister(DstTy);
887 
888   // Coerce the type of the low half of the result so we can use merge_values.
889   Register SrcAsInt = MRI.createGenericVirtualRegister(LLT::scalar(32));
890   MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
891     .addDef(SrcAsInt)
892     .addUse(Src);
893 
894   // TODO: Should we allow mismatched types but matching sizes in merges to
895   // avoid the ptrtoint?
896   MIRBuilder.buildMerge(BuildPtr, {SrcAsInt, ApertureReg});
897   MIRBuilder.buildSelect(Dst, CmpRes, BuildPtr, FlatNull.getReg(0));
898 
899   MI.eraseFromParent();
900   return true;
901 }
902 
903 bool AMDGPULegalizerInfo::legalizeFrint(
904   MachineInstr &MI, MachineRegisterInfo &MRI,
905   MachineIRBuilder &MIRBuilder) const {
906   MIRBuilder.setInstr(MI);
907 
908   Register Src = MI.getOperand(1).getReg();
909   LLT Ty = MRI.getType(Src);
910   assert(Ty.isScalar() && Ty.getSizeInBits() == 64);
911 
912   APFloat C1Val(APFloat::IEEEdouble(), "0x1.0p+52");
913   APFloat C2Val(APFloat::IEEEdouble(), "0x1.fffffffffffffp+51");
914 
915   auto C1 = MIRBuilder.buildFConstant(Ty, C1Val);
916   auto CopySign = MIRBuilder.buildFCopysign(Ty, C1, Src);
917 
918   // TODO: Should this propagate fast-math-flags?
919   auto Tmp1 = MIRBuilder.buildFAdd(Ty, Src, CopySign);
920   auto Tmp2 = MIRBuilder.buildFSub(Ty, Tmp1, CopySign);
921 
922   auto C2 = MIRBuilder.buildFConstant(Ty, C2Val);
923   auto Fabs = MIRBuilder.buildFAbs(Ty, Src);
924 
925   auto Cond = MIRBuilder.buildFCmp(CmpInst::FCMP_OGT, LLT::scalar(1), Fabs, C2);
926   MIRBuilder.buildSelect(MI.getOperand(0).getReg(), Cond, Src, Tmp2);
927   return true;
928 }
929 
930 bool AMDGPULegalizerInfo::legalizeFceil(
931   MachineInstr &MI, MachineRegisterInfo &MRI,
932   MachineIRBuilder &B) const {
933   B.setInstr(MI);
934 
935   const LLT S1 = LLT::scalar(1);
936   const LLT S64 = LLT::scalar(64);
937 
938   Register Src = MI.getOperand(1).getReg();
939   assert(MRI.getType(Src) == S64);
940 
941   // result = trunc(src)
942   // if (src > 0.0 && src != result)
943   //   result += 1.0
944 
945   auto Trunc = B.buildInstr(TargetOpcode::G_INTRINSIC_TRUNC, {S64}, {Src});
946 
947   const auto Zero = B.buildFConstant(S64, 0.0);
948   const auto One = B.buildFConstant(S64, 1.0);
949   auto Lt0 = B.buildFCmp(CmpInst::FCMP_OGT, S1, Src, Zero);
950   auto NeTrunc = B.buildFCmp(CmpInst::FCMP_ONE, S1, Src, Trunc);
951   auto And = B.buildAnd(S1, Lt0, NeTrunc);
952   auto Add = B.buildSelect(S64, And, One, Zero);
953 
954   // TODO: Should this propagate fast-math-flags?
955   B.buildFAdd(MI.getOperand(0).getReg(), Trunc, Add);
956   return true;
957 }
958 
959 static MachineInstrBuilder extractF64Exponent(unsigned Hi,
960                                               MachineIRBuilder &B) {
961   const unsigned FractBits = 52;
962   const unsigned ExpBits = 11;
963   LLT S32 = LLT::scalar(32);
964 
965   auto Const0 = B.buildConstant(S32, FractBits - 32);
966   auto Const1 = B.buildConstant(S32, ExpBits);
967 
968   auto ExpPart = B.buildIntrinsic(Intrinsic::amdgcn_ubfe, {S32}, false)
969     .addUse(Const0.getReg(0))
970     .addUse(Const1.getReg(0));
971 
972   return B.buildSub(S32, ExpPart, B.buildConstant(S32, 1023));
973 }
974 
975 bool AMDGPULegalizerInfo::legalizeIntrinsicTrunc(
976   MachineInstr &MI, MachineRegisterInfo &MRI,
977   MachineIRBuilder &B) const {
978   B.setInstr(MI);
979 
980   const LLT S1 = LLT::scalar(1);
981   const LLT S32 = LLT::scalar(32);
982   const LLT S64 = LLT::scalar(64);
983 
984   Register Src = MI.getOperand(1).getReg();
985   assert(MRI.getType(Src) == S64);
986 
987   // TODO: Should this use extract since the low half is unused?
988   auto Unmerge = B.buildUnmerge({S32, S32}, Src);
989   Register Hi = Unmerge.getReg(1);
990 
991   // Extract the upper half, since this is where we will find the sign and
992   // exponent.
993   auto Exp = extractF64Exponent(Hi, B);
994 
995   const unsigned FractBits = 52;
996 
997   // Extract the sign bit.
998   const auto SignBitMask = B.buildConstant(S32, UINT32_C(1) << 31);
999   auto SignBit = B.buildAnd(S32, Hi, SignBitMask);
1000 
1001   const auto FractMask = B.buildConstant(S64, (UINT64_C(1) << FractBits) - 1);
1002 
1003   const auto Zero32 = B.buildConstant(S32, 0);
1004 
1005   // Extend back to 64-bits.
1006   auto SignBit64 = B.buildMerge(S64, {Zero32.getReg(0), SignBit.getReg(0)});
1007 
1008   auto Shr = B.buildAShr(S64, FractMask, Exp);
1009   auto Not = B.buildNot(S64, Shr);
1010   auto Tmp0 = B.buildAnd(S64, Src, Not);
1011   auto FiftyOne = B.buildConstant(S32, FractBits - 1);
1012 
1013   auto ExpLt0 = B.buildICmp(CmpInst::ICMP_SLT, S1, Exp, Zero32);
1014   auto ExpGt51 = B.buildICmp(CmpInst::ICMP_SGT, S1, Exp, FiftyOne);
1015 
1016   auto Tmp1 = B.buildSelect(S64, ExpLt0, SignBit64, Tmp0);
1017   B.buildSelect(MI.getOperand(0).getReg(), ExpGt51, Src, Tmp1);
1018   return true;
1019 }
1020 
1021 bool AMDGPULegalizerInfo::legalizeITOFP(
1022   MachineInstr &MI, MachineRegisterInfo &MRI,
1023   MachineIRBuilder &B, bool Signed) const {
1024   B.setInstr(MI);
1025 
1026   Register Dst = MI.getOperand(0).getReg();
1027   Register Src = MI.getOperand(1).getReg();
1028 
1029   const LLT S64 = LLT::scalar(64);
1030   const LLT S32 = LLT::scalar(32);
1031 
1032   assert(MRI.getType(Src) == S64 && MRI.getType(Dst) == S64);
1033 
1034   auto Unmerge = B.buildUnmerge({S32, S32}, Src);
1035 
1036   auto CvtHi = Signed ?
1037     B.buildSITOFP(S64, Unmerge.getReg(1)) :
1038     B.buildUITOFP(S64, Unmerge.getReg(1));
1039 
1040   auto CvtLo = B.buildUITOFP(S64, Unmerge.getReg(0));
1041 
1042   auto ThirtyTwo = B.buildConstant(S32, 32);
1043   auto LdExp = B.buildIntrinsic(Intrinsic::amdgcn_ldexp, {S64}, false)
1044     .addUse(CvtHi.getReg(0))
1045     .addUse(ThirtyTwo.getReg(0));
1046 
1047   // TODO: Should this propagate fast-math-flags?
1048   B.buildFAdd(Dst, LdExp, CvtLo);
1049   MI.eraseFromParent();
1050   return true;
1051 }
1052 
1053 // Return the use branch instruction, otherwise null if the usage is invalid.
1054 static MachineInstr *verifyCFIntrinsic(MachineInstr &MI,
1055                                        MachineRegisterInfo &MRI) {
1056   Register CondDef = MI.getOperand(0).getReg();
1057   if (!MRI.hasOneNonDBGUse(CondDef))
1058     return nullptr;
1059 
1060   MachineInstr &UseMI = *MRI.use_instr_nodbg_begin(CondDef);
1061   return UseMI.getParent() == MI.getParent() &&
1062     UseMI.getOpcode() == AMDGPU::G_BRCOND ? &UseMI : nullptr;
1063 }
1064 
1065 Register AMDGPULegalizerInfo::getLiveInRegister(MachineRegisterInfo &MRI,
1066                                                 Register Reg, LLT Ty) const {
1067   Register LiveIn = MRI.getLiveInVirtReg(Reg);
1068   if (LiveIn)
1069     return LiveIn;
1070 
1071   Register NewReg = MRI.createGenericVirtualRegister(Ty);
1072   MRI.addLiveIn(Reg, NewReg);
1073   return NewReg;
1074 }
1075 
1076 bool AMDGPULegalizerInfo::loadInputValue(Register DstReg, MachineIRBuilder &B,
1077                                          const ArgDescriptor *Arg) const {
1078   if (!Arg->isRegister())
1079     return false; // TODO: Handle these
1080 
1081   assert(Arg->getRegister() != 0);
1082   assert(Arg->getRegister().isPhysical());
1083 
1084   MachineRegisterInfo &MRI = *B.getMRI();
1085 
1086   LLT Ty = MRI.getType(DstReg);
1087   Register LiveIn = getLiveInRegister(MRI, Arg->getRegister(), Ty);
1088 
1089   if (Arg->isMasked()) {
1090     // TODO: Should we try to emit this once in the entry block?
1091     const LLT S32 = LLT::scalar(32);
1092     const unsigned Mask = Arg->getMask();
1093     const unsigned Shift = countTrailingZeros<unsigned>(Mask);
1094 
1095     auto ShiftAmt = B.buildConstant(S32, Shift);
1096     auto LShr = B.buildLShr(S32, LiveIn, ShiftAmt);
1097     B.buildAnd(DstReg, LShr, B.buildConstant(S32, Mask >> Shift));
1098   } else
1099     B.buildCopy(DstReg, LiveIn);
1100 
1101   // Insert the argument copy if it doens't already exist.
1102   // FIXME: It seems EmitLiveInCopies isn't called anywhere?
1103   if (!MRI.getVRegDef(LiveIn)) {
1104     MachineBasicBlock &EntryMBB = B.getMF().front();
1105     EntryMBB.addLiveIn(Arg->getRegister());
1106     B.setInsertPt(EntryMBB, EntryMBB.begin());
1107     B.buildCopy(LiveIn, Arg->getRegister());
1108   }
1109 
1110   return true;
1111 }
1112 
1113 bool AMDGPULegalizerInfo::legalizePreloadedArgIntrin(
1114   MachineInstr &MI,
1115   MachineRegisterInfo &MRI,
1116   MachineIRBuilder &B,
1117   AMDGPUFunctionArgInfo::PreloadedValue ArgType) const {
1118   B.setInstr(MI);
1119 
1120   const SIMachineFunctionInfo *MFI = B.getMF().getInfo<SIMachineFunctionInfo>();
1121 
1122   const ArgDescriptor *Arg;
1123   const TargetRegisterClass *RC;
1124   std::tie(Arg, RC) = MFI->getPreloadedValue(ArgType);
1125   if (!Arg) {
1126     LLVM_DEBUG(dbgs() << "Required arg register missing\n");
1127     return false;
1128   }
1129 
1130   if (loadInputValue(MI.getOperand(0).getReg(), B, Arg)) {
1131     MI.eraseFromParent();
1132     return true;
1133   }
1134 
1135   return false;
1136 }
1137 
1138 bool AMDGPULegalizerInfo::legalizeImplicitArgPtr(MachineInstr &MI,
1139                                                  MachineRegisterInfo &MRI,
1140                                                  MachineIRBuilder &B) const {
1141   const SIMachineFunctionInfo *MFI = B.getMF().getInfo<SIMachineFunctionInfo>();
1142   if (!MFI->isEntryFunction()) {
1143     return legalizePreloadedArgIntrin(MI, MRI, B,
1144                                       AMDGPUFunctionArgInfo::IMPLICIT_ARG_PTR);
1145   }
1146 
1147   B.setInstr(MI);
1148 
1149   uint64_t Offset =
1150     ST.getTargetLowering()->getImplicitParameterOffset(
1151       B.getMF(), AMDGPUTargetLowering::FIRST_IMPLICIT);
1152   Register DstReg = MI.getOperand(0).getReg();
1153   LLT DstTy = MRI.getType(DstReg);
1154   LLT IdxTy = LLT::scalar(DstTy.getSizeInBits());
1155 
1156   const ArgDescriptor *Arg;
1157   const TargetRegisterClass *RC;
1158   std::tie(Arg, RC)
1159     = MFI->getPreloadedValue(AMDGPUFunctionArgInfo::KERNARG_SEGMENT_PTR);
1160   if (!Arg)
1161     return false;
1162 
1163   Register KernargPtrReg = MRI.createGenericVirtualRegister(DstTy);
1164   if (!loadInputValue(KernargPtrReg, B, Arg))
1165     return false;
1166 
1167   B.buildGEP(DstReg, KernargPtrReg, B.buildConstant(IdxTy, Offset).getReg(0));
1168   MI.eraseFromParent();
1169   return true;
1170 }
1171 
1172 bool AMDGPULegalizerInfo::legalizeIntrinsic(MachineInstr &MI,
1173                                             MachineRegisterInfo &MRI,
1174                                             MachineIRBuilder &B) const {
1175   // Replace the use G_BRCOND with the exec manipulate and branch pseudos.
1176   switch (MI.getOperand(MI.getNumExplicitDefs()).getIntrinsicID()) {
1177   case Intrinsic::amdgcn_if: {
1178     if (MachineInstr *BrCond = verifyCFIntrinsic(MI, MRI)) {
1179       const SIRegisterInfo *TRI
1180         = static_cast<const SIRegisterInfo *>(MRI.getTargetRegisterInfo());
1181 
1182       B.setInstr(*BrCond);
1183       Register Def = MI.getOperand(1).getReg();
1184       Register Use = MI.getOperand(3).getReg();
1185       B.buildInstr(AMDGPU::SI_IF)
1186         .addDef(Def)
1187         .addUse(Use)
1188         .addMBB(BrCond->getOperand(1).getMBB());
1189 
1190       MRI.setRegClass(Def, TRI->getWaveMaskRegClass());
1191       MRI.setRegClass(Use, TRI->getWaveMaskRegClass());
1192       MI.eraseFromParent();
1193       BrCond->eraseFromParent();
1194       return true;
1195     }
1196 
1197     return false;
1198   }
1199   case Intrinsic::amdgcn_loop: {
1200     if (MachineInstr *BrCond = verifyCFIntrinsic(MI, MRI)) {
1201       const SIRegisterInfo *TRI
1202         = static_cast<const SIRegisterInfo *>(MRI.getTargetRegisterInfo());
1203 
1204       B.setInstr(*BrCond);
1205       Register Reg = MI.getOperand(2).getReg();
1206       B.buildInstr(AMDGPU::SI_LOOP)
1207         .addUse(Reg)
1208         .addMBB(BrCond->getOperand(1).getMBB());
1209       MI.eraseFromParent();
1210       BrCond->eraseFromParent();
1211       MRI.setRegClass(Reg, TRI->getWaveMaskRegClass());
1212       return true;
1213     }
1214 
1215     return false;
1216   }
1217   case Intrinsic::amdgcn_kernarg_segment_ptr:
1218     return legalizePreloadedArgIntrin(
1219       MI, MRI, B, AMDGPUFunctionArgInfo::KERNARG_SEGMENT_PTR);
1220   case Intrinsic::amdgcn_implicitarg_ptr:
1221     return legalizeImplicitArgPtr(MI, MRI, B);
1222   case Intrinsic::amdgcn_workitem_id_x:
1223     return legalizePreloadedArgIntrin(MI, MRI, B,
1224                                       AMDGPUFunctionArgInfo::WORKITEM_ID_X);
1225   case Intrinsic::amdgcn_workitem_id_y:
1226     return legalizePreloadedArgIntrin(MI, MRI, B,
1227                                       AMDGPUFunctionArgInfo::WORKITEM_ID_Y);
1228   case Intrinsic::amdgcn_workitem_id_z:
1229     return legalizePreloadedArgIntrin(MI, MRI, B,
1230                                       AMDGPUFunctionArgInfo::WORKITEM_ID_Z);
1231   case Intrinsic::amdgcn_workgroup_id_x:
1232     return legalizePreloadedArgIntrin(MI, MRI, B,
1233                                       AMDGPUFunctionArgInfo::WORKGROUP_ID_X);
1234   case Intrinsic::amdgcn_workgroup_id_y:
1235     return legalizePreloadedArgIntrin(MI, MRI, B,
1236                                       AMDGPUFunctionArgInfo::WORKGROUP_ID_Y);
1237   case Intrinsic::amdgcn_workgroup_id_z:
1238     return legalizePreloadedArgIntrin(MI, MRI, B,
1239                                       AMDGPUFunctionArgInfo::WORKGROUP_ID_Z);
1240   case Intrinsic::amdgcn_dispatch_ptr:
1241     return legalizePreloadedArgIntrin(MI, MRI, B,
1242                                       AMDGPUFunctionArgInfo::DISPATCH_PTR);
1243   case Intrinsic::amdgcn_queue_ptr:
1244     return legalizePreloadedArgIntrin(MI, MRI, B,
1245                                       AMDGPUFunctionArgInfo::QUEUE_PTR);
1246   case Intrinsic::amdgcn_implicit_buffer_ptr:
1247     return legalizePreloadedArgIntrin(
1248       MI, MRI, B, AMDGPUFunctionArgInfo::IMPLICIT_BUFFER_PTR);
1249   case Intrinsic::amdgcn_dispatch_id:
1250     return legalizePreloadedArgIntrin(MI, MRI, B,
1251                                       AMDGPUFunctionArgInfo::DISPATCH_ID);
1252   default:
1253     return true;
1254   }
1255 
1256   return true;
1257 }
1258