1 //===- AMDGPULegalizerInfo.cpp -----------------------------------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements the targeting of the Machinelegalizer class for
10 /// AMDGPU.
11 /// \todo This should be generated by TableGen.
12 //===----------------------------------------------------------------------===//
13 
14 #include "AMDGPU.h"
15 #include "AMDGPULegalizerInfo.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "SIMachineFunctionInfo.h"
18 
19 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
20 #include "llvm/CodeGen/TargetOpcodes.h"
21 #include "llvm/CodeGen/ValueTypes.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Type.h"
24 #include "llvm/Support/Debug.h"
25 
26 using namespace llvm;
27 using namespace LegalizeActions;
28 using namespace LegalizeMutations;
29 using namespace LegalityPredicates;
30 
31 
32 static LegalityPredicate isMultiple32(unsigned TypeIdx,
33                                       unsigned MaxSize = 512) {
34   return [=](const LegalityQuery &Query) {
35     const LLT Ty = Query.Types[TypeIdx];
36     const LLT EltTy = Ty.getScalarType();
37     return Ty.getSizeInBits() <= MaxSize && EltTy.getSizeInBits() % 32 == 0;
38   };
39 }
40 
41 static LegalityPredicate isSmallOddVector(unsigned TypeIdx) {
42   return [=](const LegalityQuery &Query) {
43     const LLT Ty = Query.Types[TypeIdx];
44     return Ty.isVector() &&
45            Ty.getNumElements() % 2 != 0 &&
46            Ty.getElementType().getSizeInBits() < 32;
47   };
48 }
49 
50 static LegalizeMutation oneMoreElement(unsigned TypeIdx) {
51   return [=](const LegalityQuery &Query) {
52     const LLT Ty = Query.Types[TypeIdx];
53     const LLT EltTy = Ty.getElementType();
54     return std::make_pair(TypeIdx, LLT::vector(Ty.getNumElements() + 1, EltTy));
55   };
56 }
57 
58 static LegalizeMutation fewerEltsToSize64Vector(unsigned TypeIdx) {
59   return [=](const LegalityQuery &Query) {
60     const LLT Ty = Query.Types[TypeIdx];
61     const LLT EltTy = Ty.getElementType();
62     unsigned Size = Ty.getSizeInBits();
63     unsigned Pieces = (Size + 63) / 64;
64     unsigned NewNumElts = (Ty.getNumElements() + 1) / Pieces;
65     return std::make_pair(TypeIdx, LLT::scalarOrVector(NewNumElts, EltTy));
66   };
67 }
68 
69 static LegalityPredicate vectorWiderThan(unsigned TypeIdx, unsigned Size) {
70   return [=](const LegalityQuery &Query) {
71     const LLT QueryTy = Query.Types[TypeIdx];
72     return QueryTy.isVector() && QueryTy.getSizeInBits() > Size;
73   };
74 }
75 
76 static LegalityPredicate numElementsNotEven(unsigned TypeIdx) {
77   return [=](const LegalityQuery &Query) {
78     const LLT QueryTy = Query.Types[TypeIdx];
79     return QueryTy.isVector() && QueryTy.getNumElements() % 2 != 0;
80   };
81 }
82 
83 AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST,
84                                          const GCNTargetMachine &TM) {
85   using namespace TargetOpcode;
86 
87   auto GetAddrSpacePtr = [&TM](unsigned AS) {
88     return LLT::pointer(AS, TM.getPointerSizeInBits(AS));
89   };
90 
91   const LLT S1 = LLT::scalar(1);
92   const LLT S8 = LLT::scalar(8);
93   const LLT S16 = LLT::scalar(16);
94   const LLT S32 = LLT::scalar(32);
95   const LLT S64 = LLT::scalar(64);
96   const LLT S128 = LLT::scalar(128);
97   const LLT S256 = LLT::scalar(256);
98   const LLT S512 = LLT::scalar(512);
99 
100   const LLT V2S16 = LLT::vector(2, 16);
101   const LLT V4S16 = LLT::vector(4, 16);
102   const LLT V8S16 = LLT::vector(8, 16);
103 
104   const LLT V2S32 = LLT::vector(2, 32);
105   const LLT V3S32 = LLT::vector(3, 32);
106   const LLT V4S32 = LLT::vector(4, 32);
107   const LLT V5S32 = LLT::vector(5, 32);
108   const LLT V6S32 = LLT::vector(6, 32);
109   const LLT V7S32 = LLT::vector(7, 32);
110   const LLT V8S32 = LLT::vector(8, 32);
111   const LLT V9S32 = LLT::vector(9, 32);
112   const LLT V10S32 = LLT::vector(10, 32);
113   const LLT V11S32 = LLT::vector(11, 32);
114   const LLT V12S32 = LLT::vector(12, 32);
115   const LLT V13S32 = LLT::vector(13, 32);
116   const LLT V14S32 = LLT::vector(14, 32);
117   const LLT V15S32 = LLT::vector(15, 32);
118   const LLT V16S32 = LLT::vector(16, 32);
119 
120   const LLT V2S64 = LLT::vector(2, 64);
121   const LLT V3S64 = LLT::vector(3, 64);
122   const LLT V4S64 = LLT::vector(4, 64);
123   const LLT V5S64 = LLT::vector(5, 64);
124   const LLT V6S64 = LLT::vector(6, 64);
125   const LLT V7S64 = LLT::vector(7, 64);
126   const LLT V8S64 = LLT::vector(8, 64);
127 
128   std::initializer_list<LLT> AllS32Vectors =
129     {V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32,
130      V9S32, V10S32, V11S32, V12S32, V13S32, V14S32, V15S32, V16S32};
131   std::initializer_list<LLT> AllS64Vectors =
132     {V2S64, V3S64, V4S64, V5S64, V6S64, V7S64, V8S64};
133 
134   const LLT GlobalPtr = GetAddrSpacePtr(AMDGPUAS::GLOBAL_ADDRESS);
135   const LLT ConstantPtr = GetAddrSpacePtr(AMDGPUAS::CONSTANT_ADDRESS);
136   const LLT LocalPtr = GetAddrSpacePtr(AMDGPUAS::LOCAL_ADDRESS);
137   const LLT FlatPtr = GetAddrSpacePtr(AMDGPUAS::FLAT_ADDRESS);
138   const LLT PrivatePtr = GetAddrSpacePtr(AMDGPUAS::PRIVATE_ADDRESS);
139 
140   const LLT CodePtr = FlatPtr;
141 
142   const std::initializer_list<LLT> AddrSpaces64 = {
143     GlobalPtr, ConstantPtr, FlatPtr
144   };
145 
146   const std::initializer_list<LLT> AddrSpaces32 = {
147     LocalPtr, PrivatePtr
148   };
149 
150   setAction({G_BRCOND, S1}, Legal);
151 
152   // TODO: All multiples of 32, vectors of pointers, all v2s16 pairs, more
153   // elements for v3s16
154   getActionDefinitionsBuilder(G_PHI)
155     .legalFor({S32, S64, V2S16, V4S16, S1, S128, S256})
156     .legalFor(AllS32Vectors)
157     .legalFor(AllS64Vectors)
158     .legalFor(AddrSpaces64)
159     .legalFor(AddrSpaces32)
160     .clampScalar(0, S32, S256)
161     .widenScalarToNextPow2(0, 32)
162     .clampMaxNumElements(0, S32, 16)
163     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
164     .legalIf(isPointer(0));
165 
166 
167   getActionDefinitionsBuilder({G_ADD, G_SUB, G_MUL, G_UMULH, G_SMULH})
168     .legalFor({S32})
169     .clampScalar(0, S32, S32)
170     .scalarize(0);
171 
172   // Report legal for any types we can handle anywhere. For the cases only legal
173   // on the SALU, RegBankSelect will be able to re-legalize.
174   getActionDefinitionsBuilder({G_AND, G_OR, G_XOR})
175     .legalFor({S32, S1, S64, V2S32, V2S16, V4S16})
176     .clampScalar(0, S32, S64)
177     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
178     .fewerElementsIf(vectorWiderThan(0, 32), fewerEltsToSize64Vector(0))
179     .widenScalarToNextPow2(0)
180     .scalarize(0);
181 
182   getActionDefinitionsBuilder({G_UADDO, G_SADDO, G_USUBO, G_SSUBO,
183                                G_UADDE, G_SADDE, G_USUBE, G_SSUBE})
184     .legalFor({{S32, S1}})
185     .clampScalar(0, S32, S32);
186 
187   getActionDefinitionsBuilder(G_BITCAST)
188     .legalForCartesianProduct({S32, V2S16})
189     .legalForCartesianProduct({S64, V2S32, V4S16})
190     .legalForCartesianProduct({V2S64, V4S32})
191     // Don't worry about the size constraint.
192     .legalIf(all(isPointer(0), isPointer(1)));
193 
194   if (ST.has16BitInsts()) {
195     getActionDefinitionsBuilder(G_FCONSTANT)
196       .legalFor({S32, S64, S16})
197       .clampScalar(0, S16, S64);
198   } else {
199     getActionDefinitionsBuilder(G_FCONSTANT)
200       .legalFor({S32, S64})
201       .clampScalar(0, S32, S64);
202   }
203 
204   getActionDefinitionsBuilder(G_IMPLICIT_DEF)
205     .legalFor({S1, S32, S64, V2S32, V4S32, V2S16, V4S16, GlobalPtr,
206                ConstantPtr, LocalPtr, FlatPtr, PrivatePtr})
207     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
208     .clampScalarOrElt(0, S32, S512)
209     .legalIf(isMultiple32(0))
210     .widenScalarToNextPow2(0, 32)
211     .clampMaxNumElements(0, S32, 16);
212 
213 
214   // FIXME: i1 operands to intrinsics should always be legal, but other i1
215   // values may not be legal.  We need to figure out how to distinguish
216   // between these two scenarios.
217   getActionDefinitionsBuilder(G_CONSTANT)
218     .legalFor({S1, S32, S64, GlobalPtr,
219                LocalPtr, ConstantPtr, PrivatePtr, FlatPtr })
220     .clampScalar(0, S32, S64)
221     .widenScalarToNextPow2(0)
222     .legalIf(isPointer(0));
223 
224   setAction({G_FRAME_INDEX, PrivatePtr}, Legal);
225 
226   auto &FPOpActions = getActionDefinitionsBuilder(
227     { G_FADD, G_FMUL, G_FNEG, G_FABS, G_FMA, G_FCANONICALIZE})
228     .legalFor({S32, S64});
229 
230   if (ST.has16BitInsts()) {
231     if (ST.hasVOP3PInsts())
232       FPOpActions.legalFor({S16, V2S16});
233     else
234       FPOpActions.legalFor({S16});
235   }
236 
237   if (ST.hasVOP3PInsts())
238     FPOpActions.clampMaxNumElements(0, S16, 2);
239   FPOpActions
240     .scalarize(0)
241     .clampScalar(0, ST.has16BitInsts() ? S16 : S32, S64);
242 
243   if (ST.has16BitInsts()) {
244     getActionDefinitionsBuilder(G_FSQRT)
245       .legalFor({S32, S64, S16})
246       .scalarize(0)
247       .clampScalar(0, S16, S64);
248   } else {
249     getActionDefinitionsBuilder(G_FSQRT)
250       .legalFor({S32, S64})
251       .scalarize(0)
252       .clampScalar(0, S32, S64);
253   }
254 
255   getActionDefinitionsBuilder(G_FPTRUNC)
256     .legalFor({{S32, S64}, {S16, S32}})
257     .scalarize(0);
258 
259   getActionDefinitionsBuilder(G_FPEXT)
260     .legalFor({{S64, S32}, {S32, S16}})
261     .lowerFor({{S64, S16}}) // FIXME: Implement
262     .scalarize(0);
263 
264   getActionDefinitionsBuilder(G_FSUB)
265       // Use actual fsub instruction
266       .legalFor({S32})
267       // Must use fadd + fneg
268       .lowerFor({S64, S16, V2S16})
269       .scalarize(0)
270       .clampScalar(0, S32, S64);
271 
272   getActionDefinitionsBuilder({G_SEXT, G_ZEXT, G_ANYEXT})
273     .legalFor({{S64, S32}, {S32, S16}, {S64, S16},
274                {S32, S1}, {S64, S1}, {S16, S1},
275                // FIXME: Hack
276                {S64, LLT::scalar(33)},
277                {S32, S8}, {S128, S32}, {S128, S64}, {S32, LLT::scalar(24)}})
278     .scalarize(0);
279 
280   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
281     .legalFor({{S32, S32}, {S64, S32}})
282     .scalarize(0);
283 
284   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
285     .legalFor({{S32, S32}, {S32, S64}})
286     .scalarize(0);
287 
288   getActionDefinitionsBuilder({G_INTRINSIC_TRUNC, G_INTRINSIC_ROUND})
289     .legalFor({S32, S64})
290     .scalarize(0);
291 
292 
293   getActionDefinitionsBuilder(G_GEP)
294     .legalForCartesianProduct(AddrSpaces64, {S64})
295     .legalForCartesianProduct(AddrSpaces32, {S32})
296     .scalarize(0);
297 
298   setAction({G_BLOCK_ADDR, CodePtr}, Legal);
299 
300   getActionDefinitionsBuilder(G_ICMP)
301     .legalForCartesianProduct(
302       {S1}, {S32, S64, GlobalPtr, LocalPtr, ConstantPtr, PrivatePtr, FlatPtr})
303     .legalFor({{S1, S32}, {S1, S64}})
304     .widenScalarToNextPow2(1)
305     .clampScalar(1, S32, S64)
306     .scalarize(0)
307     .legalIf(all(typeIs(0, S1), isPointer(1)));
308 
309   getActionDefinitionsBuilder(G_FCMP)
310     .legalFor({{S1, S32}, {S1, S64}})
311     .widenScalarToNextPow2(1)
312     .clampScalar(1, S32, S64)
313     .scalarize(0);
314 
315   // FIXME: fexp, flog2, flog10 needs to be custom lowered.
316   getActionDefinitionsBuilder({G_FPOW, G_FEXP, G_FEXP2,
317                                G_FLOG, G_FLOG2, G_FLOG10})
318     .legalFor({S32})
319     .scalarize(0);
320 
321   // The 64-bit versions produce 32-bit results, but only on the SALU.
322   getActionDefinitionsBuilder({G_CTLZ, G_CTLZ_ZERO_UNDEF,
323                                G_CTTZ, G_CTTZ_ZERO_UNDEF,
324                                G_CTPOP})
325     .legalFor({{S32, S32}, {S32, S64}})
326     .clampScalar(0, S32, S32)
327     .clampScalar(1, S32, S64)
328     .scalarize(0)
329     .widenScalarToNextPow2(0, 32)
330     .widenScalarToNextPow2(1, 32);
331 
332   // TODO: Expand for > s32
333   getActionDefinitionsBuilder(G_BSWAP)
334     .legalFor({S32})
335     .clampScalar(0, S32, S32)
336     .scalarize(0);
337 
338 
339   auto smallerThan = [](unsigned TypeIdx0, unsigned TypeIdx1) {
340     return [=](const LegalityQuery &Query) {
341       return Query.Types[TypeIdx0].getSizeInBits() <
342              Query.Types[TypeIdx1].getSizeInBits();
343     };
344   };
345 
346   auto greaterThan = [](unsigned TypeIdx0, unsigned TypeIdx1) {
347     return [=](const LegalityQuery &Query) {
348       return Query.Types[TypeIdx0].getSizeInBits() >
349              Query.Types[TypeIdx1].getSizeInBits();
350     };
351   };
352 
353   getActionDefinitionsBuilder(G_INTTOPTR)
354     // List the common cases
355     .legalForCartesianProduct(AddrSpaces64, {S64})
356     .legalForCartesianProduct(AddrSpaces32, {S32})
357     .scalarize(0)
358     // Accept any address space as long as the size matches
359     .legalIf(sameSize(0, 1))
360     .widenScalarIf(smallerThan(1, 0),
361       [](const LegalityQuery &Query) {
362         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
363       })
364     .narrowScalarIf(greaterThan(1, 0),
365       [](const LegalityQuery &Query) {
366         return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits()));
367       });
368 
369   getActionDefinitionsBuilder(G_PTRTOINT)
370     // List the common cases
371     .legalForCartesianProduct(AddrSpaces64, {S64})
372     .legalForCartesianProduct(AddrSpaces32, {S32})
373     .scalarize(0)
374     // Accept any address space as long as the size matches
375     .legalIf(sameSize(0, 1))
376     .widenScalarIf(smallerThan(0, 1),
377       [](const LegalityQuery &Query) {
378         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
379       })
380     .narrowScalarIf(
381       greaterThan(0, 1),
382       [](const LegalityQuery &Query) {
383         return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits()));
384       });
385 
386   if (ST.hasFlatAddressSpace()) {
387     getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
388       .scalarize(0)
389       .custom();
390   }
391 
392   getActionDefinitionsBuilder({G_LOAD, G_STORE})
393     .narrowScalarIf([](const LegalityQuery &Query) {
394         unsigned Size = Query.Types[0].getSizeInBits();
395         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
396         return (Size > 32 && MemSize < Size);
397       },
398       [](const LegalityQuery &Query) {
399         return std::make_pair(0, LLT::scalar(32));
400       })
401     .fewerElementsIf([=, &ST](const LegalityQuery &Query) {
402         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
403         return (MemSize == 96) &&
404                Query.Types[0].isVector() &&
405                ST.getGeneration() < AMDGPUSubtarget::SEA_ISLANDS;
406       },
407       [=](const LegalityQuery &Query) {
408         return std::make_pair(0, V2S32);
409       })
410     .legalIf([=, &ST](const LegalityQuery &Query) {
411         const LLT &Ty0 = Query.Types[0];
412 
413         unsigned Size = Ty0.getSizeInBits();
414         unsigned MemSize = Query.MMODescrs[0].SizeInBits;
415         if (Size < 32 || (Size > 32 && MemSize < Size))
416           return false;
417 
418         if (Ty0.isVector() && Size != MemSize)
419           return false;
420 
421         // TODO: Decompose private loads into 4-byte components.
422         // TODO: Illegal flat loads on SI
423         switch (MemSize) {
424         case 8:
425         case 16:
426           return Size == 32;
427         case 32:
428         case 64:
429         case 128:
430           return true;
431 
432         case 96:
433           // XXX hasLoadX3
434           return (ST.getGeneration() >= AMDGPUSubtarget::SEA_ISLANDS);
435 
436         case 256:
437         case 512:
438           // TODO: constant loads
439         default:
440           return false;
441         }
442       })
443     .clampScalar(0, S32, S64);
444 
445 
446   // FIXME: Handle alignment requirements.
447   auto &ExtLoads = getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD})
448     .legalForTypesWithMemDesc({
449         {S32, GlobalPtr, 8, 8},
450         {S32, GlobalPtr, 16, 8},
451         {S32, LocalPtr, 8, 8},
452         {S32, LocalPtr, 16, 8},
453         {S32, PrivatePtr, 8, 8},
454         {S32, PrivatePtr, 16, 8}});
455   if (ST.hasFlatAddressSpace()) {
456     ExtLoads.legalForTypesWithMemDesc({{S32, FlatPtr, 8, 8},
457                                        {S32, FlatPtr, 16, 8}});
458   }
459 
460   ExtLoads.clampScalar(0, S32, S32)
461           .widenScalarToNextPow2(0)
462           .unsupportedIfMemSizeNotPow2()
463           .lower();
464 
465   auto &Atomics = getActionDefinitionsBuilder(
466     {G_ATOMICRMW_XCHG, G_ATOMICRMW_ADD, G_ATOMICRMW_SUB,
467      G_ATOMICRMW_AND, G_ATOMICRMW_OR, G_ATOMICRMW_XOR,
468      G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_UMAX,
469      G_ATOMICRMW_UMIN, G_ATOMIC_CMPXCHG})
470     .legalFor({{S32, GlobalPtr}, {S32, LocalPtr},
471                {S64, GlobalPtr}, {S64, LocalPtr}});
472   if (ST.hasFlatAddressSpace()) {
473     Atomics.legalFor({{S32, FlatPtr}, {S64, FlatPtr}});
474   }
475 
476   // TODO: Pointer types, any 32-bit or 64-bit vector
477   getActionDefinitionsBuilder(G_SELECT)
478     .legalForCartesianProduct({S32, S64, V2S32, V2S16, V4S16,
479           GlobalPtr, LocalPtr, FlatPtr, PrivatePtr,
480           LLT::vector(2, LocalPtr), LLT::vector(2, PrivatePtr)}, {S1})
481     .clampScalar(0, S32, S64)
482     .moreElementsIf(isSmallOddVector(0), oneMoreElement(0))
483     .fewerElementsIf(numElementsNotEven(0), scalarize(0))
484     .scalarize(1)
485     .clampMaxNumElements(0, S32, 2)
486     .clampMaxNumElements(0, LocalPtr, 2)
487     .clampMaxNumElements(0, PrivatePtr, 2)
488     .scalarize(0)
489     .widenScalarToNextPow2(0)
490     .legalIf(all(isPointer(0), typeIs(1, S1)));
491 
492   // TODO: Only the low 4/5/6 bits of the shift amount are observed, so we can
493   // be more flexible with the shift amount type.
494   auto &Shifts = getActionDefinitionsBuilder({G_SHL, G_LSHR, G_ASHR})
495     .legalFor({{S32, S32}, {S64, S32}});
496   if (ST.has16BitInsts()) {
497     if (ST.hasVOP3PInsts()) {
498       Shifts.legalFor({{S16, S32}, {S16, S16}, {V2S16, V2S16}})
499             .clampMaxNumElements(0, S16, 2);
500     } else
501       Shifts.legalFor({{S16, S32}, {S16, S16}});
502 
503     Shifts.clampScalar(1, S16, S32);
504     Shifts.clampScalar(0, S16, S64);
505     Shifts.widenScalarToNextPow2(0, 16);
506   } else {
507     // Make sure we legalize the shift amount type first, as the general
508     // expansion for the shifted type will produce much worse code if it hasn't
509     // been truncated already.
510     Shifts.clampScalar(1, S32, S32);
511     Shifts.clampScalar(0, S32, S64);
512     Shifts.widenScalarToNextPow2(0, 32);
513   }
514   Shifts.scalarize(0);
515 
516   for (unsigned Op : {G_EXTRACT_VECTOR_ELT, G_INSERT_VECTOR_ELT}) {
517     unsigned VecTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 1 : 0;
518     unsigned EltTypeIdx = Op == G_EXTRACT_VECTOR_ELT ? 0 : 1;
519     unsigned IdxTypeIdx = 2;
520 
521     getActionDefinitionsBuilder(Op)
522       .legalIf([=](const LegalityQuery &Query) {
523           const LLT &VecTy = Query.Types[VecTypeIdx];
524           const LLT &IdxTy = Query.Types[IdxTypeIdx];
525           return VecTy.getSizeInBits() % 32 == 0 &&
526             VecTy.getSizeInBits() <= 512 &&
527             IdxTy.getSizeInBits() == 32;
528         })
529       .clampScalar(EltTypeIdx, S32, S64)
530       .clampScalar(VecTypeIdx, S32, S64)
531       .clampScalar(IdxTypeIdx, S32, S32);
532   }
533 
534   getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
535     .unsupportedIf([=](const LegalityQuery &Query) {
536         const LLT &EltTy = Query.Types[1].getElementType();
537         return Query.Types[0] != EltTy;
538       });
539 
540   for (unsigned Op : {G_EXTRACT, G_INSERT}) {
541     unsigned BigTyIdx = Op == G_EXTRACT ? 1 : 0;
542     unsigned LitTyIdx = Op == G_EXTRACT ? 0 : 1;
543 
544     // FIXME: Doesn't handle extract of illegal sizes.
545     getActionDefinitionsBuilder(Op)
546       .legalIf([=](const LegalityQuery &Query) {
547           const LLT BigTy = Query.Types[BigTyIdx];
548           const LLT LitTy = Query.Types[LitTyIdx];
549           return (BigTy.getSizeInBits() % 32 == 0) &&
550                  (LitTy.getSizeInBits() % 16 == 0);
551         })
552       .widenScalarIf(
553         [=](const LegalityQuery &Query) {
554           const LLT BigTy = Query.Types[BigTyIdx];
555           return (BigTy.getScalarSizeInBits() < 16);
556         },
557         LegalizeMutations::widenScalarOrEltToNextPow2(BigTyIdx, 16))
558       .widenScalarIf(
559         [=](const LegalityQuery &Query) {
560           const LLT LitTy = Query.Types[LitTyIdx];
561           return (LitTy.getScalarSizeInBits() < 16);
562         },
563         LegalizeMutations::widenScalarOrEltToNextPow2(LitTyIdx, 16))
564       .moreElementsIf(isSmallOddVector(BigTyIdx), oneMoreElement(BigTyIdx));
565   }
566 
567   // TODO: vectors of pointers
568   getActionDefinitionsBuilder(G_BUILD_VECTOR)
569       .legalForCartesianProduct(AllS32Vectors, {S32})
570       .legalForCartesianProduct(AllS64Vectors, {S64})
571       .clampNumElements(0, V16S32, V16S32)
572       .clampNumElements(0, V2S64, V8S64)
573       .minScalarSameAs(1, 0)
574       // FIXME: Sort of a hack to make progress on other legalizations.
575       .legalIf([=](const LegalityQuery &Query) {
576         return Query.Types[0].getScalarSizeInBits() <= 32 ||
577                Query.Types[0].getScalarSizeInBits() == 64;
578       });
579 
580   // TODO: Support any combination of v2s32
581   getActionDefinitionsBuilder(G_CONCAT_VECTORS)
582     .legalFor({{V4S32, V2S32},
583                {V8S32, V2S32},
584                {V8S32, V4S32},
585                {V4S64, V2S64},
586                {V4S16, V2S16},
587                {V8S16, V2S16},
588                {V8S16, V4S16},
589                {LLT::vector(4, LocalPtr), LLT::vector(2, LocalPtr)},
590                {LLT::vector(4, PrivatePtr), LLT::vector(2, PrivatePtr)}});
591 
592   // Merge/Unmerge
593   for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
594     unsigned BigTyIdx = Op == G_MERGE_VALUES ? 0 : 1;
595     unsigned LitTyIdx = Op == G_MERGE_VALUES ? 1 : 0;
596 
597     auto notValidElt = [=](const LegalityQuery &Query, unsigned TypeIdx) {
598       const LLT &Ty = Query.Types[TypeIdx];
599       if (Ty.isVector()) {
600         const LLT &EltTy = Ty.getElementType();
601         if (EltTy.getSizeInBits() < 8 || EltTy.getSizeInBits() > 64)
602           return true;
603         if (!isPowerOf2_32(EltTy.getSizeInBits()))
604           return true;
605       }
606       return false;
607     };
608 
609     getActionDefinitionsBuilder(Op)
610       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 16)
611       // Clamp the little scalar to s8-s256 and make it a power of 2. It's not
612       // worth considering the multiples of 64 since 2*192 and 2*384 are not
613       // valid.
614       .clampScalar(LitTyIdx, S16, S256)
615       .widenScalarToNextPow2(LitTyIdx, /*Min*/ 32)
616 
617       // Break up vectors with weird elements into scalars
618       .fewerElementsIf(
619         [=](const LegalityQuery &Query) { return notValidElt(Query, 0); },
620         scalarize(0))
621       .fewerElementsIf(
622         [=](const LegalityQuery &Query) { return notValidElt(Query, 1); },
623         scalarize(1))
624       .clampScalar(BigTyIdx, S32, S512)
625       .widenScalarIf(
626         [=](const LegalityQuery &Query) {
627           const LLT &Ty = Query.Types[BigTyIdx];
628           return !isPowerOf2_32(Ty.getSizeInBits()) &&
629                  Ty.getSizeInBits() % 16 != 0;
630         },
631         [=](const LegalityQuery &Query) {
632           // Pick the next power of 2, or a multiple of 64 over 128.
633           // Whichever is smaller.
634           const LLT &Ty = Query.Types[BigTyIdx];
635           unsigned NewSizeInBits = 1 << Log2_32_Ceil(Ty.getSizeInBits() + 1);
636           if (NewSizeInBits >= 256) {
637             unsigned RoundedTo = alignTo<64>(Ty.getSizeInBits() + 1);
638             if (RoundedTo < NewSizeInBits)
639               NewSizeInBits = RoundedTo;
640           }
641           return std::make_pair(BigTyIdx, LLT::scalar(NewSizeInBits));
642         })
643       .legalIf([=](const LegalityQuery &Query) {
644           const LLT &BigTy = Query.Types[BigTyIdx];
645           const LLT &LitTy = Query.Types[LitTyIdx];
646 
647           if (BigTy.isVector() && BigTy.getSizeInBits() < 32)
648             return false;
649           if (LitTy.isVector() && LitTy.getSizeInBits() < 32)
650             return false;
651 
652           return BigTy.getSizeInBits() % 16 == 0 &&
653                  LitTy.getSizeInBits() % 16 == 0 &&
654                  BigTy.getSizeInBits() <= 512;
655         })
656       // Any vectors left are the wrong size. Scalarize them.
657       .scalarize(0)
658       .scalarize(1);
659   }
660 
661   computeTables();
662   verify(*ST.getInstrInfo());
663 }
664 
665 bool AMDGPULegalizerInfo::legalizeCustom(MachineInstr &MI,
666                                          MachineRegisterInfo &MRI,
667                                          MachineIRBuilder &MIRBuilder,
668                                          GISelChangeObserver &Observer) const {
669   switch (MI.getOpcode()) {
670   case TargetOpcode::G_ADDRSPACE_CAST:
671     return legalizeAddrSpaceCast(MI, MRI, MIRBuilder);
672   default:
673     return false;
674   }
675 
676   llvm_unreachable("expected switch to return");
677 }
678 
679 unsigned AMDGPULegalizerInfo::getSegmentAperture(
680   unsigned AS,
681   MachineRegisterInfo &MRI,
682   MachineIRBuilder &MIRBuilder) const {
683   MachineFunction &MF = MIRBuilder.getMF();
684   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
685   const LLT S32 = LLT::scalar(32);
686 
687   if (ST.hasApertureRegs()) {
688     // FIXME: Use inline constants (src_{shared, private}_base) instead of
689     // getreg.
690     unsigned Offset = AS == AMDGPUAS::LOCAL_ADDRESS ?
691         AMDGPU::Hwreg::OFFSET_SRC_SHARED_BASE :
692         AMDGPU::Hwreg::OFFSET_SRC_PRIVATE_BASE;
693     unsigned WidthM1 = AS == AMDGPUAS::LOCAL_ADDRESS ?
694         AMDGPU::Hwreg::WIDTH_M1_SRC_SHARED_BASE :
695         AMDGPU::Hwreg::WIDTH_M1_SRC_PRIVATE_BASE;
696     unsigned Encoding =
697         AMDGPU::Hwreg::ID_MEM_BASES << AMDGPU::Hwreg::ID_SHIFT_ |
698         Offset << AMDGPU::Hwreg::OFFSET_SHIFT_ |
699         WidthM1 << AMDGPU::Hwreg::WIDTH_M1_SHIFT_;
700 
701     unsigned ShiftAmt = MRI.createGenericVirtualRegister(S32);
702     unsigned ApertureReg = MRI.createGenericVirtualRegister(S32);
703     unsigned GetReg = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
704 
705     MIRBuilder.buildInstr(AMDGPU::S_GETREG_B32)
706       .addDef(GetReg)
707       .addImm(Encoding);
708     MRI.setType(GetReg, S32);
709 
710     MIRBuilder.buildConstant(ShiftAmt, WidthM1 + 1);
711     MIRBuilder.buildInstr(TargetOpcode::G_SHL)
712       .addDef(ApertureReg)
713       .addUse(GetReg)
714       .addUse(ShiftAmt);
715 
716     return ApertureReg;
717   }
718 
719   unsigned QueuePtr = MRI.createGenericVirtualRegister(
720     LLT::pointer(AMDGPUAS::CONSTANT_ADDRESS, 64));
721 
722   // FIXME: Placeholder until we can track the input registers.
723   MIRBuilder.buildConstant(QueuePtr, 0xdeadbeef);
724 
725   // Offset into amd_queue_t for group_segment_aperture_base_hi /
726   // private_segment_aperture_base_hi.
727   uint32_t StructOffset = (AS == AMDGPUAS::LOCAL_ADDRESS) ? 0x40 : 0x44;
728 
729   // FIXME: Don't use undef
730   Value *V = UndefValue::get(PointerType::get(
731                                Type::getInt8Ty(MF.getFunction().getContext()),
732                                AMDGPUAS::CONSTANT_ADDRESS));
733 
734   MachinePointerInfo PtrInfo(V, StructOffset);
735   MachineMemOperand *MMO = MF.getMachineMemOperand(
736     PtrInfo,
737     MachineMemOperand::MOLoad |
738     MachineMemOperand::MODereferenceable |
739     MachineMemOperand::MOInvariant,
740     4,
741     MinAlign(64, StructOffset));
742 
743   unsigned LoadResult = MRI.createGenericVirtualRegister(S32);
744   unsigned LoadAddr = AMDGPU::NoRegister;
745 
746   MIRBuilder.materializeGEP(LoadAddr, QueuePtr, LLT::scalar(64), StructOffset);
747   MIRBuilder.buildLoad(LoadResult, LoadAddr, *MMO);
748   return LoadResult;
749 }
750 
751 bool AMDGPULegalizerInfo::legalizeAddrSpaceCast(
752   MachineInstr &MI, MachineRegisterInfo &MRI,
753   MachineIRBuilder &MIRBuilder) const {
754   MachineFunction &MF = MIRBuilder.getMF();
755 
756   MIRBuilder.setInstr(MI);
757 
758   unsigned Dst = MI.getOperand(0).getReg();
759   unsigned Src = MI.getOperand(1).getReg();
760 
761   LLT DstTy = MRI.getType(Dst);
762   LLT SrcTy = MRI.getType(Src);
763   unsigned DestAS = DstTy.getAddressSpace();
764   unsigned SrcAS = SrcTy.getAddressSpace();
765 
766   // TODO: Avoid reloading from the queue ptr for each cast, or at least each
767   // vector element.
768   assert(!DstTy.isVector());
769 
770   const AMDGPUTargetMachine &TM
771     = static_cast<const AMDGPUTargetMachine &>(MF.getTarget());
772 
773   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
774   if (ST.getTargetLowering()->isNoopAddrSpaceCast(SrcAS, DestAS)) {
775     MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::G_BITCAST));
776     return true;
777   }
778 
779   if (SrcAS == AMDGPUAS::FLAT_ADDRESS) {
780     assert(DestAS == AMDGPUAS::LOCAL_ADDRESS ||
781            DestAS == AMDGPUAS::PRIVATE_ADDRESS);
782     unsigned NullVal = TM.getNullPointerValue(DestAS);
783 
784     unsigned SegmentNullReg = MRI.createGenericVirtualRegister(DstTy);
785     unsigned FlatNullReg = MRI.createGenericVirtualRegister(SrcTy);
786 
787     MIRBuilder.buildConstant(SegmentNullReg, NullVal);
788     MIRBuilder.buildConstant(FlatNullReg, 0);
789 
790     unsigned PtrLo32 = MRI.createGenericVirtualRegister(DstTy);
791 
792     // Extract low 32-bits of the pointer.
793     MIRBuilder.buildExtract(PtrLo32, Src, 0);
794 
795     unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1));
796     MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, FlatNullReg);
797     MIRBuilder.buildSelect(Dst, CmpRes, PtrLo32, SegmentNullReg);
798 
799     MI.eraseFromParent();
800     return true;
801   }
802 
803   assert(SrcAS == AMDGPUAS::LOCAL_ADDRESS ||
804          SrcAS == AMDGPUAS::PRIVATE_ADDRESS);
805 
806   unsigned FlatNullReg = MRI.createGenericVirtualRegister(DstTy);
807   unsigned SegmentNullReg = MRI.createGenericVirtualRegister(SrcTy);
808   MIRBuilder.buildConstant(SegmentNullReg, TM.getNullPointerValue(SrcAS));
809   MIRBuilder.buildConstant(FlatNullReg, TM.getNullPointerValue(DestAS));
810 
811   unsigned ApertureReg = getSegmentAperture(DestAS, MRI, MIRBuilder);
812 
813   unsigned CmpRes = MRI.createGenericVirtualRegister(LLT::scalar(1));
814   MIRBuilder.buildICmp(CmpInst::ICMP_NE, CmpRes, Src, SegmentNullReg);
815 
816   unsigned BuildPtr = MRI.createGenericVirtualRegister(DstTy);
817 
818   // Coerce the type of the low half of the result so we can use merge_values.
819   unsigned SrcAsInt = MRI.createGenericVirtualRegister(LLT::scalar(32));
820   MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
821     .addDef(SrcAsInt)
822     .addUse(Src);
823 
824   // TODO: Should we allow mismatched types but matching sizes in merges to
825   // avoid the ptrtoint?
826   MIRBuilder.buildMerge(BuildPtr, {SrcAsInt, ApertureReg});
827   MIRBuilder.buildSelect(Dst, CmpRes, BuildPtr, FlatNullReg);
828 
829   MI.eraseFromParent();
830   return true;
831 }
832