1 //===- AMDGPULegalizerInfo.cpp -----------------------------------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements the targeting of the Machinelegalizer class for
10 /// AMDGPU.
11 /// \todo This should be generated by TableGen.
12 //===----------------------------------------------------------------------===//
13 
14 #include "AMDGPU.h"
15 #include "AMDGPULegalizerInfo.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "SIMachineFunctionInfo.h"
18 
19 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
20 #include "llvm/CodeGen/TargetOpcodes.h"
21 #include "llvm/CodeGen/ValueTypes.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Type.h"
24 #include "llvm/Support/Debug.h"
25 
26 using namespace llvm;
27 using namespace LegalizeActions;
28 using namespace LegalizeMutations;
29 using namespace LegalityPredicates;
30 
31 
32 static LegalityPredicate isMultiple32(unsigned TypeIdx,
33                                       unsigned MaxSize = 512) {
34   return [=](const LegalityQuery &Query) {
35     const LLT Ty = Query.Types[TypeIdx];
36     const LLT EltTy = Ty.getScalarType();
37     return Ty.getSizeInBits() <= MaxSize && EltTy.getSizeInBits() % 32 == 0;
38   };
39 }
40 
41 static LegalityPredicate isSmallOddVector(unsigned TypeIdx) {
42   return [=](const LegalityQuery &Query) {
43     const LLT Ty = Query.Types[TypeIdx];
44     return Ty.isVector() &&
45            Ty.getNumElements() % 2 != 0 &&
46            Ty.getElementType().getSizeInBits() < 32;
47   };
48 }
49 
50 static LegalizeMutation oneMoreElement(unsigned TypeIdx) {
51   return [=](const LegalityQuery &Query) {
52     const LLT Ty = Query.Types[TypeIdx];
53     const LLT EltTy = Ty.getElementType();
54     return std::make_pair(TypeIdx, LLT::vector(Ty.getNumElements() + 1, EltTy));
55   };
56 }
57 
58 static LegalizeMutation fewerEltsToSize64Vector(unsigned TypeIdx) {
59   return [=](const LegalityQuery &Query) {
60     const LLT Ty = Query.Types[TypeIdx];
61     const LLT EltTy = Ty.getElementType();
62     unsigned Size = Ty.getSizeInBits();
63     unsigned Pieces = (Size + 63) / 64;
64     unsigned NewNumElts = (Ty.getNumElements() + 1) / Pieces;
65     return std::make_pair(TypeIdx, LLT::scalarOrVector(NewNumElts, EltTy));
66   };
67 }
68 
69 static LegalityPredicate vectorWiderThan(unsigned TypeIdx, unsigned Size) {
70   return [=](const LegalityQuery &Query) {
71     const LLT QueryTy = Query.Types[TypeIdx];
72     return QueryTy.isVector() && QueryTy.getSizeInBits() > Size;
73   };
74 }
75 
76 static LegalityPredicate numElementsNotEven(unsigned TypeIdx) {
77   return [=](const LegalityQuery &Query) {
78     const LLT QueryTy = Query.Types[TypeIdx];
79     return QueryTy.isVector() && QueryTy.getNumElements() % 2 != 0;
80   };
81 }
82 
83 AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST,
84                                          const GCNTargetMachine &TM) {
85   using namespace TargetOpcode;
86 
87   auto GetAddrSpacePtr = [&TM](unsigned AS) {
88     return LLT::pointer(AS, TM.getPointerSizeInBits(AS));
89   };
90 
91   const LLT S1 = LLT::scalar(1);
92   const LLT S8 = LLT::scalar(8);
93   const LLT S16 = LLT::scalar(16);
94   const LLT S32 = LLT::scalar(32);
95   const LLT S64 = LLT::scalar(64);
96   const LLT S128 = LLT::scalar(128);
97   const LLT S256 = LLT::scalar(256);
98   const LLT S512 = LLT::scalar(512);
99 
100   const LLT V2S16 = LLT::vector(2, 16);
101   const LLT V4S16 = LLT::vector(4, 16);
102   const LLT V8S16 = LLT::vector(8, 16);
103 
104   const LLT V2S32 = LLT::vector(2, 32);
105   const LLT V3S32 = LLT::vector(3, 32);
106   const LLT V4S32 = LLT::vector(4, 32);
107   const LLT V5S32 = LLT::vector(5, 32);
108   const LLT V6S32 = LLT::vector(6, 32);
109   const LLT V7S32 = LLT::vector(7, 32);
110   const LLT V8S32 = LLT::vector(8, 32);
111   const LLT V9S32 = LLT::vector(9, 32);
112   const LLT V10S32 = LLT::vector(10, 32);
113   const LLT V11S32 = LLT::vector(11, 32);
114   const LLT V12S32 = LLT::vector(12, 32);
115   const LLT V13S32 = LLT::vector(13, 32);
116   const LLT V14S32 = LLT::vector(14, 32);
117   const LLT V15S32 = LLT::vector(15, 32);
118   const LLT V16S32 = LLT::vector(16, 32);
119 
120   const LLT V2S64 = LLT::vector(2, 64);
121   const LLT V3S64 = LLT::vector(3, 64);
122   const LLT V4S64 = LLT::vector(4, 64);
123   const LLT V5S64 = LLT::vector(5, 64);
124   const LLT V6S64 = LLT::vector(6, 64);
125   const LLT V7S64 = LLT::vector(7, 64);
126   const LLT V8S64 = LLT::vector(8, 64);
127 
128   std::initializer_list<LLT> AllS32Vectors =
129     {V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32,
130      V9S32, V10S32, V11S32, V12S32, V13S32, V14S32, V15S32, V16S32};
131   std::initializer_list<LLT> AllS64Vectors =
132     {V2S64, V3S64, V4S64, V5S64, V6S64, V7S64, V8S64};
133 
134   const LLT GlobalPtr = GetAddrSpacePtr(AMDGPUAS::GLOBAL_ADDRESS);
135   const LLT ConstantPtr = GetAddrSpacePtr(AMDGPUAS::CONSTANT_ADDRESS);
136   const LLT LocalPtr = GetAddrSpacePtr(AMDGPUAS::LOCAL_ADDRESS);
137   const LLT FlatPtr = GetAddrSpacePtr(AMDGPUAS::FLAT_ADDRESS);
138   const LLT PrivatePtr = GetAddrSpacePtr(AMDGPUAS::PRIVATE_ADDRESS);
139 
140   const LLT CodePtr = FlatPtr;
141 
142   const std::initializer_list<LLT> AddrSpaces64 = {
143     GlobalPtr, ConstantPtr, FlatPtr
144   };
145 
146   const std::initializer_list<LLT> AddrSpaces32 = {
147     LocalPtr, PrivatePtr
148   };
149 
150   setAction({G_BRCOND, S1}, Legal);
151 
152   // TODO: All multiples of 32, vectors of pointers, all v2s16 pairs, more
153   // elements for v3s16
154   getActionDefinitionsBuilder(G_PHI)
155     .legalFor({S32, S64, V2S16, V4S16, S1, S128, S256})
156     .legalFor(AllS32Vectors)
157     .legalFor(AllS64Vectors)
158     .legalFor(AddrSpaces64)
159     .legalFor(AddrSpaces32)
160     .clampScalar(0, S32, S256)
161     .widenScalarToNextPow2(0, 32)
162     .clampMaxNumElements(0, S32, 16)
163     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
164     .legalIf(isPointer(0));
165 
166 
167   getActionDefinitionsBuilder({G_ADD, G_SUB, G_MUL, G_UMULH, G_SMULH})
168     .legalFor({S32})
169     .clampScalar(0, S32, S32)
170     .scalarize(0);
171 
172   // Report legal for any types we can handle anywhere. For the cases only legal
173   // on the SALU, RegBankSelect will be able to re-legalize.
174   getActionDefinitionsBuilder({G_AND, G_OR, G_XOR})
175     .legalFor({S32, S1, S64, V2S32, V2S16, V4S16})
176     .clampScalar(0, S32, S64)
177     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
178     .fewerElementsIf(vectorWiderThan(0, 32), fewerEltsToSize64Vector(0))
179     .widenScalarToNextPow2(0)
180     .scalarize(0);
181 
182   getActionDefinitionsBuilder({G_UADDO, G_SADDO, G_USUBO, G_SSUBO,
183                                G_UADDE, G_SADDE, G_USUBE, G_SSUBE})
184     .legalFor({{S32, S1}})
185     .clampScalar(0, S32, S32);
186 
187   getActionDefinitionsBuilder(G_BITCAST)
188     .legalForCartesianProduct({S32, V2S16})
189     .legalForCartesianProduct({S64, V2S32, V4S16})
190     .legalForCartesianProduct({V2S64, V4S32})
191     // Don't worry about the size constraint.
192     .legalIf(all(isPointer(0), isPointer(1)));
193 
194   if (ST.has16BitInsts()) {
195     getActionDefinitionsBuilder(G_FCONSTANT)
196       .legalFor({S32, S64, S16})
197       .clampScalar(0, S16, S64);
198   } else {
199     getActionDefinitionsBuilder(G_FCONSTANT)
200       .legalFor({S32, S64})
201       .clampScalar(0, S32, S64);
202   }
203 
204   getActionDefinitionsBuilder(G_IMPLICIT_DEF)
205     .legalFor({S1, S32, S64, V2S32, V4S32, V2S16, V4S16, GlobalPtr,
206                ConstantPtr, LocalPtr, FlatPtr, PrivatePtr})
207     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
208     .clampScalarOrElt(0, S32, S512)
209     .legalIf(isMultiple32(0))
210     .widenScalarToNextPow2(0, 32)
211     .clampMaxNumElements(0, S32, 16);
212 
213 
214   // FIXME: i1 operands to intrinsics should always be legal, but other i1
215   // values may not be legal.  We need to figure out how to distinguish
216   // between these two scenarios.
217   getActionDefinitionsBuilder(G_CONSTANT)
218     .legalFor({S1, S32, S64, GlobalPtr,
219                LocalPtr, ConstantPtr, PrivatePtr, FlatPtr })
220     .clampScalar(0, S32, S64)
221     .widenScalarToNextPow2(0)
222     .legalIf(isPointer(0));
223 
224   setAction({G_FRAME_INDEX, PrivatePtr}, Legal);
225 
226   auto &FPOpActions = getActionDefinitionsBuilder(
227     { G_FADD, G_FMUL, G_FNEG, G_FABS, G_FMA, G_FCANONICALIZE})
228     .legalFor({S32, S64});
229 
230   if (ST.has16BitInsts()) {
231     if (ST.hasVOP3PInsts())
232       FPOpActions.legalFor({S16, V2S16});
233     else
234       FPOpActions.legalFor({S16});
235   }
236 
237   if (ST.hasVOP3PInsts())
238     FPOpActions.clampMaxNumElements(0, S16, 2);
239   FPOpActions
240     .scalarize(0)
241     .clampScalar(0, ST.has16BitInsts() ? S16 : S32, S64);
242 
243   if (ST.has16BitInsts()) {
244     getActionDefinitionsBuilder(G_FSQRT)
245       .legalFor({S32, S64, S16})
246       .scalarize(0)
247       .clampScalar(0, S16, S64);
248   } else {
249     getActionDefinitionsBuilder(G_FSQRT)
250       .legalFor({S32, S64})
251       .scalarize(0)
252       .clampScalar(0, S32, S64);
253   }
254 
255   getActionDefinitionsBuilder(G_FPTRUNC)
256     .legalFor({{S32, S64}, {S16, S32}})
257     .scalarize(0);
258 
259   getActionDefinitionsBuilder(G_FPEXT)
260     .legalFor({{S64, S32}, {S32, S16}})
261     .lowerFor({{S64, S16}}) // FIXME: Implement
262     .scalarize(0);
263 
264   getActionDefinitionsBuilder(G_FCOPYSIGN)
265     .legalForCartesianProduct({S16, S32, S64}, {S16, S32, S64})
266     .scalarize(0);
267 
268   getActionDefinitionsBuilder(G_FSUB)
269       // Use actual fsub instruction
270       .legalFor({S32})
271       // Must use fadd + fneg
272       .lowerFor({S64, S16, V2S16})
273       .scalarize(0)
274       .clampScalar(0, S32, S64);
275 
276   getActionDefinitionsBuilder({G_SEXT, G_ZEXT, G_ANYEXT})
277     .legalFor({{S64, S32}, {S32, S16}, {S64, S16},
278                {S32, S1}, {S64, S1}, {S16, S1},
279                // FIXME: Hack
280                {S64, LLT::scalar(33)},
281                {S32, S8}, {S128, S32}, {S128, S64}, {S32, LLT::scalar(24)}})
282     .scalarize(0);
283 
284   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
285     .legalFor({{S32, S32}, {S64, S32}})
286     .lowerFor({{S32, S64}})
287     .customFor({{S64, S64}})
288     .scalarize(0);
289 
290   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
291     .legalFor({{S32, S32}, {S32, S64}})
292     .scalarize(0);
293 
294   getActionDefinitionsBuilder(G_INTRINSIC_ROUND)
295     .legalFor({S32, S64})
296     .scalarize(0);
297 
298   if (ST.getGeneration() >= AMDGPUSubtarget::SEA_ISLANDS) {
299     getActionDefinitionsBuilder({G_INTRINSIC_TRUNC, G_FCEIL, G_FRINT})
300       .legalFor({S32, S64})
301       .clampScalar(0, S32, S64)
302       .scalarize(0);
303   } else {
304     getActionDefinitionsBuilder({G_INTRINSIC_TRUNC, G_FCEIL, G_FRINT})
305       .legalFor({S32})
306       .customFor({S64})
307       .clampScalar(0, S32, S64)
308       .scalarize(0);
309   }
310 
311   getActionDefinitionsBuilder(G_GEP)
312     .legalForCartesianProduct(AddrSpaces64, {S64})
313     .legalForCartesianProduct(AddrSpaces32, {S32})
314     .scalarize(0);
315 
316   setAction({G_BLOCK_ADDR, CodePtr}, Legal);
317 
318   getActionDefinitionsBuilder(G_ICMP)
319     .legalForCartesianProduct(
320       {S1}, {S32, S64, GlobalPtr, LocalPtr, ConstantPtr, PrivatePtr, FlatPtr})
321     .legalFor({{S1, S32}, {S1, S64}})
322     .widenScalarToNextPow2(1)
323     .clampScalar(1, S32, S64)
324     .scalarize(0)
325     .legalIf(all(typeIs(0, S1), isPointer(1)));
326 
327   getActionDefinitionsBuilder(G_FCMP)
328     .legalFor({{S1, S32}, {S1, S64}})
329     .widenScalarToNextPow2(1)
330     .clampScalar(1, S32, S64)
331     .scalarize(0);
332 
333   // FIXME: fexp, flog2, flog10 needs to be custom lowered.
334   getActionDefinitionsBuilder({G_FPOW, G_FEXP, G_FEXP2,
335                                G_FLOG, G_FLOG2, G_FLOG10})
336     .legalFor({S32})
337     .scalarize(0);
338 
339   // The 64-bit versions produce 32-bit results, but only on the SALU.
340   getActionDefinitionsBuilder({G_CTLZ, G_CTLZ_ZERO_UNDEF,
341                                G_CTTZ, G_CTTZ_ZERO_UNDEF,
342                                G_CTPOP})
343     .legalFor({{S32, S32}, {S32, S64}})
344     .clampScalar(0, S32, S32)
345     .clampScalar(1, S32, S64)
346     .scalarize(0)
347     .widenScalarToNextPow2(0, 32)
348     .widenScalarToNextPow2(1, 32);
349 
350   // TODO: Expand for > s32
351   getActionDefinitionsBuilder(G_BSWAP)
352     .legalFor({S32})
353     .clampScalar(0, S32, S32)
354     .scalarize(0);
355 
356   if (ST.has16BitInsts()) {
357     if (ST.hasVOP3PInsts()) {
358       getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
359         .legalFor({S32, S16, V2S16})
360         .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
361         .clampMaxNumElements(0, S16, 2)
362         .clampScalar(0, S16, S32)
363         .widenScalarToNextPow2(0)
364         .scalarize(0);
365     } else {
366       getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
367         .legalFor({S32, S16})
368         .widenScalarToNextPow2(0)
369         .clampScalar(0, S16, S32)
370         .scalarize(0);
371     }
372   } else {
373     getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
374       .legalFor({S32})
375       .clampScalar(0, S32, S32)
376       .widenScalarToNextPow2(0)
377       .scalarize(0);
378   }
379 
380   auto smallerThan = [](unsigned TypeIdx0, unsigned TypeIdx1) {
381     return [=](const LegalityQuery &Query) {
382       return Query.Types[TypeIdx0].getSizeInBits() <
383              Query.Types[TypeIdx1].getSizeInBits();
384     };
385   };
386 
387   auto greaterThan = [](unsigned TypeIdx0, unsigned TypeIdx1) {
388     return [=](const LegalityQuery &Query) {
389       return Query.Types[TypeIdx0].getSizeInBits() >
390              Query.Types[TypeIdx1].getSizeInBits();
391     };
392   };
393 
394   getActionDefinitionsBuilder(G_INTTOPTR)
395     // List the common cases
396     .legalForCartesianProduct(AddrSpaces64, {S64})
397     .legalForCartesianProduct(AddrSpaces32, {S32})
398     .scalarize(0)
399     // Accept any address space as long as the size matches
400     .legalIf(sameSize(0, 1))
401     .widenScalarIf(smallerThan(1, 0),
402       [](const LegalityQuery &Query) {
403         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
404       })
405     .narrowScalarIf(greaterThan(1, 0),
406       [](const LegalityQuery &Query) {
407         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
408       });
409 
410   getActionDefinitionsBuilder(G_PTRTOINT)
411     // List the common cases
412     .legalForCartesianProduct(AddrSpaces64, {S64})
413     .legalForCartesianProduct(AddrSpaces32, {S32})
414     .scalarize(0)
415     // Accept any address space as long as the size matches
416     .legalIf(sameSize(0, 1))
417     .widenScalarIf(smallerThan(0, 1),
418       [](const LegalityQuery &Query) {
419         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
420       })
421     .narrowScalarIf(
422       greaterThan(0, 1),
423       [](const LegalityQuery &Query) {
424         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
425       });
426 
427   if (ST.hasFlatAddressSpace()) {
428     getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
429       .scalarize(0)
430       .custom();
431   }
432 
433   getActionDefinitionsBuilder({G_LOAD, G_STORE})
434     .narrowScalarIf([](const LegalityQuery &Query) {
435         unsigned Size = Query.Types[0].getSizeInBits();
436         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
437         return (Size > 32 && MemSize < Size);
438       },
439       [](const LegalityQuery &Query) {
440         return std::make_pair(0, LLT::scalar(32));
441       })
442     .fewerElementsIf([=, &ST](const LegalityQuery &Query) {
443         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
444         return (MemSize == 96) &&
445                Query.Types[0].isVector() &&
446                ST.getGeneration() < AMDGPUSubtarget::SEA_ISLANDS;
447       },
448       [=](const LegalityQuery &Query) {
449         return std::make_pair(0, V2S32);
450       })
451     .legalIf([=, &ST](const LegalityQuery &Query) {
452         const LLT &Ty0 = Query.Types[0];
453 
454         unsigned Size = Ty0.getSizeInBits();
455         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
456         if (Size < 32 || (Size > 32 && MemSize < Size))
457           return false;
458 
459         if (Ty0.isVector() && Size != MemSize)
460           return false;
461 
462         // TODO: Decompose private loads into 4-byte components.
463         // TODO: Illegal flat loads on SI
464         switch (MemSize) {
465         case 8:
466         case 16:
467           return Size == 32;
468         case 32:
469         case 64:
470         case 128:
471           return true;
472 
473         case 96:
474           // XXX hasLoadX3
475           return (ST.getGeneration() >= AMDGPUSubtarget::SEA_ISLANDS);
476 
477         case 256:
478         case 512:
479           // TODO: constant loads
480         default:
481           return false;
482         }
483       })
484     .clampScalar(0, S32, S64);
485 
486 
487   // FIXME: Handle alignment requirements.
488   auto &ExtLoads = getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD})
489     .legalForTypesWithMemDesc({
490         {S32, GlobalPtr, 8, 8},
491         {S32, GlobalPtr, 16, 8},
492         {S32, LocalPtr, 8, 8},
493         {S32, LocalPtr, 16, 8},
494         {S32, PrivatePtr, 8, 8},
495         {S32, PrivatePtr, 16, 8}});
496   if (ST.hasFlatAddressSpace()) {
497     ExtLoads.legalForTypesWithMemDesc({{S32, FlatPtr, 8, 8},
498                                        {S32, FlatPtr, 16, 8}});
499   }
500 
501   ExtLoads.clampScalar(0, S32, S32)
502           .widenScalarToNextPow2(0)
503           .unsupportedIfMemSizeNotPow2()
504           .lower();
505 
506   auto &Atomics = getActionDefinitionsBuilder(
507     {G_ATOMICRMW_XCHG, G_ATOMICRMW_ADD, G_ATOMICRMW_SUB,
508      G_ATOMICRMW_AND, G_ATOMICRMW_OR, G_ATOMICRMW_XOR,
509      G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_UMAX,
510      G_ATOMICRMW_UMIN, G_ATOMIC_CMPXCHG})
511     .legalFor({{S32, GlobalPtr}, {S32, LocalPtr},
512                {S64, GlobalPtr}, {S64, LocalPtr}});
513   if (ST.hasFlatAddressSpace()) {
514     Atomics.legalFor({{S32, FlatPtr}, {S64, FlatPtr}});
515   }
516 
517   // TODO: Pointer types, any 32-bit or 64-bit vector
518   getActionDefinitionsBuilder(G_SELECT)
519     .legalForCartesianProduct({S32, S64, V2S32, V2S16, V4S16,
520           GlobalPtr, LocalPtr, FlatPtr, PrivatePtr,
521           LLT::vector(2, LocalPtr), LLT::vector(2, PrivatePtr)}, {S1})
522     .clampScalar(0, S32, S64)
523     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
524     .fewerElementsIf(numElementsNotEven(0), scalarize(0))
525     .scalarize(1)
526     .clampMaxNumElements(0, S32, 2)
527     .clampMaxNumElements(0, LocalPtr, 2)
528     .clampMaxNumElements(0, PrivatePtr, 2)
529     .scalarize(0)
530     .widenScalarToNextPow2(0)
531     .legalIf(all(isPointer(0), typeIs(1, S1)));
532 
533   // TODO: Only the low 4/5/6 bits of the shift amount are observed, so we can
534   // be more flexible with the shift amount type.
535   auto &Shifts = getActionDefinitionsBuilder({G_SHL, G_LSHR, G_ASHR})
536     .legalFor({{S32, S32}, {S64, S32}});
537   if (ST.has16BitInsts()) {
538     if (ST.hasVOP3PInsts()) {
539       Shifts.legalFor({{S16, S32}, {S16, S16}, {V2S16, V2S16}})
540             .clampMaxNumElements(0, S16, 2);
541     } else
542       Shifts.legalFor({{S16, S32}, {S16, S16}});
543 
544     Shifts.clampScalar(1, S16, S32);
545     Shifts.clampScalar(0, S16, S64);
546     Shifts.widenScalarToNextPow2(0, 16);
547   } else {
548     // Make sure we legalize the shift amount type first, as the general
549     // expansion for the shifted type will produce much worse code if it hasn't
550     // been truncated already.
551     Shifts.clampScalar(1, S32, S32);
552     Shifts.clampScalar(0, S32, S64);
553     Shifts.widenScalarToNextPow2(0, 32);
554   }
555   Shifts.scalarize(0);
556 
557   for (unsigned Op : {G_EXTRACT_VECTOR_ELT, G_INSERT_VECTOR_ELT}) {
558     unsigned VecTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 1 : 0;
559     unsigned EltTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 0 : 1;
560     unsigned IdxTypeIdx = 2;
561 
562     getActionDefinitionsBuilder(Op)
563       .legalIf([=](const LegalityQuery &Query) {
564           const LLT &VecTy = Query.Types[VecTypeIdx];
565           const LLT &IdxTy = Query.Types[IdxTypeIdx];
566           return VecTy.getSizeInBits() % 32 == 0 &&
567             VecTy.getSizeInBits() <= 512 &&
568             IdxTy.getSizeInBits() == 32;
569         })
570       .clampScalar(EltTypeIdx, S32, S64)
571       .clampScalar(VecTypeIdx, S32, S64)
572       .clampScalar(IdxTypeIdx, S32, S32);
573   }
574 
575   getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
576     .unsupportedIf([=](const LegalityQuery &Query) {
577         const LLT &EltTy = Query.Types[1].getElementType();
578         return Query.Types[0] != EltTy;
579       });
580 
581   for (unsigned Op : {G_EXTRACT, G_INSERT}) {
582     unsigned BigTyIdx = Op == G_EXTRACT ? 1 : 0;
583     unsigned LitTyIdx = Op == G_EXTRACT ? 0 : 1;
584 
585     // FIXME: Doesn't handle extract of illegal sizes.
586     getActionDefinitionsBuilder(Op)
587       .legalIf([=](const LegalityQuery &Query) {
588           const LLT BigTy = Query.Types[BigTyIdx];
589           const LLT LitTy = Query.Types[LitTyIdx];
590           return (BigTy.getSizeInBits() % 32 == 0) &&
591                  (LitTy.getSizeInBits() % 16 == 0);
592         })
593       .widenScalarIf(
594         [=](const LegalityQuery &Query) {
595           const LLT BigTy = Query.Types[BigTyIdx];
596           return (BigTy.getScalarSizeInBits() < 16);
597         },
598         LegalizeMutations::widenScalarOrEltToNextPow2(BigTyIdx, 16))
599       .widenScalarIf(
600         [=](const LegalityQuery &Query) {
601           const LLT LitTy = Query.Types[LitTyIdx];
602           return (LitTy.getScalarSizeInBits() < 16);
603         },
604         LegalizeMutations::widenScalarOrEltToNextPow2(LitTyIdx, 16))
605       .moreElementsIf(isSmallOddVector(BigTyIdx), oneMoreElement(BigTyIdx))
606       .widenScalarToNextPow2(BigTyIdx, 32);
607 
608   }
609 
610   // TODO: vectors of pointers
611   getActionDefinitionsBuilder(G_BUILD_VECTOR)
612       .legalForCartesianProduct(AllS32Vectors, {S32})
613       .legalForCartesianProduct(AllS64Vectors, {S64})
614       .clampNumElements(0, V16S32, V16S32)
615       .clampNumElements(0, V2S64, V8S64)
616       .minScalarSameAs(1, 0)
617       // FIXME: Sort of a hack to make progress on other legalizations.
618       .legalIf([=](const LegalityQuery &Query) {
619         return Query.Types[0].getScalarSizeInBits() <= 32 ||
620                Query.Types[0].getScalarSizeInBits() == 64;
621       });
622 
623   // TODO: Support any combination of v2s32
624   getActionDefinitionsBuilder(G_CONCAT_VECTORS)
625     .legalFor({{V4S32, V2S32},
626                {V8S32, V2S32},
627                {V8S32, V4S32},
628                {V4S64, V2S64},
629                {V4S16, V2S16},
630                {V8S16, V2S16},
631                {V8S16, V4S16},
632                {LLT::vector(4, LocalPtr), LLT::vector(2, LocalPtr)},
633                {LLT::vector(4, PrivatePtr), LLT::vector(2, PrivatePtr)}});
634 
635   // Merge/Unmerge
636   for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
637     unsigned BigTyIdx = Op == G_MERGE_VALUES ? 0 : 1;
638     unsigned LitTyIdx = Op == G_MERGE_VALUES ? 1 : 0;
639 
640     auto notValidElt = [=](const LegalityQuery &Query, unsigned TypeIdx) {
641       const LLT &Ty = Query.Types[TypeIdx];
642       if (Ty.isVector()) {
643         const LLT &EltTy = Ty.getElementType();
644         if (EltTy.getSizeInBits() < 8 || EltTy.getSizeInBits() > 64)
645           return true;
646         if (!isPowerOf2_32(EltTy.getSizeInBits()))
647           return true;
648       }
649       return false;
650     };
651 
652     getActionDefinitionsBuilder(Op)
653       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 16)
654       // Clamp the little scalar to s8-s256 and make it a power of 2. It's not
655       // worth considering the multiples of 64 since 2*192 and 2*384 are not
656       // valid.
657       .clampScalar(LitTyIdx, S16, S256)
658       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 32)
659 
660       // Break up vectors with weird elements into scalars
661       .fewerElementsIf(
662         [=](const LegalityQuery &Query) { return notValidElt(Query, 0); },
663         scalarize(0))
664       .fewerElementsIf(
665         [=](const LegalityQuery &Query) { return notValidElt(Query, 1); },
666         scalarize(1))
667       .clampScalar(BigTyIdx, S32, S512)
668       .widenScalarIf(
669         [=](const LegalityQuery &Query) {
670           const LLT &Ty = Query.Types[BigTyIdx];
671           return !isPowerOf2_32(Ty.getSizeInBits()) &&
672                  Ty.getSizeInBits() % 16 != 0;
673         },
674         [=](const LegalityQuery &Query) {
675           // Pick the next power of 2, or a multiple of 64 over 128.
676           // Whichever is smaller.
677           const LLT &Ty = Query.Types[BigTyIdx];
678           unsigned NewSizeInBits = 1 << Log2_32_Ceil(Ty.getSizeInBits() + 1);
679           if (NewSizeInBits >= 256) {
680             unsigned RoundedTo = alignTo<64>(Ty.getSizeInBits() + 1);
681             if (RoundedTo < NewSizeInBits)
682               NewSizeInBits = RoundedTo;
683           }
684           return std::make_pair(BigTyIdx, LLT::scalar(NewSizeInBits));
685         })
686       .legalIf([=](const LegalityQuery &Query) {
687           const LLT &BigTy = Query.Types[BigTyIdx];
688           const LLT &LitTy = Query.Types[LitTyIdx];
689 
690           if (BigTy.isVector() && BigTy.getSizeInBits() < 32)
691             return false;
692           if (LitTy.isVector() && LitTy.getSizeInBits() < 32)
693             return false;
694 
695           return BigTy.getSizeInBits() % 16 == 0 &&
696                  LitTy.getSizeInBits() % 16 == 0 &&
697                  BigTy.getSizeInBits() <= 512;
698         })
699       // Any vectors left are the wrong size. Scalarize them.
700       .scalarize(0)
701       .scalarize(1);
702   }
703 
704   computeTables();
705   verify(*ST.getInstrInfo());
706 }
707 
708 bool AMDGPULegalizerInfo::legalizeCustom(MachineInstr &MI,
709                                          MachineRegisterInfo &MRI,
710                                          MachineIRBuilder &MIRBuilder,
711                                          GISelChangeObserver &Observer) const {
712   switch (MI.getOpcode()) {
713   case TargetOpcode::G_ADDRSPACE_CAST:
714     return legalizeAddrSpaceCast(MI, MRI, MIRBuilder);
715   case TargetOpcode::G_FRINT:
716     return legalizeFrint(MI, MRI, MIRBuilder);
717   case TargetOpcode::G_FCEIL:
718     return legalizeFceil(MI, MRI, MIRBuilder);
719   case TargetOpcode::G_INTRINSIC_TRUNC:
720     return legalizeIntrinsicTrunc(MI, MRI, MIRBuilder);
721   case TargetOpcode::G_SITOFP:
722     return legalizeITOFP(MI, MRI, MIRBuilder, true);
723   case TargetOpcode::G_UITOFP:
724     return legalizeITOFP(MI, MRI, MIRBuilder, false);
725   default:
726     return false;
727   }
728 
729   llvm_unreachable("expected switch to return");
730 }
731 
732 unsigned AMDGPULegalizerInfo::getSegmentAperture(
733   unsigned AS,
734   MachineRegisterInfo &MRI,
735   MachineIRBuilder &MIRBuilder) const {
736   MachineFunction &MF = MIRBuilder.getMF();
737   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
738   const LLT S32 = LLT::scalar(32);
739 
740   if (ST.hasApertureRegs()) {
741     // FIXME: Use inline constants (src_{shared, private}_base) instead of
742     // getreg.
743     unsigned Offset = AS == AMDGPUAS::LOCAL_ADDRESS ?
744         AMDGPU::Hwreg::OFFSET_SRC_SHARED_BASE :
745         AMDGPU::Hwreg::OFFSET_SRC_PRIVATE_BASE;
746     unsigned WidthM1 = AS == AMDGPUAS::LOCAL_ADDRESS ?
747         AMDGPU::Hwreg::WIDTH_M1_SRC_SHARED_BASE :
748         AMDGPU::Hwreg::WIDTH_M1_SRC_PRIVATE_BASE;
749     unsigned Encoding =
750         AMDGPU::Hwreg::ID_MEM_BASES << AMDGPU::Hwreg::ID_SHIFT_ |
751         Offset << AMDGPU::Hwreg::OFFSET_SHIFT_ |
752         WidthM1 << AMDGPU::Hwreg::WIDTH_M1_SHIFT_;
753 
754     unsigned ApertureReg = MRI.createGenericVirtualRegister(S32);
755     unsigned GetReg = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
756 
757     MIRBuilder.buildInstr(AMDGPU::S_GETREG_B32)
758       .addDef(GetReg)
759       .addImm(Encoding);
760     MRI.setType(GetReg, S32);
761 
762     auto ShiftAmt = MIRBuilder.buildConstant(S32, WidthM1 + 1);
763     MIRBuilder.buildInstr(TargetOpcode::G_SHL)
764       .addDef(ApertureReg)
765       .addUse(GetReg)
766       .addUse(ShiftAmt.getReg(0));
767 
768     return ApertureReg;
769   }
770 
771   unsigned QueuePtr = MRI.createGenericVirtualRegister(
772     LLT::pointer(AMDGPUAS::CONSTANT_ADDRESS, 64));
773 
774   // FIXME: Placeholder until we can track the input registers.
775   MIRBuilder.buildConstant(QueuePtr, 0xdeadbeef);
776 
777   // Offset into amd_queue_t for group_segment_aperture_base_hi /
778   // private_segment_aperture_base_hi.
779   uint32_t StructOffset = (AS == AMDGPUAS::LOCAL_ADDRESS) ? 0x40 : 0x44;
780 
781   // FIXME: Don't use undef
782   Value *V = UndefValue::get(PointerType::get(
783                                Type::getInt8Ty(MF.getFunction().getContext()),
784                                AMDGPUAS::CONSTANT_ADDRESS));
785 
786   MachinePointerInfo PtrInfo(V, StructOffset);
787   MachineMemOperand *MMO = MF.getMachineMemOperand(
788     PtrInfo,
789     MachineMemOperand::MOLoad |
790     MachineMemOperand::MODereferenceable |
791     MachineMemOperand::MOInvariant,
792     4,
793     MinAlign(64, StructOffset));
794 
795   unsigned LoadResult = MRI.createGenericVirtualRegister(S32);
796   unsigned LoadAddr = AMDGPU::NoRegister;
797 
798   MIRBuilder.materializeGEP(LoadAddr, QueuePtr, LLT::scalar(64), StructOffset);
799   MIRBuilder.buildLoad(LoadResult, LoadAddr, *MMO);
800   return LoadResult;
801 }
802 
803 bool AMDGPULegalizerInfo::legalizeAddrSpaceCast(
804   MachineInstr &MI, MachineRegisterInfo &MRI,
805   MachineIRBuilder &MIRBuilder) const {
806   MachineFunction &MF = MIRBuilder.getMF();
807 
808   MIRBuilder.setInstr(MI);
809 
810   unsigned Dst = MI.getOperand(0).getReg();
811   unsigned Src = MI.getOperand(1).getReg();
812 
813   LLT DstTy = MRI.getType(Dst);
814   LLT SrcTy = MRI.getType(Src);
815   unsigned DestAS = DstTy.getAddressSpace();
816   unsigned SrcAS = SrcTy.getAddressSpace();
817 
818   // TODO: Avoid reloading from the queue ptr for each cast, or at least each
819   // vector element.
820   assert(!DstTy.isVector());
821 
822   const AMDGPUTargetMachine &TM
823     = static_cast<const AMDGPUTargetMachine &>(MF.getTarget());
824 
825   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
826   if (ST.getTargetLowering()->isNoopAddrSpaceCast(SrcAS, DestAS)) {
827     MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::G_BITCAST));
828     return true;
829   }
830 
831   if (SrcAS == AMDGPUAS::FLAT_ADDRESS) {
832     assert(DestAS == AMDGPUAS::LOCAL_ADDRESS ||
833            DestAS == AMDGPUAS::PRIVATE_ADDRESS);
834     unsigned NullVal = TM.getNullPointerValue(DestAS);
835 
836     auto SegmentNull = MIRBuilder.buildConstant(DstTy, NullVal);
837     auto FlatNull = MIRBuilder.buildConstant(SrcTy, 0);
838 
839     unsigned PtrLo32 = MRI.createGenericVirtualRegister(DstTy);
840 
841     // Extract low 32-bits of the pointer.
842     MIRBuilder.buildExtract(PtrLo32, Src, 0);
843 
844     unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1));
845     MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, FlatNull.getReg(0));
846     MIRBuilder.buildSelect(Dst, CmpRes, PtrLo32, SegmentNull.getReg(0));
847 
848     MI.eraseFromParent();
849     return true;
850   }
851 
852   assert(SrcAS == AMDGPUAS::LOCAL_ADDRESS ||
853          SrcAS == AMDGPUAS::PRIVATE_ADDRESS);
854 
855   auto SegmentNull =
856       MIRBuilder.buildConstant(SrcTy, TM.getNullPointerValue(SrcAS));
857   auto FlatNull =
858       MIRBuilder.buildConstant(DstTy, TM.getNullPointerValue(DestAS));
859 
860   unsigned ApertureReg = getSegmentAperture(DestAS, MRI, MIRBuilder);
861 
862   unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1));
863   MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, SegmentNull.getReg(0));
864 
865   unsigned BuildPtr = MRI.createGenericVirtualRegister(DstTy);
866 
867   // Coerce the type of the low half of the result so we can use merge_values.
868   unsigned SrcAsInt = MRI.createGenericVirtualRegister(LLT::scalar(32));
869   MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
870     .addDef(SrcAsInt)
871     .addUse(Src);
872 
873   // TODO: Should we allow mismatched types but matching sizes in merges to
874   // avoid the ptrtoint?
875   MIRBuilder.buildMerge(BuildPtr, {SrcAsInt, ApertureReg});
876   MIRBuilder.buildSelect(Dst, CmpRes, BuildPtr, FlatNull.getReg(0));
877 
878   MI.eraseFromParent();
879   return true;
880 }
881 
882 bool AMDGPULegalizerInfo::legalizeFrint(
883   MachineInstr &MI, MachineRegisterInfo &MRI,
884   MachineIRBuilder &MIRBuilder) const {
885   MIRBuilder.setInstr(MI);
886 
887   unsigned Src = MI.getOperand(1).getReg();
888   LLT Ty = MRI.getType(Src);
889   assert(Ty.isScalar() && Ty.getSizeInBits() == 64);
890 
891   APFloat C1Val(APFloat::IEEEdouble(), "0x1.0p+52");
892   APFloat C2Val(APFloat::IEEEdouble(), "0x1.fffffffffffffp+51");
893 
894   auto C1 = MIRBuilder.buildFConstant(Ty, C1Val);
895   auto CopySign = MIRBuilder.buildFCopysign(Ty, C1, Src);
896 
897   // TODO: Should this propagate fast-math-flags?
898   auto Tmp1 = MIRBuilder.buildFAdd(Ty, Src, CopySign);
899   auto Tmp2 = MIRBuilder.buildFSub(Ty, Tmp1, CopySign);
900 
901   auto C2 = MIRBuilder.buildFConstant(Ty, C2Val);
902   auto Fabs = MIRBuilder.buildFAbs(Ty, Src);
903 
904   auto Cond = MIRBuilder.buildFCmp(CmpInst::FCMP_OGT, LLT::scalar(1), Fabs, C2);
905   MIRBuilder.buildSelect(MI.getOperand(0).getReg(), Cond, Src, Tmp2);
906   return true;
907 }
908 
909 bool AMDGPULegalizerInfo::legalizeFceil(
910   MachineInstr &MI, MachineRegisterInfo &MRI,
911   MachineIRBuilder &B) const {
912   B.setInstr(MI);
913 
914   const LLT S1 = LLT::scalar(1);
915   const LLT S64 = LLT::scalar(64);
916 
917   unsigned Src = MI.getOperand(1).getReg();
918   assert(MRI.getType(Src) == S64);
919 
920   // result = trunc(src)
921   // if (src > 0.0 && src != result)
922   //   result += 1.0
923 
924   auto Trunc = B.buildInstr(TargetOpcode::G_INTRINSIC_TRUNC, {S64}, {Src});
925 
926   const auto Zero = B.buildFConstant(S64, 0.0);
927   const auto One = B.buildFConstant(S64, 1.0);
928   auto Lt0 = B.buildFCmp(CmpInst::FCMP_OGT, S1, Src, Zero);
929   auto NeTrunc = B.buildFCmp(CmpInst::FCMP_ONE, S1, Src, Trunc);
930   auto And = B.buildAnd(S1, Lt0, NeTrunc);
931   auto Add = B.buildSelect(S64, And, One, Zero);
932 
933   // TODO: Should this propagate fast-math-flags?
934   B.buildFAdd(MI.getOperand(0).getReg(), Trunc, Add);
935   return true;
936 }
937 
938 static MachineInstrBuilder extractF64Exponent(unsigned Hi,
939                                               MachineIRBuilder &B) {
940   const unsigned FractBits = 52;
941   const unsigned ExpBits = 11;
942   LLT S32 = LLT::scalar(32);
943 
944   auto Const0 = B.buildConstant(S32, FractBits - 32);
945   auto Const1 = B.buildConstant(S32, ExpBits);
946 
947   auto ExpPart = B.buildIntrinsic(Intrinsic::amdgcn_ubfe, {S32}, false)
948     .addUse(Const0.getReg(0))
949     .addUse(Const1.getReg(0));
950 
951   return B.buildSub(S32, ExpPart, B.buildConstant(S32, 1023));
952 }
953 
954 bool AMDGPULegalizerInfo::legalizeIntrinsicTrunc(
955   MachineInstr &MI, MachineRegisterInfo &MRI,
956   MachineIRBuilder &B) const {
957   B.setInstr(MI);
958 
959   const LLT S1 = LLT::scalar(1);
960   const LLT S32 = LLT::scalar(32);
961   const LLT S64 = LLT::scalar(64);
962 
963   unsigned Src = MI.getOperand(1).getReg();
964   assert(MRI.getType(Src) == S64);
965 
966   // TODO: Should this use extract since the low half is unused?
967   auto Unmerge = B.buildUnmerge({S32, S32}, Src);
968   unsigned Hi = Unmerge.getReg(1);
969 
970   // Extract the upper half, since this is where we will find the sign and
971   // exponent.
972   auto Exp = extractF64Exponent(Hi, B);
973 
974   const unsigned FractBits = 52;
975 
976   // Extract the sign bit.
977   const auto SignBitMask = B.buildConstant(S32, UINT32_C(1) << 31);
978   auto SignBit = B.buildAnd(S32, Hi, SignBitMask);
979 
980   const auto FractMask = B.buildConstant(S64, (UINT64_C(1) << FractBits) - 1);
981 
982   const auto Zero32 = B.buildConstant(S32, 0);
983 
984   // Extend back to 64-bits.
985   auto SignBit64 = B.buildMerge(S64, {Zero32.getReg(0), SignBit.getReg(0)});
986 
987   auto Shr = B.buildAShr(S64, FractMask, Exp);
988   auto Not = B.buildNot(S64, Shr);
989   auto Tmp0 = B.buildAnd(S64, Src, Not);
990   auto FiftyOne = B.buildConstant(S32, FractBits - 1);
991 
992   auto ExpLt0 = B.buildICmp(CmpInst::ICMP_SLT, S1, Exp, Zero32);
993   auto ExpGt51 = B.buildICmp(CmpInst::ICMP_SGT, S1, Exp, FiftyOne);
994 
995   auto Tmp1 = B.buildSelect(S64, ExpLt0, SignBit64, Tmp0);
996   B.buildSelect(MI.getOperand(0).getReg(), ExpGt51, Src, Tmp1);
997   return true;
998 }
999 
1000 bool AMDGPULegalizerInfo::legalizeITOFP(
1001   MachineInstr &MI, MachineRegisterInfo &MRI,
1002   MachineIRBuilder &B, bool Signed) const {
1003   B.setInstr(MI);
1004 
1005   unsigned Dst = MI.getOperand(0).getReg();
1006   unsigned Src = MI.getOperand(1).getReg();
1007 
1008   const LLT S64 = LLT::scalar(64);
1009   const LLT S32 = LLT::scalar(32);
1010 
1011   assert(MRI.getType(Src) == S64 && MRI.getType(Dst) == S64);
1012 
1013   auto Unmerge = B.buildUnmerge({S32, S32}, Src);
1014 
1015   auto CvtHi = Signed ?
1016     B.buildSITOFP(S64, Unmerge.getReg(1)) :
1017     B.buildUITOFP(S64, Unmerge.getReg(1));
1018 
1019   auto CvtLo = B.buildUITOFP(S64, Unmerge.getReg(0));
1020 
1021   auto ThirtyTwo = B.buildConstant(S32, 32);
1022   auto LdExp = B.buildIntrinsic(Intrinsic::amdgcn_ldexp, {S64}, false)
1023     .addUse(CvtHi.getReg(0))
1024     .addUse(ThirtyTwo.getReg(0));
1025 
1026   // TODO: Should this propagate fast-math-flags?
1027   B.buildFAdd(Dst, LdExp, CvtLo);
1028   MI.eraseFromParent();
1029   return true;
1030 }
1031