1 //===- AMDGPULegalizerInfo.cpp -----------------------------------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements the targeting of the Machinelegalizer class for
10 /// AMDGPU.
11 /// \todo This should be generated by TableGen.
12 //===----------------------------------------------------------------------===//
13 
14 #include "AMDGPU.h"
15 #include "AMDGPULegalizerInfo.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "SIMachineFunctionInfo.h"
18 
19 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
20 #include "llvm/CodeGen/TargetOpcodes.h"
21 #include "llvm/CodeGen/ValueTypes.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Type.h"
24 #include "llvm/Support/Debug.h"
25 
26 using namespace llvm;
27 using namespace LegalizeActions;
28 using namespace LegalizeMutations;
29 using namespace LegalityPredicates;
30 
31 
32 static LegalityPredicate isMultiple32(unsigned TypeIdx,
33                                       unsigned MaxSize = 512) {
34   return [=](const LegalityQuery &Query) {
35     const LLT Ty = Query.Types[TypeIdx];
36     const LLT EltTy = Ty.getScalarType();
37     return Ty.getSizeInBits() <= MaxSize && EltTy.getSizeInBits() % 32 == 0;
38   };
39 }
40 
41 static LegalityPredicate isSmallOddVector(unsigned TypeIdx) {
42   return [=](const LegalityQuery &Query) {
43     const LLT Ty = Query.Types[TypeIdx];
44     return Ty.isVector() &&
45            Ty.getNumElements() % 2 != 0 &&
46            Ty.getElementType().getSizeInBits() < 32;
47   };
48 }
49 
50 static LegalizeMutation oneMoreElement(unsigned TypeIdx) {
51   return [=](const LegalityQuery &Query) {
52     const LLT Ty = Query.Types[TypeIdx];
53     const LLT EltTy = Ty.getElementType();
54     return std::make_pair(TypeIdx, LLT::vector(Ty.getNumElements() + 1, EltTy));
55   };
56 }
57 
58 static LegalizeMutation fewerEltsToSize64Vector(unsigned TypeIdx) {
59   return [=](const LegalityQuery &Query) {
60     const LLT Ty = Query.Types[TypeIdx];
61     const LLT EltTy = Ty.getElementType();
62     unsigned Size = Ty.getSizeInBits();
63     unsigned Pieces = (Size + 63) / 64;
64     unsigned NewNumElts = (Ty.getNumElements() + 1) / Pieces;
65     return std::make_pair(TypeIdx, LLT::scalarOrVector(NewNumElts, EltTy));
66   };
67 }
68 
69 static LegalityPredicate vectorWiderThan(unsigned TypeIdx, unsigned Size) {
70   return [=](const LegalityQuery &Query) {
71     const LLT QueryTy = Query.Types[TypeIdx];
72     return QueryTy.isVector() && QueryTy.getSizeInBits() > Size;
73   };
74 }
75 
76 static LegalityPredicate numElementsNotEven(unsigned TypeIdx) {
77   return [=](const LegalityQuery &Query) {
78     const LLT QueryTy = Query.Types[TypeIdx];
79     return QueryTy.isVector() && QueryTy.getNumElements() % 2 != 0;
80   };
81 }
82 
83 AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST,
84                                          const GCNTargetMachine &TM) {
85   using namespace TargetOpcode;
86 
87   auto GetAddrSpacePtr = [&TM](unsigned AS) {
88     return LLT::pointer(AS, TM.getPointerSizeInBits(AS));
89   };
90 
91   const LLT S1 = LLT::scalar(1);
92   const LLT S8 = LLT::scalar(8);
93   const LLT S16 = LLT::scalar(16);
94   const LLT S32 = LLT::scalar(32);
95   const LLT S64 = LLT::scalar(64);
96   const LLT S128 = LLT::scalar(128);
97   const LLT S256 = LLT::scalar(256);
98   const LLT S512 = LLT::scalar(512);
99 
100   const LLT V2S16 = LLT::vector(2, 16);
101   const LLT V4S16 = LLT::vector(4, 16);
102   const LLT V8S16 = LLT::vector(8, 16);
103 
104   const LLT V2S32 = LLT::vector(2, 32);
105   const LLT V3S32 = LLT::vector(3, 32);
106   const LLT V4S32 = LLT::vector(4, 32);
107   const LLT V5S32 = LLT::vector(5, 32);
108   const LLT V6S32 = LLT::vector(6, 32);
109   const LLT V7S32 = LLT::vector(7, 32);
110   const LLT V8S32 = LLT::vector(8, 32);
111   const LLT V9S32 = LLT::vector(9, 32);
112   const LLT V10S32 = LLT::vector(10, 32);
113   const LLT V11S32 = LLT::vector(11, 32);
114   const LLT V12S32 = LLT::vector(12, 32);
115   const LLT V13S32 = LLT::vector(13, 32);
116   const LLT V14S32 = LLT::vector(14, 32);
117   const LLT V15S32 = LLT::vector(15, 32);
118   const LLT V16S32 = LLT::vector(16, 32);
119 
120   const LLT V2S64 = LLT::vector(2, 64);
121   const LLT V3S64 = LLT::vector(3, 64);
122   const LLT V4S64 = LLT::vector(4, 64);
123   const LLT V5S64 = LLT::vector(5, 64);
124   const LLT V6S64 = LLT::vector(6, 64);
125   const LLT V7S64 = LLT::vector(7, 64);
126   const LLT V8S64 = LLT::vector(8, 64);
127 
128   std::initializer_list<LLT> AllS32Vectors =
129     {V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32,
130      V9S32, V10S32, V11S32, V12S32, V13S32, V14S32, V15S32, V16S32};
131   std::initializer_list<LLT> AllS64Vectors =
132     {V2S64, V3S64, V4S64, V5S64, V6S64, V7S64, V8S64};
133 
134   const LLT GlobalPtr = GetAddrSpacePtr(AMDGPUAS::GLOBAL_ADDRESS);
135   const LLT ConstantPtr = GetAddrSpacePtr(AMDGPUAS::CONSTANT_ADDRESS);
136   const LLT LocalPtr = GetAddrSpacePtr(AMDGPUAS::LOCAL_ADDRESS);
137   const LLT FlatPtr = GetAddrSpacePtr(AMDGPUAS::FLAT_ADDRESS);
138   const LLT PrivatePtr = GetAddrSpacePtr(AMDGPUAS::PRIVATE_ADDRESS);
139 
140   const LLT CodePtr = FlatPtr;
141 
142   const std::initializer_list<LLT> AddrSpaces64 = {
143     GlobalPtr, ConstantPtr, FlatPtr
144   };
145 
146   const std::initializer_list<LLT> AddrSpaces32 = {
147     LocalPtr, PrivatePtr
148   };
149 
150   setAction({G_BRCOND, S1}, Legal);
151 
152   // TODO: All multiples of 32, vectors of pointers, all v2s16 pairs, more
153   // elements for v3s16
154   getActionDefinitionsBuilder(G_PHI)
155     .legalFor({S32, S64, V2S16, V4S16, S1, S128, S256})
156     .legalFor(AllS32Vectors)
157     .legalFor(AllS64Vectors)
158     .legalFor(AddrSpaces64)
159     .legalFor(AddrSpaces32)
160     .clampScalar(0, S32, S256)
161     .widenScalarToNextPow2(0, 32)
162     .clampMaxNumElements(0, S32, 16)
163     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
164     .legalIf(isPointer(0));
165 
166 
167   getActionDefinitionsBuilder({G_ADD, G_SUB, G_MUL, G_UMULH, G_SMULH})
168     .legalFor({S32})
169     .clampScalar(0, S32, S32)
170     .scalarize(0);
171 
172   // Report legal for any types we can handle anywhere. For the cases only legal
173   // on the SALU, RegBankSelect will be able to re-legalize.
174   getActionDefinitionsBuilder({G_AND, G_OR, G_XOR})
175     .legalFor({S32, S1, S64, V2S32, V2S16, V4S16})
176     .clampScalar(0, S32, S64)
177     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
178     .fewerElementsIf(vectorWiderThan(0, 32), fewerEltsToSize64Vector(0))
179     .widenScalarToNextPow2(0)
180     .scalarize(0);
181 
182   getActionDefinitionsBuilder({G_UADDO, G_SADDO, G_USUBO, G_SSUBO,
183                                G_UADDE, G_SADDE, G_USUBE, G_SSUBE})
184     .legalFor({{S32, S1}})
185     .clampScalar(0, S32, S32);
186 
187   getActionDefinitionsBuilder(G_BITCAST)
188     .legalForCartesianProduct({S32, V2S16})
189     .legalForCartesianProduct({S64, V2S32, V4S16})
190     .legalForCartesianProduct({V2S64, V4S32})
191     // Don't worry about the size constraint.
192     .legalIf(all(isPointer(0), isPointer(1)));
193 
194   if (ST.has16BitInsts()) {
195     getActionDefinitionsBuilder(G_FCONSTANT)
196       .legalFor({S32, S64, S16})
197       .clampScalar(0, S16, S64);
198   } else {
199     getActionDefinitionsBuilder(G_FCONSTANT)
200       .legalFor({S32, S64})
201       .clampScalar(0, S32, S64);
202   }
203 
204   getActionDefinitionsBuilder(G_IMPLICIT_DEF)
205     .legalFor({S1, S32, S64, V2S32, V4S32, V2S16, V4S16, GlobalPtr,
206                ConstantPtr, LocalPtr, FlatPtr, PrivatePtr})
207     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
208     .clampScalarOrElt(0, S32, S512)
209     .legalIf(isMultiple32(0))
210     .widenScalarToNextPow2(0, 32)
211     .clampMaxNumElements(0, S32, 16);
212 
213 
214   // FIXME: i1 operands to intrinsics should always be legal, but other i1
215   // values may not be legal.  We need to figure out how to distinguish
216   // between these two scenarios.
217   getActionDefinitionsBuilder(G_CONSTANT)
218     .legalFor({S1, S32, S64, GlobalPtr,
219                LocalPtr, ConstantPtr, PrivatePtr, FlatPtr })
220     .clampScalar(0, S32, S64)
221     .widenScalarToNextPow2(0)
222     .legalIf(isPointer(0));
223 
224   setAction({G_FRAME_INDEX, PrivatePtr}, Legal);
225 
226   auto &FPOpActions = getActionDefinitionsBuilder(
227     { G_FADD, G_FMUL, G_FNEG, G_FABS, G_FMA, G_FCANONICALIZE})
228     .legalFor({S32, S64});
229 
230   if (ST.has16BitInsts()) {
231     if (ST.hasVOP3PInsts())
232       FPOpActions.legalFor({S16, V2S16});
233     else
234       FPOpActions.legalFor({S16});
235   }
236 
237   if (ST.hasVOP3PInsts())
238     FPOpActions.clampMaxNumElements(0, S16, 2);
239   FPOpActions
240     .scalarize(0)
241     .clampScalar(0, ST.has16BitInsts() ? S16 : S32, S64);
242 
243   if (ST.has16BitInsts()) {
244     getActionDefinitionsBuilder(G_FSQRT)
245       .legalFor({S32, S64, S16})
246       .scalarize(0)
247       .clampScalar(0, S16, S64);
248   } else {
249     getActionDefinitionsBuilder(G_FSQRT)
250       .legalFor({S32, S64})
251       .scalarize(0)
252       .clampScalar(0, S32, S64);
253   }
254 
255   getActionDefinitionsBuilder(G_FPTRUNC)
256     .legalFor({{S32, S64}, {S16, S32}})
257     .scalarize(0);
258 
259   getActionDefinitionsBuilder(G_FPEXT)
260     .legalFor({{S64, S32}, {S32, S16}})
261     .lowerFor({{S64, S16}}) // FIXME: Implement
262     .scalarize(0);
263 
264   getActionDefinitionsBuilder(G_FSUB)
265       // Use actual fsub instruction
266       .legalFor({S32})
267       // Must use fadd + fneg
268       .lowerFor({S64, S16, V2S16})
269       .scalarize(0)
270       .clampScalar(0, S32, S64);
271 
272   getActionDefinitionsBuilder({G_SEXT, G_ZEXT, G_ANYEXT})
273     .legalFor({{S64, S32}, {S32, S16}, {S64, S16},
274                {S32, S1}, {S64, S1}, {S16, S1},
275                // FIXME: Hack
276                {S64, LLT::scalar(33)},
277                {S32, S8}, {S128, S32}, {S128, S64}, {S32, LLT::scalar(24)}})
278     .scalarize(0);
279 
280   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
281     .legalFor({{S32, S32}, {S64, S32}})
282     .scalarize(0);
283 
284   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
285     .legalFor({{S32, S32}, {S32, S64}})
286     .scalarize(0);
287 
288   getActionDefinitionsBuilder({G_INTRINSIC_TRUNC, G_INTRINSIC_ROUND})
289     .legalFor({S32, S64})
290     .scalarize(0);
291 
292 
293   getActionDefinitionsBuilder(G_GEP)
294     .legalForCartesianProduct(AddrSpaces64, {S64})
295     .legalForCartesianProduct(AddrSpaces32, {S32})
296     .scalarize(0);
297 
298   setAction({G_BLOCK_ADDR, CodePtr}, Legal);
299 
300   getActionDefinitionsBuilder(G_ICMP)
301     .legalForCartesianProduct(
302       {S1}, {S32, S64, GlobalPtr, LocalPtr, ConstantPtr, PrivatePtr, FlatPtr})
303     .legalFor({{S1, S32}, {S1, S64}})
304     .widenScalarToNextPow2(1)
305     .clampScalar(1, S32, S64)
306     .scalarize(0)
307     .legalIf(all(typeIs(0, S1), isPointer(1)));
308 
309   getActionDefinitionsBuilder(G_FCMP)
310     .legalFor({{S1, S32}, {S1, S64}})
311     .widenScalarToNextPow2(1)
312     .clampScalar(1, S32, S64)
313     .scalarize(0);
314 
315   // FIXME: fexp, flog2, flog10 needs to be custom lowered.
316   getActionDefinitionsBuilder({G_FPOW, G_FEXP, G_FEXP2,
317                                G_FLOG, G_FLOG2, G_FLOG10})
318     .legalFor({S32})
319     .scalarize(0);
320 
321   // The 64-bit versions produce 32-bit results, but only on the SALU.
322   getActionDefinitionsBuilder({G_CTLZ, G_CTLZ_ZERO_UNDEF,
323                                G_CTTZ, G_CTTZ_ZERO_UNDEF,
324                                G_CTPOP})
325     .legalFor({{S32, S32}, {S32, S64}})
326     .clampScalar(0, S32, S32)
327     .clampScalar(1, S32, S64)
328     .scalarize(0)
329     .widenScalarToNextPow2(0, 32)
330     .widenScalarToNextPow2(1, 32);
331 
332   // TODO: Expand for > s32
333   getActionDefinitionsBuilder(G_BSWAP)
334     .legalFor({S32})
335     .clampScalar(0, S32, S32)
336     .scalarize(0);
337 
338 
339   auto smallerThan = [](unsigned TypeIdx0, unsigned TypeIdx1) {
340     return [=](const LegalityQuery &Query) {
341       return Query.Types[TypeIdx0].getSizeInBits() <
342              Query.Types[TypeIdx1].getSizeInBits();
343     };
344   };
345 
346   auto greaterThan = [](unsigned TypeIdx0, unsigned TypeIdx1) {
347     return [=](const LegalityQuery &Query) {
348       return Query.Types[TypeIdx0].getSizeInBits() >
349              Query.Types[TypeIdx1].getSizeInBits();
350     };
351   };
352 
353   getActionDefinitionsBuilder(G_INTTOPTR)
354     // List the common cases
355     .legalForCartesianProduct(AddrSpaces64, {S64})
356     .legalForCartesianProduct(AddrSpaces32, {S32})
357     .scalarize(0)
358     // Accept any address space as long as the size matches
359     .legalIf(sameSize(0, 1))
360     .widenScalarIf(smallerThan(1, 0),
361       [](const LegalityQuery &Query) {
362         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
363       })
364     .narrowScalarIf(greaterThan(1, 0),
365       [](const LegalityQuery &Query) {
366         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
367       });
368 
369   getActionDefinitionsBuilder(G_PTRTOINT)
370     // List the common cases
371     .legalForCartesianProduct(AddrSpaces64, {S64})
372     .legalForCartesianProduct(AddrSpaces32, {S32})
373     .scalarize(0)
374     // Accept any address space as long as the size matches
375     .legalIf(sameSize(0, 1))
376     .widenScalarIf(smallerThan(0, 1),
377       [](const LegalityQuery &Query) {
378         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
379       })
380     .narrowScalarIf(
381       greaterThan(0, 1),
382       [](const LegalityQuery &Query) {
383         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
384       });
385 
386   if (ST.hasFlatAddressSpace()) {
387     getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
388       .scalarize(0)
389       .custom();
390   }
391 
392   getActionDefinitionsBuilder({G_LOAD, G_STORE})
393     .narrowScalarIf([](const LegalityQuery &Query) {
394         unsigned Size = Query.Types[0].getSizeInBits();
395         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
396         return (Size > 32 && MemSize < Size);
397       },
398       [](const LegalityQuery &Query) {
399         return std::make_pair(0, LLT::scalar(32));
400       })
401     .fewerElementsIf([=, &ST](const LegalityQuery &Query) {
402         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
403         return (MemSize == 96) &&
404                Query.Types[0].isVector() &&
405                ST.getGeneration() < AMDGPUSubtarget::SEA_ISLANDS;
406       },
407       [=](const LegalityQuery &Query) {
408         return std::make_pair(0, V2S32);
409       })
410     .legalIf([=, &ST](const LegalityQuery &Query) {
411         const LLT &Ty0 = Query.Types[0];
412 
413         unsigned Size = Ty0.getSizeInBits();
414         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
415         if (Size < 32 || (Size > 32 && MemSize < Size))
416           return false;
417 
418         if (Ty0.isVector() && Size != MemSize)
419           return false;
420 
421         // TODO: Decompose private loads into 4-byte components.
422         // TODO: Illegal flat loads on SI
423         switch (MemSize) {
424         case 8:
425         case 16:
426           return Size == 32;
427         case 32:
428         case 64:
429         case 128:
430           return true;
431 
432         case 96:
433           // XXX hasLoadX3
434           return (ST.getGeneration() >= AMDGPUSubtarget::SEA_ISLANDS);
435 
436         case 256:
437         case 512:
438           // TODO: constant loads
439         default:
440           return false;
441         }
442       })
443     .clampScalar(0, S32, S64);
444 
445 
446   // FIXME: Handle alignment requirements.
447   auto &ExtLoads = getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD})
448     .legalForTypesWithMemDesc({
449         {S32, GlobalPtr, 8, 8},
450         {S32, GlobalPtr, 16, 8},
451         {S32, LocalPtr, 8, 8},
452         {S32, LocalPtr, 16, 8},
453         {S32, PrivatePtr, 8, 8},
454         {S32, PrivatePtr, 16, 8}});
455   if (ST.hasFlatAddressSpace()) {
456     ExtLoads.legalForTypesWithMemDesc({{S32, FlatPtr, 8, 8},
457                                        {S32, FlatPtr, 16, 8}});
458   }
459 
460   ExtLoads.clampScalar(0, S32, S32)
461           .widenScalarToNextPow2(0)
462           .unsupportedIfMemSizeNotPow2()
463           .lower();
464 
465   auto &Atomics = getActionDefinitionsBuilder(
466     {G_ATOMICRMW_XCHG, G_ATOMICRMW_ADD, G_ATOMICRMW_SUB,
467      G_ATOMICRMW_AND, G_ATOMICRMW_OR, G_ATOMICRMW_XOR,
468      G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_UMAX,
469      G_ATOMICRMW_UMIN, G_ATOMIC_CMPXCHG})
470     .legalFor({{S32, GlobalPtr}, {S32, LocalPtr},
471                {S64, GlobalPtr}, {S64, LocalPtr}});
472   if (ST.hasFlatAddressSpace()) {
473     Atomics.legalFor({{S32, FlatPtr}, {S64, FlatPtr}});
474   }
475 
476   // TODO: Pointer types, any 32-bit or 64-bit vector
477   getActionDefinitionsBuilder(G_SELECT)
478     .legalForCartesianProduct({S32, S64, V2S32, V2S16, V4S16,
479           GlobalPtr, LocalPtr, FlatPtr, PrivatePtr,
480           LLT::vector(2, LocalPtr), LLT::vector(2, PrivatePtr)}, {S1})
481     .clampScalar(0, S32, S64)
482     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
483     .fewerElementsIf(numElementsNotEven(0), scalarize(0))
484     .scalarize(1)
485     .clampMaxNumElements(0, S32, 2)
486     .clampMaxNumElements(0, LocalPtr, 2)
487     .clampMaxNumElements(0, PrivatePtr, 2)
488     .scalarize(0)
489     .widenScalarToNextPow2(0)
490     .legalIf(all(isPointer(0), typeIs(1, S1)));
491 
492   // TODO: Only the low 4/5/6 bits of the shift amount are observed, so we can
493   // be more flexible with the shift amount type.
494   auto &Shifts = getActionDefinitionsBuilder({G_SHL, G_LSHR, G_ASHR})
495     .legalFor({{S32, S32}, {S64, S32}});
496   if (ST.has16BitInsts()) {
497     if (ST.hasVOP3PInsts()) {
498       Shifts.legalFor({{S16, S32}, {S16, S16}, {V2S16, V2S16}})
499             .clampMaxNumElements(0, S16, 2);
500     } else
501       Shifts.legalFor({{S16, S32}, {S16, S16}});
502 
503     Shifts.clampScalar(1, S16, S32);
504     Shifts.clampScalar(0, S16, S64);
505     Shifts.widenScalarToNextPow2(0, 16);
506   } else {
507     // Make sure we legalize the shift amount type first, as the general
508     // expansion for the shifted type will produce much worse code if it hasn't
509     // been truncated already.
510     Shifts.clampScalar(1, S32, S32);
511     Shifts.clampScalar(0, S32, S64);
512     Shifts.widenScalarToNextPow2(0, 32);
513   }
514   Shifts.scalarize(0);
515 
516   for (unsigned Op : {G_EXTRACT_VECTOR_ELT, G_INSERT_VECTOR_ELT}) {
517     unsigned VecTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 1 : 0;
518     unsigned EltTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 0 : 1;
519     unsigned IdxTypeIdx = 2;
520 
521     getActionDefinitionsBuilder(Op)
522       .legalIf([=](const LegalityQuery &Query) {
523           const LLT &VecTy = Query.Types[VecTypeIdx];
524           const LLT &IdxTy = Query.Types[IdxTypeIdx];
525           return VecTy.getSizeInBits() % 32 == 0 &&
526             VecTy.getSizeInBits() <= 512 &&
527             IdxTy.getSizeInBits() == 32;
528         })
529       .clampScalar(EltTypeIdx, S32, S64)
530       .clampScalar(VecTypeIdx, S32, S64)
531       .clampScalar(IdxTypeIdx, S32, S32);
532   }
533 
534   getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
535     .unsupportedIf([=](const LegalityQuery &Query) {
536         const LLT &EltTy = Query.Types[1].getElementType();
537         return Query.Types[0] != EltTy;
538       });
539 
540   for (unsigned Op : {G_EXTRACT, G_INSERT}) {
541     unsigned BigTyIdx = Op == G_EXTRACT ? 1 : 0;
542     unsigned LitTyIdx = Op == G_EXTRACT ? 0 : 1;
543 
544     // FIXME: Doesn't handle extract of illegal sizes.
545     getActionDefinitionsBuilder(Op)
546       .legalIf([=](const LegalityQuery &Query) {
547           const LLT BigTy = Query.Types[BigTyIdx];
548           const LLT LitTy = Query.Types[LitTyIdx];
549           return (BigTy.getSizeInBits() % 32 == 0) &&
550                  (LitTy.getSizeInBits() % 16 == 0);
551         })
552       .widenScalarIf(
553         [=](const LegalityQuery &Query) {
554           const LLT BigTy = Query.Types[BigTyIdx];
555           return (BigTy.getScalarSizeInBits() < 16);
556         },
557         LegalizeMutations::widenScalarOrEltToNextPow2(BigTyIdx, 16))
558       .widenScalarIf(
559         [=](const LegalityQuery &Query) {
560           const LLT LitTy = Query.Types[LitTyIdx];
561           return (LitTy.getScalarSizeInBits() < 16);
562         },
563         LegalizeMutations::widenScalarOrEltToNextPow2(LitTyIdx, 16))
564       .moreElementsIf(isSmallOddVector(BigTyIdx), oneMoreElement(BigTyIdx))
565       .widenScalarToNextPow2(BigTyIdx, 32);
566 
567   }
568 
569   // TODO: vectors of pointers
570   getActionDefinitionsBuilder(G_BUILD_VECTOR)
571       .legalForCartesianProduct(AllS32Vectors, {S32})
572       .legalForCartesianProduct(AllS64Vectors, {S64})
573       .clampNumElements(0, V16S32, V16S32)
574       .clampNumElements(0, V2S64, V8S64)
575       .minScalarSameAs(1, 0)
576       // FIXME: Sort of a hack to make progress on other legalizations.
577       .legalIf([=](const LegalityQuery &Query) {
578         return Query.Types[0].getScalarSizeInBits() <= 32 ||
579                Query.Types[0].getScalarSizeInBits() == 64;
580       });
581 
582   // TODO: Support any combination of v2s32
583   getActionDefinitionsBuilder(G_CONCAT_VECTORS)
584     .legalFor({{V4S32, V2S32},
585                {V8S32, V2S32},
586                {V8S32, V4S32},
587                {V4S64, V2S64},
588                {V4S16, V2S16},
589                {V8S16, V2S16},
590                {V8S16, V4S16},
591                {LLT::vector(4, LocalPtr), LLT::vector(2, LocalPtr)},
592                {LLT::vector(4, PrivatePtr), LLT::vector(2, PrivatePtr)}});
593 
594   // Merge/Unmerge
595   for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
596     unsigned BigTyIdx = Op == G_MERGE_VALUES ? 0 : 1;
597     unsigned LitTyIdx = Op == G_MERGE_VALUES ? 1 : 0;
598 
599     auto notValidElt = [=](const LegalityQuery &Query, unsigned TypeIdx) {
600       const LLT &Ty = Query.Types[TypeIdx];
601       if (Ty.isVector()) {
602         const LLT &EltTy = Ty.getElementType();
603         if (EltTy.getSizeInBits() < 8 || EltTy.getSizeInBits() > 64)
604           return true;
605         if (!isPowerOf2_32(EltTy.getSizeInBits()))
606           return true;
607       }
608       return false;
609     };
610 
611     getActionDefinitionsBuilder(Op)
612       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 16)
613       // Clamp the little scalar to s8-s256 and make it a power of 2. It's not
614       // worth considering the multiples of 64 since 2*192 and 2*384 are not
615       // valid.
616       .clampScalar(LitTyIdx, S16, S256)
617       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 32)
618 
619       // Break up vectors with weird elements into scalars
620       .fewerElementsIf(
621         [=](const LegalityQuery &Query) { return notValidElt(Query, 0); },
622         scalarize(0))
623       .fewerElementsIf(
624         [=](const LegalityQuery &Query) { return notValidElt(Query, 1); },
625         scalarize(1))
626       .clampScalar(BigTyIdx, S32, S512)
627       .widenScalarIf(
628         [=](const LegalityQuery &Query) {
629           const LLT &Ty = Query.Types[BigTyIdx];
630           return !isPowerOf2_32(Ty.getSizeInBits()) &&
631                  Ty.getSizeInBits() % 16 != 0;
632         },
633         [=](const LegalityQuery &Query) {
634           // Pick the next power of 2, or a multiple of 64 over 128.
635           // Whichever is smaller.
636           const LLT &Ty = Query.Types[BigTyIdx];
637           unsigned NewSizeInBits = 1 << Log2_32_Ceil(Ty.getSizeInBits() + 1);
638           if (NewSizeInBits >= 256) {
639             unsigned RoundedTo = alignTo<64>(Ty.getSizeInBits() + 1);
640             if (RoundedTo < NewSizeInBits)
641               NewSizeInBits = RoundedTo;
642           }
643           return std::make_pair(BigTyIdx, LLT::scalar(NewSizeInBits));
644         })
645       .legalIf([=](const LegalityQuery &Query) {
646           const LLT &BigTy = Query.Types[BigTyIdx];
647           const LLT &LitTy = Query.Types[LitTyIdx];
648 
649           if (BigTy.isVector() && BigTy.getSizeInBits() < 32)
650             return false;
651           if (LitTy.isVector() && LitTy.getSizeInBits() < 32)
652             return false;
653 
654           return BigTy.getSizeInBits() % 16 == 0 &&
655                  LitTy.getSizeInBits() % 16 == 0 &&
656                  BigTy.getSizeInBits() <= 512;
657         })
658       // Any vectors left are the wrong size. Scalarize them.
659       .scalarize(0)
660       .scalarize(1);
661   }
662 
663   computeTables();
664   verify(*ST.getInstrInfo());
665 }
666 
667 bool AMDGPULegalizerInfo::legalizeCustom(MachineInstr &MI,
668                                          MachineRegisterInfo &MRI,
669                                          MachineIRBuilder &MIRBuilder,
670                                          GISelChangeObserver &Observer) const {
671   switch (MI.getOpcode()) {
672   case TargetOpcode::G_ADDRSPACE_CAST:
673     return legalizeAddrSpaceCast(MI, MRI, MIRBuilder);
674   default:
675     return false;
676   }
677 
678   llvm_unreachable("expected switch to return");
679 }
680 
681 unsigned AMDGPULegalizerInfo::getSegmentAperture(
682   unsigned AS,
683   MachineRegisterInfo &MRI,
684   MachineIRBuilder &MIRBuilder) const {
685   MachineFunction &MF = MIRBuilder.getMF();
686   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
687   const LLT S32 = LLT::scalar(32);
688 
689   if (ST.hasApertureRegs()) {
690     // FIXME: Use inline constants (src_{shared, private}_base) instead of
691     // getreg.
692     unsigned Offset = AS == AMDGPUAS::LOCAL_ADDRESS ?
693         AMDGPU::Hwreg::OFFSET_SRC_SHARED_BASE :
694         AMDGPU::Hwreg::OFFSET_SRC_PRIVATE_BASE;
695     unsigned WidthM1 = AS == AMDGPUAS::LOCAL_ADDRESS ?
696         AMDGPU::Hwreg::WIDTH_M1_SRC_SHARED_BASE :
697         AMDGPU::Hwreg::WIDTH_M1_SRC_PRIVATE_BASE;
698     unsigned Encoding =
699         AMDGPU::Hwreg::ID_MEM_BASES << AMDGPU::Hwreg::ID_SHIFT_ |
700         Offset << AMDGPU::Hwreg::OFFSET_SHIFT_ |
701         WidthM1 << AMDGPU::Hwreg::WIDTH_M1_SHIFT_;
702 
703     unsigned ApertureReg = MRI.createGenericVirtualRegister(S32);
704     unsigned GetReg = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
705 
706     MIRBuilder.buildInstr(AMDGPU::S_GETREG_B32)
707       .addDef(GetReg)
708       .addImm(Encoding);
709     MRI.setType(GetReg, S32);
710 
711     auto ShiftAmt = MIRBuilder.buildConstant(S32, WidthM1 + 1);
712     MIRBuilder.buildInstr(TargetOpcode::G_SHL)
713       .addDef(ApertureReg)
714       .addUse(GetReg)
715       .addUse(ShiftAmt.getReg(0));
716 
717     return ApertureReg;
718   }
719 
720   unsigned QueuePtr = MRI.createGenericVirtualRegister(
721     LLT::pointer(AMDGPUAS::CONSTANT_ADDRESS, 64));
722 
723   // FIXME: Placeholder until we can track the input registers.
724   MIRBuilder.buildConstant(QueuePtr, 0xdeadbeef);
725 
726   // Offset into amd_queue_t for group_segment_aperture_base_hi /
727   // private_segment_aperture_base_hi.
728   uint32_t StructOffset = (AS == AMDGPUAS::LOCAL_ADDRESS) ? 0x40 : 0x44;
729 
730   // FIXME: Don't use undef
731   Value *V = UndefValue::get(PointerType::get(
732                                Type::getInt8Ty(MF.getFunction().getContext()),
733                                AMDGPUAS::CONSTANT_ADDRESS));
734 
735   MachinePointerInfo PtrInfo(V, StructOffset);
736   MachineMemOperand *MMO = MF.getMachineMemOperand(
737     PtrInfo,
738     MachineMemOperand::MOLoad |
739     MachineMemOperand::MODereferenceable |
740     MachineMemOperand::MOInvariant,
741     4,
742     MinAlign(64, StructOffset));
743 
744   unsigned LoadResult = MRI.createGenericVirtualRegister(S32);
745   unsigned LoadAddr = AMDGPU::NoRegister;
746 
747   MIRBuilder.materializeGEP(LoadAddr, QueuePtr, LLT::scalar(64), StructOffset);
748   MIRBuilder.buildLoad(LoadResult, LoadAddr, *MMO);
749   return LoadResult;
750 }
751 
752 bool AMDGPULegalizerInfo::legalizeAddrSpaceCast(
753   MachineInstr &MI, MachineRegisterInfo &MRI,
754   MachineIRBuilder &MIRBuilder) const {
755   MachineFunction &MF = MIRBuilder.getMF();
756 
757   MIRBuilder.setInstr(MI);
758 
759   unsigned Dst = MI.getOperand(0).getReg();
760   unsigned Src = MI.getOperand(1).getReg();
761 
762   LLT DstTy = MRI.getType(Dst);
763   LLT SrcTy = MRI.getType(Src);
764   unsigned DestAS = DstTy.getAddressSpace();
765   unsigned SrcAS = SrcTy.getAddressSpace();
766 
767   // TODO: Avoid reloading from the queue ptr for each cast, or at least each
768   // vector element.
769   assert(!DstTy.isVector());
770 
771   const AMDGPUTargetMachine &TM
772     = static_cast<const AMDGPUTargetMachine &>(MF.getTarget());
773 
774   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
775   if (ST.getTargetLowering()->isNoopAddrSpaceCast(SrcAS, DestAS)) {
776     MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::G_BITCAST));
777     return true;
778   }
779 
780   if (SrcAS == AMDGPUAS::FLAT_ADDRESS) {
781     assert(DestAS == AMDGPUAS::LOCAL_ADDRESS ||
782            DestAS == AMDGPUAS::PRIVATE_ADDRESS);
783     unsigned NullVal = TM.getNullPointerValue(DestAS);
784 
785     auto SegmentNull = MIRBuilder.buildConstant(DstTy, NullVal);
786     auto FlatNull = MIRBuilder.buildConstant(SrcTy, 0);
787 
788     unsigned PtrLo32 = MRI.createGenericVirtualRegister(DstTy);
789 
790     // Extract low 32-bits of the pointer.
791     MIRBuilder.buildExtract(PtrLo32, Src, 0);
792 
793     unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1));
794     MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, FlatNull.getReg(0));
795     MIRBuilder.buildSelect(Dst, CmpRes, PtrLo32, SegmentNull.getReg(0));
796 
797     MI.eraseFromParent();
798     return true;
799   }
800 
801   assert(SrcAS == AMDGPUAS::LOCAL_ADDRESS ||
802          SrcAS == AMDGPUAS::PRIVATE_ADDRESS);
803 
804   auto SegmentNull =
805       MIRBuilder.buildConstant(SrcTy, TM.getNullPointerValue(SrcAS));
806   auto FlatNull =
807       MIRBuilder.buildConstant(DstTy, TM.getNullPointerValue(DestAS));
808 
809   unsigned ApertureReg = getSegmentAperture(DestAS, MRI, MIRBuilder);
810 
811   unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1));
812   MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, SegmentNull.getReg(0));
813 
814   unsigned BuildPtr = MRI.createGenericVirtualRegister(DstTy);
815 
816   // Coerce the type of the low half of the result so we can use merge_values.
817   unsigned SrcAsInt = MRI.createGenericVirtualRegister(LLT::scalar(32));
818   MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
819     .addDef(SrcAsInt)
820     .addUse(Src);
821 
822   // TODO: Should we allow mismatched types but matching sizes in merges to
823   // avoid the ptrtoint?
824   MIRBuilder.buildMerge(BuildPtr, {SrcAsInt, ApertureReg});
825   MIRBuilder.buildSelect(Dst, CmpRes, BuildPtr, FlatNull.getReg(0));
826 
827   MI.eraseFromParent();
828   return true;
829 }
830